StyleGAN 2 Model Training

This is the training code for StyleGAN 2 model.

Generated Images

These are images generated after training for about 80K steps.

Our implementation is a minimalistic StyleGAN 2 model training code. Only single GPU training is supported to keep the implementation simple. We managed to shrink it to keep it at less than 500 lines of code, including the training loop.

Without DDP (distributed data parallel) and multi-gpu training it will not be possible to train the model for large resolutions (128+). If you want training code with fp16 and DDP take a look at lucidrains/stylegan2-pytorch.

We trained this on CelebA-HQ dataset. You can find the download instruction in this discussion on fast.ai. Save the images inside data/stylegan folder.

31import math
32from pathlib import Path
33from typing import Iterator, Tuple
34
35import torchvision
36from PIL import Image
37
38import torch
39import torch.utils.data
40from labml import tracker, lab, monit, experiment
41from labml.configs import BaseConfigs
42from labml_nn.gan.stylegan import Discriminator, Generator, MappingNetwork, GradientPenalty, PathLengthPenalty
43from labml_nn.gan.wasserstein import DiscriminatorLoss, GeneratorLoss
44from labml_nn.helpers.device import DeviceConfigs
45from labml_nn.helpers.trainer import ModeState
46from labml_nn.utils import cycle_dataloader

Dataset

This loads the training dataset and resize it to the give image size.

49class Dataset(torch.utils.data.Dataset):
  • path path to the folder containing the images
  • image_size size of the image
56    def __init__(self, path: str, image_size: int):
61        super().__init__()

Get the paths of all jpg files

64        self.paths = [p for p in Path(path).glob(f'**/*.jpg')]

Transformation

67        self.transform = torchvision.transforms.Compose([

Resize the image

69            torchvision.transforms.Resize(image_size),

Convert to PyTorch tensor

71            torchvision.transforms.ToTensor(),
72        ])

Number of images

74    def __len__(self):
76        return len(self.paths)

Get the the index -th image

78    def __getitem__(self, index):
80        path = self.paths[index]
81        img = Image.open(path)
82        return self.transform(img)

Configurations

85class Configs(BaseConfigs):

Device to train the model on. DeviceConfigs picks up an available CUDA device or defaults to CPU.

93    device: torch.device = DeviceConfigs()
96    discriminator: Discriminator
98    generator: Generator
100    mapping_network: MappingNetwork

Discriminator and generator loss functions. We use Wasserstein loss

104    discriminator_loss: DiscriminatorLoss
105    generator_loss: GeneratorLoss

Optimizers

108    generator_optimizer: torch.optim.Adam
109    discriminator_optimizer: torch.optim.Adam
110    mapping_network_optimizer: torch.optim.Adam
113    gradient_penalty = GradientPenalty()

Gradient penalty coefficient

115    gradient_penalty_coefficient: float = 10.
118    path_length_penalty: PathLengthPenalty

Data loader

121    loader: Iterator

Batch size

124    batch_size: int = 32

Dimensionality of and

126    d_latent: int = 512

Height/width of the image

128    image_size: int = 32

Number of layers in the mapping network

130    mapping_network_layers: int = 8

Generator & Discriminator learning rate

132    learning_rate: float = 1e-3

Mapping network learning rate ( lower than the others)

134    mapping_network_learning_rate: float = 1e-5

Number of steps to accumulate gradients on. Use this to increase the effective batch size.

136    gradient_accumulate_steps: int = 1

and for Adam optimizer

138    adam_betas: Tuple[float, float] = (0.0, 0.99)

Probability of mixing styles

140    style_mixing_prob: float = 0.9

Total number of training steps

143    training_steps: int = 150_000

Number of blocks in the generator (calculated based on image resolution)

146    n_gen_blocks: int

Lazy regularization

Instead of calculating the regularization losses, the paper proposes lazy regularization where the regularization terms are calculated once in a while. This improves the training efficiency a lot.

The interval at which to compute gradient penalty

154    lazy_gradient_penalty_interval: int = 4

Path length penalty calculation interval

156    lazy_path_penalty_interval: int = 32

Skip calculating path length penalty during the initial phase of training

158    lazy_path_penalty_after: int = 5_000

How often to log generated images

161    log_generated_interval: int = 500

How often to save model checkpoints

163    save_checkpoint_interval: int = 2_000

Training mode state for logging activations

166    mode: ModeState

