14from typing import List
15
16from torch import nn
17from transformers import CLIPTokenizer, CLIPTextModel20class CLIPTextEmbedder(nn.Module):version
モデルバージョンですdevice
デバイスですmax_length
トークン化されたプロンプトの最大長です25 def __init__(self, version: str = "openai/clip-vit-large-patch14", device="cuda:0", max_length: int = 77):31 super().__init__()トークナイザーをロード
33 self.tokenizer = CLIPTokenizer.from_pretrained(version)CLIP トランスをロードします
35 self.transformer = CLIPTextModel.from_pretrained(version).eval()
36
37 self.device = device
38 self.max_length = max_lengthprompts
埋め込むプロンプトのリストです40 def forward(self, prompts: List[str]):プロンプトをトークン化
45 batch_encoding = self.tokenizer(prompts, truncation=True, max_length=self.max_length, return_length=True,
46 return_overflowing_tokens=False, padding="max_length", return_tensors="pt")トークン ID を取得
48 tokens = batch_encoding["input_ids"].to(self.device)CLIP 埋め込みを取得
50 return self.transformer(input_ids=tokens).last_hidden_state