14from typing import List
15
16from torch import nn
17from transformers import CLIPTokenizer, CLIPTextModel
20class CLIPTextEmbedder(nn.Module):
version
モデルバージョンですdevice
デバイスですmax_length
トークン化されたプロンプトの最大長です25 def __init__(self, version: str = "openai/clip-vit-large-patch14", device="cuda:0", max_length: int = 77):
31 super().__init__()
トークナイザーをロード
33 self.tokenizer = CLIPTokenizer.from_pretrained(version)
CLIP トランスをロードします
35 self.transformer = CLIPTextModel.from_pretrained(version).eval()
36
37 self.device = device
38 self.max_length = max_length
prompts
埋め込むプロンプトのリストです40 def forward(self, prompts: List[str]):
プロンプトをトークン化
45 batch_encoding = self.tokenizer(prompts, truncation=True, max_length=self.max_length, return_length=True,
46 return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
トークン ID を取得
48 tokens = batch_encoding["input_ids"].to(self.device)
CLIP 埋め込みを取得
50 return self.transformer(input_ids=tokens).last_hidden_state