Style GAN 2 Model Training

This is the training code for Style GAN 2 model.

Generated Images

These are $64 \times 64$ images generated after training for about 80K steps.

Our implementation is a minimalistic Style GAN2 model training code. Only single GPU training is supported to keep the implementation simple. We managed to shrink it to keep it at less than 500 lines of code, including the training loop.

Without DDP (distributed data parallel) and multi-gpu training it will not be possible to train the model for large resolutions (128+). If you want training code with fp16 and DDP take a look at lucidrains/stylegan2-pytorch.

We trained this on CelebA-HQ dataset. You can find the download instruction in this discussion on fast.ai. Save the images inside data/stylegan folder.

31import math
32from pathlib import Path
33from typing import Iterator, Tuple
34
35import torch
36import torch.utils.data
37import torchvision
38from PIL import Image
39
40from labml import tracker, lab, monit, experiment
41from labml.configs import BaseConfigs
42from labml_helpers.device import DeviceConfigs
43from labml_helpers.train_valid import ModeState, hook_model_outputs
44from labml_nn.gan.stylegan import Discriminator, Generator, MappingNetwork, GradientPenalty, PathLengthPenalty
45from labml_nn.gan.wasserstein import DiscriminatorLoss, GeneratorLoss
46from labml_nn.utils import cycle_dataloader

Dataset

This loads the training dataset and resize it to the give image size.

49class Dataset(torch.utils.data.Dataset):
  • path path to the folder containing the images
  • image_size size of the image
56    def __init__(self, path: str, image_size: int):
61        super().__init__()

Get the paths of all jpg files

64        self.paths = [p for p in Path(path).glob(f'**/*.jpg')]

Transformation

67        self.transform = torchvision.transforms.Compose([

Resize the image

69            torchvision.transforms.Resize(image_size),

Convert to PyTorch tensor

71            torchvision.transforms.ToTensor(),
72        ])

Number of images

74    def __len__(self):
76        return len(self.paths)

Get the the index-th image

78    def __getitem__(self, index):
80        path = self.paths[index]
81        img = Image.open(path)
82        return self.transform(img)

Configurations

85class Configs(BaseConfigs):

Device to train the model on. DeviceConfigs picks up an available CUDA device or defaults to CPU.

93    device: torch.device = DeviceConfigs()
96    discriminator: Discriminator
98    generator: Generator
100    mapping_network: MappingNetwork

Discriminator and generator loss functions. We use Wasserstein loss

104    discriminator_loss: DiscriminatorLoss
105    generator_loss: GeneratorLoss

Optimizers

108    generator_optimizer: torch.optim.Adam
109    discriminator_optimizer: torch.optim.Adam
110    mapping_network_optimizer: torch.optim.Adam
113    gradient_penalty = GradientPenalty()

Gradient penalty coefficient $\gamma$

115    gradient_penalty_coefficient: float = 10.
118    path_length_penalty: PathLengthPenalty

Data loader

121    loader: Iterator

Batch size

124    batch_size: int = 32

Dimensionality of $z$ and $w$

126    d_latent: int = 512

Height/width of the image

128    image_size: int = 32

Number of layers in the mapping network

130    mapping_network_layers: int = 8

Generator & Discriminator learning rate

132    learning_rate: float = 1e-3

Mapping network learning rate ($100 \times$ lower than the others)

134    mapping_network_learning_rate: float = 1e-5

Number of steps to accumulate gradients on. Use this to increase the effective batch size.

136    gradient_accumulate_steps: int = 1

$\beta_1$ and $\beta_2$ for Adam optimizer

138    adam_betas: Tuple[float, float] = (0.0, 0.99)

Probability of mixing styles

140    style_mixing_prob: float = 0.9

Total number of training steps

143    training_steps: int = 150_000

Number of blocks in the generator (calculated based on image resolution)

146    n_gen_blocks: int

Lazy regularization

Instead of calculating the regularization losses, the paper proposes lazy regularization where the regularization terms are calculated once in a while. This improves the training efficiency a lot.

The interval at which to compute gradient penalty

154    lazy_gradient_penalty_interval: int = 4

Path length penalty calculation interval

156    lazy_path_penalty_interval: int = 32

Skip calculating path length penalty during the initial phase of training

158    lazy_path_penalty_after: int = 5_000

How often to log generated images

161    log_generated_interval: int = 500

How often to save model checkpoints

163    save_checkpoint_interval: int = 2_000

Training mode state for logging activations

166    mode: ModeState

Whether to log model layer outputs

168    log_layer_outputs: bool = False

We trained this on CelebA-HQ dataset. You can find the download instruction in this discussion on fast.ai. Save the images inside data/stylegan folder.

