11import os
12import random
13from pathlib import Path
14
15import PIL
16import numpy as np
17import torch
18from PIL import Image
19
20from labml import monit
21from labml.logger import inspect
22from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion
23from labml_nn.diffusion.stable_diffusion.model.autoencoder import Encoder, Decoder, Autoencoder
24from labml_nn.diffusion.stable_diffusion.model.clip_embedder import CLIPTextEmbedder
25from labml_nn.diffusion.stable_diffusion.model.unet import UNetModel

ランダムシードを設定

28def set_seed(seed: int):
32    random.seed(seed)
33    np.random.seed(seed)
34    torch.manual_seed(seed)
35    torch.cuda.manual_seed_all(seed)
38def load_model(path: Path = None) -> LatentDiffusion:

オートエンコーダを初期化

44    with monit.section('Initialize autoencoder'):
45        encoder = Encoder(z_channels=4,
46                          in_channels=3,
47                          channels=128,
48                          channel_multipliers=[1, 2, 4, 4],
49                          n_resnet_blocks=2)
50
51        decoder = Decoder(out_channels=3,
52                          z_channels=4,
53                          channels=128,
54                          channel_multipliers=[1, 2, 4, 4],
55                          n_resnet_blocks=2)
56
57        autoencoder = Autoencoder(emb_channels=4,
58                                  encoder=encoder,
59                                  decoder=decoder,
60                                  z_channels=4)

CLIP テキストエンベダーを初期化

63    with monit.section('Initialize CLIP Embedder'):
64        clip_text_embedder = CLIPTextEmbedder()

U-Net を初期化します

67    with monit.section('Initialize U-Net'):
68        unet_model = UNetModel(in_channels=4,
69                               out_channels=4,
70                               channels=320,
71                               attention_levels=[0, 1, 2],
72                               n_res_blocks=2,
73                               channel_multipliers=[1, 2, 4, 4],
74                               n_heads=8,
75                               tf_layers=1,
76                               d_cond=768)

潜在拡散モデルを初期化

79    with monit.section('Initialize Latent Diffusion model'):
80        model = LatentDiffusion(linear_start=0.00085,
81                                linear_end=0.0120,
82                                n_steps=1000,
83                                latent_scaling_factor=0.18215,
84
85                                autoencoder=autoencoder,
86                                clip_embedder=clip_text_embedder,
87                                unet_model=unet_model)

チェックポイントをロード

90    with monit.section(f"Loading model from {path}"):
91        checkpoint = torch.load(path, map_location="cpu")

モデルステートの設定

94    with monit.section('Load state'):
95        missing_keys, extra_keys = model.load_state_dict(checkpoint["state_dict"], strict=False)

デバッグ出力

98    inspect(global_step=checkpoint.get('global_step', -1), missing_keys=missing_keys, extra_keys=extra_keys,
99            _expand=True)

102    model.eval()
103    return model

画像を読み込む

これはファイルから画像をロードし、PyTorch テンソルを返します。

  • path 画像のパスです
106def load_img(path: str):

[イメージを開く]

115    image = Image.open(path).convert("RGB")

画像サイズを取得

117    w, h = image.size

32 の倍数にリサイズ

119    w = w - w % 32
120    h = h - h % 32
121    image = image.resize((w, h), resample=PIL.Image.LANCZOS)

numpy に変換して for にマップする [-1, 1] [0, 255]

123    image = np.array(image).astype(np.float32) * (2. / 255.0) - 1

シェイプに転置 [batch_size, channels, height, width]

125    image = image[None].transpose(0, 3, 1, 2)

トーチに変換

127    return torch.from_numpy(image)

画像を保存する

  • images 形状の画像を含むテンソルです [batch_size, channels, height, width]
  • dest_path 画像を保存するフォルダーです
  • prefix ファイル名に追加するプレフィックスです
  • img_format は画像形式です
130def save_images(images: torch.Tensor, dest_path: str, prefix: str = '', img_format: str = 'jpeg'):

保存先フォルダーの作成

141    os.makedirs(dest_path, exist_ok=True)

[0, 1] 画像をスペースにマップしてクリップする

144    images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)

[batch_size, height, width, channels] numpyへの転置とnumpyへの変換

146    images = images.cpu().permute(0, 2, 3, 1).numpy()

画像を保存

149    for i, img in enumerate(images):
150        img = Image.fromarray((255. * img).astype(np.uint8))
151        img.save(os.path.join(dest_path, f"{prefix}{i:05}.{img_format}"), format=img_format)