降噪扩散概率模型 (DDPM) 训练

Open In Colab

这将基于 CeleBA HQ 数据集训练基于 DDPM 的模型。你可以在 fast.ai 的讨论中找到下载说明。将图像保存在data/celebA 文件夹中

该论文使用了该模型的指数移动平均线,其衰减量为。为简单起见,我们跳过了这个。

20from typing import List
21
22import torch
23import torch.utils.data
24import torchvision
25from PIL import Image
26
27from labml import lab, tracker, experiment, monit
28from labml.configs import BaseConfigs, option
29from labml_helpers.device import DeviceConfigs
30from labml_nn.diffusion.ddpm import DenoiseDiffusion
31from labml_nn.diffusion.ddpm.unet import UNet

配置

34class Configs(BaseConfigs):

用于训练模型的设备。DeviceConfigs 选择可用的 CUDA 设备或默认为 CPU。

41    device: torch.device = DeviceConfigs()

U-Net 模型用于

44    eps_model: UNet
46    diffusion: DenoiseDiffusion

图像中的通道数。对于 RGB。

49    image_channels: int = 3

图像大小

51    image_size: int = 32

初始特征图中的频道数量

53    n_channels: int = 64

每种分辨率下的通道编号列表。频道的数量是channel_multipliers[i] * n_channels

56    channel_multipliers: List[int] = [1, 2, 2, 4]

指示是否在每个分辨率下使用注意力的布尔值列表

58    is_attention: List[int] = [False, False, False, True]

时间步数

61    n_steps: int = 1_000

批量大小

63    batch_size: int = 64

要生成的样本数

65    n_samples: int = 16

学习率

67    learning_rate: float = 2e-5

训练周期的数量

70    epochs: int = 1_000

数据集

73    dataset: torch.utils.data.Dataset

数据加载器

75    data_loader: torch.utils.data.DataLoader

Adam 优化器

78    optimizer: torch.optim.Adam
80    def init(self):

创建模型

82        self.eps_model = UNet(
83            image_channels=self.image_channels,
84            n_channels=self.n_channels,
85            ch_mults=self.channel_multipliers,
86            is_attn=self.is_attention,
87        ).to(self.device)

创建 DDPM 类

90        self.diffusion = DenoiseDiffusion(
91            eps_model=self.eps_model,
92            n_steps=self.n_steps,
93            device=self.device,
94        )

创建数据加载器

97        self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size, shuffle=True, pin_memory=True)

创建优化器

99        self.optimizer = torch.optim.Adam(self.eps_model.parameters(), lr=self.learning_rate)

图像日志记录

102        tracker.set_image("sample", True)

样本图片

104    def sample(self):
108        with torch.no_grad():

110            x = torch.randn([self.n_samples, self.image_channels, self.image_size, self.image_size],
111                            device=self.device)

消除台阶噪音

114            for t_ in monit.iterate('Sample', self.n_steps):

116                t = self.n_steps - t_ - 1

样本来自

118                x = self.diffusion.p_sample(x, x.new_full((self.n_samples,), t, dtype=torch.long))

日志样本

121            tracker.save('sample', x)

火车

123    def train(self):

遍历数据集

129        for data in monit.iterate('Train', self.data_loader):

递增全局步长

131            tracker.add_global_step()

将数据移动到设备

133            data = data.to(self.device)

将渐变设为零

136            self.optimizer.zero_grad()

计算损失

138            loss = self.diffusion.loss(data)

计算梯度

140            loss.backward()

采取优化步骤

142            self.optimizer.step()

追踪损失

144            tracker.save('loss', loss)

训练循环

146    def run(self):
150        for _ in monit.loop(self.epochs):

训练模型

152            self.train()

对一些图像进行采样

154            self.sample()

控制台中的新行

156            tracker.new_line()

保存模型

158            experiment.save_checkpoint()

CeleBA HQ 数据集

161class CelebADataset(torch.utils.data.Dataset):
166    def __init__(self, image_size: int):
167        super().__init__()

CeleBA 图片文件夹

170        folder = lab.get_data_path() / 'celebA'

文件清单

172        self._files = [p for p in folder.glob(f'**/*.jpg')]

用于调整图像大小并转换为张量的转换

175        self._transform = torchvision.transforms.Compose([
176            torchvision.transforms.Resize(image_size),
177            torchvision.transforms.ToTensor(),
178        ])

数据集的大小

180    def __len__(self):
184        return len(self._files)

获取一张图片

186    def __getitem__(self, index: int):
190        img = Image.open(self._files[index])
191        return self._transform(img)

创建 CeleBA 数据集

194@option(Configs.dataset, 'CelebA')
195def celeb_dataset(c: Configs):
199    return CelebADataset(c.image_size)

MNIST 数据集

202class MNISTDataset(torchvision.datasets.MNIST):
207    def __init__(self, image_size):
208        transform = torchvision.transforms.Compose([
209            torchvision.transforms.Resize(image_size),
210            torchvision.transforms.ToTensor(),
211        ])
212
213        super().__init__(str(lab.get_data_path()), train=True, download=True, transform=transform)
215    def __getitem__(self, item):
216        return super().__getitem__(item)[0]

创建 MNIST 数据集

219@option(Configs.dataset, 'MNIST')
220def mnist_dataset(c: Configs):
224    return MNISTDataset(c.image_size)
227def main():

创建实验

229    experiment.create(name='diffuse', writers={'screen', 'labml'})

创建配置

232    configs = Configs()

设置配置。您可以通过在字典中传递值来覆盖默认值。

235    experiment.configs(configs, {
236        'dataset': 'CelebA',  # 'MNIST'
237        'image_channels': 3,  # 1,
238        'epochs': 100,  # 5,
239    })

初始化

242    configs.init()

设置用于保存和加载的模型

245    experiment.add_pytorch_models({'eps_model': configs.eps_model})

启动并运行训练循环

248    with experiment.start():
249        configs.run()

253if __name__ == '__main__':
254    main()