11import os
12import random
13from pathlib import Path
14
15import PIL
16import numpy as np
17import torch
18from PIL import Image
19
20from labml import monit
21from labml.logger import inspect
22from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion
23from labml_nn.diffusion.stable_diffusion.model.autoencoder import Encoder, Decoder, Autoencoder
24from labml_nn.diffusion.stable_diffusion.model.clip_embedder import CLIPTextEmbedder
25from labml_nn.diffusion.stable_diffusion.model.unet import UNetModel
28def set_seed(seed: int):
32 random.seed(seed)
33 np.random.seed(seed)
34 torch.manual_seed(seed)
35 torch.cuda.manual_seed_all(seed)
38def load_model(path: Path = None) -> LatentDiffusion:
オートエンコーダを初期化
44 with monit.section('Initialize autoencoder'):
45 encoder = Encoder(z_channels=4,
46 in_channels=3,
47 channels=128,
48 channel_multipliers=[1, 2, 4, 4],
49 n_resnet_blocks=2)
50
51 decoder = Decoder(out_channels=3,
52 z_channels=4,
53 channels=128,
54 channel_multipliers=[1, 2, 4, 4],
55 n_resnet_blocks=2)
56
57 autoencoder = Autoencoder(emb_channels=4,
58 encoder=encoder,
59 decoder=decoder,
60 z_channels=4)
CLIP テキストエンベダーを初期化
63 with monit.section('Initialize CLIP Embedder'):
64 clip_text_embedder = CLIPTextEmbedder()
U-Net を初期化します
67 with monit.section('Initialize U-Net'):
68 unet_model = UNetModel(in_channels=4,
69 out_channels=4,
70 channels=320,
71 attention_levels=[0, 1, 2],
72 n_res_blocks=2,
73 channel_multipliers=[1, 2, 4, 4],
74 n_heads=8,
75 tf_layers=1,
76 d_cond=768)
潜在拡散モデルを初期化
79 with monit.section('Initialize Latent Diffusion model'):
80 model = LatentDiffusion(linear_start=0.00085,
81 linear_end=0.0120,
82 n_steps=1000,
83 latent_scaling_factor=0.18215,
84
85 autoencoder=autoencoder,
86 clip_embedder=clip_text_embedder,
87 unet_model=unet_model)
チェックポイントをロード
90 with monit.section(f"Loading model from {path}"):
91 checkpoint = torch.load(path, map_location="cpu")
モデルステートの設定
94 with monit.section('Load state'):
95 missing_keys, extra_keys = model.load_state_dict(checkpoint["state_dict"], strict=False)
デバッグ出力
98 inspect(global_step=checkpoint.get('global_step', -1), missing_keys=missing_keys, extra_keys=extra_keys,
99 _expand=True)
102 model.eval()
103 return model
106def load_img(path: str):
[イメージを開く]
115 image = Image.open(path).convert("RGB")
画像サイズを取得
117 w, h = image.size
32 の倍数にリサイズ
119 w = w - w % 32
120 h = h - h % 32
121 image = image.resize((w, h), resample=PIL.Image.LANCZOS)
numpy に変換して for にマップする [-1, 1]
[0, 255]
123 image = np.array(image).astype(np.float32) * (2. / 255.0) - 1
シェイプに転置 [batch_size, channels, height, width]
125 image = image[None].transpose(0, 3, 1, 2)
トーチに変換
127 return torch.from_numpy(image)
images
形状の画像を含むテンソルです [batch_size, channels, height, width]
dest_path
画像を保存するフォルダーですprefix
ファイル名に追加するプレフィックスですimg_format
は画像形式です130def save_images(images: torch.Tensor, dest_path: str, prefix: str = '', img_format: str = 'jpeg'):
保存先フォルダーの作成
141 os.makedirs(dest_path, exist_ok=True)
[0, 1]
画像をスペースにマップしてクリップする
144 images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
[batch_size, height, width, channels]
numpyへの転置とnumpyへの変換
146 images = images.cpu().permute(0, 2, 3, 1).numpy()
画像を保存
149 for i, img in enumerate(images):
150 img = Image.fromarray((255. * img).astype(np.uint8))
151 img.save(os.path.join(dest_path, f"{prefix}{i:05}.{img_format}"), format=img_format)