CLIP Text Embedder

This is used to get prompt embeddings for stable diffusion. It uses HuggingFace Transformers CLIP model.

14from typing import List
15
16from torch import nn
17from transformers import CLIPTokenizer, CLIPTextModel

CLIP Text Embedder

20class CLIPTextEmbedder(nn.Module):
  • version is the model version
  • device is the device
  • max_length is the max length of the tokenized prompt
25    def __init__(self, version: str = "openai/clip-vit-large-patch14", device="cuda:0", max_length: int = 77):
31        super().__init__()

Load the tokenizer

33        self.tokenizer = CLIPTokenizer.from_pretrained(version)

Load the CLIP transformer

35        self.transformer = CLIPTextModel.from_pretrained(version).eval()
36
37        self.device = device
38        self.max_length = max_length
  • prompts are the list of prompts to embed
40    def forward(self, prompts: List[str]):

Tokenize the prompts

45        batch_encoding = self.tokenizer(prompts, truncation=True, max_length=self.max_length, return_length=True,
46                                        return_overflowing_tokens=False, padding="max_length", return_tensors="pt")

Get token ids

48        tokens = batch_encoding["input_ids"].to(self.device)

Get CLIP embeddings

50        return self.transformer(input_ids=tokens).last_hidden_state