StyleGAN 2

This is a PyTorch implementation of the paper Analyzing and Improving the Image Quality of StyleGAN which introduces StyleGAN 2. StyleGAN 2 is an improvement over StyleGAN from the paper A Style-Based Generator Architecture for Generative Adversarial Networks. And StyleGAN is based on Progressive GAN from the paper Progressive Growing of GANs for Improved Quality, Stability, and Variation. All three papers are from the same authors from NVIDIA AI.

Our implementation is a minimalistic StyleGAN 2 model training code. Only single GPU training is supported to keep the implementation simple. We managed to shrink it to keep it at less than 500 lines of code, including the training loop.

🏃 Here's the training code: experiment.py .

Generated Images

These are images generated after training for about 80K steps.

We'll first introduce the three papers at a high level.

Generative Adversarial Networks

Generative adversarial networks have two components; the generator and the discriminator. The generator network takes a random latent vector () and tries to generate a realistic image. The discriminator network tries to differentiate the real images from generated images. When we train the two networks together the generator starts generating images indistinguishable from real images.

Progressive GAN

Progressive GAN generates high-resolution images () of size. It does so by progressively increasing the image size. First, it trains a network that produces a image, then , then an image, and so on up to the desired image resolution.

At each resolution, the generator network produces an image in latent space which is converted into RGB, with a convolution. When we progress from a lower resolution to a higher resolution (say from to ) we scale the latent image by and add a new block (two convolution layers) and a new layer to get RGB. The transition is done smoothly by adding a residual connection to the scaled RGB image. The weight of this residual connection is slowly reduced, to let the new block take over.

The discriminator is a mirror image of the generator network. The progressive growth of the discriminator is done similarly.

progressive_gan.svg

and denote feature map resolution scaling and scaling. , , ... denote feature map resolution at the generator or discriminator block. Each discriminator and generator block consists of 2 convolution layers with leaky ReLU activations.

They use minibatch standard deviation to increase variation and equalized learning rate which we discussed below in the implementation. They also use pixel-wise normalization where at each pixel the feature vector is normalized. They apply this to all the convolution layer outputs (except RGB).

StyleGAN

StyleGAN improves the generator of Progressive GAN keeping the discriminator architecture the same.

Mapping Network

It maps the random latent vector () into a different latent space (), with an 8-layer neural network. This gives an intermediate latent space where the factors of variations are more linear (disentangled).

AdaIN

Then is transformed into two vectors (styles) per layer, , and used for scaling and shifting (biasing) in each layer with operator (normalize and scale):

Style Mixing

To prevent the generator from assuming adjacent styles are correlated, they randomly use different styles for different blocks. That is, they sample two latent vectors and corresponding and use based styles for some blocks and based styles for some blacks randomly.

Stochastic Variation

Noise is made available to each block which helps the generator create more realistic images. Noise is scaled per channel by a learned weight.

Bilinear Up and Down Sampling

All the up and down-sampling operations are accompanied by bilinear smoothing.

style_gan.svg

denotes a linear layer. denotes a broadcast and scaling operation (noise is a single channel). StyleGAN also uses progressive growing like Progressive GAN.

StyleGAN 2

StyleGAN 2 changes both the generator and the discriminator of StyleGAN.

Weight Modulation and Demodulation

They remove the operator and replace it with the weight modulation and demodulation step. This is supposed to improve what they call droplet artifacts that are present in generated images, which are caused by the normalization in operator. Style vector per layer is calculated from as .

Then the convolution weights are modulated as follows. ( here on refers to weights not intermediate latent space, we are sticking to the same notation as the paper.)

Then it's demodulated by normalizing, where is the input channel, is the output channel, and is the kernel index.

Path Length Regularization

Path length regularization encourages a fixed-size step in to result in a non-zero, fixed-magnitude change in the generated image.

No Progressive Growing

StyleGAN2 uses residual connections (with down-sampling) in the discriminator and skip connections in the generator with up-sampling (the RGB outputs from each layer are added - no residual connections in feature maps). They show that with experiments that the contribution of low-resolution layers is higher at beginning of the training and then high-resolution layers take over.

148import math
149from typing import Tuple, Optional, List
150
151import numpy as np
152import torch
153import torch.nn.functional as F
154import torch.utils.data
155from torch import nn

Mapping Network

Mapping Network

