サイクル GAN

これは、サイクルコンシステントな敵対的ネットワークを使用したペアリングされていない画像から画像への翻訳という論文のPyTorch実装/チュートリアルです

エリック・リンダーノレン/ピトーチ・ガンからコードの一部を取り出しました。他のGANバリエーションもチェックしたい場合にとても良いリソースです。

サイクルGANは画像から画像への変換を行います。特定の分布の画像を別のクラスAとBの画像に変換するようにモデルを訓練します。特定の分布の画像は、特定のスタイルや性質の画像などである可能性があります。モデルにはAとBの画像をペアにする必要はありません。各クラスの画像のセットで十分です。これは、たとえば夏から冬に、絵のスタイルを写真に、馬をシマウマに変えるなど、画像スタイルの変更、照明の変更、パターンの変更などにとても効果的です。

Cycle GAN は 2 つのジェネレータモデルと 2 つのディスクリミネーターモデルをトレーニングします。一方のジェネレータはイメージを A から B に、もう 1 つのジェネレータは B から A に変換します。ディスクリミネータは、生成されたイメージが本物に見えるかどうかをテストします

このファイルには、モデルコードとトレーニングコードが含まれています。Google Colabノートブックもあります

Open In Colab

35import itertools
36import random
37import zipfile
38from typing import Tuple
39
40import torch
41import torch.nn as nn
42import torchvision.transforms as transforms
43from PIL import Image
44from torch.utils.data import DataLoader, Dataset
45from torchvision.transforms import InterpolationMode
46from torchvision.utils import make_grid
47
48from labml import lab, tracker, experiment, monit
49from labml.configs import BaseConfigs
50from labml.utils.download import download_file
51from labml.utils.pytorch import get_modules
52from labml_helpers.device import DeviceConfigs
53from labml_helpers.module import Module

ジェネレータは残留ネットワークです。

56class GeneratorResNet(Module):
61    def __init__(self, input_channels: int, n_residual_blocks: int):
62        super().__init__()

この最初のブロックは畳み込みを実行し、画像を特徴マップにマッピングします。パディングがになっているため、出力フィーチャマップの高さと幅は同じです。エッジの画質が良くなるため、反射パディングが使われています

inplace=True in ReLU はメモリを少し節約できます。

70        out_features = 64
71        layers = [
72            nn.Conv2d(input_channels, out_features, kernel_size=7, padding=3, padding_mode='reflect'),
73            nn.InstanceNorm2d(out_features),
74            nn.ReLU(inplace=True),
75        ]
76        in_features = out_features

ストライドが 2 の 2 つのコンボリューションでダウンサンプリングします。

80        for _ in range(2):
81            out_features *= 2
82            layers += [
83                nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
84                nn.InstanceNorm2d(out_features),
85                nn.ReLU(inplace=True),
86            ]
87            in_features = out_features

n_residual_blocks これをやり遂げます。このモジュールは以下に定義されています。

91        for _ in range(n_residual_blocks):
92            layers += [ResidualBlock(out_features)]

次に、生成された特徴マップは、元の画像の高さと幅に一致するようにアップサンプリングされます。

96        for _ in range(2):
97            out_features //= 2
98            layers += [
99                nn.Upsample(scale_factor=2),
100                nn.Conv2d(in_features, out_features, kernel_size=3, stride=1, padding=1),
101                nn.InstanceNorm2d(out_features),
102                nn.ReLU(inplace=True),
103            ]
104            in_features = out_features

最後に、特徴マップを RGB 画像にマッピングします。

107        layers += [nn.Conv2d(out_features, input_channels, 7, padding=3, padding_mode='reflect'), nn.Tanh()]

レイヤーを含むシーケンシャルモジュールを作成

110        self.layers = nn.Sequential(*layers)

ウェイトを次のように初期化

113        self.apply(weights_init_normal)
115    def forward(self, x):
116        return self.layers(x)

これは、畳み込み層が 2 つある残差ブロックです。

119class ResidualBlock(Module):
124    def __init__(self, in_features: int):
125        super().__init__()
126        self.block = nn.Sequential(
127            nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
128            nn.InstanceNorm2d(in_features),
129            nn.ReLU(inplace=True),
130            nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
131            nn.InstanceNorm2d(in_features),
132            nn.ReLU(inplace=True),
133        )
135    def forward(self, x: torch.Tensor):
136        return x + self.block(x)

これがディスクリミネーターです。

139class Discriminator(Module):
144    def __init__(self, input_shape: Tuple[int, int, int]):
145        super().__init__()
146        channels, height, width = input_shape

ディスクリミネーターの出力は、画像の各領域が実在するか生成されたものかを問わず、確率のマップでもあります。

