Denoising Diffusion Probabilistic Models (DDPM) training

Open In Colab

This trains a DDPM based model on CelebA HQ dataset. You can find the download instruction in this discussion on fast.ai. Save the images inside data/celebA folder.

The paper had used a exponential moving average of the model with a decay of . We have skipped this for simplicity.

20from typing import List
21
22import torch
23import torch.utils.data
24import torchvision
25from PIL import Image
26
27from labml import lab, tracker, experiment, monit
28from labml.configs import BaseConfigs, option
29from labml_helpers.device import DeviceConfigs
30from labml_nn.diffusion.ddpm import DenoiseDiffusion
31from labml_nn.diffusion.ddpm.unet import UNet

Configurations

34class Configs(BaseConfigs):

Device to train the model on. DeviceConfigs picks up an available CUDA device or defaults to CPU.

41    device: torch.device = DeviceConfigs()

U-Net model for

44    eps_model: UNet
46    diffusion: DenoiseDiffusion

Number of channels in the image. for RGB.

49    image_channels: int = 3

Image size

51    image_size: int = 32

Number of channels in the initial feature map

53    n_channels: int = 64

The list of channel numbers at each resolution. The number of channels is channel_multipliers[i] * n_channels

56    channel_multipliers: List[int] = [1, 2, 2, 4]

The list of booleans that indicate whether to use attention at each resolution

58    is_attention: List[int] = [False, False, False, True]

Number of time steps

61    n_steps: int = 1_000

Batch size

63    batch_size: int = 64

Number of samples to generate

65    n_samples: int = 16

Learning rate

67    learning_rate: float = 2e-5

Number of training epochs

70    epochs: int = 1_000

Dataset

73    dataset: torch.utils.data.Dataset

Dataloader

75    data_loader: torch.utils.data.DataLoader

Adam optimizer

78    optimizer: torch.optim.Adam
80    def init(self):

Create model

82        self.eps_model = UNet(
83            image_channels=self.image_channels,
84            n_channels=self.n_channels,
85            ch_mults=self.channel_multipliers,
86            is_attn=self.is_attention,
87        ).to(self.device)

Create DDPM class

90        self.diffusion = DenoiseDiffusion(
91            eps_model=self.eps_model,
92            n_steps=self.n_steps,
93            device=self.device,
94        )

Create dataloader

97        self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size, shuffle=True, pin_memory=True)

Create optimizer

99        self.optimizer = torch.optim.Adam(self.eps_model.parameters(), lr=self.learning_rate)

Image logging

102        tracker.set_image("sample", True)

Sample images

104    def sample(self):
108        with torch.no_grad():

110            x = torch.randn([self.n_samples, self.image_channels, self.image_size, self.image_size],
111                            device=self.device)

Remove noise for steps

114            for t_ in monit.iterate('Sample', self.n_steps):

116                t = self.n_steps - t_ - 1

Sample from

118                x = self.diffusion.p_sample(x, x.new_full((self.n_samples,), t, dtype=torch.long))

Log samples

121            tracker.save('sample', x)

Train

123    def train(self):

Iterate through the dataset

129        for data in monit.iterate('Train', self.data_loader):

Increment global step

131            tracker.add_global_step()

Move data to device

133            data = data.to(self.device)

Make the gradients zero

136            self.optimizer.zero_grad()

Calculate loss

138            loss = self.diffusion.loss(data)

Compute gradients

140            loss.backward()

Take an optimization step

142            self.optimizer.step()

Track the loss

144            tracker.save('loss', loss)

Training loop

146    def run(self):
150        for _ in monit.loop(self.epochs):

Train the model

152            self.train()

Sample some images

154            self.sample()

New line in the console

156            tracker.new_line()

Save the model

158            experiment.save_checkpoint()

CelebA HQ dataset

161class CelebADataset(torch.utils.data.Dataset):
166    def __init__(self, image_size: int):
167        super().__init__()

CelebA images folder

170        folder = lab.get_data_path() / 'celebA'

List of files

172        self._files = [p for p in folder.glob(f'**/*.jpg')]

Transformations to resize the image and convert to tensor

175        self._transform = torchvision.transforms.Compose([
176            torchvision.transforms.Resize(image_size),
177            torchvision.transforms.ToTensor(),
178        ])

Size of the dataset

180    def __len__(self):
184        return len(self._files)

Get an image

186    def __getitem__(self, index: int):
190        img = Image.open(self._files[index])
191        return self._transform(img)

Create CelebA dataset

194@option(Configs.dataset, 'CelebA')
195def celeb_dataset(c: Configs):
199    return CelebADataset(c.image_size)

MNIST dataset

202class MNISTDataset(torchvision.datasets.MNIST):
207    def __init__(self, image_size):
208        transform = torchvision.transforms.Compose([
209            torchvision.transforms.Resize(image_size),
210            torchvision.transforms.ToTensor(),
211        ])
212
213        super().__init__(str(lab.get_data_path()), train=True, download=True, transform=transform)
215    def __getitem__(self, item):
216        return super().__getitem__(item)[0]

Create MNIST dataset

219@option(Configs.dataset, 'MNIST')
220def mnist_dataset(c: Configs):
224    return MNISTDataset(c.image_size)
227def main():

Create experiment

229    experiment.create(name='diffuse', writers={'screen', 'labml'})

Create configurations

232    configs = Configs()

Set configurations. You can override the defaults by passing the values in the dictionary.

235    experiment.configs(configs, {
236        'dataset': 'CelebA',  # 'MNIST'
237        'image_channels': 3,  # 1,
238        'epochs': 100,  # 5,
239    })

Initialize

242    configs.init()

Set models for saving and loading

245    experiment.add_pytorch_models({'eps_model': configs.eps_model})

Start and run the training loop

248    with experiment.start():
249        configs.run()

253if __name__ == '__main__':
254    main()