14from typing import List
15
16from torch import nn
17from transformers import CLIPTokenizer, CLIPTextModel
20class CLIPTextEmbedder(nn.Module):
version
是模型版本device
是设备吗max_length
是标记化提示的最大长度25 def __init__(self, version: str = "openai/clip-vit-large-patch14", device="cuda:0", max_length: int = 77):
31 super().__init__()
加载代币生成器
33 self.tokenizer = CLIPTokenizer.from_pretrained(version)
加载 CLIP 变压器
35 self.transformer = CLIPTextModel.from_pretrained(version).eval()
36
37 self.device = device
38 self.max_length = max_length
prompts
是要嵌入的提示列表40 def forward(self, prompts: List[str]):
对提示进行标记化
45 batch_encoding = self.tokenizer(prompts, truncation=True, max_length=self.max_length, return_length=True,
46 return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
获取代币 ID
48 tokens = batch_encoding["input_ids"].to(self.device)
获取 CLIP 嵌入内容
50 return self.transformer(input_ids=tokens).last_hidden_state