150        self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)
151
152        self.layers = nn.Sequential(

これらのブロックはそれぞれ、高さと幅を 2 分の 1 に縮小します。

154            DiscriminatorBlock(channels, 64, normalize=False),
155            DiscriminatorBlock(64, 128),
156            DiscriminatorBlock(128, 256),
157            DiscriminatorBlock(256, 512),

出力の高さと幅をカーネルと同じに保つため、上部と左側にゼロパッドがあります

160            nn.ZeroPad2d((1, 0, 1, 0)),
161            nn.Conv2d(512, 1, kernel_size=4, padding=1)
162        )

ウェイトを次のように初期化

165        self.apply(weights_init_normal)
167    def forward(self, img):
168        return self.layers(img)

これはディスクリミネーターブロックモジュールです。畳み込み、オプションの正規化、およびリークの多い ReLU を行います。

入力フィーチャマップの高さと幅が半分に縮小されます。

171class DiscriminatorBlock(Module):
179    def __init__(self, in_filters: int, out_filters: int, normalize: bool = True):
180        super().__init__()
181        layers = [nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=2, padding=1)]
182        if normalize:
183            layers.append(nn.InstanceNorm2d(out_filters))
184        layers.append(nn.LeakyReLU(0.2, inplace=True))
185        self.layers = nn.Sequential(*layers)
187    def forward(self, x: torch.Tensor):
188        return self.layers(x)

畳み込み層の重みを次のように初期化

191def weights_init_normal(m):
195    classname = m.__class__.__name__
196    if classname.find("Conv") != -1:
197        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)

画像を読み込み、グレースケールの場合は RGB に変更します。

200def load_image(path: str):
204    image = Image.open(path)
205    if image.mode != 'RGB':
206        image = Image.new("RGB", image.size).paste(image)
207
208    return image

画像を読み込むデータセット

211class ImageDataset(Dataset):

データセットのダウンロードとデータの抽出

216    @staticmethod
217    def download(dataset_name: str):

URL

222        url = f'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/{dataset_name}.zip'

ダウンロードフォルダー

224        root = lab.get_data_path() / 'cycle_gan'
225        if not root.exists():
226            root.mkdir(parents=True)

ダウンロード先

228        archive = root / f'{dataset_name}.zip'

ダウンロードファイル (通常は 100 MB まで)

230        download_file(url, archive)

アーカイブを抽出

232        with zipfile.ZipFile(archive, 'r') as f:
233            f.extractall(root)

データセットの初期化

  • dataset_name はデータセットの名前
  • transforms_ 画像変換のセットです
  • mode train またはのどちらかです test
235    def __init__(self, dataset_name: str, transforms_, mode: str):

データセットパス

244        root = lab.get_data_path() / 'cycle_gan' / dataset_name

見つからない場合はダウンロード

246        if not root.exists():
247            self.download(dataset_name)

画像変換

250        self.transform = transforms.Compose(transforms_)

画像パスを取得

253        path_a = root / f'{mode}A'
254        path_b = root / f'{mode}B'
255        self.files_a = sorted(str(f) for f in path_a.iterdir())
256        self.files_b = sorted(str(f) for f in path_b.iterdir())
258    def __getitem__(self, index):

2 つの画像を返します。これらのペアはまとめてバッチ処理され、トレーニング中のペアとは異なります。だから、いつも同じペアを与え続けても大丈夫です。

262        return {"x": self.transform(load_image(self.files_a[index % len(self.files_a)])),
263                "y": self.transform(load_image(self.files_b[index % len(self.files_b)]))}
265    def __len__(self):

データセットの画像数

267        return max(len(self.files_a), len(self.files_b))

リプレイバッファ

リプレイバッファはディスクリミネーターのトレーニングに使用されます。生成された画像は再生バッファに追加され、そこからサンプリングされます

再生バッファは、新しく追加された画像を、の確率で返します。それ以外の場合は、古い生成イメージを送信し、古いイメージを新しく生成されたイメージに置き換えます

これはモデルの振動を減らすためです。

270class ReplayBuffer:
284    def __init__(self, max_size: int = 50):
285        self.max_size = max_size
286        self.data = []

画像の追加/取得

288    def push_and_pop(self, data: torch.Tensor):
290        data = data.detach()
291        res = []
292        for element in data:
293            if len(self.data) < self.max_size:
294                self.data.append(element)
295                res.append(element)
296            else:
297                if random.uniform(0, 1) > 0.5:
298                    i = random.randint(0, self.max_size - 1)
299                    res.append(self.data[i].clone())
300                    self.data[i] = element
301                else:
302                    res.append(element)
303        return torch.stack(res)

