これらは、約 80K ステップのトレーニング後に生成された画像です。
私たちの実装は、最小限のStyleGAN 2モデルトレーニングコードです。実装をシンプルに保つため、単一の GPU トレーニングのみがサポートされています。なんとか縮小して、トレーニングループを含めて 500 行未満のコードに抑えることができました
。DDP (分散データ並列) とマルチ GPU トレーニングがなければ、大きな解像度 (128 以上) でモデルをトレーニングすることはできません。
fp16とDDPを使ったトレーニングコードが必要な場合は、lucidrains/stylegan2-pytorchを見てください。これをCeleba-HQデータセットでトレーニングしました。ダウンロードの説明は、fast.ai のこのディスカッションにあります。data/stylegan
画像をフォルダーに保存します。
31import math
32from pathlib import Path
33from typing import Iterator, Tuple
34
35import torch
36import torch.utils.data
37import torchvision
38from PIL import Image
39
40from labml import tracker, lab, monit, experiment
41from labml.configs import BaseConfigs
42from labml_helpers.device import DeviceConfigs
43from labml_helpers.train_valid import ModeState, hook_model_outputs
44from labml_nn.gan.stylegan import Discriminator, Generator, MappingNetwork, GradientPenalty, PathLengthPenalty
45from labml_nn.gan.wasserstein import DiscriminatorLoss, GeneratorLoss
46from labml_nn.utils import cycle_dataloader
49class Dataset(torch.utils.data.Dataset):
path
画像を含むフォルダーへのパスimage_size
画像のサイズ56 def __init__(self, path: str, image_size: int):
61 super().__init__()
jpg
すべてのファイルのパスを取得
64 self.paths = [p for p in Path(path).glob(f'**/*.jpg')]
トランスフォーメーション
67 self.transform = torchvision.transforms.Compose([
画像のサイズを変更
69 torchvision.transforms.Resize(image_size),
PyTorch テンソルに変換
71 torchvision.transforms.ToTensor(),
72 ])
画像数
74 def __len__(self):
76 return len(self.paths)
index
-番目の画像を取得
78 def __getitem__(self, index):
80 path = self.paths[index]
81 img = Image.open(path)
82 return self.transform(img)
85class Configs(BaseConfigs):
モデルをトレーニングするデバイス。DeviceConfigs
使用可能な CUDA デバイスを選択するか、デフォルトで CPU に設定します
93 device: torch.device = DeviceConfigs()
96 discriminator: Discriminator
98 generator: Generator
100 mapping_network: MappingNetwork
ディスクリミネーターとジェネレータの損失関数ワッサーシュタインロスを使います
104 discriminator_loss: DiscriminatorLoss
105 generator_loss: GeneratorLoss
オプティマイザー
108 generator_optimizer: torch.optim.Adam
109 discriminator_optimizer: torch.optim.Adam
110 mapping_network_optimizer: torch.optim.Adam
113 gradient_penalty = GradientPenalty()
グラデーションペナルティ係数
115 gradient_penalty_coefficient: float = 10.
データローダー
121 loader: Iterator
バッチサイズ
124 batch_size: int = 32
との次元
126 d_latent: int = 512
画像の高さ/幅
128 image_size: int = 32
マッピングネットワークのレイヤー数
130 mapping_network_layers: int = 8
ジェネレーターとディスクリミネーターの学習率
132 learning_rate: float = 1e-3
マッピングネットワークの学習率 (他よりも低い)
134 mapping_network_learning_rate: float = 1e-5
グラデーションを蓄積するステップの数。これを使って有効なバッチサイズを増やしてください。
136 gradient_accumulate_steps: int = 1
そしてアダムオプティマイザーの場合
138 adam_betas: Tuple[float, float] = (0.0, 0.99)
スタイルが混在する確率
140 style_mixing_prob: float = 0.9
トレーニングステップの総数
143 training_steps: int = 150_000
ジェネレータ内のブロック数 (画像の解像度に基づいて計算)
146 n_gen_blocks: int
グラデーションペナルティを計算する間隔
154 lazy_gradient_penalty_interval: int = 4
パス長ペナルティ計算間隔
156 lazy_path_penalty_interval: int = 32
トレーニングの初期段階では、経路長のペナルティの計算を省略
158 lazy_path_penalty_after: int = 5_000
生成された画像をログに記録する頻度
161 log_generated_interval: int = 500
モデルチェックポイントを保存する頻度
163 save_checkpoint_interval: int = 2_000
アクティベーションをログに記録するためのトレーニングモード状態
166 mode: ModeState
モデルレイヤーの出力をログに記録するかどうか
168 log_layer_outputs: bool = False
これをCeleba-HQデータセットでトレーニングしました。ダウンロードの説明は、fast.ai のこのディスカッションにあります。data/stylegan
画像をフォルダーに保存します。
175 dataset_path: str = str(lab.get_data_path() / 'stylegan2')
177 def init(self):
データセットの作成
182 dataset = Dataset(self.dataset_path, self.image_size)
データローダーの作成
184 dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, num_workers=8,
185 shuffle=True, drop_last=True, pin_memory=True)
187 self.loader = cycle_dataloader(dataloader)
画像解像度の
190 log_resolution = int(math.log2(self.image_size))
ディスクリミネーターとジェネレーターの作成
193 self.discriminator = Discriminator(log_resolution).to(self.device)
194 self.generator = Generator(log_resolution, self.d_latent).to(self.device)
スタイル入力とノイズ入力を作成するためのジェネレータブロックの数を取得
196 self.n_gen_blocks = self.generator.n_blocks
マッピングネットワークの作成
198 self.mapping_network = MappingNetwork(self.d_latent, self.mapping_network_layers).to(self.device)
パス長のペナルティロスを作成
200 self.path_length_penalty = PathLengthPenalty(0.99).to(self.device)
モニターレイヤー出力へのモデルフックの追加
203 if self.log_layer_outputs:
204 hook_model_outputs(self.mode, self.discriminator, 'discriminator')
205 hook_model_outputs(self.mode, self.generator, 'generator')
206 hook_model_outputs(self.mode, self.mapping_network, 'mapping_network')
ディスクリミネーターとジェネレータの損失
209 self.discriminator_loss = DiscriminatorLoss().to(self.device)
210 self.generator_loss = GeneratorLoss().to(self.device)
オプティマイザーの作成
213 self.discriminator_optimizer = torch.optim.Adam(
214 self.discriminator.parameters(),
215 lr=self.learning_rate, betas=self.adam_betas
216 )
217 self.generator_optimizer = torch.optim.Adam(
218 self.generator.parameters(),
219 lr=self.learning_rate, betas=self.adam_betas
220 )
221 self.mapping_network_optimizer = torch.optim.Adam(
222 self.mapping_network.parameters(),
223 lr=self.mapping_network_learning_rate, betas=self.adam_betas
224 )
トラッカー構成を設定
227 tracker.set_image("generated", True)
これはランダムにサンプリングされ、マッピングネットワークから取得されます。
また、スタイルミキシングを適用して、 2つの潜在変数とを生成し、対応するおよびを取得することもあります。次に、クロスオーバーポイントをランダムにサンプリングし、クロスオーバーポイントの前のジェネレーターブロックとクロスオーバーポイント後のブロックに適用します
。229 def get_w(self, batch_size: int):
ミックススタイル
243 if torch.rand(()).item() < self.style_mixing_prob:
ランダムクロスオーバーポイント
245 cross_over_point = int(torch.rand(()).item() * self.n_gen_blocks)
サンプルと
247 z2 = torch.randn(batch_size, self.d_latent).to(self.device)
248 z1 = torch.randn(batch_size, self.d_latent).to(self.device)
取得して
250 w1 = self.mapping_network(z1)
251 w2 = self.mapping_network(z2)
ジェネレータブロックを拡張して連結する
253 w1 = w1[None, :, :].expand(cross_over_point, -1, -1)
254 w2 = w2[None, :, :].expand(self.n_gen_blocks - cross_over_point, -1, -1)
255 return torch.cat((w1, w2), dim=0)
混合せずに
257 else:
サンプルと
259 z = torch.randn(batch_size, self.d_latent).to(self.device)
取得して
261 w = self.mapping_network(z)
ジェネレータブロック用に拡張
263 return w[None, :, :].expand(self.n_gen_blocks, -1, -1)
265 def get_noise(self, batch_size: int):
ノイズを保存するリスト
272 noise = []
ノイズ分解能は次から始まります
274 resolution = 4
ジェネレータブロックごとにノイズを生成
277 for i in range(self.n_gen_blocks):
最初のブロックには畳み込みが 1 つしかありません
279 if i == 0:
280 n1 = None
ノイズを生成して最初のコンボリューション層の後に追加します
282 else:
283 n1 = torch.randn(batch_size, 1, resolution, resolution, device=self.device)
ノイズを生成して 2 番目のコンボリューション層の後に追加します
285 n2 = torch.randn(batch_size, 1, resolution, resolution, device=self.device)
ノイズテンソルをリストに追加
288 noise.append((n1, n2))
次のブロックには解像度があります
291 resolution *= 2
リターン・ノイズ・テンソル
294 return noise
296 def generate_images(self, batch_size: int):
取得
304 w = self.get_w(batch_size)
ノイズが出る
306 noise = self.get_noise(batch_size)
画像を生成
309 images = self.generator(w, noise)
画像を返すと
312 return images, w
314 def step(self, idx: int):
ディスクリミネーターのトレーニング
320 with monit.section('Discriminator'):
グラデーションをリセット
322 self.discriminator_optimizer.zero_grad()
の勾配を累積 gradient_accumulate_steps
325 for i in range(self.gradient_accumulate_steps):
[更新] mode
。アクティベーションをログに記録するかどうかを設定
327 with self.mode.update(is_log_activations=(idx + 1) % self.log_generated_interval == 0):
ジェネレータからのサンプル画像
329 generated_images, _ = self.generate_images(self.batch_size)
生成された画像のディスクリミネーター分類
331 fake_output = self.discriminator(generated_images.detach())
データローダーから実際の画像を取得
334 real_images = next(self.loader).to(self.device)
勾配ペナルティのためには、実際の画像に対して勾配を計算する必要があります
336 if (idx + 1) % self.lazy_gradient_penalty_interval == 0:
337 real_images.requires_grad_()
実画像のディスクリミネーター分類
339 real_output = self.discriminator(real_images)
ディスクリミネーター損失を取得
342 real_loss, fake_loss = self.discriminator_loss(real_output, fake_output)
343 disc_loss = real_loss + fake_loss
グラデーションペナルティを追加
346 if (idx + 1) % self.lazy_gradient_penalty_interval == 0:
グラデーションペナルティの計算と記録
348 gp = self.gradient_penalty(real_images, real_output)
349 tracker.add('loss.gp', gp)
係数を掛けて勾配ペナルティを加える
351 disc_loss = disc_loss + 0.5 * self.gradient_penalty_coefficient * gp * self.lazy_gradient_penalty_interval
勾配の計算
354 disc_loss.backward()
ログディスクリミネーター損失
357 tracker.add('loss.discriminator', disc_loss)
358
359 if (idx + 1) % self.log_generated_interval == 0:
ディスクリミネーターモデルのパラメーターを時々ログに記録する
361 tracker.add('discriminator', self.discriminator)
安定化のためのクリップグラデーション
364 torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), max_norm=1.0)
最適化の一歩を踏み出す
366 self.discriminator_optimizer.step()
ジェネレータをトレーニング
369 with monit.section('Generator'):
グラデーションをリセット
371 self.generator_optimizer.zero_grad()
372 self.mapping_network_optimizer.zero_grad()
の勾配を累積 gradient_accumulate_steps
375 for i in range(self.gradient_accumulate_steps):
ジェネレータからのサンプル画像
377 generated_images, w = self.generate_images(self.batch_size)
生成された画像のディスクリミネーター分類
379 fake_output = self.discriminator(generated_images)
ジェネレータロスを取得
382 gen_loss = self.generator_loss(fake_output)
パス長ペナルティを追加
385 if idx > self.lazy_path_penalty_after and (idx + 1) % self.lazy_path_penalty_interval == 0:
経路長ペナルティの計算
387 plp = self.path_length_penalty(w, generated_images)
次の場合は無視 nan
389 if not torch.isnan(plp):
390 tracker.add('loss.plp', plp)
391 gen_loss = gen_loss + plp
勾配の計算
394 gen_loss.backward()
ログジェネレータの損失
397 tracker.add('loss.generator', gen_loss)
398
399 if (idx + 1) % self.log_generated_interval == 0:
ディスクリミネーターモデルのパラメーターを時々ログに記録する
401 tracker.add('generator', self.generator)
402 tracker.add('mapping_network', self.mapping_network)
安定化のためのクリップグラデーション
405 torch.nn.utils.clip_grad_norm_(self.generator.parameters(), max_norm=1.0)
406 torch.nn.utils.clip_grad_norm_(self.mapping_network.parameters(), max_norm=1.0)
最適化の一歩を踏み出す
409 self.generator_optimizer.step()
410 self.mapping_network_optimizer.step()
生成された画像をログに記録する
413 if (idx + 1) % self.log_generated_interval == 0:
414 tracker.add('generated', torch.cat([generated_images[:6], real_images[:3]], dim=0))
モデルチェックポイントの保存
416 if (idx + 1) % self.save_checkpoint_interval == 0:
417 experiment.save_checkpoint()
フラッシュトラッカー
420 tracker.save()
422 def train(self):
ループ用 training_steps
428 for i in monit.loop(self.training_steps):
トレーニングの一歩を踏み出す
430 self.step(i)
432 if (i + 1) % self.log_generated_interval == 0:
433 tracker.new_line()
436def main():
テストを作成
442 experiment.create(name='stylegan2')
設定オブジェクトの作成
444 configs = Configs()
構成を設定し、一部をオーバーライドする
447 experiment.configs(configs, {
448 'device.cuda_device': 0,
449 'image_size': 64,
450 'log_generated_interval': 200
451 })
[初期化]
454 configs.init()
保存および読み込み用のモデルを設定する
456 experiment.add_pytorch_models(mapping_network=configs.mapping_network,
457 generator=configs.generator,
458 discriminator=configs.discriminator)
実験を始める
461 with experiment.start():
トレーニングループを実行する
463 configs.train()
467if __name__ == '__main__':
468 main()