これは、サイクルコンシステントな敵対的ネットワークを使用したペアリングされていない画像から画像への翻訳という論文のPyTorch実装/チュートリアルです。
エリック・リンダーノレン/ピトーチ・ガンからコードの一部を取り出しました。他のGANバリエーションもチェックしたい場合にとても良いリソースです。
サイクルGANは画像から画像への変換を行います。特定の分布の画像を別のクラスAとBの画像に変換するようにモデルを訓練します。特定の分布の画像は、特定のスタイルや性質の画像などである可能性があります。モデルにはAとBの画像をペアにする必要はありません。各クラスの画像のセットで十分です。これは、たとえば夏から冬に、絵のスタイルを写真に、馬をシマウマに変えるなど、画像スタイルの変更、照明の変更、パターンの変更などにとても効果的です。
Cycle GAN は 2 つのジェネレータモデルと 2 つのディスクリミネーターモデルをトレーニングします。一方のジェネレータはイメージを A から B に、もう 1 つのジェネレータは B から A に変換します。ディスクリミネータは、生成されたイメージが本物に見えるかどうかをテストします
。このファイルには、モデルコードとトレーニングコードが含まれています。Google Colabノートブックもあります
。35import itertools
36import random
37import zipfile
38from typing import Tuple
39
40import torch
41import torch.nn as nn
42import torchvision.transforms as transforms
43from PIL import Image
44from torch.utils.data import DataLoader, Dataset
45from torchvision.transforms import InterpolationMode
46from torchvision.utils import make_grid
47
48from labml import lab, tracker, experiment, monit
49from labml.configs import BaseConfigs
50from labml.utils.download import download_file
51from labml.utils.pytorch import get_modules
52from labml_helpers.device import DeviceConfigs
53from labml_helpers.module import Module
ジェネレータは残留ネットワークです。
56class GeneratorResNet(Module):
61 def __init__(self, input_channels: int, n_residual_blocks: int):
62 super().__init__()
この最初のブロックは畳み込みを実行し、画像を特徴マップにマッピングします。パディングがになっているため、出力フィーチャマップの高さと幅は同じです。エッジの画質が良くなるため、反射パディングが使われています
。inplace=True
in ReLU
はメモリを少し節約できます。
70 out_features = 64
71 layers = [
72 nn.Conv2d(input_channels, out_features, kernel_size=7, padding=3, padding_mode='reflect'),
73 nn.InstanceNorm2d(out_features),
74 nn.ReLU(inplace=True),
75 ]
76 in_features = out_features
ストライドが 2 の 2 つのコンボリューションでダウンサンプリングします。
80 for _ in range(2):
81 out_features *= 2
82 layers += [
83 nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
84 nn.InstanceNorm2d(out_features),
85 nn.ReLU(inplace=True),
86 ]
87 in_features = out_features
n_residual_blocks
これをやり遂げます。このモジュールは以下に定義されています。
91 for _ in range(n_residual_blocks):
92 layers += [ResidualBlock(out_features)]
次に、生成された特徴マップは、元の画像の高さと幅に一致するようにアップサンプリングされます。
96 for _ in range(2):
97 out_features //= 2
98 layers += [
99 nn.Upsample(scale_factor=2),
100 nn.Conv2d(in_features, out_features, kernel_size=3, stride=1, padding=1),
101 nn.InstanceNorm2d(out_features),
102 nn.ReLU(inplace=True),
103 ]
104 in_features = out_features
最後に、特徴マップを RGB 画像にマッピングします。
107 layers += [nn.Conv2d(out_features, input_channels, 7, padding=3, padding_mode='reflect'), nn.Tanh()]
レイヤーを含むシーケンシャルモジュールを作成
110 self.layers = nn.Sequential(*layers)
ウェイトを次のように初期化
113 self.apply(weights_init_normal)
115 def forward(self, x):
116 return self.layers(x)
これは、畳み込み層が 2 つある残差ブロックです。
119class ResidualBlock(Module):
124 def __init__(self, in_features: int):
125 super().__init__()
126 self.block = nn.Sequential(
127 nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
128 nn.InstanceNorm2d(in_features),
129 nn.ReLU(inplace=True),
130 nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
131 nn.InstanceNorm2d(in_features),
132 nn.ReLU(inplace=True),
133 )
135 def forward(self, x: torch.Tensor):
136 return x + self.block(x)
これがディスクリミネーターです。
139class Discriminator(Module):
144 def __init__(self, input_shape: Tuple[int, int, int]):
145 super().__init__()
146 channels, height, width = input_shape
ディスクリミネーターの出力は、画像の各領域が実在するか生成されたものかを問わず、確率のマップでもあります。
150 self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)
151
152 self.layers = nn.Sequential(
これらのブロックはそれぞれ、高さと幅を 2 分の 1 に縮小します。
154 DiscriminatorBlock(channels, 64, normalize=False),
155 DiscriminatorBlock(64, 128),
156 DiscriminatorBlock(128, 256),
157 DiscriminatorBlock(256, 512),
出力の高さと幅をカーネルと同じに保つため、上部と左側にゼロパッドがあります
160 nn.ZeroPad2d((1, 0, 1, 0)),
161 nn.Conv2d(512, 1, kernel_size=4, padding=1)
162 )
ウェイトを次のように初期化
165 self.apply(weights_init_normal)
167 def forward(self, img):
168 return self.layers(img)
171class DiscriminatorBlock(Module):
179 def __init__(self, in_filters: int, out_filters: int, normalize: bool = True):
180 super().__init__()
181 layers = [nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=2, padding=1)]
182 if normalize:
183 layers.append(nn.InstanceNorm2d(out_filters))
184 layers.append(nn.LeakyReLU(0.2, inplace=True))
185 self.layers = nn.Sequential(*layers)
187 def forward(self, x: torch.Tensor):
188 return self.layers(x)
畳み込み層の重みを次のように初期化
191def weights_init_normal(m):
195 classname = m.__class__.__name__
196 if classname.find("Conv") != -1:
197 torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
画像を読み込み、グレースケールの場合は RGB に変更します。
200def load_image(path: str):
204 image = Image.open(path)
205 if image.mode != 'RGB':
206 image = Image.new("RGB", image.size).paste(image)
207
208 return image
211class ImageDataset(Dataset):
216 @staticmethod
217 def download(dataset_name: str):
URL
222 url = f'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/{dataset_name}.zip'
ダウンロードフォルダー
224 root = lab.get_data_path() / 'cycle_gan'
225 if not root.exists():
226 root.mkdir(parents=True)
ダウンロード先
228 archive = root / f'{dataset_name}.zip'
ダウンロードファイル (通常は 100 MB まで)
230 download_file(url, archive)
アーカイブを抽出
232 with zipfile.ZipFile(archive, 'r') as f:
233 f.extractall(root)
235 def __init__(self, dataset_name: str, transforms_, mode: str):
データセットパス
244 root = lab.get_data_path() / 'cycle_gan' / dataset_name
見つからない場合はダウンロード
246 if not root.exists():
247 self.download(dataset_name)
画像変換
250 self.transform = transforms.Compose(transforms_)
画像パスを取得
253 path_a = root / f'{mode}A'
254 path_b = root / f'{mode}B'
255 self.files_a = sorted(str(f) for f in path_a.iterdir())
256 self.files_b = sorted(str(f) for f in path_b.iterdir())
258 def __getitem__(self, index):
2 つの画像を返します。これらのペアはまとめてバッチ処理され、トレーニング中のペアとは異なります。だから、いつも同じペアを与え続けても大丈夫です。
262 return {"x": self.transform(load_image(self.files_a[index % len(self.files_a)])),
263 "y": self.transform(load_image(self.files_b[index % len(self.files_b)]))}
265 def __len__(self):
データセットの画像数
267 return max(len(self.files_a), len(self.files_b))
リプレイバッファはディスクリミネーターのトレーニングに使用されます。生成された画像は再生バッファに追加され、そこからサンプリングされます
。再生バッファは、新しく追加された画像を、の確率で返します。それ以外の場合は、古い生成イメージを送信し、古いイメージを新しく生成されたイメージに置き換えます
。これはモデルの振動を減らすためです。
270class ReplayBuffer:
284 def __init__(self, max_size: int = 50):
285 self.max_size = max_size
286 self.data = []
画像の追加/取得
288 def push_and_pop(self, data: torch.Tensor):
290 data = data.detach()
291 res = []
292 for element in data:
293 if len(self.data) < self.max_size:
294 self.data.append(element)
295 res.append(element)
296 else:
297 if random.uniform(0, 1) > 0.5:
298 i = random.randint(0, self.max_size - 1)
299 res.append(self.data[i].clone())
300 self.data[i] = element
301 else:
302 res.append(element)
303 return torch.stack(res)
306class Configs(BaseConfigs):
DeviceConfigs
利用可能な場合は GPU を選択します
310 device: torch.device = DeviceConfigs()
ハイパーパラメータ
313 epochs: int = 200
314 dataset_name: str = 'monet2photo'
315 batch_size: int = 1
316
317 data_loader_workers = 8
318
319 learning_rate = 0.0002
320 adam_betas = (0.5, 0.999)
321 decay_start = 100
この論文では、負の対数確率よりも安定性が高いことがわかっているため、最小二乗損失の代わりに最小二乗損失を使用することを提案しています。
325 gan_loss = torch.nn.MSELoss()
L1ロスはサイクルロスとアイデンティティロスに使用されます
328 cycle_loss = torch.nn.L1Loss()
329 identity_loss = torch.nn.L1Loss()
画像サイズ
332 img_height = 256
333 img_width = 256
334 img_channels = 3
ジェネレータ内の残留ブロック数
337 n_residual_blocks = 9
損失係数
340 cyclic_loss_coefficient = 10.0
341 identity_loss_coefficient = 5.
342
343 sample_interval = 500
モデル
346 generator_xy: GeneratorResNet
347 generator_yx: GeneratorResNet
348 discriminator_x: Discriminator
349 discriminator_y: Discriminator
オプティマイザー
352 generator_optimizer: torch.optim.Adam
353 discriminator_optimizer: torch.optim.Adam
学習料金スケジュール
356 generator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR
357 discriminator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR
データローダー
360 dataloader: DataLoader
361 valid_dataloader: DataLoader
テストセットからサンプルを生成して保存する
363 def sample_images(self, n: int):
365 batch = next(iter(self.valid_dataloader))
366 self.generator_xy.eval()
367 self.generator_yx.eval()
368 with torch.no_grad():
369 data_x, data_y = batch['x'].to(self.generator_xy.device), batch['y'].to(self.generator_yx.device)
370 gen_y = self.generator_xy(data_x)
371 gen_x = self.generator_yx(data_y)
X 軸に沿って画像を配置
374 data_x = make_grid(data_x, nrow=5, normalize=True)
375 data_y = make_grid(data_y, nrow=5, normalize=True)
376 gen_x = make_grid(gen_x, nrow=5, normalize=True)
377 gen_y = make_grid(gen_y, nrow=5, normalize=True)
Y 軸に沿って画像を配置
380 image_grid = torch.cat((data_x, gen_y, data_y, gen_x), 1)
サンプルを表示
383 plot_image(image_grid)
385 def initialize(self):
389 input_shape = (self.img_channels, self.img_height, self.img_width)
モデルを作成
392 self.generator_xy = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
393 self.generator_yx = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
394 self.discriminator_x = Discriminator(input_shape).to(self.device)
395 self.discriminator_y = Discriminator(input_shape).to(self.device)
オプティマイザーの作成
398 self.generator_optimizer = torch.optim.Adam(
399 itertools.chain(self.generator_xy.parameters(), self.generator_yx.parameters()),
400 lr=self.learning_rate, betas=self.adam_betas)
401 self.discriminator_optimizer = torch.optim.Adam(
402 itertools.chain(self.discriminator_x.parameters(), self.discriminator_y.parameters()),
403 lr=self.learning_rate, betas=self.adam_betas)
学習率スケジュールを作成します。学習率は、decay_start
エポックまでは横ばいから始まり、トレーニングの終了時には直線的に低下します
408 decay_epochs = self.epochs - self.decay_start
409 self.generator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
410 self.generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
411 self.discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
412 self.discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
画像変換
415 transforms_ = [
416 transforms.Resize(int(self.img_height * 1.12), InterpolationMode.BICUBIC),
417 transforms.RandomCrop((self.img_height, self.img_width)),
418 transforms.RandomHorizontalFlip(),
419 transforms.ToTensor(),
420 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
421 ]
トレーニングデータローダー
424 self.dataloader = DataLoader(
425 ImageDataset(self.dataset_name, transforms_, 'train'),
426 batch_size=self.batch_size,
427 shuffle=True,
428 num_workers=self.data_loader_workers,
429 )
検証データローダー
432 self.valid_dataloader = DataLoader(
433 ImageDataset(self.dataset_name, transforms_, "test"),
434 batch_size=5,
435 shuffle=True,
436 num_workers=self.data_loader_workers,
437 )
私たちは次のことを解決することを目指しています。
ここで、画像を変換元、画像を変換元、画像が宇宙からのものかどうかをテストし、画像が宇宙からのものかどうかをテストし、
は元の GAN 論文で生成される敵対的損失です。
は周期的損失で、そこで似たような存在になること、そしてそれに似ているように努めることです。基本的に、2つのジェネレーター(変換)を連続して適用すると、元の画像が返されるはずです。これがこの論文の主な貢献です。ジェネレーターをトレーニングして、元の画像と同様の他の分布の画像を生成します。この損失がないと、の配布による何かが生成される可能性があります。今度は、のディストリビューションから何かを生成する必要がありますが、それでものプロパティを持っているので、次のようなものを再生成できます
。アイデンティティの喪失です。これは、入力と出力の間の色構成を維持するようにマッピングを促すために使用されました。
解くには、ディスクリミネーターとディスクリミネーターを勾配で上げる必要があります。
これは、負の対数確率損失を基にしたものです。
学習を安定させるために、負の対数確率目標を最小二乗損失(ディスクリミネーターの最小二乗誤差)に置き換えました。実際の画像には1を、生成された画像には0というラベルを付けました。だから勾配を降りたいのですが
ジェネレーターにも最小二乗法を使います。ジェネレータは勾配に沿って下降するはずですが、
generator_xy
generator_yx
私たちは目的と目的で使用します。discriminator_x
discriminator_y
私たちは目的と目的で使用します。
439 def run(self):
生成されたサンプルを保存するためのリプレイバッファ
541 gen_x_buffer = ReplayBuffer()
542 gen_y_buffer = ReplayBuffer()
エポックをループスルーする
545 for epoch in monit.loop(self.epochs):
データセットのループ処理
547 for i, batch in monit.enum('Train', self.dataloader):
画像をデバイスに移動
549 data_x, data_y = batch['x'].to(self.device), batch['y'].to(self.device)
真のラベルは次と等しい
552 true_labels = torch.ones(data_x.size(0), *self.discriminator_x.output_shape,
553 device=self.device, requires_grad=False)
等しい偽ラベル
555 false_labels = torch.zeros(data_x.size(0), *self.discriminator_x.output_shape,
556 device=self.device, requires_grad=False)
発電機を訓練しなさい。これにより、生成された画像が返されます。
560 gen_x, gen_y = self.optimize_generators(data_x, data_y, true_labels)
訓練用ディスクリミネーター
563 self.optimize_discriminator(data_x, data_y,
564 gen_x_buffer.push_and_pop(gen_x), gen_y_buffer.push_and_pop(gen_y),
565 true_labels, false_labels)
トレーニング統計を保存してグローバルステップカウンタを増やす
568 tracker.save()
569 tracker.add_global_step(max(len(data_x), len(data_y)))
画像を一定間隔で保存
572 batches_done = epoch * len(self.dataloader) + i
573 if batches_done % self.sample_interval == 0:
画像のサンプリング時にモデルを保存する
575 experiment.save_checkpoint()
サンプル画像
577 self.sample_images(batches_done)
学習率の更新
580 self.generator_lr_scheduler.step()
581 self.discriminator_lr_scheduler.step()
ニューライン
583 tracker.new_line()
585 def optimize_generators(self, data_x: torch.Tensor, data_y: torch.Tensor, true_labels: torch.Tensor):
トレーニングモードに変更
591 self.generator_xy.train()
592 self.generator_yx.train()
アイデンティティの喪失
597 loss_identity = (self.identity_loss(self.generator_yx(data_x), data_x) +
598 self.identity_loss(self.generator_xy(data_y), data_y))
画像を生成し、
601 gen_y = self.generator_xy(data_x)
602 gen_x = self.generator_yx(data_y)
GAN ロス
607 loss_gan = (self.gan_loss(self.discriminator_y(gen_y), true_labels) +
608 self.gan_loss(self.discriminator_x(gen_x), true_labels))
サイクルロス
615 loss_cycle = (self.cycle_loss(self.generator_yx(gen_y), data_x) +
616 self.cycle_loss(self.generator_xy(gen_x), data_y))
総損失
619 loss_generator = (loss_gan +
620 self.cyclic_loss_coefficient * loss_cycle +
621 self.identity_loss_coefficient * loss_identity)
オプティマイザーで一歩踏み出そう
624 self.generator_optimizer.zero_grad()
625 loss_generator.backward()
626 self.generator_optimizer.step()
ログロス
629 tracker.add({'loss.generator': loss_generator,
630 'loss.generator.cycle': loss_cycle,
631 'loss.generator.gan': loss_gan,
632 'loss.generator.identity': loss_identity})
生成された画像を返す
635 return gen_x, gen_y
637 def optimize_discriminator(self, data_x: torch.Tensor, data_y: torch.Tensor,
638 gen_x: torch.Tensor, gen_y: torch.Tensor,
639 true_labels: torch.Tensor, false_labels: torch.Tensor):
652 loss_discriminator = (self.gan_loss(self.discriminator_x(data_x), true_labels) +
653 self.gan_loss(self.discriminator_x(gen_x), false_labels) +
654 self.gan_loss(self.discriminator_y(data_y), true_labels) +
655 self.gan_loss(self.discriminator_y(gen_y), false_labels))
オプティマイザーで一歩踏み出そう
658 self.discriminator_optimizer.zero_grad()
659 loss_discriminator.backward()
660 self.discriminator_optimizer.step()
ログロス
663 tracker.add({'loss.discriminator': loss_discriminator})
666def train():
構成の作成
671 conf = Configs()
テストを作成
673 experiment.create(name='cycle_gan')
構成を計算します。conf.run
計算とそれに必要なその他すべての構成を行います。
676 experiment.configs(conf, {'dataset_name': 'summer2winter_yosemite'})
677 conf.initialize()
保存および読み込み用にモデルを登録します。get_modules
nn.Modules
in の辞書が表示されますconf
。モデルのカスタム辞書を指定することもできます。
682 experiment.add_pytorch_models(get_modules(conf))
実験を開始して見る
684 with experiment.start():
トレーニングを実行
686 conf.run()
689def plot_image(img: torch.Tensor):
693 from matplotlib import pyplot as plt
テンソルを CPU に移動
696 img = img.cpu()
正規化用の画像の最小値と最大値を取得
698 img_min, img_max = img.min(), img.max()
寸法の順序をHWCに変更する必要があります。
702 img = img.permute(1, 2, 0)
[イメージを表示]
704 plt.imshow(img)
軸はいらない
706 plt.axis('off')
ディスプレイ
708 plt.show()
711def evaluate():
トレーニングランからランUUIDを設定
716 trained_run_uuid = 'f73c1164184711eb9190b74249275441'
コンフィグオブジェクトの作成
718 conf = Configs()
実験を作成
720 experiment.create(name='cycle_gan_inference')
トレーニング用に設定されたハイパーパラメータをロード
722 conf_dict = experiment.load_configs(trained_run_uuid)
構成を計算します。'generator_xy', 'generator_yx'
ジェネレータとその依存関係のみをロードするようにジェネレータを指定します。device
やのようなコンフィグは、img_channels
generator_xy
generator_yx
やで必要になるので計算されます
dataset_name
他のパラメータが必要な場合は、ここで指定してください。何も指定しない場合、データローダーを含むすべての構成が計算されます。設定とその依存関係の計算は、呼び出すときに行われます experiment.start
731 experiment.configs(conf, conf_dict)
732 conf.initialize()
保存および読み込み用にモデルを登録します。get_modules
nn.Modules
in の辞書が表示されますconf
。モデルのカスタム辞書を指定することもできます。
737 experiment.add_pytorch_models(get_modules(conf))
どのランからロードするかを指定します。呼び出すと実際に読み込みが行われます experiment.start
740 experiment.load(trained_run_uuid)
実験を始める
743 with experiment.start():
画像変換
745 transforms_ = [
746 transforms.ToTensor(),
747 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
748 ]
独自のデータをロードします。ここでテストセットを試してみます。ヨセミテの写真を試してみましたが、見栄えが最高です。dataset_name
の呼び出しで計算対象として指定した場合は使用できます conf.dataset_name
experiment.configs
754 dataset = ImageDataset(conf.dataset_name, transforms_, 'train')
データセットから画像を取得
756 x_image = dataset[10]['x']
画像を表示する
758 plot_image(x_image)
評価モード
761 conf.generator_xy.eval()
762 conf.generator_yx.eval()
グラデーションはいらない
765 with torch.no_grad():
バッチディメンションを追加し、使用するデバイスに移動します
767 data = x_image.unsqueeze(0).to(conf.device)
768 generated_y = conf.generator_xy(data)
生成された画像を表示します。
771 plot_image(generated_y[0].cpu())
772
773
774if __name__ == '__main__':
775 train()
評価 ()