Denoising Diffusion Probabilistic Models (DDPM) training

Open In Colab

This trains a DDPM based model on CelebA HQ dataset. You can find the download instruction in this discussion on fast.ai. Save the images inside data/celebA folder.

The paper had used a exponential moving average of the model with a decay of . We have skipped this for simplicity.

20from typing import List
21
22import torchvision
23from PIL import Image
24
25import torch
26import torch.utils.data
27from labml import lab, tracker, experiment, monit
28from labml.configs import BaseConfigs, option
29from labml_nn.diffusion.ddpm import DenoiseDiffusion
30from labml_nn.diffusion.ddpm.unet import UNet
31from labml_nn.helpers.device import DeviceConfigs

Configurations

34class Configs(BaseConfigs):

Device to train the model on. DeviceConfigs picks up an available CUDA device or defaults to CPU.

41    device: torch.device = DeviceConfigs()

U-Net model for

44    eps_model: UNet
46    diffusion: DenoiseDiffusion

Number of channels in the image. for RGB.

49    image_channels: int = 3

Image size

51    image_size: int = 32

Number of channels in the initial feature map

53    n_channels: int = 64

The list of channel numbers at each resolution. The number of channels is channel_multipliers[i] * n_channels

56    channel_multipliers: List[int] = [1, 2, 2, 4]

The list of booleans that indicate whether to use attention at each resolution

58    is_attention: List[int] = [False, False, False, True]

Number of time steps

61    n_steps: int = 1_000

Batch size

63    batch_size: int = 64

Number of samples to generate

65    n_samples: int = 16

Learning rate

67    learning_rate: float = 2e-5

Number of training epochs

70    epochs: int = 1_000

Dataset

73    dataset: torch.utils.data.Dataset

Dataloader

75    data_loader: torch.utils.data.DataLoader

Adam optimizer

78    optimizer: torch.optim.Adam
80    def init(self):

Create model

82        self.eps_model = UNet(
83            image_channels=self.image_channels,
84            n_channels=self.n_channels,
85            ch_mults=self.channel_multipliers,
86            is_attn=self.is_attention,
87        ).to(self.device)

Create DDPM class

90        self.diffusion = DenoiseDiffusion(
91            eps_model=self.eps_model,
92            n_steps=self.n_steps,
93            device=self.device,
94        )

Create dataloader

97        self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size, shuffle=True, pin_memory=True)

Create optimizer

99        self.optimizer = torch.optim.Adam(self.eps_model.parameters(), lr=self.learning_rate)

Image logging

102        tracker.set_image("sample", True)

Sample images

104    def sample(self):
108        with torch.no_grad():

110            x = torch.randn([self.n_samples, self.image_channels, self.image_size, self.image_size],
111                            device=self.device)

Remove noise for steps

114            for t_ in monit.iterate('Sample', self.n_steps):

116                t = self.n_steps - t_ - 1

Sample from

118                x = self.diffusion.p_sample(x, x.new_full((self.n_samples,), t, dtype=torch.long))

Log samples

121            tracker.save('sample', x)

Train

123    def train(self):

Iterate through the dataset

129        for data in monit.iterate('Train', self.data_loader):

Increment global step

131            tracker.add_global_step()

Move data to device

133            data = data.to(self.device)

Make the gradients zero

136            self.optimizer.zero_grad()

Calculate loss

138            loss = self.diffusion.loss(data)

Compute gradients

140            loss.backward()

Take an optimization step

142            self.optimizer.step()

Track the loss

144            tracker.save('loss', loss)

Training loop

146    def run(self):
150        for _ in monit.loop(self.epochs):

Train the model

152            self.train()

Sample some images

154            self.sample()

New line in the console

156            tracker.new_line()

CelebA HQ dataset

159class CelebADataset(torch.utils.data.Dataset):
164    def __init__(self, image_size: int):
165        super().__init__()

CelebA images folder

168        folder = lab.get_data_path() / 'celebA'

List of files

170        self._files = [p for p in folder.glob(f'**/*.jpg')]

Transformations to resize the image and convert to tensor

173        self._transform = torchvision.transforms.Compose([
174            torchvision.transforms.Resize(image_size),
175            torchvision.transforms.ToTensor(),
176        ])

Size of the dataset

178    def __len__(self):
182        return len(self._files)

Get an image

184    def __getitem__(self, index: int):
188        img = Image.open(self._files[index])
189        return self._transform(img)

Create CelebA dataset

192@option(Configs.dataset, 'CelebA')
193def celeb_dataset(c: Configs):
197    return CelebADataset(c.image_size)

MNIST dataset

200class MNISTDataset(torchvision.datasets.MNIST):
205    def __init__(self, image_size):
206        transform = torchvision.transforms.Compose([
207            torchvision.transforms.Resize(image_size),
208            torchvision.transforms.ToTensor(),
209        ])
210
211        super().__init__(str(lab.get_data_path()), train=True, download=True, transform=transform)
213    def __getitem__(self, item):
214        return super().__getitem__(item)[0]

Create MNIST dataset

217@option(Configs.dataset, 'MNIST')
218def mnist_dataset(c: Configs):
222    return MNISTDataset(c.image_size)
225def main():

Create experiment

227    experiment.create(name='diffuse', writers={'screen', 'labml'})

Create configurations

230    configs = Configs()

Set configurations. You can override the defaults by passing the values in the dictionary.

233    experiment.configs(configs, {
234        'dataset': 'CelebA',  # 'MNIST'
235        'image_channels': 3,  # 1,
236        'epochs': 100,  # 5,
237    })

Initialize

240    configs.init()

Set models for saving and loading

243    experiment.add_pytorch_models({'eps_model': configs.eps_model})

Start and run the training loop

246    with experiment.start():
247        configs.run()

251if __name__ == '__main__':
252    main()