This is an MLP with 8 linear layers. The mapping network maps the latent vector to an intermediate latent space . space will be disentangled from the image space where the factors of variation become more linear.

158class MappingNetwork(nn.Module):
  • features is the number of features in and
  • n_layers is the number of layers in the mapping network.
173    def __init__(self, features: int, n_layers: int):
178        super().__init__()

Create the MLP

181        layers = []
182        for i in range(n_layers):
184            layers.append(EqualizedLinear(features, features))

Leaky Relu

186            layers.append(nn.LeakyReLU(negative_slope=0.2, inplace=True))
187
188        self.net = nn.Sequential(*layers)
190    def forward(self, z: torch.Tensor):

Normalize

192        z = F.normalize(z, dim=1)

Map to

194        return self.net(z)

StyleGAN2 Generator

Generator

denotes a linear layer. denotes a broadcast and scaling operation (noise is a single channel). toRGB also has a style modulation which is not shown in the diagram to keep it simple.

The generator starts with a learned constant. Then it has a series of blocks. The feature map resolution is doubled at each block Each block outputs an RGB image and they are scaled up and summed to get the final RGB image.

197class Generator(nn.Module):
  • log_resolution is the of image resolution
  • d_latent is the dimensionality of
  • n_features number of features in the convolution layer at the highest resolution (final block)
  • max_features maximum number of features in any generator block
214    def __init__(self, log_resolution: int, d_latent: int, n_features: int = 32, max_features: int = 512):
221        super().__init__()

Calculate the number of features for each block

Something like [512, 512, 256, 128, 64, 32]

226        features = [min(max_features, n_features * (2 ** i)) for i in range(log_resolution - 2, -1, -1)]

Number of generator blocks

228        self.n_blocks = len(features)

Trainable constant

231        self.initial_constant = nn.Parameter(torch.randn((1, features[0], 4, 4)))

First style block for resolution and layer to get RGB

234        self.style_block = StyleBlock(d_latent, features[0], features[0])
235        self.to_rgb = ToRGB(d_latent, features[0])

Generator blocks

238        blocks = [GeneratorBlock(d_latent, features[i - 1], features[i]) for i in range(1, self.n_blocks)]
239        self.blocks = nn.ModuleList(blocks)

up sampling layer. The feature space is up sampled at each block

243        self.up_sample = UpSample()
  • w is . In order to mix-styles (use different for different layers), we provide a separate for each generator block. It has shape [n_blocks, batch_size, d_latent] .
  • input_noise is the noise for each block. It's a list of pairs of noise sensors because each block (except the initial) has two noise inputs after each convolution layer (see the diagram).
245    def forward(self, w: torch.Tensor, input_noise: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]]):

Get batch size

255        batch_size = w.shape[1]

Expand the learned constant to match batch size

258        x = self.initial_constant.expand(batch_size, -1, -1, -1)

The first style block

261        x = self.style_block(x, w[0], input_noise[0][1])

Get first rgb image

263        rgb = self.to_rgb(x, w[0])

Evaluate rest of the blocks

266        for i in range(1, self.n_blocks):

Up sample the feature map

268            x = self.up_sample(x)

Run it through the generator block

270            x, rgb_new = self.blocks[i - 1](x, w[i], input_noise[i])

Up sample the RGB image and add to the rgb from the block

272            rgb = self.up_sample(rgb) + rgb_new

Return the final RGB image

275        return rgb

Generator Block

Generator block

denotes a linear layer. denotes a broadcast and scaling operation (noise is a single channel). toRGB also has a style modulation which is not shown in the diagram to keep it simple.

The generator block consists of two style blocks ( convolutions with style modulation) and an RGB output.

278class GeneratorBlock(nn.Module):
  • d_latent is the dimensionality of
  • in_features is the number of features in the input feature map
  • out_features is the number of features in the output feature map
294    def __init__(self, d_latent: int, in_features: int, out_features: int):
300        super().__init__()

First style block changes the feature map size to out_features

303        self.style_block1 = StyleBlock(d_latent, in_features, out_features)

Second style block

305        self.style_block2 = StyleBlock(d_latent, out_features, out_features)

toRGB layer

308        self.to_rgb = ToRGB(d_latent, out_features)
  • x is the input feature map of shape [batch_size, in_features, height, width]
  • w is with shape [batch_size, d_latent]
  • noise is a tuple of two noise tensors of shape [batch_size, 1, height, width]
