This trains a DDPM based model on CelebA HQ dataset. You can find the download instruction in this discussion on fast.ai. Save the images inside data/celebA
folder.
The paper had used a exponential moving average of the model with a decay of . We have skipped this for simplicity.
20from typing import List
21
22import torchvision
23from PIL import Image
24
25import torch
26import torch.utils.data
27from labml import lab, tracker, experiment, monit
28from labml.configs import BaseConfigs, option
29from labml_nn.diffusion.ddpm import DenoiseDiffusion
30from labml_nn.diffusion.ddpm.unet import UNet
31from labml_nn.helpers.device import DeviceConfigs
34class Configs(BaseConfigs):
Device to train the model on. DeviceConfigs
picks up an available CUDA device or defaults to CPU.
41 device: torch.device = DeviceConfigs()
U-Net model for
44 eps_model: UNet
46 diffusion: DenoiseDiffusion
Number of channels in the image. for RGB.
49 image_channels: int = 3
Image size
51 image_size: int = 32
Number of channels in the initial feature map
53 n_channels: int = 64
The list of channel numbers at each resolution. The number of channels is channel_multipliers[i] * n_channels
56 channel_multipliers: List[int] = [1, 2, 2, 4]
The list of booleans that indicate whether to use attention at each resolution
58 is_attention: List[int] = [False, False, False, True]
Number of time steps
61 n_steps: int = 1_000
Batch size
63 batch_size: int = 64
Number of samples to generate
65 n_samples: int = 16
Learning rate
67 learning_rate: float = 2e-5
Number of training epochs
70 epochs: int = 1_000
Dataset
73 dataset: torch.utils.data.Dataset
Dataloader
75 data_loader: torch.utils.data.DataLoader
Adam optimizer
78 optimizer: torch.optim.Adam
80 def init(self):
Create model
82 self.eps_model = UNet(
83 image_channels=self.image_channels,
84 n_channels=self.n_channels,
85 ch_mults=self.channel_multipliers,
86 is_attn=self.is_attention,
87 ).to(self.device)
Create DDPM class
90 self.diffusion = DenoiseDiffusion(
91 eps_model=self.eps_model,
92 n_steps=self.n_steps,
93 device=self.device,
94 )
Create dataloader
97 self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size, shuffle=True, pin_memory=True)
Create optimizer
99 self.optimizer = torch.optim.Adam(self.eps_model.parameters(), lr=self.learning_rate)
Image logging
102 tracker.set_image("sample", True)
104 def sample(self):
108 with torch.no_grad():
110 x = torch.randn([self.n_samples, self.image_channels, self.image_size, self.image_size],
111 device=self.device)
Remove noise for steps
114 for t_ in monit.iterate('Sample', self.n_steps):
116 t = self.n_steps - t_ - 1
Sample from
118 x = self.diffusion.p_sample(x, x.new_full((self.n_samples,), t, dtype=torch.long))
Log samples
121 tracker.save('sample', x)
123 def train(self):
Iterate through the dataset
129 for data in monit.iterate('Train', self.data_loader):
Increment global step
131 tracker.add_global_step()
Move data to device
133 data = data.to(self.device)
Make the gradients zero
136 self.optimizer.zero_grad()
Calculate loss
138 loss = self.diffusion.loss(data)
Compute gradients
140 loss.backward()
Take an optimization step
142 self.optimizer.step()
Track the loss
144 tracker.save('loss', loss)
146 def run(self):
150 for _ in monit.loop(self.epochs):
Train the model
152 self.train()
Sample some images
154 self.sample()
New line in the console
156 tracker.new_line()
159class CelebADataset(torch.utils.data.Dataset):
164 def __init__(self, image_size: int):
165 super().__init__()
CelebA images folder
168 folder = lab.get_data_path() / 'celebA'
List of files
170 self._files = [p for p in folder.glob(f'**/*.jpg')]
Transformations to resize the image and convert to tensor
173 self._transform = torchvision.transforms.Compose([
174 torchvision.transforms.Resize(image_size),
175 torchvision.transforms.ToTensor(),
176 ])
Size of the dataset
178 def __len__(self):
182 return len(self._files)
Get an image
184 def __getitem__(self, index: int):
188 img = Image.open(self._files[index])
189 return self._transform(img)
Create CelebA dataset
192@option(Configs.dataset, 'CelebA')
193def celeb_dataset(c: Configs):
197 return CelebADataset(c.image_size)
200class MNISTDataset(torchvision.datasets.MNIST):
205 def __init__(self, image_size):
206 transform = torchvision.transforms.Compose([
207 torchvision.transforms.Resize(image_size),
208 torchvision.transforms.ToTensor(),
209 ])
210
211 super().__init__(str(lab.get_data_path()), train=True, download=True, transform=transform)
213 def __getitem__(self, item):
214 return super().__getitem__(item)[0]
Create MNIST dataset
217@option(Configs.dataset, 'MNIST')
218def mnist_dataset(c: Configs):
222 return MNISTDataset(c.image_size)
225def main():
Create experiment
227 experiment.create(name='diffuse', writers={'screen', 'labml'})
Create configurations
230 configs = Configs()
Set configurations. You can override the defaults by passing the values in the dictionary.
233 experiment.configs(configs, {
234 'dataset': 'CelebA', # 'MNIST'
235 'image_channels': 3, # 1,
236 'epochs': 100, # 5,
237 })
Initialize
240 configs.init()
Set models for saving and loading
243 experiment.add_pytorch_models({'eps_model': configs.eps_model})
Start and run the training loop
246 with experiment.start():
247 configs.run()
251if __name__ == '__main__':
252 main()