スタイルガン 2 モデルトレーニング

これはStyleGAN 2モデルのトレーニングコードです

Generated Images

これらは、約 80K ステップのトレーニング後に生成された画像です。

私たちの実装は、最小限のStyleGAN 2モデルトレーニングコードです。実装をシンプルに保つため、単一の GPU トレーニングのみがサポートされています。なんとか縮小して、トレーニングループを含めて 500 行未満のコードに抑えることができました

DDP (分散データ並列) とマルチ GPU トレーニングがなければ、大きな解像度 (128 以上) でモデルをトレーニングすることはできません。

fp16とDDPを使ったトレーニングコードが必要な場合は、lucidrains/stylegan2-pytorchを見てください。

これをCeleba-HQデータセットでトレーニングしました。ダウンロードの説明は、fast.ai のこのディスカッションにありますdata/stylegan 画像をフォルダーに保存します

31import math
32from pathlib import Path
33from typing import Iterator, Tuple
34
35import torch
36import torch.utils.data
37import torchvision
38from PIL import Image
39
40from labml import tracker, lab, monit, experiment
41from labml.configs import BaseConfigs
42from labml_helpers.device import DeviceConfigs
43from labml_helpers.train_valid import ModeState, hook_model_outputs
44from labml_nn.gan.stylegan import Discriminator, Generator, MappingNetwork, GradientPenalty, PathLengthPenalty
45from labml_nn.gan.wasserstein import DiscriminatorLoss, GeneratorLoss
46from labml_nn.utils import cycle_dataloader

データセット

これにより、トレーニングデータセットが読み込まれ、指定された画像サイズにリサイズされます。

49class Dataset(torch.utils.data.Dataset):
  • path 画像を含むフォルダーへのパス
  • image_size 画像のサイズ
56    def __init__(self, path: str, image_size: int):
61        super().__init__()

jpg すべてのファイルのパスを取得

64        self.paths = [p for p in Path(path).glob(f'**/*.jpg')]

トランスフォーメーション

67        self.transform = torchvision.transforms.Compose([

画像のサイズを変更

69            torchvision.transforms.Resize(image_size),

PyTorch テンソルに変換

71            torchvision.transforms.ToTensor(),
72        ])

画像数

74    def __len__(self):
76        return len(self.paths)

index -番目の画像を取得

78    def __getitem__(self, index):
80        path = self.paths[index]
81        img = Image.open(path)
82        return self.transform(img)

コンフィギュレーション

85class Configs(BaseConfigs):

モデルをトレーニングするデバイス。DeviceConfigs 使用可能な CUDA デバイスを選択するか、デフォルトで CPU に設定します

93    device: torch.device = DeviceConfigs()
96    discriminator: Discriminator
98    generator: Generator
100    mapping_network: MappingNetwork

ディスクリミネーターとジェネレータの損失関数ワッサーシュタインロスを使います

104    discriminator_loss: DiscriminatorLoss
105    generator_loss: GeneratorLoss

オプティマイザー

108    generator_optimizer: torch.optim.Adam
109    discriminator_optimizer: torch.optim.Adam
110    mapping_network_optimizer: torch.optim.Adam
113    gradient_penalty = GradientPenalty()

グラデーションペナルティ係数

115    gradient_penalty_coefficient: float = 10.
118    path_length_penalty: PathLengthPenalty

データローダー

121    loader: Iterator

バッチサイズ

124    batch_size: int = 32

との次元

126    d_latent: int = 512

画像の高さ/幅

128    image_size: int = 32

マッピングネットワークのレイヤー数

130    mapping_network_layers: int = 8

ジェネレーターとディスクリミネーターの学習率

132    learning_rate: float = 1e-3

マッピングネットワークの学習率 (他よりも低い)

134    mapping_network_learning_rate: float = 1e-5

グラデーションを蓄積するステップの数。これを使って有効なバッチサイズを増やしてください。

136    gradient_accumulate_steps: int = 1

そしてアダムオプティマイザーの場合

138    adam_betas: Tuple[float, float] = (0.0, 0.99)

スタイルが混在する確率

140    style_mixing_prob: float = 0.9

トレーニングステップの総数

143    training_steps: int = 150_000

ジェネレータ内のブロック数 (画像の解像度に基づいて計算)

146    n_gen_blocks: int

レイジー・レギュラライゼーション

この論文では、正則化損失を計算する代わりに、正規化項をたまに計算する遅延正則化を提案しています。これにより、トレーニングの効率が大幅に向上します。

グラデーションペナルティを計算する間隔

154    lazy_gradient_penalty_interval: int = 4

パス長ペナルティ計算間隔

156    lazy_path_penalty_interval: int = 32

トレーニングの初期段階では、経路長のペナルティの計算を省略

158    lazy_path_penalty_after: int = 5_000

生成された画像をログに記録する頻度

161    log_generated_interval: int = 500

モデルチェックポイントを保存する頻度

163    save_checkpoint_interval: int = 2_000

アクティベーションをログに記録するためのトレーニングモード状態

166    mode: ModeState

モデルレイヤーの出力をログに記録するかどうか

168    log_layer_outputs: bool = False