310    def forward(self, x: torch.Tensor, w: torch.Tensor, noise: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]):

First style block with first noise tensor. The output is of shape [batch_size, out_features, height, width]

318        x = self.style_block1(x, w, noise[0])

Second style block with second noise tensor. The output is of shape [batch_size, out_features, height, width]

321        x = self.style_block2(x, w, noise[1])

Get RGB image

324        rgb = self.to_rgb(x, w)

Return feature map and rgb image

327        return x, rgb

Style Block

Style block

denotes a linear layer. denotes a broadcast and scaling operation (noise is single channel).

Style block has a weight modulation convolution layer.

330class StyleBlock(nn.Module):
  • d_latent is the dimensionality of
  • in_features is the number of features in the input feature map
  • out_features is the number of features in the output feature map
344    def __init__(self, d_latent: int, in_features: int, out_features: int):
350        super().__init__()

Get style vector from (denoted by in the diagram) with an equalized learning-rate linear layer

353        self.to_style = EqualizedLinear(d_latent, in_features, bias=1.0)

Weight modulated convolution layer

355        self.conv = Conv2dWeightModulate(in_features, out_features, kernel_size=3)

Noise scale

357        self.scale_noise = nn.Parameter(torch.zeros(1))

Bias

359        self.bias = nn.Parameter(torch.zeros(out_features))

Activation function

362        self.activation = nn.LeakyReLU(0.2, True)
  • x is the input feature map of shape [batch_size, in_features, height, width]
  • w is with shape [batch_size, d_latent]
  • noise is a tensor of shape [batch_size, 1, height, width]
364    def forward(self, x: torch.Tensor, w: torch.Tensor, noise: Optional[torch.Tensor]):

Get style vector

371        s = self.to_style(w)

Weight modulated convolution

373        x = self.conv(x, s)

Scale and add noise

375        if noise is not None:
376            x = x + self.scale_noise[None, :, None, None] * noise

Add bias and evaluate activation function

378        return self.activation(x + self.bias[None, :, None, None])

To RGB

To RGB

denotes a linear layer.

Generates an RGB image from a feature map using convolution.

381class ToRGB(nn.Module):
  • d_latent is the dimensionality of
  • features is the number of features in the feature map
394    def __init__(self, d_latent: int, features: int):
399        super().__init__()

Get style vector from (denoted by in the diagram) with an equalized learning-rate linear layer

402        self.to_style = EqualizedLinear(d_latent, features, bias=1.0)

Weight modulated convolution layer without demodulation

405        self.conv = Conv2dWeightModulate(features, 3, kernel_size=1, demodulate=False)

Bias

407        self.bias = nn.Parameter(torch.zeros(3))

Activation function

409        self.activation = nn.LeakyReLU(0.2, True)
  • x is the input feature map of shape [batch_size, in_features, height, width]
  • w is with shape [batch_size, d_latent]
411    def forward(self, x: torch.Tensor, w: torch.Tensor):

Get style vector

417        style = self.to_style(w)

Weight modulated convolution

419        x = self.conv(x, style)

Add bias and evaluate activation function

421        return self.activation(x + self.bias[None, :, None, None])

Convolution with Weight Modulation and Demodulation

This layer scales the convolution weights by the style vector and demodulates by normalizing it.

424class Conv2dWeightModulate(nn.Module):
  • in_features is the number of features in the input feature map
  • out_features is the number of features in the output feature map
  • kernel_size is the size of the convolution kernel
  • demodulate is flag whether to normalize weights by its standard deviation
  • eps is the for normalizing
431    def __init__(self, in_features: int, out_features: int, kernel_size: int,
432                 demodulate: float = True, eps: float = 1e-8):
440        super().__init__()

Number of output features

442        self.out_features = out_features

Whether to normalize weights

444        self.demodulate = demodulate

Padding size

446        self.padding = (kernel_size - 1) // 2
449        self.weight = EqualizedWeight([out_features, in_features, kernel_size, kernel_size])

451        self.eps = eps
  • x is the input feature map of shape [batch_size, in_features, height, width]
  • s is style based scaling tensor of shape [batch_size, in_features]
453    def forward(self, x: torch.Tensor, s: torch.Tensor):

Get batch size, height and width

460        b, _, h, w = x.shape

Reshape the scales