We trained this on CelebA-HQ dataset. You can find the download instruction in this discussion on fast.ai. Save the images inside data/stylegan folder.

173    dataset_path: str = str(lab.get_data_path() / 'stylegan2')

Initialize

175    def init(self):

Create dataset

180        dataset = Dataset(self.dataset_path, self.image_size)

Create data loader

182        dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, num_workers=8,
183                                                 shuffle=True, drop_last=True, pin_memory=True)

Continuous cyclic loader

185        self.loader = cycle_dataloader(dataloader)

of image resolution

188        log_resolution = int(math.log2(self.image_size))

Create discriminator and generator

191        self.discriminator = Discriminator(log_resolution).to(self.device)
192        self.generator = Generator(log_resolution, self.d_latent).to(self.device)

Get number of generator blocks for creating style and noise inputs

194        self.n_gen_blocks = self.generator.n_blocks

Create mapping network

196        self.mapping_network = MappingNetwork(self.d_latent, self.mapping_network_layers).to(self.device)

Create path length penalty loss

198        self.path_length_penalty = PathLengthPenalty(0.99).to(self.device)

Discriminator and generator losses

201        self.discriminator_loss = DiscriminatorLoss().to(self.device)
202        self.generator_loss = GeneratorLoss().to(self.device)

Create optimizers

205        self.discriminator_optimizer = torch.optim.Adam(
206            self.discriminator.parameters(),
207            lr=self.learning_rate, betas=self.adam_betas
208        )
209        self.generator_optimizer = torch.optim.Adam(
210            self.generator.parameters(),
211            lr=self.learning_rate, betas=self.adam_betas
212        )
213        self.mapping_network_optimizer = torch.optim.Adam(
214            self.mapping_network.parameters(),
215            lr=self.mapping_network_learning_rate, betas=self.adam_betas
216        )

Set tracker configurations

219        tracker.set_image("generated", True)

Sample

This samples randomly and get from the mapping network.

We also apply style mixing sometimes where we generate two latent variables and and get corresponding and . Then we randomly sample a cross-over point and apply to the generator blocks before the cross-over point and to the blocks after.

221    def get_w(self, batch_size: int):

Mix styles

235        if torch.rand(()).item() < self.style_mixing_prob:

Random cross-over point

237            cross_over_point = int(torch.rand(()).item() * self.n_gen_blocks)

Sample and

239            z2 = torch.randn(batch_size, self.d_latent).to(self.device)
240            z1 = torch.randn(batch_size, self.d_latent).to(self.device)

Get and

242            w1 = self.mapping_network(z1)
243            w2 = self.mapping_network(z2)

Expand and for the generator blocks and concatenate

245            w1 = w1[None, :, :].expand(cross_over_point, -1, -1)
246            w2 = w2[None, :, :].expand(self.n_gen_blocks - cross_over_point, -1, -1)
247            return torch.cat((w1, w2), dim=0)

Without mixing

249        else:

Sample and

251            z = torch.randn(batch_size, self.d_latent).to(self.device)

Get and

253            w = self.mapping_network(z)

Expand for the generator blocks

255            return w[None, :, :].expand(self.n_gen_blocks, -1, -1)

Generate noise

This generates noise for each generator block

257    def get_noise(self, batch_size: int):

List to store noise

264        noise = []

Noise resolution starts from

266        resolution = 4

Generate noise for each generator block

269        for i in range(self.n_gen_blocks):

The first block has only one convolution

271            if i == 0:
272                n1 = None

Generate noise to add after the first convolution layer

274            else:
275                n1 = torch.randn(batch_size, 1, resolution, resolution, device=self.device)

Generate noise to add after the second convolution layer

277            n2 = torch.randn(batch_size, 1, resolution, resolution, device=self.device)

Add noise tensors to the list

280            noise.append((n1, n2))

Next block has resolution

283            resolution *= 2

Return noise tensors

286        return noise

Generate images

This generate images using the generator

288    def generate_images(self, batch_size: int):

Get

296        w = self.get_w(batch_size)

Get noise

298        noise = self.get_noise(batch_size)

Generate images

301        images = self.generator(w, noise)

Return images and

304        return images, w

Training Step

306    def step(self, idx: int):

Train the discriminator

312        with monit.section('Discriminator'):

Reset gradients

314            self.discriminator_optimizer.zero_grad()

Accumulate gradients for gradient_accumulate_steps