175    dataset_path: str = str(lab.get_data_path() / 'stylegan2')

Initialize

177    def init(self):

Create dataset

182        dataset = Dataset(self.dataset_path, self.image_size)

Create data loader

184        dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, num_workers=32,
185                                                 shuffle=True, drop_last=True, pin_memory=True)

Continuous cyclic loader

187        self.loader = cycle_dataloader(dataloader)

$\log_2$ of image resolution

190        log_resolution = int(math.log2(self.image_size))

Create discriminator and generator

193        self.discriminator = Discriminator(log_resolution).to(self.device)
194        self.generator = Generator(log_resolution, self.d_latent).to(self.device)

Get number of generator blocks for creating style and noise inputs

196        self.n_gen_blocks = self.generator.n_blocks

Create mapping network

198        self.mapping_network = MappingNetwork(self.d_latent, self.mapping_network_layers).to(self.device)

Create path length penalty loss

200        self.path_length_penalty = PathLengthPenalty(0.99).to(self.device)

Add model hooks to monitor layer outputs

203        if self.log_layer_outputs:
204            hook_model_outputs(self.mode, self.discriminator, 'discriminator')
205            hook_model_outputs(self.mode, self.generator, 'generator')
206            hook_model_outputs(self.mode, self.mapping_network, 'mapping_network')

Discriminator and generator losses

209        self.discriminator_loss = DiscriminatorLoss().to(self.device)
210        self.generator_loss = GeneratorLoss().to(self.device)

Create optimizers

213        self.discriminator_optimizer = torch.optim.Adam(
214            self.discriminator.parameters(),
215            lr=self.learning_rate, betas=self.adam_betas
216        )
217        self.generator_optimizer = torch.optim.Adam(
218            self.generator.parameters(),
219            lr=self.learning_rate, betas=self.adam_betas
220        )
221        self.mapping_network_optimizer = torch.optim.Adam(
222            self.mapping_network.parameters(),
223            lr=self.mapping_network_learning_rate, betas=self.adam_betas
224        )

Set tracker configurations

227        tracker.set_image("generated", True)

Sample $w$

This samples $z$ randomly and get $w$ from the mapping network.

We also apply style mixing sometimes where we generate two latent variables $z_1$ and $z_2$ and get corresponding $w_1$ and $w_2$. Then we randomly sample a cross-over point and apply $w_1$ to the generator blocks before the cross-over point and $w_2$ to the blocks after.

229    def get_w(self, batch_size: int):

Mix styles

243        if torch.rand(()).item() < self.style_mixing_prob:

Random cross-over point

245            cross_over_point = int(torch.rand(()).item() * self.n_gen_blocks)

Sample $z_1$ and $z_2$

247            z2 = torch.randn(batch_size, self.d_latent).to(self.device)
248            z1 = torch.randn(batch_size, self.d_latent).to(self.device)

Get $w_1$ and $w_2$

250            w1 = self.mapping_network(z1)
251            w2 = self.mapping_network(z2)

Expand $w_1$ and $w_2$ for the generator blocks and concatenate

253            w1 = w1[None, :, :].expand(cross_over_point, -1, -1)
254            w2 = w2[None, :, :].expand(self.n_gen_blocks - cross_over_point, -1, -1)
255            return torch.cat((w1, w2), dim=0)

Without mixing

257        else:

Sample $z$ and $z$

259            z = torch.randn(batch_size, self.d_latent).to(self.device)

Get $w$ and $w$

261            w = self.mapping_network(z)

Expand $w$ for the generator blocks

263            return w[None, :, :].expand(self.n_gen_blocks, -1, -1)

Generate noise

This generates noise for each generator block

265    def get_noise(self, batch_size: int):

List to store noise

272        noise = []

Noise resolution starts from $4$

274        resolution = 4

Generate noise for each generator block

277        for i in range(self.n_gen_blocks):

The first block has only one $3 \times 3$ convolution

279            if i == 0:
280                n1 = None

Generate noise to add after the first convolution layer

282            else:
283                n1 = torch.randn(batch_size, 1, resolution, resolution, device=self.device)

Generate noise to add after the second convolution layer

285            n2 = torch.randn(batch_size, 1, resolution, resolution, device=self.device)

Add noise tensors to the list

288            noise.append((n1, n2))

Next block has $2 \times$ resolution

291            resolution *= 2

Return noise tensors

294        return noise

Generate images

This generate images using the generator

296    def generate_images(self, batch_size: int):

Get $w$

304        w = self.get_w(batch_size)

Get noise

306        noise = self.get_noise(batch_size)

Generate images

309        images = self.generator(w, noise)

Return images and $w$

312        return images, w

Training Step

314    def step(self, idx: int):

