这是生成对抗网络的实现。
生成@@器生成与数据分布相匹配的样本,而鉴别器则给出来自数据而不是来自数据的概率。
我们在具有值功能的双人最小最大游戏中同时进行训练。
是数据的概率分布,而概率分布则设置为高斯噪声。
这个文件定义了损失函数。这是一个 MNIST 示例,其中包含两个用于生成器和鉴别器的多层感知器。
34import torch
35import torch.nn as nn
36import torch.utils.data
37import torch.utils.data
38
39from labml_helpers.module import Module
42class DiscriminatorLogitsLoss(Module):
57 def __init__(self, smoothing: float = 0.2):
58 super().__init__()
我们使用 PyTorch 二进制交叉熵损失,也就是说,标签在哪里,预测在哪里。注意负号。我们使用等于 for fro m 的标签和等于 f or from 的标签然后按这些总和降序与上面的梯度上升相同。
BCEWithLogitsLoss
结合了 softmax 和二进制交叉熵损失。
69 self.loss_true = nn.BCEWithLogitsLoss()
70 self.loss_false = nn.BCEWithLogitsLoss()
我们使用标签平滑,因为它在某些情况下效果更好
73 self.smoothing = smoothing
标签注册为缓冲区,并将持久性设置为False
。
76 self.register_buffer('labels_true', _create_labels(256, 1.0 - smoothing, 1.0), False)
77 self.register_buffer('labels_false', _create_labels(256, 0.0, smoothing), False)
logits_true
是 logits 来自,logits_false
logits 来自
79 def forward(self, logits_true: torch.Tensor, logits_false: torch.Tensor):
84 if len(logits_true) > len(self.labels_true):
85 self.register_buffer("labels_true",
86 _create_labels(len(logits_true), 1.0 - self.smoothing, 1.0, logits_true.device), False)
87 if len(logits_false) > len(self.labels_false):
88 self.register_buffer("labels_false",
89 _create_labels(len(logits_false), 0.0, self.smoothing, logits_false.device), False)
90
91 return (self.loss_true(logits_true, self.labels_true[:len(logits_true)]),
92 self.loss_false(logits_false, self.labels_false[:len(logits_false)]))
95class GeneratorLogitsLoss(Module):
105 def __init__(self, smoothing: float = 0.2):
106 super().__init__()
107 self.loss_true = nn.BCEWithLogitsLoss()
108 self.smoothing = smoothing
我们使用等于 f or fro m 的标签,然后在此损失上降序与上面梯度上的降序相同。
112 self.register_buffer('fake_labels', _create_labels(256, 1.0 - smoothing, 1.0), False)
114 def forward(self, logits: torch.Tensor):
115 if len(logits) > len(self.fake_labels):
116 self.register_buffer("fake_labels",
117 _create_labels(len(logits), 1.0 - self.smoothing, 1.0, logits.device), False)
118
119 return self.loss_true(logits, self.fake_labels[:len(logits)])
创建经过平滑处理的标注
122def _create_labels(n: int, r1: float, r2: float, device: torch.device = None):
126 return torch.empty(n, 1, requires_grad=False, device=device).uniform_(r1, r2)