317            for i in range(self.gradient_accumulate_steps):

Sample images from generator

319                generated_images, _ = self.generate_images(self.batch_size)

Discriminator classification for generated images

321                fake_output = self.discriminator(generated_images.detach())

Get real images from the data loader

324                real_images = next(self.loader).to(self.device)

We need to calculate gradients w.r.t. real images for gradient penalty

326                if (idx + 1) % self.lazy_gradient_penalty_interval == 0:
327                    real_images.requires_grad_()

Discriminator classification for real images

329                real_output = self.discriminator(real_images)

Get discriminator loss

332                real_loss, fake_loss = self.discriminator_loss(real_output, fake_output)
333                disc_loss = real_loss + fake_loss

Add gradient penalty

336                if (idx + 1) % self.lazy_gradient_penalty_interval == 0:

Calculate and log gradient penalty

338                    gp = self.gradient_penalty(real_images, real_output)
339                    tracker.add('loss.gp', gp)

Multiply by coefficient and add gradient penalty

341                    disc_loss = disc_loss + 0.5 * self.gradient_penalty_coefficient * gp * self.lazy_gradient_penalty_interval

Compute gradients

344                disc_loss.backward()

Log discriminator loss

347                tracker.add('loss.discriminator', disc_loss)
348
349            if (idx + 1) % self.log_generated_interval == 0:

Log discriminator model parameters occasionally

351                tracker.add('discriminator', self.discriminator)

Clip gradients for stabilization

354            torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), max_norm=1.0)

Take optimizer step

356            self.discriminator_optimizer.step()

Train the generator

359        with monit.section('Generator'):

Reset gradients

361            self.generator_optimizer.zero_grad()
362            self.mapping_network_optimizer.zero_grad()

Accumulate gradients for gradient_accumulate_steps

365            for i in range(self.gradient_accumulate_steps):

Sample images from generator

367                generated_images, w = self.generate_images(self.batch_size)

Discriminator classification for generated images

369                fake_output = self.discriminator(generated_images)

Get generator loss

372                gen_loss = self.generator_loss(fake_output)

Add path length penalty

375                if idx > self.lazy_path_penalty_after and (idx + 1) % self.lazy_path_penalty_interval == 0:

Calculate path length penalty

377                    plp = self.path_length_penalty(w, generated_images)

Ignore if nan

379                    if not torch.isnan(plp):
380                        tracker.add('loss.plp', plp)
381                        gen_loss = gen_loss + plp

Calculate gradients

384                gen_loss.backward()

Log generator loss

387                tracker.add('loss.generator', gen_loss)
388
389            if (idx + 1) % self.log_generated_interval == 0:

Log discriminator model parameters occasionally

391                tracker.add('generator', self.generator)
392                tracker.add('mapping_network', self.mapping_network)

Clip gradients for stabilization

395            torch.nn.utils.clip_grad_norm_(self.generator.parameters(), max_norm=1.0)
396            torch.nn.utils.clip_grad_norm_(self.mapping_network.parameters(), max_norm=1.0)

Take optimizer step

399            self.generator_optimizer.step()
400            self.mapping_network_optimizer.step()

Log generated images

403        if (idx + 1) % self.log_generated_interval == 0:
404            tracker.add('generated', torch.cat([generated_images[:6], real_images[:3]], dim=0))

Save model checkpoints

406        if (idx + 1) % self.save_checkpoint_interval == 0:

Save checkpoint

408            pass

Flush tracker

411        tracker.save()

Train model

413    def train(self):

Loop for training_steps

419        for i in monit.loop(self.training_steps):

Take a training step

421            self.step(i)

423            if (i + 1) % self.log_generated_interval == 0:
424                tracker.new_line()

Train StyleGAN2

427def main():

Create an experiment

433    experiment.create(name='stylegan2')

Create configurations object

435    configs = Configs()

Set configurations and override some

438    experiment.configs(configs, {
439        'device.cuda_device': 0,
440        'image_size': 64,
441        'log_generated_interval': 200
442    })

Initialize

445    configs.init()

Set models for saving and loading

447    experiment.add_pytorch_models(mapping_network=configs.mapping_network,
448                                  generator=configs.generator,
449                                  discriminator=configs.discriminator)

Start the experiment

452    with experiment.start():

Run the training loop

454        configs.train()

458if __name__ == '__main__':
459    main()