This trains a DDPM based model on CelebA HQ dataset. You can find the download instruction in this discussion on fast.ai. Save the images inside data/celebA
folder.
The paper had used a exponential moving average of the model with a decay of . We have skipped this for simplicity.
20from typing import List
21
22import torch
23import torch.utils.data
24import torchvision
25from PIL import Image
26
27from labml import lab, tracker, experiment, monit
28from labml.configs import BaseConfigs, option
29from labml_helpers.device import DeviceConfigs
30from labml_nn.diffusion.ddpm import DenoiseDiffusion
31from labml_nn.diffusion.ddpm.unet import UNet
34class Configs(BaseConfigs):
Device to train the model on. DeviceConfigs
picks up an available CUDA device or defaults to CPU.
41 device: torch.device = DeviceConfigs()
U-Net model for
44 eps_model: UNet
46 diffusion: DenoiseDiffusion
Number of channels in the image. for RGB.
49 image_channels: int = 3
Image size
51 image_size: int = 32
Number of channels in the initial feature map
53 n_channels: int = 64
The list of channel numbers at each resolution. The number of channels is channel_multipliers[i] * n_channels
56 channel_multipliers: List[int] = [1, 2, 2, 4]
The list of booleans that indicate whether to use attention at each resolution
58 is_attention: List[int] = [False, False, False, True]
Number of time steps
61 n_steps: int = 1_000
Batch size
63 batch_size: int = 64
Number of samples to generate
65 n_samples: int = 16
Learning rate
67 learning_rate: float = 2e-5
Number of training epochs
70 epochs: int = 1_000
Dataset
73 dataset: torch.utils.data.Dataset
Dataloader
75 data_loader: torch.utils.data.DataLoader
Adam optimizer
78 optimizer: torch.optim.Adam
80 def init(self):
Create model
82 self.eps_model = UNet(
83 image_channels=self.image_channels,
84 n_channels=self.n_channels,
85 ch_mults=self.channel_multipliers,
86 is_attn=self.is_attention,
87 ).to(self.device)
Create DDPM class
90 self.diffusion = DenoiseDiffusion(
91 eps_model=self.eps_model,
92 n_steps=self.n_steps,
93 device=self.device,
94 )
Create dataloader
97 self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size, shuffle=True, pin_memory=True)
Create optimizer
99 self.optimizer = torch.optim.Adam(self.eps_model.parameters(), lr=self.learning_rate)
Image logging
102 tracker.set_image("sample", True)
104 def sample(self):
108 with torch.no_grad():
110 x = torch.randn([self.n_samples, self.image_channels, self.image_size, self.image_size],
111 device=self.device)
Remove noise for steps
114 for t_ in monit.iterate('Sample', self.n_steps):
116 t = self.n_steps - t_ - 1
Sample from
118 x = self.diffusion.p_sample(x, x.new_full((self.n_samples,), t, dtype=torch.long))
Log samples
121 tracker.save('sample', x)
123 def train(self):
Iterate through the dataset
129 for data in monit.iterate('Train', self.data_loader):
Increment global step
131 tracker.add_global_step()
Move data to device
133 data = data.to(self.device)
Make the gradients zero
136 self.optimizer.zero_grad()
Calculate loss
138 loss = self.diffusion.loss(data)
Compute gradients
140 loss.backward()
Take an optimization step
142 self.optimizer.step()
Track the loss
144 tracker.save('loss', loss)
146 def run(self):
150 for _ in monit.loop(self.epochs):
Train the model
152 self.train()
Sample some images
154 self.sample()
New line in the console
156 tracker.new_line()
Save the model
158 experiment.save_checkpoint()
161class CelebADataset(torch.utils.data.Dataset):
166 def __init__(self, image_size: int):
167 super().__init__()
CelebA images folder
170 folder = lab.get_data_path() / 'celebA'
List of files
172 self._files = [p for p in folder.glob(f'**/*.jpg')]
Transformations to resize the image and convert to tensor
175 self._transform = torchvision.transforms.Compose([
176 torchvision.transforms.Resize(image_size),
177 torchvision.transforms.ToTensor(),
178 ])
Size of the dataset
180 def __len__(self):
184 return len(self._files)
Get an image
186 def __getitem__(self, index: int):
190 img = Image.open(self._files[index])
191 return self._transform(img)
Create CelebA dataset
194@option(Configs.dataset, 'CelebA')
195def celeb_dataset(c: Configs):
199 return CelebADataset(c.image_size)
202class MNISTDataset(torchvision.datasets.MNIST):
207 def __init__(self, image_size):
208 transform = torchvision.transforms.Compose([
209 torchvision.transforms.Resize(image_size),
210 torchvision.transforms.ToTensor(),
211 ])
212
213 super().__init__(str(lab.get_data_path()), train=True, download=True, transform=transform)
215 def __getitem__(self, item):
216 return super().__getitem__(item)[0]
Create MNIST dataset
219@option(Configs.dataset, 'MNIST')
220def mnist_dataset(c: Configs):
224 return MNISTDataset(c.image_size)
227def main():
Create experiment
229 experiment.create(name='diffuse', writers={'screen', 'labml'})
Create configurations
232 configs = Configs()
Set configurations. You can override the defaults by passing the values in the dictionary.
235 experiment.configs(configs, {
236 'dataset': 'CelebA', # 'MNIST'
237 'image_channels': 3, # 1,
238 'epochs': 100, # 5,
239 })
Initialize
242 configs.init()
Set models for saving and loading
245 experiment.add_pytorch_models({'eps_model': configs.eps_model})
Start and run the training loop
248 with experiment.start():
249 configs.run()
253if __name__ == '__main__':
254 main()