463        s = s[:, None, :, None, None]
465        weights = self.weight()[None, :, :, :, :]

where is the input channel, is the output channel, and is the kernel index.

The result has shape [batch_size, out_features, in_features, kernel_size, kernel_size]

470        weights = weights * s

Demodulate

473        if self.demodulate:

475            sigma_inv = torch.rsqrt((weights ** 2).sum(dim=(2, 3, 4), keepdim=True) + self.eps)

477            weights = weights * sigma_inv

Reshape x

480        x = x.reshape(1, -1, h, w)

Reshape weights

483        _, _, *ws = weights.shape
484        weights = weights.reshape(b * self.out_features, *ws)

Use grouped convolution to efficiently calculate the convolution with sample wise kernel. i.e. we have a different kernel (weights) for each sample in the batch

488        x = F.conv2d(x, weights, padding=self.padding, groups=b)

Reshape x to [batch_size, out_features, height, width] and return

491        return x.reshape(-1, self.out_features, h, w)

StyleGAN 2 Discriminator

Discriminator

Discriminator first transforms the image to a feature map of the same resolution and then runs it through a series of blocks with residual connections. The resolution is down-sampled by at each block while doubling the number of features.

494class Discriminator(nn.Module):
  • log_resolution is the of image resolution
  • n_features number of features in the convolution layer at the highest resolution (first block)
  • max_features maximum number of features in any generator block
508    def __init__(self, log_resolution: int, n_features: int = 64, max_features: int = 512):
514        super().__init__()

Layer to convert RGB image to a feature map with n_features number of features.

517        self.from_rgb = nn.Sequential(
518            EqualizedConv2d(3, n_features, 1),
519            nn.LeakyReLU(0.2, True),
520        )

Calculate the number of features for each block.

Something like [64, 128, 256, 512, 512, 512] .

525        features = [min(max_features, n_features * (2 ** i)) for i in range(log_resolution - 1)]
527        n_blocks = len(features) - 1

Discriminator blocks

529        blocks = [DiscriminatorBlock(features[i], features[i + 1]) for i in range(n_blocks)]
530        self.blocks = nn.Sequential(*blocks)
533        self.std_dev = MiniBatchStdDev()

Number of features after adding the standard deviations map

535        final_features = features[-1] + 1

Final convolution layer

537        self.conv = EqualizedConv2d(final_features, final_features, 3)

Final linear layer to get the classification

539        self.final = EqualizedLinear(2 * 2 * final_features, 1)
  • x is the input image of shape [batch_size, 3, height, width]
541    def forward(self, x: torch.Tensor):

Try to normalize the image (this is totally optional, but sped up the early training a little)

547        x = x - 0.5

Convert from RGB

549        x = self.from_rgb(x)

Run through the discriminator blocks

551        x = self.blocks(x)

Calculate and append mini-batch standard deviation

554        x = self.std_dev(x)

convolution

556        x = self.conv(x)

Flatten

558        x = x.reshape(x.shape[0], -1)

Return the classification score

560        return self.final(x)

Discriminator Block

Discriminator block

Discriminator block consists of two convolutions with a residual connection.

563class DiscriminatorBlock(nn.Module):
  • in_features is the number of features in the input feature map
  • out_features is the number of features in the output feature map
574    def __init__(self, in_features, out_features):
579        super().__init__()

Down-sampling and convolution layer for the residual connection

581        self.residual = nn.Sequential(DownSample(),
582                                      EqualizedConv2d(in_features, out_features, kernel_size=1))

Two convolutions

585        self.block = nn.Sequential(
586            EqualizedConv2d(in_features, in_features, kernel_size=3, padding=1),
587            nn.LeakyReLU(0.2, True),
588            EqualizedConv2d(in_features, out_features, kernel_size=3, padding=1),
589            nn.LeakyReLU(0.2, True),
590        )

Down-sampling layer

593        self.down_sample = DownSample()

Scaling factor after adding the residual

596        self.scale = 1 / math.sqrt(2)
598    def forward(self, x):

Get the residual connection

600        residual = self.residual(x)

Convolutions

603        x = self.block(x)

Down-sample

605        x = self.down_sample(x)

Add the residual and scale

608        return (x + residual) * self.scale

Mini-batch Standard Deviation

