ノイズ除去拡散確率モデル (DDPM) トレーニング

Open In Colab

これにより、CeleBA HQ データセットで DDPM ベースのモデルがトレーニングされます。ダウンロードの説明は、fast.ai のこのディスカッションにありますdata/celebA 画像をフォルダーに保存します

この論文では、モデルの指数移動平均を減衰させて使用していました。簡略化のため、ここでは省略しています

20from typing import List
21
22import torch
23import torch.utils.data
24import torchvision
25from PIL import Image
26
27from labml import lab, tracker, experiment, monit
28from labml.configs import BaseConfigs, option
29from labml_helpers.device import DeviceConfigs
30from labml_nn.diffusion.ddpm import DenoiseDiffusion
31from labml_nn.diffusion.ddpm.unet import UNet

コンフィギュレーション

34class Configs(BaseConfigs):

モデルをトレーニングするデバイス。DeviceConfigs 使用可能な CUDA デバイスを選択するか、デフォルトで CPU に設定します

41    device: torch.device = DeviceConfigs()

用の U-Net モデル

44    eps_model: UNet
46    diffusion: DenoiseDiffusion

画像内のチャンネル数。RGB 用です。

49    image_channels: int = 3

画像サイズ

51    image_size: int = 32

初期機能マップのチャンネル数

53    n_channels: int = 64

各解像度のチャンネル番号のリスト。チャンネル数は channel_multipliers[i] * n_channels

56    channel_multipliers: List[int] = [1, 2, 2, 4]

各解像度で注意を向けるかどうかを示すブーリアンのリスト

58    is_attention: List[int] = [False, False, False, True]

タイムステップ数

61    n_steps: int = 1_000

バッチサイズ

63    batch_size: int = 64

生成するサンプルの数

65    n_samples: int = 16

学習率

67    learning_rate: float = 2e-5

トレーニングエポックの数

70    epochs: int = 1_000

データセット

73    dataset: torch.utils.data.Dataset

データローダー

75    data_loader: torch.utils.data.DataLoader

アダム・オプティマイザー

78    optimizer: torch.optim.Adam
80    def init(self):

モデル作成

82        self.eps_model = UNet(
83            image_channels=self.image_channels,
84            n_channels=self.n_channels,
85            ch_mults=self.channel_multipliers,
86            is_attn=self.is_attention,
87        ).to(self.device)
90        self.diffusion = DenoiseDiffusion(
91            eps_model=self.eps_model,
92            n_steps=self.n_steps,
93            device=self.device,
94        )

データローダーの作成

97        self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size, shuffle=True, pin_memory=True)

オプティマイザーを作成

99        self.optimizer = torch.optim.Adam(self.eps_model.parameters(), lr=self.learning_rate)

画像ロギング

102        tracker.set_image("sample", True)

サンプル画像

104    def sample(self):
108        with torch.no_grad():

110            x = torch.randn([self.n_samples, self.image_channels, self.image_size, self.image_size],
111                            device=self.device)

ステップのノイズ除去

114            for t_ in monit.iterate('Sample', self.n_steps):

116                t = self.n_steps - t_ - 1

からのサンプル

118                x = self.diffusion.p_sample(x, x.new_full((self.n_samples,), t, dtype=torch.long))

ログサンプル

121            tracker.save('sample', x)

列車

123    def train(self):

データセットの反復処理

129        for data in monit.iterate('Train', self.data_loader):

グローバルステップをインクリメント

131            tracker.add_global_step()

データをデバイスに移動

133            data = data.to(self.device)

グラデーションをゼロにする

136            self.optimizer.zero_grad()

損失の計算

138            loss = self.diffusion.loss(data)

勾配の計算

140            loss.backward()

最適化の一歩を踏み出す

142            self.optimizer.step()

損失をトラッキング

144            tracker.save('loss', loss)

トレーニングループ

146    def run(self):
150        for _ in monit.loop(self.epochs):

モデルのトレーニング

152            self.train()

いくつかの画像のサンプル

154            self.sample()

コンソールの新しい行

156            tracker.new_line()

モデルを保存する

158            experiment.save_checkpoint()

CeleBA 本社データセット

161class CelebADataset(torch.utils.data.Dataset):
166    def __init__(self, image_size: int):
167        super().__init__()

セレバ画像フォルダー

170        folder = lab.get_data_path() / 'celebA'

ファイルリスト

172        self._files = [p for p in folder.glob(f'**/*.jpg')]

画像のサイズを変更してテンソルに変換する変換

175        self._transform = torchvision.transforms.Compose([
176            torchvision.transforms.Resize(image_size),
177            torchvision.transforms.ToTensor(),
178        ])

データセットのサイズ

180    def __len__(self):
184        return len(self._files)

画像を取得

186    def __getitem__(self, index: int):
190        img = Image.open(self._files[index])
191        return self._transform(img)

CeleBA データセットの作成

194@option(Configs.dataset, 'CelebA')
195def celeb_dataset(c: Configs):
199    return CelebADataset(c.image_size)

MNIST データセット

202class MNISTDataset(torchvision.datasets.MNIST):
207    def __init__(self, image_size):
208        transform = torchvision.transforms.Compose([
209            torchvision.transforms.Resize(image_size),
210            torchvision.transforms.ToTensor(),
211        ])
212
213        super().__init__(str(lab.get_data_path()), train=True, download=True, transform=transform)
215    def __getitem__(self, item):
216        return super().__getitem__(item)[0]

MNIST データセットの作成

219@option(Configs.dataset, 'MNIST')
220def mnist_dataset(c: Configs):
224    return MNISTDataset(c.image_size)
227def main():

実験を作成

229    experiment.create(name='diffuse', writers={'screen', 'labml'})

構成の作成

232    configs = Configs()

構成を設定します。ディクショナリに値を渡すことでデフォルトをオーバーライドできます。

235    experiment.configs(configs, {
236        'dataset': 'CelebA',  # 'MNIST'
237        'image_channels': 3,  # 1,
238        'epochs': 100,  # 5,
239    })

[初期化]

242    configs.init()

保存および読み込み用のモデルを設定する

245    experiment.add_pytorch_models({'eps_model': configs.eps_model})

トレーニングループを開始して実行する

248    with experiment.start():
249        configs.run()

253if __name__ == '__main__':
254    main()