用于稳定扩散的实用函数

11import os
12import random
13from pathlib import Path
14
15import PIL
16import numpy as np
17import torch
18from PIL import Image
19
20from labml import monit
21from labml.logger import inspect
22from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion
23from labml_nn.diffusion.stable_diffusion.model.autoencoder import Encoder, Decoder, Autoencoder
24from labml_nn.diffusion.stable_diffusion.model.clip_embedder import CLIPTextEmbedder
25from labml_nn.diffusion.stable_diffusion.model.unet import UNetModel

设置随机种子

28def set_seed(seed: int):
32    random.seed(seed)
33    np.random.seed(seed)
34    torch.manual_seed(seed)
35    torch.cuda.manual_seed_all(seed)
38def load_model(path: Path = None) -> LatentDiffusion:

初始化自动编码器

44    with monit.section('Initialize autoencoder'):
45        encoder = Encoder(z_channels=4,
46                          in_channels=3,
47                          channels=128,
48                          channel_multipliers=[1, 2, 4, 4],
49                          n_resnet_blocks=2)
50
51        decoder = Decoder(out_channels=3,
52                          z_channels=4,
53                          channels=128,
54                          channel_multipliers=[1, 2, 4, 4],
55                          n_resnet_blocks=2)
56
57        autoencoder = Autoencoder(emb_channels=4,
58                                  encoder=encoder,
59                                  decoder=decoder,
60                                  z_channels=4)

初始化 CLIP 文本嵌入器

63    with monit.section('Initialize CLIP Embedder'):
64        clip_text_embedder = CLIPTextEmbedder()

初始化 U-Net

67    with monit.section('Initialize U-Net'):
68        unet_model = UNetModel(in_channels=4,
69                               out_channels=4,
70                               channels=320,
71                               attention_levels=[0, 1, 2],
72                               n_res_blocks=2,
73                               channel_multipliers=[1, 2, 4, 4],
74                               n_heads=8,
75                               tf_layers=1,
76                               d_cond=768)

初始化潜在扩散模型

79    with monit.section('Initialize Latent Diffusion model'):
80        model = LatentDiffusion(linear_start=0.00085,
81                                linear_end=0.0120,
82                                n_steps=1000,
83                                latent_scaling_factor=0.18215,
84
85                                autoencoder=autoencoder,
86                                clip_embedder=clip_text_embedder,
87                                unet_model=unet_model)

加载检查点

90    with monit.section(f"Loading model from {path}"):
91        checkpoint = torch.load(path, map_location="cpu")

设置模型状态

94    with monit.section('Load state'):
95        missing_keys, extra_keys = model.load_state_dict(checkpoint["state_dict"], strict=False)

调试输出

98    inspect(global_step=checkpoint.get('global_step', -1), missing_keys=missing_keys, extra_keys=extra_keys,
99            _expand=True)

102    model.eval()
103    return model

加载图片

这将从文件加载图像并返回 PyTorch 张量。

  • path 是图像的路径
106def load_img(path: str):

打开图片

115    image = Image.open(path).convert("RGB")

获取图像大小

117    w, h = image.size

调整为 32 的倍数

119    w = w - w % 32
120    h = h - h % 32
121    image = image.resize((w, h), resample=PIL.Image.LANCZOS)

转换为 numpy 并映射到 fo[-1, 1] r[0, 255]

123    image = np.array(image).astype(np.float32) * (2. / 255.0) - 1

转置成形状[batch_size, channels, height, width]

125    image = image[None].transpose(0, 3, 1, 2)

转换为 torch

127    return torch.from_numpy(image)

保存图像

  • images 是带有形状图像的张量[batch_size, channels, height, width]
  • dest_path 是保存图像的文件夹
  • prefix 是添加到文件名的前缀
  • img_format 是图像格式
130def save_images(images: torch.Tensor, dest_path: str, prefix: str = '', img_format: str = 'jpeg'):

创建目标文件夹

141    os.makedirs(dest_path, exist_ok=True)

将图像映射到[0, 1] 空间并剪辑

144    images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)

转置为 numpy[batch_size, height, width, channels] 并转换为 numpy

146    images = images.cpu().permute(0, 2, 3, 1).numpy()

保存图片

149    for i, img in enumerate(images):
150        img = Image.fromarray((255. * img).astype(np.uint8))
151        img.save(os.path.join(dest_path, f"{prefix}{i:05}.{img_format}"), format=img_format)