CLIP テキストエンベダー

これを使うと、高速に埋め込むことができ、安定した拡散が得られます。ハギングフェイストランスフォーマーCLIPモデルを使用しています

14from typing import List
15
16from torch import nn
17from transformers import CLIPTokenizer, CLIPTextModel

CLIP テキストエンベダー

20class CLIPTextEmbedder(nn.Module):
  • version モデルバージョンです
  • device デバイスです
  • max_length トークン化されたプロンプトの最大長です
25    def __init__(self, version: str = "openai/clip-vit-large-patch14", device="cuda:0", max_length: int = 77):
31        super().__init__()

トークナイザーをロード

33        self.tokenizer = CLIPTokenizer.from_pretrained(version)

CLIP トランスをロードします

35        self.transformer = CLIPTextModel.from_pretrained(version).eval()
36
37        self.device = device
38        self.max_length = max_length
  • prompts 埋め込むプロンプトのリストです
40    def forward(self, prompts: List[str]):

プロンプトをトークン化

45        batch_encoding = self.tokenizer(prompts, truncation=True, max_length=self.max_length, return_length=True,
46                                        return_overflowing_tokens=False, padding="max_length", return_tensors="pt")

トークン ID を取得

48        tokens = batch_encoding["input_ids"].to(self.device)

CLIP 埋め込みを取得

50        return self.transformer(input_ids=tokens).last_hidden_state