11import argparse
12from pathlib import Path
13
14import torch
15
16from labml import lab, monit
17from labml_nn.diffusion.stable_diffusion.sampler.ddim import DDIMSampler
18from labml_nn.diffusion.stable_diffusion.util import load_model, load_img, save_images, set_seed

画像から画像へのクラス

21class Img2Img:
  • checkpoint_path チェックポイントのパスです
  • ddim_steps はサンプリングステップの数
  • ddim_eta DDIM サンプリング定数です
26    def __init__(self, *, checkpoint_path: Path,
27                 ddim_steps: int = 50,
28                 ddim_eta: float = 0.0):
34        self.ddim_steps = ddim_steps
37        self.model = load_model(checkpoint_path)

デバイスを取得

39        self.device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

モデルをデバイスに移動

41        self.model.to(self.device)
44        self.sampler = DDIMSampler(self.model,
45                                   n_steps=ddim_steps,
46                                   ddim_eta=ddim_eta)
  • dest_path 生成された画像を保存するパスです
  • orig_img 変換する画像です
  • strength 元の画像のどの程度保存しないかを指定します
  • batch_size はバッチで生成する画像の数です
  • prompt で画像を生成するプロンプトです
  • uncond_scale 無条件ガイダンススケールです これは次の用途に使用されます
48    @torch.no_grad()
49    def __call__(self, *,
50                 dest_path: str,
51                 orig_img: str,
52                 strength: float,
53                 batch_size: int = 3,
54                 prompt: str,
55                 uncond_scale: float = 5.0,
56                 ):

プロンプトを一括作成

67        prompts = batch_size * [prompt]

画像を読み込む

69        orig_image = load_img(orig_img).to(self.device)

潜在空間に画像をエンコードし、batch_size そのコピーを作成します

71        orig = self.model.autoencoder_encode(orig_image).repeat(batch_size, 1, 1, 1)

オリジナルを拡散させるまでのステップ数を求める

74        assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]'
75        t_index = int(strength * self.ddim_steps)

AMP オートキャスティング

78        with torch.cuda.amp.autocast():

無条件スケーリングでは、空のプロンプトでは埋め込みは取得されません (条件なし)。

80            if uncond_scale != 1.0:
81                un_cond = self.model.get_text_conditioning(batch_size * [""])
82            else:
83                un_cond = None

プロンプトの埋め込みを入手

85            cond = self.model.get_text_conditioning(prompts)

元の画像にノイズを追加

87            x = self.sampler.q_sample(orig, t_index)

ノイズの多い画像からの再構築

89            x = self.sampler.paint(x, cond, t_index,
90                                   uncond_scale=uncond_scale,
91                                   uncond_cond=un_cond)
93            images = self.model.autoencoder_decode(x)

画像を保存

96        save_images(images, dest_path, 'img_')

CLI

99def main():
103    parser = argparse.ArgumentParser()
104
105    parser.add_argument(
106        "--prompt",
107        type=str,
108        nargs="?",
109        default="a painting of a cute monkey playing guitar",
110        help="the prompt to render"
111    )
112
113    parser.add_argument(
114        "--orig-img",
115        type=str,
116        nargs="?",
117        help="path to the input image"
118    )
119
120    parser.add_argument("--batch_size", type=int, default=4, help="batch size", )
121    parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps")
122
123    parser.add_argument("--scale", type=float, default=5.0,
124                        help="unconditional guidance scale: "
125                             "eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))")
126
127    parser.add_argument("--strength", type=float, default=0.75,
128                        help="strength for noise: "
129                             " 1.0 corresponds to full destruction of information in init image")
130
131    opt = parser.parse_args()
132    set_seed(42)
133
134    img2img = Img2Img(checkpoint_path=lab.get_data_path() / 'stable-diffusion' / 'sd-v1-4.ckpt',
135                      ddim_steps=opt.steps)
136
137    with monit.section('Generate'):
138        img2img(
139            dest_path='outputs',
140            orig_img=opt.orig_img,
141            strength=opt.strength,
142            batch_size=opt.batch_size,
143            prompt=opt.prompt,
144            uncond_scale=opt.scale)

148if __name__ == "__main__":
149    main()