Utility functions for stable diffusion

11import os
12import random
13from pathlib import Path
15import PIL
16import numpy as np
17import torch
18from PIL import Image
20from labml import monit
21from labml.logger import inspect
22from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion
23from labml_nn.diffusion.stable_diffusion.model.autoencoder import Encoder, Decoder, Autoencoder
24from labml_nn.diffusion.stable_diffusion.model.clip_embedder import CLIPTextEmbedder
25from labml_nn.diffusion.stable_diffusion.model.unet import UNetModel

Set random seeds

28def set_seed(seed: int):
32    random.seed(seed)
33    np.random.seed(seed)
34    torch.manual_seed(seed)
35    torch.cuda.manual_seed_all(seed)
38def load_model(path: Path = None) -> LatentDiffusion:

Initialize the autoencoder

44    with monit.section('Initialize autoencoder'):
45        encoder = Encoder(z_channels=4,
46                          in_channels=3,
47                          channels=128,
48                          channel_multipliers=[1, 2, 4, 4],
49                          n_resnet_blocks=2)
51        decoder = Decoder(out_channels=3,
52                          z_channels=4,
53                          channels=128,
54                          channel_multipliers=[1, 2, 4, 4],
55                          n_resnet_blocks=2)
57        autoencoder = Autoencoder(emb_channels=4,
58                                  encoder=encoder,
59                                  decoder=decoder,
60                                  z_channels=4)

Initialize the CLIP text embedder

63    with monit.section('Initialize CLIP Embedder'):
64        clip_text_embedder = CLIPTextEmbedder()

Initialize the U-Net

67    with monit.section('Initialize U-Net'):
68        unet_model = UNetModel(in_channels=4,
69                               out_channels=4,
70                               channels=320,
71                               attention_levels=[0, 1, 2],
72                               n_res_blocks=2,
73                               channel_multipliers=[1, 2, 4, 4],
74                               n_heads=8,
75                               tf_layers=1,
76                               d_cond=768)

Initialize the Latent Diffusion model

79    with monit.section('Initialize Latent Diffusion model'):
80        model = LatentDiffusion(linear_start=0.00085,
81                                linear_end=0.0120,
82                                n_steps=1000,
83                                latent_scaling_factor=0.18215,
85                                autoencoder=autoencoder,
86                                clip_embedder=clip_text_embedder,
87                                unet_model=unet_model)

Load the checkpoint

90    with monit.section(f"Loading model from {path}"):
91        checkpoint = torch.load(path, map_location="cpu")

Set model state

94    with monit.section('Load state'):
95        missing_keys, extra_keys = model.load_state_dict(checkpoint["state_dict"], strict=False)

Debugging output

98    inspect(global_step=checkpoint.get('global_step', -1), missing_keys=missing_keys, extra_keys=extra_keys,
99            _expand=True)

102    model.eval()
103    return model

Load an image

This loads an image from a file and returns a PyTorch tensor.

  • path is the path of the image
106def load_img(path: str):

Open Image

115    image = Image.open(path).convert("RGB")

Get image size

117    w, h = image.size

Resize to a multiple of 32

119    w = w - w % 32
120    h = h - h % 32
121    image = image.resize((w, h), resample=PIL.Image.LANCZOS)

Convert to numpy and map to [-1, 1] for [0, 255]

123    image = np.array(image).astype(np.float32) * (2. / 255.0) - 1

Transpose to shape [batch_size, channels, height, width]

125    image = image[None].transpose(0, 3, 1, 2)

Convert to torch

127    return torch.from_numpy(image)

Save a images

  • images is the tensor with images of shape [batch_size, channels, height, width]
  • dest_path is the folder to save images in
  • prefix is the prefix to add to file names
  • img_format is the image format
130def save_images(images: torch.Tensor, dest_path: str, prefix: str = '', img_format: str = 'jpeg'):

Create the destination folder

141    os.makedirs(dest_path, exist_ok=True)

Map images to [0, 1] space and clip

144    images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)

Transpose to [batch_size, height, width, channels] and convert to numpy

146    images = images.cpu().permute(0, 2, 3, 1).numpy()

Save images

149    for i, img in enumerate(images):
150        img = Image.fromarray((255. * img).astype(np.uint8))
151        img.save(os.path.join(dest_path, f"{prefix}{i:05}.{img_format}"), format=img_format)