This is the training code for StyleGAN 2 model.
These are images generated after training for about 80K steps.
Our implementation is a minimalistic StyleGAN 2 model training code. Only single GPU training is supported to keep the implementation simple. We managed to shrink it to keep it at less than 500 lines of code, including the training loop.
Without DDP (distributed data parallel) and multi-gpu training it will not be possible to train the model for large resolutions (128+). If you want training code with fp16 and DDP take a look at lucidrains/stylegan2-pytorch.
We trained this on CelebA-HQ dataset. You can find the download instruction in this discussion on fast.ai. Save the images inside data/stylegan
folder.
31import math
32from pathlib import Path
33from typing import Iterator, Tuple
34
35import torch
36import torch.utils.data
37import torchvision
38from PIL import Image
39
40from labml import tracker, lab, monit, experiment
41from labml.configs import BaseConfigs
42from labml_helpers.device import DeviceConfigs
43from labml_helpers.train_valid import ModeState, hook_model_outputs
44from labml_nn.gan.stylegan import Discriminator, Generator, MappingNetwork, GradientPenalty, PathLengthPenalty
45from labml_nn.gan.wasserstein import DiscriminatorLoss, GeneratorLoss
46from labml_nn.utils import cycle_dataloader
49class Dataset(torch.utils.data.Dataset):
path
path to the folder containing the images image_size
size of the image56 def __init__(self, path: str, image_size: int):
61 super().__init__()
Get the paths of all jpg
files
64 self.paths = [p for p in Path(path).glob(f'**/*.jpg')]
Transformation
67 self.transform = torchvision.transforms.Compose([
Resize the image
69 torchvision.transforms.Resize(image_size),
Convert to PyTorch tensor
71 torchvision.transforms.ToTensor(),
72 ])
Number of images
74 def __len__(self):
76 return len(self.paths)
Get the the index
-th image
78 def __getitem__(self, index):
80 path = self.paths[index]
81 img = Image.open(path)
82 return self.transform(img)
85class Configs(BaseConfigs):
Device to train the model on. DeviceConfigs
picks up an available CUDA device or defaults to CPU.
93 device: torch.device = DeviceConfigs()
96 discriminator: Discriminator
98 generator: Generator
100 mapping_network: MappingNetwork
Discriminator and generator loss functions. We use Wasserstein loss
104 discriminator_loss: DiscriminatorLoss
105 generator_loss: GeneratorLoss
Optimizers
108 generator_optimizer: torch.optim.Adam
109 discriminator_optimizer: torch.optim.Adam
110 mapping_network_optimizer: torch.optim.Adam
113 gradient_penalty = GradientPenalty()
Gradient penalty coefficient
115 gradient_penalty_coefficient: float = 10.
118 path_length_penalty: PathLengthPenalty
Data loader
121 loader: Iterator
Batch size
124 batch_size: int = 32
Dimensionality of and
126 d_latent: int = 512
Height/width of the image
128 image_size: int = 32
Number of layers in the mapping network
130 mapping_network_layers: int = 8
Generator & Discriminator learning rate
132 learning_rate: float = 1e-3
Mapping network learning rate ( lower than the others)
134 mapping_network_learning_rate: float = 1e-5
Number of steps to accumulate gradients on. Use this to increase the effective batch size.
136 gradient_accumulate_steps: int = 1
and for Adam optimizer
138 adam_betas: Tuple[float, float] = (0.0, 0.99)
Probability of mixing styles
140 style_mixing_prob: float = 0.9
Total number of training steps
143 training_steps: int = 150_000
Number of blocks in the generator (calculated based on image resolution)
146 n_gen_blocks: int
Instead of calculating the regularization losses, the paper proposes lazy regularization where the regularization terms are calculated once in a while. This improves the training efficiency a lot.
The interval at which to compute gradient penalty
154 lazy_gradient_penalty_interval: int = 4
Path length penalty calculation interval
156 lazy_path_penalty_interval: int = 32
Skip calculating path length penalty during the initial phase of training
158 lazy_path_penalty_after: int = 5_000
How often to log generated images
161 log_generated_interval: int = 500
How often to save model checkpoints
163 save_checkpoint_interval: int = 2_000
Training mode state for logging activations
166 mode: ModeState
Whether to log model layer outputs
168 log_layer_outputs: bool = False
We trained this on CelebA-HQ dataset. You can find the download instruction in this discussion on fast.ai. Save the images inside data/stylegan
folder.
175 dataset_path: str = str(lab.get_data_path() / 'stylegan2')
177 def init(self):
Create dataset
182 dataset = Dataset(self.dataset_path, self.image_size)
Create data loader
184 dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, num_workers=8,
185 shuffle=True, drop_last=True, pin_memory=True)
Continuous cyclic loader
187 self.loader = cycle_dataloader(dataloader)
of image resolution
190 log_resolution = int(math.log2(self.image_size))
Create discriminator and generator
193 self.discriminator = Discriminator(log_resolution).to(self.device)
194 self.generator = Generator(log_resolution, self.d_latent).to(self.device)
Get number of generator blocks for creating style and noise inputs
196 self.n_gen_blocks = self.generator.n_blocks
Create mapping network
198 self.mapping_network = MappingNetwork(self.d_latent, self.mapping_network_layers).to(self.device)
Create path length penalty loss
200 self.path_length_penalty = PathLengthPenalty(0.99).to(self.device)
Add model hooks to monitor layer outputs
203 if self.log_layer_outputs:
204 hook_model_outputs(self.mode, self.discriminator, 'discriminator')
205 hook_model_outputs(self.mode, self.generator, 'generator')
206 hook_model_outputs(self.mode, self.mapping_network, 'mapping_network')
Discriminator and generator losses
209 self.discriminator_loss = DiscriminatorLoss().to(self.device)
210 self.generator_loss = GeneratorLoss().to(self.device)
Create optimizers
213 self.discriminator_optimizer = torch.optim.Adam(
214 self.discriminator.parameters(),
215 lr=self.learning_rate, betas=self.adam_betas
216 )
217 self.generator_optimizer = torch.optim.Adam(
218 self.generator.parameters(),
219 lr=self.learning_rate, betas=self.adam_betas
220 )
221 self.mapping_network_optimizer = torch.optim.Adam(
222 self.mapping_network.parameters(),
223 lr=self.mapping_network_learning_rate, betas=self.adam_betas
224 )
Set tracker configurations
227 tracker.set_image("generated", True)
This samples randomly and get from the mapping network.
We also apply style mixing sometimes where we generate two latent variables and and get corresponding and . Then we randomly sample a cross-over point and apply to the generator blocks before the cross-over point and to the blocks after.
229 def get_w(self, batch_size: int):
Mix styles
243 if torch.rand(()).item() < self.style_mixing_prob:
Random cross-over point
245 cross_over_point = int(torch.rand(()).item() * self.n_gen_blocks)
Sample and
247 z2 = torch.randn(batch_size, self.d_latent).to(self.device)
248 z1 = torch.randn(batch_size, self.d_latent).to(self.device)
Get and
250 w1 = self.mapping_network(z1)
251 w2 = self.mapping_network(z2)
Expand and for the generator blocks and concatenate
253 w1 = w1[None, :, :].expand(cross_over_point, -1, -1)
254 w2 = w2[None, :, :].expand(self.n_gen_blocks - cross_over_point, -1, -1)
255 return torch.cat((w1, w2), dim=0)
Without mixing
257 else:
Sample and
259 z = torch.randn(batch_size, self.d_latent).to(self.device)
Get and
261 w = self.mapping_network(z)
Expand for the generator blocks
263 return w[None, :, :].expand(self.n_gen_blocks, -1, -1)
265 def get_noise(self, batch_size: int):
List to store noise
272 noise = []
Noise resolution starts from
274 resolution = 4
Generate noise for each generator block
277 for i in range(self.n_gen_blocks):
The first block has only one convolution
279 if i == 0:
280 n1 = None
Generate noise to add after the first convolution layer
282 else:
283 n1 = torch.randn(batch_size, 1, resolution, resolution, device=self.device)
Generate noise to add after the second convolution layer
285 n2 = torch.randn(batch_size, 1, resolution, resolution, device=self.device)
Add noise tensors to the list
288 noise.append((n1, n2))
Next block has resolution
291 resolution *= 2
Return noise tensors
294 return noise
296 def generate_images(self, batch_size: int):
Get
304 w = self.get_w(batch_size)
Get noise
306 noise = self.get_noise(batch_size)
Generate images
309 images = self.generator(w, noise)
Return images and
312 return images, w
314 def step(self, idx: int):
Train the discriminator
320 with monit.section('Discriminator'):
Reset gradients
322 self.discriminator_optimizer.zero_grad()
Accumulate gradients for gradient_accumulate_steps
325 for i in range(self.gradient_accumulate_steps):
Update mode
. Set whether to log activation
327 with self.mode.update(is_log_activations=(idx + 1) % self.log_generated_interval == 0):
Sample images from generator
329 generated_images, _ = self.generate_images(self.batch_size)
Discriminator classification for generated images
331 fake_output = self.discriminator(generated_images.detach())
Get real images from the data loader
334 real_images = next(self.loader).to(self.device)
We need to calculate gradients w.r.t. real images for gradient penalty
336 if (idx + 1) % self.lazy_gradient_penalty_interval == 0:
337 real_images.requires_grad_()
Discriminator classification for real images
339 real_output = self.discriminator(real_images)
Get discriminator loss
342 real_loss, fake_loss = self.discriminator_loss(real_output, fake_output)
343 disc_loss = real_loss + fake_loss
Add gradient penalty
346 if (idx + 1) % self.lazy_gradient_penalty_interval == 0:
Calculate and log gradient penalty
348 gp = self.gradient_penalty(real_images, real_output)
349 tracker.add('loss.gp', gp)
Multiply by coefficient and add gradient penalty
351 disc_loss = disc_loss + 0.5 * self.gradient_penalty_coefficient * gp * self.lazy_gradient_penalty_interval
Compute gradients
354 disc_loss.backward()
Log discriminator loss
357 tracker.add('loss.discriminator', disc_loss)
358
359 if (idx + 1) % self.log_generated_interval == 0:
Log discriminator model parameters occasionally
361 tracker.add('discriminator', self.discriminator)
Clip gradients for stabilization
364 torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), max_norm=1.0)
Take optimizer step
366 self.discriminator_optimizer.step()
Train the generator
369 with monit.section('Generator'):
Reset gradients
371 self.generator_optimizer.zero_grad()
372 self.mapping_network_optimizer.zero_grad()
Accumulate gradients for gradient_accumulate_steps
375 for i in range(self.gradient_accumulate_steps):
Sample images from generator
377 generated_images, w = self.generate_images(self.batch_size)
Discriminator classification for generated images
379 fake_output = self.discriminator(generated_images)
Get generator loss
382 gen_loss = self.generator_loss(fake_output)
Add path length penalty
385 if idx > self.lazy_path_penalty_after and (idx + 1) % self.lazy_path_penalty_interval == 0:
Calculate path length penalty
387 plp = self.path_length_penalty(w, generated_images)
Ignore if nan
389 if not torch.isnan(plp):
390 tracker.add('loss.plp', plp)
391 gen_loss = gen_loss + plp
Calculate gradients
394 gen_loss.backward()
Log generator loss
397 tracker.add('loss.generator', gen_loss)
398
399 if (idx + 1) % self.log_generated_interval == 0:
Log discriminator model parameters occasionally
401 tracker.add('generator', self.generator)
402 tracker.add('mapping_network', self.mapping_network)
Clip gradients for stabilization
405 torch.nn.utils.clip_grad_norm_(self.generator.parameters(), max_norm=1.0)
406 torch.nn.utils.clip_grad_norm_(self.mapping_network.parameters(), max_norm=1.0)
Take optimizer step
409 self.generator_optimizer.step()
410 self.mapping_network_optimizer.step()
Log generated images
413 if (idx + 1) % self.log_generated_interval == 0:
414 tracker.add('generated', torch.cat([generated_images[:6], real_images[:3]], dim=0))
Save model checkpoints
416 if (idx + 1) % self.save_checkpoint_interval == 0:
417 experiment.save_checkpoint()
Flush tracker
420 tracker.save()
422 def train(self):
Loop for training_steps
428 for i in monit.loop(self.training_steps):
Take a training step
430 self.step(i)
432 if (i + 1) % self.log_generated_interval == 0:
433 tracker.new_line()
436def main():
Create an experiment
442 experiment.create(name='stylegan2')
Create configurations object
444 configs = Configs()
Set configurations and override some
447 experiment.configs(configs, {
448 'device.cuda_device': 0,
449 'image_size': 64,
450 'log_generated_interval': 200
451 })
Initialize
454 configs.init()
Set models for saving and loading
456 experiment.add_pytorch_models(mapping_network=configs.mapping_network,
457 generator=configs.generator,
458 discriminator=configs.discriminator)
Start the experiment
461 with experiment.start():
Run the training loop
463 configs.train()
467if __name__ == '__main__':
468 main()