これをCeleba-HQデータセットでトレーニングしました。ダウンロードの説明は、fast.ai のこのディスカッションにありますdata/stylegan 画像をフォルダーに保存します。

175    dataset_path: str = str(lab.get_data_path() / 'stylegan2')

[初期化]

177    def init(self):

データセットの作成

182        dataset = Dataset(self.dataset_path, self.image_size)

データローダーの作成

184        dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, num_workers=8,
185                                                 shuffle=True, drop_last=True, pin_memory=True)
187        self.loader = cycle_dataloader(dataloader)

画像解像度の

190        log_resolution = int(math.log2(self.image_size))

ディスクリミネーターとジェネレーターの作成

193        self.discriminator = Discriminator(log_resolution).to(self.device)
194        self.generator = Generator(log_resolution, self.d_latent).to(self.device)

スタイル入力とノイズ入力を作成するためのジェネレータブロックの数を取得

196        self.n_gen_blocks = self.generator.n_blocks

マッピングネットワークの作成

198        self.mapping_network = MappingNetwork(self.d_latent, self.mapping_network_layers).to(self.device)

パス長のペナルティロスを作成

200        self.path_length_penalty = PathLengthPenalty(0.99).to(self.device)

モニターレイヤー出力へのモデルフックの追加

203        if self.log_layer_outputs:
204            hook_model_outputs(self.mode, self.discriminator, 'discriminator')
205            hook_model_outputs(self.mode, self.generator, 'generator')
206            hook_model_outputs(self.mode, self.mapping_network, 'mapping_network')

ディスクリミネーターとジェネレータの損失

209        self.discriminator_loss = DiscriminatorLoss().to(self.device)
210        self.generator_loss = GeneratorLoss().to(self.device)

オプティマイザーの作成

213        self.discriminator_optimizer = torch.optim.Adam(
214            self.discriminator.parameters(),
215            lr=self.learning_rate, betas=self.adam_betas
216        )
217        self.generator_optimizer = torch.optim.Adam(
218            self.generator.parameters(),
219            lr=self.learning_rate, betas=self.adam_betas
220        )
221        self.mapping_network_optimizer = torch.optim.Adam(
222            self.mapping_network.parameters(),
223            lr=self.mapping_network_learning_rate, betas=self.adam_betas
224        )

トラッカー構成を設定

227        tracker.set_image("generated", True)

[サンプル]

これはランダムにサンプリングされ、マッピングネットワークから取得されます。

また、スタイルミキシングを適用して、 2つの潜在変数とを生成し、対応するおよびを取得することもあります。次に、クロスオーバーポイントをランダムにサンプリングし、クロスオーバーポイントの前のジェネレーターブロックとクロスオーバーポイント後のブロックに適用します

229    def get_w(self, batch_size: int):

ミックススタイル

243        if torch.rand(()).item() < self.style_mixing_prob:

ランダムクロスオーバーポイント

245            cross_over_point = int(torch.rand(()).item() * self.n_gen_blocks)

サンプルと

247            z2 = torch.randn(batch_size, self.d_latent).to(self.device)
248            z1 = torch.randn(batch_size, self.d_latent).to(self.device)

取得して

250            w1 = self.mapping_network(z1)
251            w2 = self.mapping_network(z2)

ジェネレータブロックを拡張して連結する

253            w1 = w1[None, :, :].expand(cross_over_point, -1, -1)
254            w2 = w2[None, :, :].expand(self.n_gen_blocks - cross_over_point, -1, -1)
255            return torch.cat((w1, w2), dim=0)

混合せずに

257        else:

サンプルと

259            z = torch.randn(batch_size, self.d_latent).to(self.device)

取得して

261            w = self.mapping_network(z)

ジェネレータブロック用に拡張

263            return w[None, :, :].expand(self.n_gen_blocks, -1, -1)
265    def get_noise(self, batch_size: int):

ノイズを保存するリスト

272        noise = []

ノイズ分解能は次から始まります

274        resolution = 4

ジェネレータブロックごとにノイズを生成

277        for i in range(self.n_gen_blocks):

最初のブロックには畳み込みが 1 つしかありません

279            if i == 0:
280                n1 = None

ノイズを生成して最初のコンボリューション層の後に追加します

282            else:
283                n1 = torch.randn(batch_size, 1, resolution, resolution, device=self.device)

ノイズを生成して 2 番目のコンボリューション層の後に追加します

285            n2 = torch.randn(batch_size, 1, resolution, resolution, device=self.device)

ノイズテンソルをリストに追加

288            noise.append((n1, n2))

次のブロックには解像度があります

291            resolution *= 2

リターン・ノイズ・テンソル

294        return noise

画像を生成

これにより、ジェネレータを使用して画像が生成されます

296    def generate_images(self, batch_size: int):

取得

304        w = self.get_w(batch_size)

ノイズが出る

306        noise = self.get_noise(batch_size)

画像を生成

309        images = self.generator(w, noise)

画像を返すと

312        return images, w

トレーニングステップ

314    def step(self, idx: int):

ディスクリミネーターのトレーニング

320        with monit.section('Discriminator'):

