これはジェネレーティブ・アドバーサリアル・ネットワークの実装です。
ジェネレーターはデータの分布に一致するサンプルを生成し、ディスクリミネーターはデータから得られる確率ではなく、データから得られる確率を返します。
バリュー機能を備えた2人用のミニマックスゲームで同時にトレーニングします。
はデータ全体の確率分布で、の確率分布はガウスノイズに設定されます。
このファイルは損失関数を定義します。これは、ジェネレーターとディスクリミネーターに2つの多層パーセプトロンを使ったMNISTの例です
。34import torch
35import torch.nn as nn
36import torch.utils.data
37import torch.utils.data
38
39from labml_helpers.module import Module
ディスクリミネーターは勾配の上を向いているはずですが
はミニバッチサイズで、ミニバッチ内のサンプルのインデックスに使用されます。からのサンプルであり、からのサンプルです。
42class DiscriminatorLogitsLoss(Module):
57 def __init__(self, smoothing: float = 0.2):
58 super().__init__()
PyTorchのバイナリクロスエントロピー損失を使います。つまり、ラベルはどこで予測はどこですか。マイナス記号に注意してください。for from と同じラベルと for from に等しいラベルを使用します。これらの合計で降順になると、上記の勾配で昇順になるのと同じになります
。BCEWithLogitsLoss
ソフトマックスとバイナリクロスエントロピー損失を組み合わせたものです。
69 self.loss_true = nn.BCEWithLogitsLoss()
70 self.loss_false = nn.BCEWithLogitsLoss()
ラベルスムージングを使用するのは、場合によってはうまくいくと思われるためです。
73 self.smoothing = smoothing
ラベルはバッファリングされて登録され、パーシスタンスはに設定されます。False
76 self.register_buffer('labels_true', _create_labels(256, 1.0 - smoothing, 1.0), False)
77 self.register_buffer('labels_false', _create_labels(256, 0.0, smoothing), False)
logits_true
logits_false
元のロジットと元のロジット
79 def forward(self, logits_true: torch.Tensor, logits_false: torch.Tensor):
84 if len(logits_true) > len(self.labels_true):
85 self.register_buffer("labels_true",
86 _create_labels(len(logits_true), 1.0 - self.smoothing, 1.0, logits_true.device), False)
87 if len(logits_false) > len(self.labels_false):
88 self.register_buffer("labels_false",
89 _create_labels(len(logits_false), 0.0, self.smoothing, logits_false.device), False)
90
91 return (self.loss_true(logits_true, self.labels_true[:len(logits_true)]),
92 self.loss_false(logits_false, self.labels_false[:len(logits_false)]))
95class GeneratorLogitsLoss(Module):
105 def __init__(self, smoothing: float = 0.2):
106 super().__init__()
107 self.loss_true = nn.BCEWithLogitsLoss()
108 self.smoothing = smoothing
for と等しいラベルを使います。この損失で降順を降順すると、上の勾配で降順になるのと同じになります。
112 self.register_buffer('fake_labels', _create_labels(256, 1.0 - smoothing, 1.0), False)
114 def forward(self, logits: torch.Tensor):
115 if len(logits) > len(self.fake_labels):
116 self.register_buffer("fake_labels",
117 _create_labels(len(logits), 1.0 - self.smoothing, 1.0, logits.device), False)
118
119 return self.loss_true(logits, self.fake_labels[:len(logits)])
なめらかなラベルを作成
122def _create_labels(n: int, r1: float, r2: float, device: torch.device = None):
126 return torch.empty(n, 1, requires_grad=False, device=device).uniform_(r1, r2)