コンフィギュレーション

306class Configs(BaseConfigs):

DeviceConfigs 利用可能な場合は GPU を選択します

310    device: torch.device = DeviceConfigs()

ハイパーパラメータ

313    epochs: int = 200
314    dataset_name: str = 'monet2photo'
315    batch_size: int = 1
316
317    data_loader_workers = 8
318
319    learning_rate = 0.0002
320    adam_betas = (0.5, 0.999)
321    decay_start = 100

この論文では、負の対数確率よりも安定性が高いことがわかっているため、最小二乗損失の代わりに最小二乗損失を使用することを提案しています。

325    gan_loss = torch.nn.MSELoss()

L1ロスはサイクルロスとアイデンティティロスに使用されます

328    cycle_loss = torch.nn.L1Loss()
329    identity_loss = torch.nn.L1Loss()

画像サイズ

332    img_height = 256
333    img_width = 256
334    img_channels = 3

ジェネレータ内の残留ブロック数

337    n_residual_blocks = 9

損失係数

340    cyclic_loss_coefficient = 10.0
341    identity_loss_coefficient = 5.
342
343    sample_interval = 500

モデル

346    generator_xy: GeneratorResNet
347    generator_yx: GeneratorResNet
348    discriminator_x: Discriminator
349    discriminator_y: Discriminator

オプティマイザー

352    generator_optimizer: torch.optim.Adam
353    discriminator_optimizer: torch.optim.Adam

学習料金スケジュール

356    generator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR
357    discriminator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR

データローダー

360    dataloader: DataLoader
361    valid_dataloader: DataLoader

テストセットからサンプルを生成して保存する

363    def sample_images(self, n: int):
365        batch = next(iter(self.valid_dataloader))
366        self.generator_xy.eval()
367        self.generator_yx.eval()
368        with torch.no_grad():
369            data_x, data_y = batch['x'].to(self.generator_xy.device), batch['y'].to(self.generator_yx.device)
370            gen_y = self.generator_xy(data_x)
371            gen_x = self.generator_yx(data_y)

X 軸に沿って画像を配置

374            data_x = make_grid(data_x, nrow=5, normalize=True)
375            data_y = make_grid(data_y, nrow=5, normalize=True)
376            gen_x = make_grid(gen_x, nrow=5, normalize=True)
377            gen_y = make_grid(gen_y, nrow=5, normalize=True)

Y 軸に沿って画像を配置

380            image_grid = torch.cat((data_x, gen_y, data_y, gen_x), 1)

サンプルを表示

383        plot_image(image_grid)

モデルとデータローダーの初期化

385    def initialize(self):
389        input_shape = (self.img_channels, self.img_height, self.img_width)

モデルを作成

392        self.generator_xy = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
393        self.generator_yx = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
394        self.discriminator_x = Discriminator(input_shape).to(self.device)
395        self.discriminator_y = Discriminator(input_shape).to(self.device)

オプティマイザーの作成

398        self.generator_optimizer = torch.optim.Adam(
399            itertools.chain(self.generator_xy.parameters(), self.generator_yx.parameters()),
400            lr=self.learning_rate, betas=self.adam_betas)
401        self.discriminator_optimizer = torch.optim.Adam(
402            itertools.chain(self.discriminator_x.parameters(), self.discriminator_y.parameters()),
403            lr=self.learning_rate, betas=self.adam_betas)

学習率スケジュールを作成します。学習率は、decay_start エポックまでは横ばいから始まり、トレーニングの終了時には直線的に低下します

408        decay_epochs = self.epochs - self.decay_start
409        self.generator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
410            self.generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
411        self.discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
412            self.discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)

画像変換

415        transforms_ = [
416            transforms.Resize(int(self.img_height * 1.12), InterpolationMode.BICUBIC),
417            transforms.RandomCrop((self.img_height, self.img_width)),
418            transforms.RandomHorizontalFlip(),
419            transforms.ToTensor(),
420            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
421        ]

トレーニングデータローダー

424        self.dataloader = DataLoader(
425            ImageDataset(self.dataset_name, transforms_, 'train'),
426            batch_size=self.batch_size,
427            shuffle=True,
428            num_workers=self.data_loader_workers,
429        )

検証データローダー

432        self.valid_dataloader = DataLoader(
433            ImageDataset(self.dataset_name, transforms_, "test"),
434            batch_size=5,
435            shuffle=True,
436            num_workers=self.data_loader_workers,
437        )

トレーニング

私たちは次のことを解決することを目指しています。