グラデーションをリセット

322            self.discriminator_optimizer.zero_grad()

の勾配を累積 gradient_accumulate_steps

325            for i in range(self.gradient_accumulate_steps):

[更新] mode 。アクティベーションをログに記録するかどうかを設定

327                with self.mode.update(is_log_activations=(idx + 1) % self.log_generated_interval == 0):

ジェネレータからのサンプル画像

329                    generated_images, _ = self.generate_images(self.batch_size)

生成された画像のディスクリミネーター分類

331                    fake_output = self.discriminator(generated_images.detach())

データローダーから実際の画像を取得

334                    real_images = next(self.loader).to(self.device)

勾配ペナルティのためには、実際の画像に対して勾配を計算する必要があります

336                    if (idx + 1) % self.lazy_gradient_penalty_interval == 0:
337                        real_images.requires_grad_()

実画像のディスクリミネーター分類

339                    real_output = self.discriminator(real_images)

ディスクリミネーター損失を取得

342                    real_loss, fake_loss = self.discriminator_loss(real_output, fake_output)
343                    disc_loss = real_loss + fake_loss

グラデーションペナルティを追加

346                    if (idx + 1) % self.lazy_gradient_penalty_interval == 0:

グラデーションペナルティの計算と記録

348                        gp = self.gradient_penalty(real_images, real_output)
349                        tracker.add('loss.gp', gp)

係数を掛けて勾配ペナルティを加える

351                        disc_loss = disc_loss + 0.5 * self.gradient_penalty_coefficient * gp * self.lazy_gradient_penalty_interval

勾配の計算

354                    disc_loss.backward()

ログディスクリミネーター損失

357                    tracker.add('loss.discriminator', disc_loss)
358
359            if (idx + 1) % self.log_generated_interval == 0:

ディスクリミネーターモデルのパラメーターを時々ログに記録する

361                tracker.add('discriminator', self.discriminator)

安定化のためのクリップグラデーション

364            torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), max_norm=1.0)

最適化の一歩を踏み出す

366            self.discriminator_optimizer.step()

ジェネレータをトレーニング

369        with monit.section('Generator'):

グラデーションをリセット

371            self.generator_optimizer.zero_grad()
372            self.mapping_network_optimizer.zero_grad()

の勾配を累積 gradient_accumulate_steps

375            for i in range(self.gradient_accumulate_steps):

ジェネレータからのサンプル画像

377                generated_images, w = self.generate_images(self.batch_size)

生成された画像のディスクリミネーター分類

379                fake_output = self.discriminator(generated_images)

ジェネレータロスを取得

382                gen_loss = self.generator_loss(fake_output)

パス長ペナルティを追加

385                if idx > self.lazy_path_penalty_after and (idx + 1) % self.lazy_path_penalty_interval == 0:

経路長ペナルティの計算

387                    plp = self.path_length_penalty(w, generated_images)

次の場合は無視 nan

389                    if not torch.isnan(plp):
390                        tracker.add('loss.plp', plp)
391                        gen_loss = gen_loss + plp

勾配の計算

394                gen_loss.backward()

ログジェネレータの損失

397                tracker.add('loss.generator', gen_loss)
398
399            if (idx + 1) % self.log_generated_interval == 0:

ディスクリミネーターモデルのパラメーターを時々ログに記録する

401                tracker.add('generator', self.generator)
402                tracker.add('mapping_network', self.mapping_network)

安定化のためのクリップグラデーション

405            torch.nn.utils.clip_grad_norm_(self.generator.parameters(), max_norm=1.0)
406            torch.nn.utils.clip_grad_norm_(self.mapping_network.parameters(), max_norm=1.0)

最適化の一歩を踏み出す

409            self.generator_optimizer.step()
410            self.mapping_network_optimizer.step()

生成された画像をログに記録する

413        if (idx + 1) % self.log_generated_interval == 0:
414            tracker.add('generated', torch.cat([generated_images[:6], real_images[:3]], dim=0))

モデルチェックポイントの保存

416        if (idx + 1) % self.save_checkpoint_interval == 0:
417            experiment.save_checkpoint()

フラッシュトラッカー

420        tracker.save()

鉄道模型

422    def train(self):

ループ用 training_steps

428        for i in monit.loop(self.training_steps):

トレーニングの一歩を踏み出す

430            self.step(i)

432            if (i + 1) % self.log_generated_interval == 0:
433                tracker.new_line()

トレインスタイルガン2

436def main():

テストを作成

442    experiment.create(name='stylegan2')

設定オブジェクトの作成

444    configs = Configs()

構成を設定し、一部をオーバーライドする

447    experiment.configs(configs, {
448        'device.cuda_device': 0,
449        'image_size': 64,
450        'log_generated_interval': 200
451    })

[初期化]

454    configs.init()

保存および読み込み用のモデルを設定する

456    experiment.add_pytorch_models(mapping_network=configs.mapping_network,
457                                  generator=configs.generator,
458                                  discriminator=configs.discriminator)

実験を始める

461    with experiment.start():

トレーニングループを実行する

463        configs.train()

467if __name__ == '__main__':
468    main()