This is used to get prompt embeddings for stable diffusion. It uses HuggingFace Transformers CLIP model.
14from typing import List
15
16from torch import nn
17from transformers import CLIPTokenizer, CLIPTextModel
20class CLIPTextEmbedder(nn.Module):
version
is the model version device
is the device max_length
is the max length of the tokenized prompt25 def __init__(self, version: str = "openai/clip-vit-large-patch14", device="cuda:0", max_length: int = 77):
31 super().__init__()
Load the tokenizer
33 self.tokenizer = CLIPTokenizer.from_pretrained(version)
Load the CLIP transformer
35 self.transformer = CLIPTextModel.from_pretrained(version).eval()
36
37 self.device = device
38 self.max_length = max_length
prompts
are the list of prompts to embed40 def forward(self, prompts: List[str]):
Tokenize the prompts
45 batch_encoding = self.tokenizer(prompts, truncation=True, max_length=self.max_length, return_length=True,
46 return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
Get token ids
48 tokens = batch_encoding["input_ids"].to(self.device)
Get CLIP embeddings
50 return self.transformer(input_ids=tokens).last_hidden_state