Mini-batch standard deviation calculates the standard deviation across a mini-batch (or a subgroups within the mini-batch) for each feature in the feature map. Then it takes the mean of all the standard deviations and appends it to the feature map as one extra feature.

611class MiniBatchStdDev(nn.Module):
  • group_size is the number of samples to calculate standard deviation across.
623    def __init__(self, group_size: int = 4):
627        super().__init__()
628        self.group_size = group_size
  • x is the feature map
630    def forward(self, x: torch.Tensor):

Check if the batch size is divisible by the group size

635        assert x.shape[0] % self.group_size == 0

Split the samples into groups of group_size , we flatten the feature map to a single dimension since we want to calculate the standard deviation for each feature.

638        grouped = x.view(self.group_size, -1)

Calculate the standard deviation for each feature among group_size samples

645        std = torch.sqrt(grouped.var(dim=0) + 1e-8)

Get the mean standard deviation

647        std = std.mean().view(1, 1, 1, 1)

Expand the standard deviation to append to the feature map

649        b, _, h, w = x.shape
650        std = std.expand(b, -1, h, w)

Append (concatenate) the standard deviations to the feature map

652        return torch.cat([x, std], dim=1)

Down-sample

The down-sample operation smoothens each feature channel and scale using bilinear interpolation. This is based on the paper Making Convolutional Networks Shift-Invariant Again.

655class DownSample(nn.Module):
667    def __init__(self):
668        super().__init__()

Smoothing layer

670        self.smooth = Smooth()
672    def forward(self, x: torch.Tensor):

Smoothing or blurring

674        x = self.smooth(x)

Scaled down

