11import os
12import random
13from pathlib import Path
15import PIL
16import numpy as np
17import torch
18from PIL import Image
20from labml import monit
21from labml.logger import inspect
22from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion
23from labml_nn.diffusion.stable_diffusion.model.autoencoder import Encoder, Decoder, Autoencoder
24from labml_nn.diffusion.stable_diffusion.model.clip_embedder import CLIPTextEmbedder
25from labml_nn.diffusion.stable_diffusion.model.unet import UNetModel
28def set_seed(seed: int):
32 random.seed(seed)
33 np.random.seed(seed)
34 torch.manual_seed(seed)
35 torch.cuda.manual_seed_all(seed)
38def load_model(path: Path = None) -> LatentDiffusion:
Initialize the autoencoder
44 with monit.section('Initialize autoencoder'):
45 encoder = Encoder(z_channels=4,
46 in_channels=3,
47 channels=128,
48 channel_multipliers=[1, 2, 4, 4],
49 n_resnet_blocks=2)
51 decoder = Decoder(out_channels=3,
52 z_channels=4,
53 channels=128,
54 channel_multipliers=[1, 2, 4, 4],
55 n_resnet_blocks=2)
57 autoencoder = Autoencoder(emb_channels=4,
58 encoder=encoder,
59 decoder=decoder,
60 z_channels=4)
Initialize the CLIP text embedder
63 with monit.section('Initialize CLIP Embedder'):
64 clip_text_embedder = CLIPTextEmbedder()
Initialize the U-Net
67 with monit.section('Initialize U-Net'):
68 unet_model = UNetModel(in_channels=4,
69 out_channels=4,
70 channels=320,
71 attention_levels=[0, 1, 2],
72 n_res_blocks=2,
73 channel_multipliers=[1, 2, 4, 4],
74 n_heads=8,
75 tf_layers=1,
76 d_cond=768)
Initialize the Latent Diffusion model
79 with monit.section('Initialize Latent Diffusion model'):
80 model = LatentDiffusion(linear_start=0.00085,
81 linear_end=0.0120,
82 n_steps=1000,
83 latent_scaling_factor=0.18215,
85 autoencoder=autoencoder,
86 clip_embedder=clip_text_embedder,
87 unet_model=unet_model)
Load the checkpoint
90 with monit.section(f"Loading model from {path}"):
91 checkpoint = torch.load(path, map_location="cpu")
Set model state
94 with monit.section('Load state'):
95 missing_keys, extra_keys = model.load_state_dict(checkpoint["state_dict"], strict=False)
Debugging output
98 inspect(global_step=checkpoint.get('global_step', -1), missing_keys=missing_keys, extra_keys=extra_keys,
99 _expand=True)
102 model.eval()
103 return model
This loads an image from a file and returns a PyTorch tensor.
is the path of the image106def load_img(path: str):
Open Image
115 image = Image.open(path).convert("RGB")
Get image size
117 w, h = image.size
Resize to a multiple of 32
119 w = w - w % 32
120 h = h - h % 32
121 image = image.resize((w, h), resample=PIL.Image.LANCZOS)
Convert to numpy and map to [-1, 1]
for [0, 255]
123 image = np.array(image).astype(np.float32) * (2. / 255.0) - 1
Transpose to shape [batch_size, channels, height, width]
125 image = image[None].transpose(0, 3, 1, 2)
Convert to torch
127 return torch.from_numpy(image)
is the tensor with images of shape [batch_size, channels, height, width]
is the folder to save images in prefix
is the prefix to add to file names img_format
is the image format130def save_images(images: torch.Tensor, dest_path: str, prefix: str = '', img_format: str = 'jpeg'):
Create the destination folder
141 os.makedirs(dest_path, exist_ok=True)
Map images to [0, 1]
space and clip
144 images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
Transpose to [batch_size, height, width, channels]
and convert to numpy
146 images = images.cpu().permute(0, 2, 3, 1).numpy()
Save images
149 for i, img in enumerate(images):
150 img = Image.fromarray((255. * img).astype(np.uint8))
151 img.save(os.path.join(dest_path, f"{prefix}{i:05}.{img_format}"), format=img_format)