11import os
12import random
13from pathlib import Path
14
15import PIL
16import numpy as np
17import torch
18from PIL import Image
19
20from labml import monit
21from labml.logger import inspect
22from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion
23from labml_nn.diffusion.stable_diffusion.model.autoencoder import Encoder, Decoder, Autoencoder
24from labml_nn.diffusion.stable_diffusion.model.clip_embedder import CLIPTextEmbedder
25from labml_nn.diffusion.stable_diffusion.model.unet import UNetModel
28def set_seed(seed: int):
32 random.seed(seed)
33 np.random.seed(seed)
34 torch.manual_seed(seed)
35 torch.cuda.manual_seed_all(seed)
38def load_model(path: Path = None) -> LatentDiffusion:
Initialize the autoencoder
44 with monit.section('Initialize autoencoder'):
45 encoder = Encoder(z_channels=4,
46 in_channels=3,
47 channels=128,
48 channel_multipliers=[1, 2, 4, 4],
49 n_resnet_blocks=2)
50
51 decoder = Decoder(out_channels=3,
52 z_channels=4,
53 channels=128,
54 channel_multipliers=[1, 2, 4, 4],
55 n_resnet_blocks=2)
56
57 autoencoder = Autoencoder(emb_channels=4,
58 encoder=encoder,
59 decoder=decoder,
60 z_channels=4)
Initialize the CLIP text embedder
63 with monit.section('Initialize CLIP Embedder'):
64 clip_text_embedder = CLIPTextEmbedder()
Initialize the U-Net
67 with monit.section('Initialize U-Net'):
68 unet_model = UNetModel(in_channels=4,
69 out_channels=4,
70 channels=320,
71 attention_levels=[0, 1, 2],
72 n_res_blocks=2,
73 channel_multipliers=[1, 2, 4, 4],
74 n_heads=8,
75 tf_layers=1,
76 d_cond=768)
Initialize the Latent Diffusion model
79 with monit.section('Initialize Latent Diffusion model'):
80 model = LatentDiffusion(linear_start=0.00085,
81 linear_end=0.0120,
82 n_steps=1000,
83 latent_scaling_factor=0.18215,
84
85 autoencoder=autoencoder,
86 clip_embedder=clip_text_embedder,
87 unet_model=unet_model)
Load the checkpoint
90 with monit.section(f"Loading model from {path}"):
91 checkpoint = torch.load(path, map_location="cpu")
Set model state
94 with monit.section('Load state'):
95 missing_keys, extra_keys = model.load_state_dict(checkpoint["state_dict"], strict=False)
Debugging output
98 inspect(global_step=checkpoint.get('global_step', -1), missing_keys=missing_keys, extra_keys=extra_keys,
99 _expand=True)
102 model.eval()
103 return model
This loads an image from a file and returns a PyTorch tensor.
path
is the path of the image106def load_img(path: str):
Open Image
115 image = Image.open(path).convert("RGB")
Get image size
117 w, h = image.size
Resize to a multiple of 32
119 w = w - w % 32
120 h = h - h % 32
121 image = image.resize((w, h), resample=PIL.Image.LANCZOS)
Convert to numpy and map to [-1, 1]
for [0, 255]
123 image = np.array(image).astype(np.float32) * (2. / 255.0) - 1
Transpose to shape [batch_size, channels, height, width]
125 image = image[None].transpose(0, 3, 1, 2)
Convert to torch
127 return torch.from_numpy(image)
images
is the tensor with images of shape [batch_size, channels, height, width]
dest_path
is the folder to save images in prefix
is the prefix to add to file names img_format
is the image format130def save_images(images: torch.Tensor, dest_path: str, prefix: str = '', img_format: str = 'jpeg'):
Create the destination folder
141 os.makedirs(dest_path, exist_ok=True)
Map images to [0, 1]
space and clip
144 images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
Transpose to [batch_size, height, width, channels]
and convert to numpy
146 images = images.cpu().permute(0, 2, 3, 1).numpy()
Save images
149 for i, img in enumerate(images):
150 img = Image.fromarray((255. * img).astype(np.uint8))
151 img.save(os.path.join(dest_path, f"{prefix}{i:05}.{img_format}"), format=img_format)