11import argparse
12import os
13from pathlib import Path
14
15import torch
16
17from labml import lab, monit
18from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion
19from labml_nn.diffusion.stable_diffusion.sampler.ddim import DDIMSampler
20from labml_nn.diffusion.stable_diffusion.sampler.ddpm import DDPMSampler
21from labml_nn.diffusion.stable_diffusion.util import load_model, save_images, set_seed
24class Txt2Img:
28 model: LatentDiffusion
checkpoint_path
チェックポイントのパスですsampler_name
サンプラーの名前ですn_steps
はサンプリングステップの数ddim_eta
DDIM サンプリング定数です 30 def __init__(self, *,
31 checkpoint_path: Path,
32 sampler_name: str,
33 n_steps: int = 50,
34 ddim_eta: float = 0.0,
35 ):
デバイスを取得
45 self.device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
モデルをデバイスに移動
47 self.model.to(self.device)
50 if sampler_name == 'ddim':
51 self.sampler = DDIMSampler(self.model,
52 n_steps=n_steps,
53 ddim_eta=ddim_eta)
54 elif sampler_name == 'ddpm':
55 self.sampler = DDPMSampler(self.model)
dest_path
生成された画像を保存するパスですbatch_size
はバッチで生成する画像の数ですprompt
で画像を生成するプロンプトですh
画像の高さですw
画像の幅ですuncond_scale
無条件ガイダンススケールです これは次の用途に使用されます 57 @torch.no_grad()
58 def __call__(self, *,
59 dest_path: str,
60 batch_size: int = 3,
61 prompt: str,
62 h: int = 512, w: int = 512,
63 uncond_scale: float = 7.5,
64 ):
画像内のチャンネル数
75 c = 4
画像から潜在空間への解像度の低下
77 f = 8
プロンプトを一括作成
80 prompts = batch_size * [prompt]
AMP オートキャスティング
83 with torch.cuda.amp.autocast():
無条件スケーリングでは、空のプロンプトでは埋め込みは取得されません (条件なし)。
85 if uncond_scale != 1.0:
86 un_cond = self.model.get_text_conditioning(batch_size * [""])
87 else:
88 un_cond = None
プロンプトの埋め込みを入手
90 cond = self.model.get_text_conditioning(prompts)
潜伏空間でサンプルを採取します。x
形が整います [batch_size, c, h / f, w / f]
93 x = self.sampler.sample(cond=cond,
94 shape=[batch_size, c, h // f, w // f],
95 uncond_scale=uncond_scale,
96 uncond_cond=un_cond)
98 images = self.model.autoencoder_decode(x)
画像を保存
101 save_images(images, dest_path, 'txt_')
104def main():
108 parser = argparse.ArgumentParser()
109
110 parser.add_argument(
111 "--prompt",
112 type=str,
113 nargs="?",
114 default="a painting of a virus monster playing guitar",
115 help="the prompt to render"
116 )
117
118 parser.add_argument("--batch_size", type=int, default=4, help="batch size")
119
120 parser.add_argument(
121 '--sampler',
122 dest='sampler_name',
123 choices=['ddim', 'ddpm'],
124 default='ddim',
125 help=f'Set the sampler.',
126 )
127
128 parser.add_argument("--flash", action='store_true', help="whether to use flash attention")
129
130 parser.add_argument("--steps", type=int, default=50, help="number of sampling steps")
131
132 parser.add_argument("--scale", type=float, default=7.5,
133 help="unconditional guidance scale: "
134 "eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))")
135
136 opt = parser.parse_args()
137
138 set_seed(42)
フラッシュアテンションを設定
141 from labml_nn.diffusion.stable_diffusion.model.unet_attention import CrossAttention
142 CrossAttention.use_flash_attention = opt.flash
145 txt2img = Txt2Img(checkpoint_path=lab.get_data_path() / 'stable-diffusion' / 'sd-v1-4.ckpt',
146 sampler_name=opt.sampler_name,
147 n_steps=opt.steps)
148
149 with monit.section('Generate'):
150 txt2img(dest_path='outputs',
151 batch_size=opt.batch_size,
152 prompt=opt.prompt,
153 uncond_scale=opt.scale)
157if __name__ == "__main__":
158 main()