ここで、画像を変換元画像を変換元、画像が宇宙からのものかどうかをテストし、画像が宇宙からのものかどうかをテストし、

は元の GAN 論文で生成される敵対的損失です。

は周期的損失で、そこで似たような存在になることそしてそれに似ているように努めることです。基本的に、2つのジェネレーター(変換)を連続して適用すると、元の画像が返されるはずです。これがこの論文の主な貢献です。ジェネレーターをトレーニングして、元の画像と同様の他の分布の画像を生成します。この損失がないと、の配布による何かが生成される可能性があります。今度は、のディストリビューションから何かを生成する必要がありますが、それでものプロパティを持っているので次のようなものを再生成できます

アイデンティティの喪失です。これは、入力と出力の間の色構成を維持するようにマッピングを促すために使用されました。

解くにはディスクリミネーターとディスクリミネーターを勾配で上げる必要があります

これは、負の対数確率損失を基にしたものです

学習を安定させるために、負の対数確率目標を最小二乗損失(ディスクリミネーターの最小二乗誤差)に置き換えました。実際の画像には1を、生成された画像には0というラベルを付けました。だから勾配を降りたいのですが

ジェネレーターにも最小二乗法を使います。ジェネレータは勾配に沿って下降するはずですが

generator_xy generator_yx 私たちは目的と目的で使用します。discriminator_x discriminator_y 私たちは目的と目的で使用します。

439    def run(self):

生成されたサンプルを保存するためのリプレイバッファ

541        gen_x_buffer = ReplayBuffer()
542        gen_y_buffer = ReplayBuffer()

エポックをループスルーする

545        for epoch in monit.loop(self.epochs):

データセットのループ処理

547            for i, batch in monit.enum('Train', self.dataloader):

画像をデバイスに移動

549                data_x, data_y = batch['x'].to(self.device), batch['y'].to(self.device)

真のラベルは次と等しい

552                true_labels = torch.ones(data_x.size(0), *self.discriminator_x.output_shape,
553                                         device=self.device, requires_grad=False)

等しい偽ラベル

555                false_labels = torch.zeros(data_x.size(0), *self.discriminator_x.output_shape,
556                                           device=self.device, requires_grad=False)

発電機を訓練しなさい。これにより、生成された画像が返されます。

560                gen_x, gen_y = self.optimize_generators(data_x, data_y, true_labels)

訓練用ディスクリミネーター

563                self.optimize_discriminator(data_x, data_y,
564                                            gen_x_buffer.push_and_pop(gen_x), gen_y_buffer.push_and_pop(gen_y),
565                                            true_labels, false_labels)

トレーニング統計を保存してグローバルステップカウンタを増やす

568                tracker.save()
569                tracker.add_global_step(max(len(data_x), len(data_y)))

画像を一定間隔で保存

572                batches_done = epoch * len(self.dataloader) + i
573                if batches_done % self.sample_interval == 0:

画像のサンプリング時にモデルを保存する

575                    experiment.save_checkpoint()

サンプル画像

577                    self.sample_images(batches_done)

学習率の更新

580            self.generator_lr_scheduler.step()
581            self.discriminator_lr_scheduler.step()

ニューライン

583            tracker.new_line()

アイデンティティ損失、ゲイン損失、サイクル損失を考慮してジェネレータを最適化します。

585    def optimize_generators(self, data_x: torch.Tensor, data_y: torch.Tensor, true_labels: torch.Tensor):

トレーニングモードに変更

591        self.generator_xy.train()
592        self.generator_yx.train()

アイデンティティの喪失

597        loss_identity = (self.identity_loss(self.generator_yx(data_x), data_x) +
598                         self.identity_loss(self.generator_xy(data_y), data_y))

画像を生成し、

601        gen_y = self.generator_xy(data_x)
602        gen_x = self.generator_yx(data_y)

GAN ロス

607        loss_gan = (self.gan_loss(self.discriminator_y(gen_y), true_labels) +
608                    self.gan_loss(self.discriminator_x(gen_x), true_labels))

サイクルロス

615        loss_cycle = (self.cycle_loss(self.generator_yx(gen_y), data_x) +
616                      self.cycle_loss(self.generator_xy(gen_x), data_y))

総損失

619        loss_generator = (loss_gan +
620                          self.cyclic_loss_coefficient * loss_cycle +
621                          self.identity_loss_coefficient * loss_identity)

オプティマイザーで一歩踏み出そう

624        self.generator_optimizer.zero_grad()
625        loss_generator.backward()
626        self.generator_optimizer.step()

ログロス

629        tracker.add({'loss.generator': loss_generator,
630                     'loss.generator.cycle': loss_cycle,
631                     'loss.generator.gan': loss_gan,
632                     'loss.generator.identity': loss_identity})

