11import argparse
12from pathlib import Path
13from typing import Optional
14
15import torch
16
17from labml import lab, monit
18from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion
19from labml_nn.diffusion.stable_diffusion.sampler import DiffusionSampler
20from labml_nn.diffusion.stable_diffusion.sampler.ddim import DDIMSampler
21from labml_nn.diffusion.stable_diffusion.util import load_model, save_images, load_img, set_seed
24class InPaint:
28 model: LatentDiffusion
29 sampler: DiffusionSampler
checkpoint_path
is the path of the checkpoint ddim_steps
is the number of sampling steps ddim_eta
is the DDIM sampling constant31 def __init__(self, *, checkpoint_path: Path,
32 ddim_steps: int = 50,
33 ddim_eta: float = 0.0):
39 self.ddim_steps = ddim_steps
42 self.model = load_model(checkpoint_path)
Get device
44 self.device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
Move the model to device
46 self.model.to(self.device)
Initialize DDIM sampler
49 self.sampler = DDIMSampler(self.model,
50 n_steps=ddim_steps,
51 ddim_eta=ddim_eta)
dest_path
is the path to store the generated images orig_img
is the image to transform strength
specifies how much of the original image should not be preserved batch_size
is the number of images to generate in a batch prompt
is the prompt to generate images with uncond_scale
is the unconditional guidance scale . This is used for 53 @torch.no_grad()
54 def __call__(self, *,
55 dest_path: str,
56 orig_img: str,
57 strength: float,
58 batch_size: int = 3,
59 prompt: str,
60 uncond_scale: float = 5.0,
61 mask: Optional[torch.Tensor] = None,
62 ):
Make a batch of prompts
73 prompts = batch_size * [prompt]
Load image
75 orig_image = load_img(orig_img).to(self.device)
Encode the image in the latent space and make batch_size
copies of it
77 orig = self.model.autoencoder_encode(orig_image).repeat(batch_size, 1, 1, 1)
If mask
is not provided, we set a sample mask to preserve the bottom half of the image
80 if mask is None:
81 mask = torch.zeros_like(orig, device=self.device)
82 mask[:, :, mask.shape[2] // 2:, :] = 1.
83 else:
84 mask = mask.to(self.device)
Noise diffuse the original image
86 orig_noise = torch.randn(orig.shape, device=self.device)
Get the number of steps to diffuse the original
89 assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]'
90 t_index = int(strength * self.ddim_steps)
AMP auto casting
93 with torch.cuda.amp.autocast():
In unconditional scaling is not get the embeddings for empty prompts (no conditioning).
95 if uncond_scale != 1.0:
96 un_cond = self.model.get_text_conditioning(batch_size * [""])
97 else:
98 un_cond = None
Get the prompt embeddings
100 cond = self.model.get_text_conditioning(prompts)
Add noise to the original image
102 x = self.sampler.q_sample(orig, t_index, noise=orig_noise)
Reconstruct from the noisy image, while preserving the masked area
104 x = self.sampler.paint(x, cond, t_index,
105 orig=orig,
106 mask=mask,
107 orig_noise=orig_noise,
108 uncond_scale=uncond_scale,
109 uncond_cond=un_cond)
Decode the image from the autoencoder
111 images = self.model.autoencoder_decode(x)
Save images
114 save_images(images, dest_path, 'paint_')
117def main():
121 parser = argparse.ArgumentParser()
122
123 parser.add_argument(
124 "--prompt",
125 type=str,
126 nargs="?",
127 default="a painting of a cute monkey playing guitar",
128 help="the prompt to render"
129 )
130
131 parser.add_argument(
132 "--orig-img",
133 type=str,
134 nargs="?",
135 help="path to the input image"
136 )
137
138 parser.add_argument("--batch_size", type=int, default=4, help="batch size", )
139 parser.add_argument("--steps", type=int, default=50, help="number of sampling steps")
140
141 parser.add_argument("--scale", type=float, default=5.0,
142 help="unconditional guidance scale: "
143 "eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))")
144
145 parser.add_argument("--strength", type=float, default=0.75,
146 help="strength for noise: "
147 " 1.0 corresponds to full destruction of information in init image")
148
149 opt = parser.parse_args()
150 set_seed(42)
151
152 in_paint = InPaint(checkpoint_path=lab.get_data_path() / 'stable-diffusion' / 'sd-v1-4.ckpt',
153 ddim_steps=opt.steps)
154
155 with monit.section('Generate'):
156 in_paint(dest_path='outputs',
157 orig_img=opt.orig_img,
158 strength=opt.strength,
159 batch_size=opt.batch_size,
160 prompt=opt.prompt,
161 uncond_scale=opt.scale)
165if __name__ == "__main__":
166 main()