This is the training code for StyleGAN 2 model.
These are images generated after training for about 80K steps.
Our implementation is a minimalistic StyleGAN 2 model training code. Only single GPU training is supported to keep the implementation simple. We managed to shrink it to keep it at less than 500 lines of code, including the training loop.
Without DDP (distributed data parallel) and multi-gpu training it will not be possible to train the model for large resolutions (128+). If you want training code with fp16 and DDP take a look at lucidrains/stylegan2-pytorch.
We trained this on CelebA-HQ dataset. You can find the download instruction in this discussion on fast.ai. Save the images inside data/stylegan
folder.
31import math
32from pathlib import Path
33from typing import Iterator, Tuple
34
35import torchvision
36from PIL import Image
37
38import torch
39import torch.utils.data
40from labml import tracker, lab, monit, experiment
41from labml.configs import BaseConfigs
42from labml_nn.gan.stylegan import Discriminator, Generator, MappingNetwork, GradientPenalty, PathLengthPenalty
43from labml_nn.gan.wasserstein import DiscriminatorLoss, GeneratorLoss
44from labml_nn.helpers.device import DeviceConfigs
45from labml_nn.helpers.trainer import ModeState
46from labml_nn.utils import cycle_dataloader
49class Dataset(torch.utils.data.Dataset):
path
path to the folder containing the images image_size
size of the image56 def __init__(self, path: str, image_size: int):
61 super().__init__()
Get the paths of all jpg
files
64 self.paths = [p for p in Path(path).glob(f'**/*.jpg')]
Transformation
67 self.transform = torchvision.transforms.Compose([
Resize the image
69 torchvision.transforms.Resize(image_size),
Convert to PyTorch tensor
71 torchvision.transforms.ToTensor(),
72 ])
Number of images
74 def __len__(self):
76 return len(self.paths)
Get the the index
-th image
78 def __getitem__(self, index):
80 path = self.paths[index]
81 img = Image.open(path)
82 return self.transform(img)
85class Configs(BaseConfigs):
Device to train the model on. DeviceConfigs
picks up an available CUDA device or defaults to CPU.
93 device: torch.device = DeviceConfigs()
96 discriminator: Discriminator
98 generator: Generator
100 mapping_network: MappingNetwork
Discriminator and generator loss functions. We use Wasserstein loss
104 discriminator_loss: DiscriminatorLoss
105 generator_loss: GeneratorLoss
Optimizers
108 generator_optimizer: torch.optim.Adam
109 discriminator_optimizer: torch.optim.Adam
110 mapping_network_optimizer: torch.optim.Adam
113 gradient_penalty = GradientPenalty()
Gradient penalty coefficient
115 gradient_penalty_coefficient: float = 10.
118 path_length_penalty: PathLengthPenalty
Data loader
121 loader: Iterator
Batch size
124 batch_size: int = 32
Dimensionality of and
126 d_latent: int = 512
Height/width of the image
128 image_size: int = 32
Number of layers in the mapping network
130 mapping_network_layers: int = 8
Generator & Discriminator learning rate
132 learning_rate: float = 1e-3
Mapping network learning rate ( lower than the others)
134 mapping_network_learning_rate: float = 1e-5
Number of steps to accumulate gradients on. Use this to increase the effective batch size.
136 gradient_accumulate_steps: int = 1
and for Adam optimizer
138 adam_betas: Tuple[float, float] = (0.0, 0.99)
Probability of mixing styles
140 style_mixing_prob: float = 0.9
Total number of training steps
143 training_steps: int = 150_000
Number of blocks in the generator (calculated based on image resolution)
146 n_gen_blocks: int
Instead of calculating the regularization losses, the paper proposes lazy regularization where the regularization terms are calculated once in a while. This improves the training efficiency a lot.
The interval at which to compute gradient penalty
154 lazy_gradient_penalty_interval: int = 4
Path length penalty calculation interval
156 lazy_path_penalty_interval: int = 32
Skip calculating path length penalty during the initial phase of training
158 lazy_path_penalty_after: int = 5_000
How often to log generated images
161 log_generated_interval: int = 500
How often to save model checkpoints
163 save_checkpoint_interval: int = 2_000
Training mode state for logging activations
166 mode: ModeState
We trained this on CelebA-HQ dataset. You can find the download instruction in this discussion on fast.ai. Save the images inside data/stylegan
folder.
173 dataset_path: str = str(lab.get_data_path() / 'stylegan2')
175 def init(self):
Create dataset
180 dataset = Dataset(self.dataset_path, self.image_size)
Create data loader
182 dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, num_workers=8,
183 shuffle=True, drop_last=True, pin_memory=True)
Continuous cyclic loader
185 self.loader = cycle_dataloader(dataloader)
of image resolution
188 log_resolution = int(math.log2(self.image_size))
Create discriminator and generator
191 self.discriminator = Discriminator(log_resolution).to(self.device)
192 self.generator = Generator(log_resolution, self.d_latent).to(self.device)
Get number of generator blocks for creating style and noise inputs
194 self.n_gen_blocks = self.generator.n_blocks
Create mapping network
196 self.mapping_network = MappingNetwork(self.d_latent, self.mapping_network_layers).to(self.device)
Create path length penalty loss
198 self.path_length_penalty = PathLengthPenalty(0.99).to(self.device)
Discriminator and generator losses
201 self.discriminator_loss = DiscriminatorLoss().to(self.device)
202 self.generator_loss = GeneratorLoss().to(self.device)
Create optimizers
205 self.discriminator_optimizer = torch.optim.Adam(
206 self.discriminator.parameters(),
207 lr=self.learning_rate, betas=self.adam_betas
208 )
209 self.generator_optimizer = torch.optim.Adam(
210 self.generator.parameters(),
211 lr=self.learning_rate, betas=self.adam_betas
212 )
213 self.mapping_network_optimizer = torch.optim.Adam(
214 self.mapping_network.parameters(),
215 lr=self.mapping_network_learning_rate, betas=self.adam_betas
216 )
Set tracker configurations
219 tracker.set_image("generated", True)
This samples randomly and get from the mapping network.
We also apply style mixing sometimes where we generate two latent variables and and get corresponding and . Then we randomly sample a cross-over point and apply to the generator blocks before the cross-over point and to the blocks after.
221 def get_w(self, batch_size: int):
Mix styles
235 if torch.rand(()).item() < self.style_mixing_prob:
Random cross-over point
237 cross_over_point = int(torch.rand(()).item() * self.n_gen_blocks)
Sample and
239 z2 = torch.randn(batch_size, self.d_latent).to(self.device)
240 z1 = torch.randn(batch_size, self.d_latent).to(self.device)
Get and
242 w1 = self.mapping_network(z1)
243 w2 = self.mapping_network(z2)
Expand and for the generator blocks and concatenate
245 w1 = w1[None, :, :].expand(cross_over_point, -1, -1)
246 w2 = w2[None, :, :].expand(self.n_gen_blocks - cross_over_point, -1, -1)
247 return torch.cat((w1, w2), dim=0)
Without mixing
249 else:
Sample and
251 z = torch.randn(batch_size, self.d_latent).to(self.device)
Get and
253 w = self.mapping_network(z)
Expand for the generator blocks
255 return w[None, :, :].expand(self.n_gen_blocks, -1, -1)
257 def get_noise(self, batch_size: int):
List to store noise
264 noise = []
Noise resolution starts from
266 resolution = 4
Generate noise for each generator block
269 for i in range(self.n_gen_blocks):
The first block has only one convolution
271 if i == 0:
272 n1 = None
Generate noise to add after the first convolution layer
274 else:
275 n1 = torch.randn(batch_size, 1, resolution, resolution, device=self.device)
Generate noise to add after the second convolution layer
277 n2 = torch.randn(batch_size, 1, resolution, resolution, device=self.device)
Add noise tensors to the list
280 noise.append((n1, n2))
Next block has resolution
283 resolution *= 2
Return noise tensors
286 return noise
288 def generate_images(self, batch_size: int):
Get
296 w = self.get_w(batch_size)
Get noise
298 noise = self.get_noise(batch_size)
Generate images
301 images = self.generator(w, noise)
Return images and
304 return images, w
306 def step(self, idx: int):
Train the discriminator
312 with monit.section('Discriminator'):
Reset gradients
314 self.discriminator_optimizer.zero_grad()
Accumulate gradients for gradient_accumulate_steps
317 for i in range(self.gradient_accumulate_steps):
Sample images from generator
319 generated_images, _ = self.generate_images(self.batch_size)
Discriminator classification for generated images
321 fake_output = self.discriminator(generated_images.detach())
Get real images from the data loader
324 real_images = next(self.loader).to(self.device)
We need to calculate gradients w.r.t. real images for gradient penalty
326 if (idx + 1) % self.lazy_gradient_penalty_interval == 0:
327 real_images.requires_grad_()
Discriminator classification for real images
329 real_output = self.discriminator(real_images)
Get discriminator loss
332 real_loss, fake_loss = self.discriminator_loss(real_output, fake_output)
333 disc_loss = real_loss + fake_loss
Add gradient penalty
336 if (idx + 1) % self.lazy_gradient_penalty_interval == 0:
Calculate and log gradient penalty
338 gp = self.gradient_penalty(real_images, real_output)
339 tracker.add('loss.gp', gp)
Multiply by coefficient and add gradient penalty
341 disc_loss = disc_loss + 0.5 * self.gradient_penalty_coefficient * gp * self.lazy_gradient_penalty_interval
Compute gradients
344 disc_loss.backward()
Log discriminator loss
347 tracker.add('loss.discriminator', disc_loss)
348
349 if (idx + 1) % self.log_generated_interval == 0:
Log discriminator model parameters occasionally
351 tracker.add('discriminator', self.discriminator)
Clip gradients for stabilization
354 torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), max_norm=1.0)
Take optimizer step
356 self.discriminator_optimizer.step()
Train the generator
359 with monit.section('Generator'):
Reset gradients
361 self.generator_optimizer.zero_grad()
362 self.mapping_network_optimizer.zero_grad()
Accumulate gradients for gradient_accumulate_steps
365 for i in range(self.gradient_accumulate_steps):
Sample images from generator
367 generated_images, w = self.generate_images(self.batch_size)
Discriminator classification for generated images
369 fake_output = self.discriminator(generated_images)
Get generator loss
372 gen_loss = self.generator_loss(fake_output)
Add path length penalty
375 if idx > self.lazy_path_penalty_after and (idx + 1) % self.lazy_path_penalty_interval == 0:
Calculate path length penalty
377 plp = self.path_length_penalty(w, generated_images)
Ignore if nan
379 if not torch.isnan(plp):
380 tracker.add('loss.plp', plp)
381 gen_loss = gen_loss + plp
Calculate gradients
384 gen_loss.backward()
Log generator loss
387 tracker.add('loss.generator', gen_loss)
388
389 if (idx + 1) % self.log_generated_interval == 0:
Log discriminator model parameters occasionally
391 tracker.add('generator', self.generator)
392 tracker.add('mapping_network', self.mapping_network)
Clip gradients for stabilization
395 torch.nn.utils.clip_grad_norm_(self.generator.parameters(), max_norm=1.0)
396 torch.nn.utils.clip_grad_norm_(self.mapping_network.parameters(), max_norm=1.0)
Take optimizer step
399 self.generator_optimizer.step()
400 self.mapping_network_optimizer.step()
Log generated images
403 if (idx + 1) % self.log_generated_interval == 0:
404 tracker.add('generated', torch.cat([generated_images[:6], real_images[:3]], dim=0))
Save model checkpoints
406 if (idx + 1) % self.save_checkpoint_interval == 0:
Save checkpoint
408 pass
Flush tracker
411 tracker.save()
413 def train(self):
Loop for training_steps
419 for i in monit.loop(self.training_steps):
Take a training step
421 self.step(i)
423 if (i + 1) % self.log_generated_interval == 0:
424 tracker.new_line()
427def main():
Create an experiment
433 experiment.create(name='stylegan2')
Create configurations object
435 configs = Configs()
Set configurations and override some
438 experiment.configs(configs, {
439 'device.cuda_device': 0,
440 'image_size': 64,
441 'log_generated_interval': 200
442 })
Initialize
445 configs.init()
Set models for saving and loading
447 experiment.add_pytorch_models(mapping_network=configs.mapping_network,
448 generator=configs.generator,
449 discriminator=configs.discriminator)
Start the experiment
452 with experiment.start():
Run the training loop
454 configs.train()
458if __name__ == '__main__':
459 main()