676        return F.interpolate(x, (x.shape[2] // 2, x.shape[3] // 2), mode='bilinear', align_corners=False)

Up-sample

The up-sample operation scales the image up by and smoothens each feature channel. This is based on the paper Making Convolutional Networks Shift-Invariant Again.

679class UpSample(nn.Module):
690    def __init__(self):
691        super().__init__()

Up-sampling layer

693        self.up_sample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)

Smoothing layer

695        self.smooth = Smooth()
697    def forward(self, x: torch.Tensor):

Up-sample and smoothen

699        return self.smooth(self.up_sample(x))

Smoothing Layer

This layer blurs each channel

702class Smooth(nn.Module):
711    def __init__(self):
712        super().__init__()

Blurring kernel

714        kernel = [[1, 2, 1],
715                  [2, 4, 2],
716                  [1, 2, 1]]

Convert the kernel to a PyTorch tensor

718        kernel = torch.tensor([[kernel]], dtype=torch.float)

Normalize the kernel

720        kernel /= kernel.sum()

Save kernel as a fixed parameter (no gradient updates)

722        self.kernel = nn.Parameter(kernel, requires_grad=False)

Padding layer

724        self.pad = nn.ReplicationPad2d(1)
726    def forward(self, x: torch.Tensor):

Get shape of the input feature map

728        b, c, h, w = x.shape

Reshape for smoothening

730        x = x.view(-1, 1, h, w)

Add padding

733        x = self.pad(x)

Smoothen (blur) with the kernel

736        x = F.conv2d(x, self.kernel)

Reshape and return

739        return x.view(b, c, h, w)

Learning-rate Equalized Linear Layer

This uses learning-rate equalized weights for a linear layer.

742class EqualizedLinear(nn.Module):
  • in_features is the number of features in the input feature map
  • out_features is the number of features in the output feature map
  • bias is the bias initialization constant
751    def __init__(self, in_features: int, out_features: int, bias: float = 0.):
758        super().__init__()
760        self.weight = EqualizedWeight([out_features, in_features])

Bias

762        self.bias = nn.Parameter(torch.ones(out_features) * bias)
764    def forward(self, x: torch.Tensor):

Linear transformation

766        return F.linear(x, self.weight(), bias=self.bias)

Learning-rate Equalized 2D Convolution Layer

This uses learning-rate equalized weights for a convolution layer.

769class EqualizedConv2d(nn.Module):
  • in_features is the number of features in the input feature map
  • out_features is the number of features in the output feature map
  • kernel_size is the size of the convolution kernel
  • padding is the padding to be added on both sides of each size dimension
778    def __init__(self, in_features: int, out_features: int,
779                 kernel_size: int, padding: int = 0):
786        super().__init__()

Padding size

788        self.padding = padding
790        self.weight = EqualizedWeight([out_features, in_features, kernel_size, kernel_size])

Bias

792        self.bias = nn.Parameter(torch.ones(out_features))
794    def forward(self, x: torch.Tensor):

Convolution

796        return F.conv2d(x, self.weight(), bias=self.bias, padding=self.padding)

Learning-rate Equalized Weights Parameter

This is based on equalized learning rate introduced in the Progressive GAN paper. Instead of initializing weights at they initialize weights to and then multiply them by when using it.

The gradients on stored parameters get multiplied by but this doesn't have an affect since optimizers such as Adam normalize them by a running mean of the squared gradients.

The optimizer updates on are proportionate to the learning rate . But the effective weights get updated proportionately to . Without equalized learning rate, the effective weights will get updated proportionately to just .

So we are effectively scaling the learning rate by for these weight parameters.

799class EqualizedWeight(nn.Module):
  • shape is the shape of the weight parameter
820    def __init__(self, shape: List[int]):
824        super().__init__()

He initialization constant

827        self.c = 1 / math.sqrt(np.prod(shape[1:]))

Initialize the weights with

829        self.weight = nn.Parameter(torch.randn(shape))

Weight multiplication coefficient

832    def forward(self):

Multiply the weights by and return

834        return self.weight * self.c

Gradient Penalty

This is the regularization penality from the paper Which Training Methods for GANs do actually Converge?.

That is we try to reduce the L2 norm of gradients of the discriminator with respect to images, for real images ().

837class GradientPenalty(nn.Module):
  • x is
  • d is
853    def forward(self, x: torch.Tensor, d: torch.Tensor):

Get batch size

860        batch_size = x.shape[0]

Calculate gradients of with respect to . grad_outputs is set to since we want the gradients of , and we need to create and retain graph since we have to compute gradients with respect to weight on this loss.

866        gradients, *_ = torch.autograd.grad(outputs=d,
867                                            inputs=x,
868                                            grad_outputs=d.new_ones(d.shape),
869                                            create_graph=True)

Reshape gradients to calculate the norm

872        gradients = gradients.reshape(batch_size, -1)

Calculate the norm

874        norm = gradients.norm(2, dim=-1)

Return the loss

876        return torch.mean(norm ** 2)

Path Length Penalty

This regularization encourages a fixed-size step in to result in a fixed-magnitude change in the image.

where is the Jacobian , are sampled from from the mapping network, and are images with noise .

is the exponential moving average of as the training progresses.

is calculated without explicitly calculating the Jacobian using

879class PathLengthPenalty(nn.Module):
  • beta is the constant used to calculate the exponential moving average
903    def __init__(self, beta: float):
907        super().__init__()

910        self.beta = beta

Number of steps calculated

912        self.steps = nn.Parameter(torch.tensor(0.), requires_grad=False)

Exponential sum of where is the value of it at -th step of training

916        self.exp_sum_a = nn.Parameter(torch.tensor(0.), requires_grad=False)
  • w is the batch of of shape [batch_size, d_latent]
  • x are the generated images of shape [batch_size, 3, height, width]
918    def forward(self, w: torch.Tensor, x: torch.Tensor):

Get the device

925        device = x.device

Get number of pixels

927        image_size = x.shape[2] * x.shape[3]

Calculate

929        y = torch.randn(x.shape, device=device)

Calculate and normalize by the square root of image size. This is scaling is not mentioned in the paper but was present in their implementation.

933        output = (x * y).sum() / math.sqrt(image_size)

Calculate gradients to get

936        gradients, *_ = torch.autograd.grad(outputs=output,
937                                            inputs=w,
938                                            grad_outputs=torch.ones(output.shape, device=device),
939                                            create_graph=True)

Calculate L2-norm of

942        norm = (gradients ** 2).sum(dim=2).mean(dim=1).sqrt()

Regularize after first step

945        if self.steps > 0:

Calculate

948            a = self.exp_sum_a / (1 - self.beta ** self.steps)

Calculate the penalty

952            loss = torch.mean((norm - a) ** 2)
953        else:

Return a dummy loss if we can't calculate

955            loss = norm.new_tensor(0)

Calculate the mean of

958        mean = norm.mean().detach()

Update exponential sum

960        self.exp_sum_a.mul_(self.beta).add_(mean, alpha=1 - self.beta)

Increment

962        self.steps.add_(1.)

Return the penalty

965        return loss