生成された画像を返す

635        return gen_x, gen_y

ゲイン損失のあるディスクリミネーターを最適化します。

637    def optimize_discriminator(self, data_x: torch.Tensor, data_y: torch.Tensor,
638                               gen_x: torch.Tensor, gen_y: torch.Tensor,
639                               true_labels: torch.Tensor, false_labels: torch.Tensor):

GAN ロス

652        loss_discriminator = (self.gan_loss(self.discriminator_x(data_x), true_labels) +
653                              self.gan_loss(self.discriminator_x(gen_x), false_labels) +
654                              self.gan_loss(self.discriminator_y(data_y), true_labels) +
655                              self.gan_loss(self.discriminator_y(gen_y), false_labels))

オプティマイザーで一歩踏み出そう

658        self.discriminator_optimizer.zero_grad()
659        loss_discriminator.backward()
660        self.discriminator_optimizer.step()

ログロス

663        tracker.add({'loss.discriminator': loss_discriminator})

トレインサイクル GAN

666def train():

構成の作成

671    conf = Configs()

テストを作成

673    experiment.create(name='cycle_gan')

構成を計算します。conf.run 計算とそれに必要なその他すべての構成を行います。

676    experiment.configs(conf, {'dataset_name': 'summer2winter_yosemite'})
677    conf.initialize()

保存および読み込み用にモデルを登録します。get_modules nn.Modules in の辞書が表示されますconf 。モデルのカスタム辞書を指定することもできます。

682    experiment.add_pytorch_models(get_modules(conf))

実験を開始して見る

684    with experiment.start():

トレーニングを実行

686        conf.run()

matplotlib を使用してイメージをプロットする

689def plot_image(img: torch.Tensor):
693    from matplotlib import pyplot as plt

テンソルを CPU に移動

696    img = img.cpu()

正規化用の画像の最小値と最大値を取得

698    img_min, img_max = img.min(), img.max()

画像の値を 0... 1 に拡大/縮小

700    img = (img - img_min) / (img_max - img_min + 1e-5)

寸法の順序をHWCに変更する必要があります。

702    img = img.permute(1, 2, 0)

[イメージを表示]

704    plt.imshow(img)

軸はいらない

706    plt.axis('off')

ディスプレイ

708    plt.show()

トレーニング済みサイクル GAN の評価

711def evaluate():

トレーニングランからランUUIDを設定

716    trained_run_uuid = 'f73c1164184711eb9190b74249275441'

コンフィグオブジェクトの作成

718    conf = Configs()

実験を作成

720    experiment.create(name='cycle_gan_inference')

トレーニング用に設定されたハイパーパラメータをロード

722    conf_dict = experiment.load_configs(trained_run_uuid)

構成を計算します。'generator_xy', 'generator_yx' ジェネレータとその依存関係のみをロードするようにジェネレータを指定します。device やのようなコンフィグは、img_channels generator_xy generator_yx やで必要になるので計算されます

dataset_name 他のパラメータが必要な場合は、ここで指定してください。何も指定しない場合、データローダーを含むすべての構成が計算されます。設定とその依存関係の計算は、呼び出すときに行われます experiment.start

731    experiment.configs(conf, conf_dict)
732    conf.initialize()

保存および読み込み用にモデルを登録します。get_modules nn.Modules in の辞書が表示されますconf 。モデルのカスタム辞書を指定することもできます。

737    experiment.add_pytorch_models(get_modules(conf))

どのランからロードするかを指定します。呼び出すと実際に読み込みが行われます experiment.start

740    experiment.load(trained_run_uuid)

実験を始める

743    with experiment.start():

画像変換

745        transforms_ = [
746            transforms.ToTensor(),
747            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
748        ]

独自のデータをロードします。ここでテストセットを試してみます。ヨセミテの写真を試してみましたが、見栄えが最高です。dataset_name の呼び出しで計算対象として指定した場合は使用できます conf.dataset_name experiment.configs

754        dataset = ImageDataset(conf.dataset_name, transforms_, 'train')

データセットから画像を取得

756        x_image = dataset[10]['x']

画像を表示する

758        plot_image(x_image)

評価モード

761        conf.generator_xy.eval()
762        conf.generator_yx.eval()

グラデーションはいらない

765        with torch.no_grad():

バッチディメンションを追加し、使用するデバイスに移動します

767            data = x_image.unsqueeze(0).to(conf.device)
768            generated_y = conf.generator_xy(data)

生成された画像を表示します。

771        plot_image(generated_y[0].cpu())
772
773
774if __name__ == '__main__':
775    train()

評価 ()