CLIP 文本嵌入器

这用于获取提示嵌入以实现稳定的扩散。它使用 HuggingFace 变形金刚 CLIP 模型。

14from typing import List
15
16from torch import nn
17from transformers import CLIPTokenizer, CLIPTextModel

CLIP 文本嵌入器

20class CLIPTextEmbedder(nn.Module):
  • version 是模型版本
  • device 是设备吗
  • max_length 是标记化提示的最大长度
25    def __init__(self, version: str = "openai/clip-vit-large-patch14", device="cuda:0", max_length: int = 77):
31        super().__init__()

加载代币生成器

33        self.tokenizer = CLIPTokenizer.from_pretrained(version)

加载 CLIP 变压器

35        self.transformer = CLIPTextModel.from_pretrained(version).eval()
36
37        self.device = device
38        self.max_length = max_length
  • prompts 是要嵌入的提示列表
40    def forward(self, prompts: List[str]):

对提示进行标记化

45        batch_encoding = self.tokenizer(prompts, truncation=True, max_length=self.max_length, return_length=True,
46                                        return_overflowing_tokens=False, padding="max_length", return_tensors="pt")

获取代币 ID

48        tokens = batch_encoding["input_ids"].to(self.device)

获取 CLIP 嵌入内容

50        return self.transformer(input_ids=tokens).last_hidden_state