Train the discriminator

320        with monit.section('Discriminator'):

Reset gradients

322            self.discriminator_optimizer.zero_grad()

Accumulate gradients for gradient_accumulate_steps

325            for i in range(self.gradient_accumulate_steps):

Update mode. Set whether to log activation

327                with self.mode.update(is_log_activations=(idx + 1) % self.log_generated_interval == 0):

Sample images from generator

329                    generated_images, _ = self.generate_images(self.batch_size)

Discriminator classification for generated images

331                    fake_output = self.discriminator(generated_images.detach())

Get real images from the data loader

334                    real_images = next(self.loader).to(self.device)

We need to calculate gradients w.r.t. real images for gradient penalty

336                    if (idx + 1) % self.lazy_gradient_penalty_interval == 0:
337                        real_images.requires_grad_()

Discriminator classification for real images

339                    real_output = self.discriminator(real_images)

Get discriminator loss

342                    real_loss, fake_loss = self.discriminator_loss(real_output, fake_output)
343                    disc_loss = real_loss + fake_loss

Add gradient penalty

346                    if (idx + 1) % self.lazy_gradient_penalty_interval == 0:

Calculate and log gradient penalty

348                        gp = self.gradient_penalty(real_images, real_output)
349                        tracker.add('loss.gp', gp)

Multiply by coefficient and add gradient penalty

351                        disc_loss = disc_loss + 0.5 * self.gradient_penalty_coefficient * gp * self.lazy_gradient_penalty_interval

Compute gradients

354                    disc_loss.backward()

Log discriminator loss

357                    tracker.add('loss.discriminator', disc_loss)
358
359            if (idx + 1) % self.log_generated_interval == 0:

Log discriminator model parameters occasionally

361                tracker.add('discriminator', self.discriminator)

Clip gradients for stabilization

364            torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), max_norm=1.0)

Take optimizer step

366            self.discriminator_optimizer.step()

Train the generator

369        with monit.section('Generator'):

Reset gradients

371            self.generator_optimizer.zero_grad()
372            self.mapping_network_optimizer.zero_grad()

Accumulate gradients for gradient_accumulate_steps

375            for i in range(self.gradient_accumulate_steps):

Sample images from generator

377                generated_images, w = self.generate_images(self.batch_size)

Discriminator classification for generated images

379                fake_output = self.discriminator(generated_images)

Get generator loss

382                gen_loss = self.generator_loss(fake_output)

Add path length penalty

385                if idx > self.lazy_path_penalty_after and (idx + 1) % self.lazy_path_penalty_interval == 0:

Calculate path length penalty

387                    plp = self.path_length_penalty(w, generated_images)

Ignore if nan

389                    if not torch.isnan(plp):
390                        tracker.add('loss.plp', plp)
391                        gen_loss = gen_loss + plp

Calculate gradients

394                gen_loss.backward()

Log generator loss

397                tracker.add('loss.generator', gen_loss)
398
399            if (idx + 1) % self.log_generated_interval == 0:

Log discriminator model parameters occasionally

401                tracker.add('generator', self.generator)
402                tracker.add('mapping_network', self.mapping_network)

Clip gradients for stabilization

405            torch.nn.utils.clip_grad_norm_(self.generator.parameters(), max_norm=1.0)
406            torch.nn.utils.clip_grad_norm_(self.mapping_network.parameters(), max_norm=1.0)

Take optimizer step

409            self.generator_optimizer.step()
410            self.mapping_network_optimizer.step()

Log generated images

413        if (idx + 1) % self.log_generated_interval == 0:
414            tracker.add('generated', torch.cat([generated_images[:6], real_images[:3]], dim=0))

Save model checkpoints

416        if (idx + 1) % self.save_checkpoint_interval == 0:
417            experiment.save_checkpoint()

Flush tracker

420        tracker.save()

Train model

422    def train(self):

Loop for training_steps

428        for i in monit.loop(self.training_steps):

Take a training step

430            self.step(i)
432            if (i + 1) % self.log_generated_interval == 0:
433                tracker.new_line()

Train StyleGAN2

436def main():

Create an experiment

442    experiment.create(name='stylegan2')

Create configurations object

444    configs = Configs()

Set configurations and override some

447    experiment.configs(configs, {
448        'device.cuda_device': 0,
449        'image_size': 64,
450        'log_generated_interval': 200
451    })

Initialize

454    configs.init()

Set models for saving and loading

456    experiment.add_pytorch_models(mapping_network=configs.mapping_network,
457                                  generator=configs.generator,
458                                  discriminator=configs.discriminator)

Start the experiment

461    with experiment.start():

Run the training loop

463        configs.train()
466if __name__ == '__main__':
467    main()