11import argparse
12import os
13from pathlib import Path
14
15import torch
16
17from labml import lab, monit
18from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion
19from labml_nn.diffusion.stable_diffusion.sampler.ddim import DDIMSampler
20from labml_nn.diffusion.stable_diffusion.sampler.ddpm import DDPMSampler
21from labml_nn.diffusion.stable_diffusion.util import load_model, save_images, set_seed

テキストから画像へのクラス

24class Txt2Img:
28    model: LatentDiffusion
  • checkpoint_path チェックポイントのパスです
  • sampler_name サンプラーの名前です
  • n_steps はサンプリングステップの数
  • ddim_eta DDIM サンプリング定数です
30    def __init__(self, *,
31                 checkpoint_path: Path,
32                 sampler_name: str,
33                 n_steps: int = 50,
34                 ddim_eta: float = 0.0,
35                 ):
43        self.model = load_model(checkpoint_path)

デバイスを取得

45        self.device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

モデルをデバイスに移動

47        self.model.to(self.device)
50        if sampler_name == 'ddim':
51            self.sampler = DDIMSampler(self.model,
52                                       n_steps=n_steps,
53                                       ddim_eta=ddim_eta)
54        elif sampler_name == 'ddpm':
55            self.sampler = DDPMSampler(self.model)
  • dest_path 生成された画像を保存するパスです
  • batch_size はバッチで生成する画像の数です
  • prompt で画像を生成するプロンプトです
  • h 画像の高さです
  • w 画像の幅です
  • uncond_scale 無条件ガイダンススケールです これは次の用途に使用されます
57    @torch.no_grad()
58    def __call__(self, *,
59                 dest_path: str,
60                 batch_size: int = 3,
61                 prompt: str,
62                 h: int = 512, w: int = 512,
63                 uncond_scale: float = 7.5,
64                 ):

画像内のチャンネル数

75        c = 4

画像から潜在空間への解像度の低下

77        f = 8

プロンプトを一括作成

80        prompts = batch_size * [prompt]

AMP オートキャスティング

83        with torch.cuda.amp.autocast():

無条件スケーリングでは、空のプロンプトでは埋め込みは取得されません (条件なし)。

85            if uncond_scale != 1.0:
86                un_cond = self.model.get_text_conditioning(batch_size * [""])
87            else:
88                un_cond = None

プロンプトの埋め込みを入手

90            cond = self.model.get_text_conditioning(prompts)

潜伏空間でサンプルを採取しますx 形が整います [batch_size, c, h / f, w / f]

93            x = self.sampler.sample(cond=cond,
94                                    shape=[batch_size, c, h // f, w // f],
95                                    uncond_scale=uncond_scale,
96                                    uncond_cond=un_cond)
98            images = self.model.autoencoder_decode(x)

画像を保存

101        save_images(images, dest_path, 'txt_')

CLI

104def main():
108    parser = argparse.ArgumentParser()
109
110    parser.add_argument(
111        "--prompt",
112        type=str,
113        nargs="?",
114        default="a painting of a virus monster playing guitar",
115        help="the prompt to render"
116    )
117
118    parser.add_argument("--batch_size", type=int, default=4, help="batch size")
119
120    parser.add_argument(
121        '--sampler',
122        dest='sampler_name',
123        choices=['ddim', 'ddpm'],
124        default='ddim',
125        help=f'Set the sampler.',
126    )
127
128    parser.add_argument("--flash", action='store_true', help="whether to use flash attention")
129
130    parser.add_argument("--steps", type=int, default=50, help="number of sampling steps")
131
132    parser.add_argument("--scale", type=float, default=7.5,
133                        help="unconditional guidance scale: "
134                             "eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))")
135
136    opt = parser.parse_args()
137
138    set_seed(42)

フラッシュアテンションを設定

141    from labml_nn.diffusion.stable_diffusion.model.unet_attention import CrossAttention
142    CrossAttention.use_flash_attention = opt.flash

145    txt2img = Txt2Img(checkpoint_path=lab.get_data_path() / 'stable-diffusion' / 'sd-v1-4.ckpt',
146                      sampler_name=opt.sampler_name,
147                      n_steps=opt.steps)
148
149    with monit.section('Generate'):
150        txt2img(dest_path='outputs',
151                batch_size=opt.batch_size,
152                prompt=opt.prompt,
153                uncond_scale=opt.scale)

157if __name__ == "__main__":
158    main()