これにより、CeleBA HQ データセットで DDPM ベースのモデルがトレーニングされます。ダウンロードの説明は、fast.ai のこのディスカッションにあります。data/celebA
画像をフォルダーに保存します。
この論文では、モデルの指数移動平均を減衰させて使用していました。簡略化のため、ここでは省略しています
。20from typing import List
21
22import torch
23import torch.utils.data
24import torchvision
25from PIL import Image
26
27from labml import lab, tracker, experiment, monit
28from labml.configs import BaseConfigs, option
29from labml_helpers.device import DeviceConfigs
30from labml_nn.diffusion.ddpm import DenoiseDiffusion
31from labml_nn.diffusion.ddpm.unet import UNet
34class Configs(BaseConfigs):
モデルをトレーニングするデバイス。DeviceConfigs
使用可能な CUDA デバイスを選択するか、デフォルトで CPU に設定します
41 device: torch.device = DeviceConfigs()
用の U-Net モデル
44 eps_model: UNet
46 diffusion: DenoiseDiffusion
画像内のチャンネル数。RGB 用です。
49 image_channels: int = 3
画像サイズ
51 image_size: int = 32
初期機能マップのチャンネル数
53 n_channels: int = 64
各解像度のチャンネル番号のリスト。チャンネル数は channel_multipliers[i] * n_channels
56 channel_multipliers: List[int] = [1, 2, 2, 4]
各解像度で注意を向けるかどうかを示すブーリアンのリスト
58 is_attention: List[int] = [False, False, False, True]
タイムステップ数
61 n_steps: int = 1_000
バッチサイズ
63 batch_size: int = 64
生成するサンプルの数
65 n_samples: int = 16
学習率
67 learning_rate: float = 2e-5
トレーニングエポックの数
70 epochs: int = 1_000
データセット
73 dataset: torch.utils.data.Dataset
データローダー
75 data_loader: torch.utils.data.DataLoader
アダム・オプティマイザー
78 optimizer: torch.optim.Adam
80 def init(self):
モデル作成
82 self.eps_model = UNet(
83 image_channels=self.image_channels,
84 n_channels=self.n_channels,
85 ch_mults=self.channel_multipliers,
86 is_attn=self.is_attention,
87 ).to(self.device)
90 self.diffusion = DenoiseDiffusion(
91 eps_model=self.eps_model,
92 n_steps=self.n_steps,
93 device=self.device,
94 )
データローダーの作成
97 self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size, shuffle=True, pin_memory=True)
オプティマイザーを作成
99 self.optimizer = torch.optim.Adam(self.eps_model.parameters(), lr=self.learning_rate)
画像ロギング
102 tracker.set_image("sample", True)
104 def sample(self):
108 with torch.no_grad():
110 x = torch.randn([self.n_samples, self.image_channels, self.image_size, self.image_size],
111 device=self.device)
ステップのノイズ除去
114 for t_ in monit.iterate('Sample', self.n_steps):
116 t = self.n_steps - t_ - 1
からのサンプル
118 x = self.diffusion.p_sample(x, x.new_full((self.n_samples,), t, dtype=torch.long))
ログサンプル
121 tracker.save('sample', x)
123 def train(self):
データセットの反復処理
129 for data in monit.iterate('Train', self.data_loader):
グローバルステップをインクリメント
131 tracker.add_global_step()
データをデバイスに移動
133 data = data.to(self.device)
グラデーションをゼロにする
136 self.optimizer.zero_grad()
損失の計算
138 loss = self.diffusion.loss(data)
勾配の計算
140 loss.backward()
最適化の一歩を踏み出す
142 self.optimizer.step()
損失をトラッキング
144 tracker.save('loss', loss)
146 def run(self):
150 for _ in monit.loop(self.epochs):
モデルのトレーニング
152 self.train()
いくつかの画像のサンプル
154 self.sample()
コンソールの新しい行
156 tracker.new_line()
モデルを保存する
158 experiment.save_checkpoint()
161class CelebADataset(torch.utils.data.Dataset):
166 def __init__(self, image_size: int):
167 super().__init__()
セレバ画像フォルダー
170 folder = lab.get_data_path() / 'celebA'
ファイルリスト
172 self._files = [p for p in folder.glob(f'**/*.jpg')]
画像のサイズを変更してテンソルに変換する変換
175 self._transform = torchvision.transforms.Compose([
176 torchvision.transforms.Resize(image_size),
177 torchvision.transforms.ToTensor(),
178 ])
データセットのサイズ
180 def __len__(self):
184 return len(self._files)
画像を取得
186 def __getitem__(self, index: int):
190 img = Image.open(self._files[index])
191 return self._transform(img)
CeleBA データセットの作成
194@option(Configs.dataset, 'CelebA')
195def celeb_dataset(c: Configs):
199 return CelebADataset(c.image_size)
202class MNISTDataset(torchvision.datasets.MNIST):
207 def __init__(self, image_size):
208 transform = torchvision.transforms.Compose([
209 torchvision.transforms.Resize(image_size),
210 torchvision.transforms.ToTensor(),
211 ])
212
213 super().__init__(str(lab.get_data_path()), train=True, download=True, transform=transform)
215 def __getitem__(self, item):
216 return super().__getitem__(item)[0]
MNIST データセットの作成
219@option(Configs.dataset, 'MNIST')
220def mnist_dataset(c: Configs):
224 return MNISTDataset(c.image_size)
227def main():
実験を作成
229 experiment.create(name='diffuse', writers={'screen', 'labml'})
構成の作成
232 configs = Configs()
構成を設定します。ディクショナリに値を渡すことでデフォルトをオーバーライドできます。
235 experiment.configs(configs, {
236 'dataset': 'CelebA', # 'MNIST'
237 'image_channels': 3, # 1,
238 'epochs': 100, # 5,
239 })
[初期化]
242 configs.init()
保存および読み込み用のモデルを設定する
245 experiment.add_pytorch_models({'eps_model': configs.eps_model})
トレーニングループを開始して実行する
248 with experiment.start():
249 configs.run()
253if __name__ == '__main__':
254 main()