生成对抗网络 (GAN)

这是生成对抗网络的实现。

生成@@

生成与数据分布相匹配的样本,而鉴别器则给出来自数据而不是来自数据的概率

我们在具有值功能的双人最小最大游戏中同时进行训练

是数据的概率分布,而概率分布则设置为高斯噪声。

这个文件定义了损失函数。这是一个 MNIST 示例,其中包含两个用于生成器和鉴别器的多层感知器。

34import torch
35import torch.nn as nn
36import torch.utils.data
37import torch.utils.data
38
39from labml_helpers.module import Module

鉴别器丢失

鉴别器应该在梯度上

是微型批次大小,用于索引微型批次中的样本。是来自的样本也是来自的样本

42class DiscriminatorLogitsLoss(Module):
57    def __init__(self, smoothing: float = 0.2):
58        super().__init__()

我们使用 PyTorch 二进制交叉熵损失,也就是说,标签在哪里,预测在哪里。注意负号。我们使用等于 for fro m 的标签和等于 f or from 的标签然后按这些总和降序与上面的梯度上升相同。

BCEWithLogitsLoss 结合了 softmax 和二进制交叉熵损失。

69        self.loss_true = nn.BCEWithLogitsLoss()
70        self.loss_false = nn.BCEWithLogitsLoss()

我们使用标签平滑,因为它在某些情况下效果更好

73        self.smoothing = smoothing

标签注册为缓冲区,并将持久性设置为False

76        self.register_buffer('labels_true', _create_labels(256, 1.0 - smoothing, 1.0), False)
77        self.register_buffer('labels_false', _create_labels(256, 0.0, smoothing), False)

logits_true 是 logits 来自logits_false logits 来自

79    def forward(self, logits_true: torch.Tensor, logits_false: torch.Tensor):
84        if len(logits_true) > len(self.labels_true):
85            self.register_buffer("labels_true",
86                                 _create_labels(len(logits_true), 1.0 - self.smoothing, 1.0, logits_true.device), False)
87        if len(logits_false) > len(self.labels_false):
88            self.register_buffer("labels_false",
89                                 _create_labels(len(logits_false), 0.0, self.smoothing, logits_false.device), False)
90
91        return (self.loss_true(logits_true, self.labels_true[:len(logits_true)]),
92                self.loss_false(logits_false, self.labels_false[:len(logits_false)]))

发电机损失

发电机应该下降到梯度上,

95class GeneratorLogitsLoss(Module):
105    def __init__(self, smoothing: float = 0.2):
106        super().__init__()
107        self.loss_true = nn.BCEWithLogitsLoss()
108        self.smoothing = smoothing

我们使用等于 f or fro m 的标签,然后在此损失上降序与上面梯度上的降序相同。

112        self.register_buffer('fake_labels', _create_labels(256, 1.0 - smoothing, 1.0), False)
114    def forward(self, logits: torch.Tensor):
115        if len(logits) > len(self.fake_labels):
116            self.register_buffer("fake_labels",
117                                 _create_labels(len(logits), 1.0 - self.smoothing, 1.0, logits.device), False)
118
119        return self.loss_true(logits, self.fake_labels[:len(logits)])

创建经过平滑处理的标注

122def _create_labels(n: int, r1: float, r2: float, device: torch.device = None):
126    return torch.empty(n, 1, requires_grad=False, device=device).uniform_(r1, r2)