ジェネレーティブ・アドバーサリアル・ネットワーク (GAN)

これはジェネレーティブ・アドバーサリアル・ネットワークの実装です

ジェネレーターはデータの分布に一致するサンプルを生成し、ディスクリミネーターはデータから得られる確率ではなく、データから得られる確率を返します。

バリュー機能を備えた2人用のミニマックスゲームで同時にトレーニングします。

はデータ全体の確率分布での確率分布はガウスノイズに設定されます。

このファイルは損失関数を定義します。これは、ジェネレーターとディスクリミネーターに2つの多層パーセプトロンを使ったMNISTの例です

34import torch
35import torch.nn as nn
36import torch.utils.data
37import torch.utils.data
38
39from labml_helpers.module import Module

ディスクリミネーターロス

ディスクリミネーターは勾配の上を向いているはずですが

はミニバッチサイズで、ミニバッチ内のサンプルのインデックスに使用されます。からのサンプルであり、からのサンプルです。

42class DiscriminatorLogitsLoss(Module):
57    def __init__(self, smoothing: float = 0.2):
58        super().__init__()

PyTorchのバイナリクロスエントロピー損失を使います。つまりラベルはどこで予測はどこですか。マイナス記号に注意してください。for from と同じラベルと for from に等しいラベルを使用します。これらの合計で降順になると、上記の勾配で昇順になるのと同じになります

BCEWithLogitsLoss ソフトマックスとバイナリクロスエントロピー損失を組み合わせたものです。

69        self.loss_true = nn.BCEWithLogitsLoss()
70        self.loss_false = nn.BCEWithLogitsLoss()

ラベルスムージングを使用するのは、場合によってはうまくいくと思われるためです。

73        self.smoothing = smoothing

ラベルはバッファリングされて登録され、パーシスタンスはに設定されます。False

76        self.register_buffer('labels_true', _create_labels(256, 1.0 - smoothing, 1.0), False)
77        self.register_buffer('labels_false', _create_labels(256, 0.0, smoothing), False)

logits_true logits_false 元のロジットと元のロジット

79    def forward(self, logits_true: torch.Tensor, logits_false: torch.Tensor):
84        if len(logits_true) > len(self.labels_true):
85            self.register_buffer("labels_true",
86                                 _create_labels(len(logits_true), 1.0 - self.smoothing, 1.0, logits_true.device), False)
87        if len(logits_false) > len(self.labels_false):
88            self.register_buffer("labels_false",
89                                 _create_labels(len(logits_false), 0.0, self.smoothing, logits_false.device), False)
90
91        return (self.loss_true(logits_true, self.labels_true[:len(logits_true)]),
92                self.loss_false(logits_false, self.labels_false[:len(logits_false)]))

発電機損失

ジェネレータは勾配に沿って下降するはずですが

95class GeneratorLogitsLoss(Module):
105    def __init__(self, smoothing: float = 0.2):
106        super().__init__()
107        self.loss_true = nn.BCEWithLogitsLoss()
108        self.smoothing = smoothing

for と等しいラベルを使います。この損失で降順を降順すると、上の勾配で降順になるのと同じになります。

112        self.register_buffer('fake_labels', _create_labels(256, 1.0 - smoothing, 1.0), False)
114    def forward(self, logits: torch.Tensor):
115        if len(logits) > len(self.fake_labels):
116            self.register_buffer("fake_labels",
117                                 _create_labels(len(logits), 1.0 - self.smoothing, 1.0, logits.device), False)
118
119        return self.loss_true(logits, self.fake_labels[:len(logits)])

なめらかなラベルを作成

122def _create_labels(n: int, r1: float, r2: float, device: torch.device = None):
126    return torch.empty(n, 1, requires_grad=False, device=device).uniform_(r1, r2)