This experiment trains Proximal Policy Optimization (PPO) agent Atari Breakout game on OpenAI Gym. It runs the game environments on multiple processes to sample efficiently.
15from typing import Dict
16
17import numpy as np
18import torch
19from torch import nn
20from torch import optim
21from torch.distributions import Categorical
22
23from labml import monit, tracker, logger, experiment
24from labml.configs import FloatDynamicHyperParam, IntDynamicHyperParam
25from labml_helpers.module import Module
26from labml_nn.rl.game import Worker
27from labml_nn.rl.ppo import ClippedPPOLoss, ClippedValueFunctionLoss
28from labml_nn.rl.ppo.gae import GAE
Select device
31if torch.cuda.is_available():
32 device = torch.device("cuda:0")
33else:
34 device = torch.device("cpu")
37class Model(Module):
42 def __init__(self):
43 super().__init__()
The first convolution layer takes a 84x84 frame and produces a 20x20 frame
47 self.conv1 = nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4)
The second convolution layer takes a 20x20 frame and produces a 9x9 frame
51 self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2)
The third convolution layer takes a 9x9 frame and produces a 7x7 frame
55 self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)
A fully connected layer takes the flattened frame from third convolution layer, and outputs 512 features
60 self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)
A fully connected layer to get logits for
63 self.pi_logits = nn.Linear(in_features=512, out_features=4)
A fully connected layer to get value function
66 self.value = nn.Linear(in_features=512, out_features=1)
69 self.activation = nn.ReLU()
71 def forward(self, obs: torch.Tensor):
72 h = self.activation(self.conv1(obs))
73 h = self.activation(self.conv2(h))
74 h = self.activation(self.conv3(h))
75 h = h.reshape((-1, 7 * 7 * 64))
76
77 h = self.activation(self.lin(h))
78
79 pi = Categorical(logits=self.pi_logits(h))
80 value = self.value(h).reshape(-1)
81
82 return pi, value
Scale observations from [0, 255]
to [0, 1]
85def obs_to_torch(obs: np.ndarray) -> torch.Tensor:
87 return torch.tensor(obs, dtype=torch.float32, device=device) / 255.
90class Trainer:
95 def __init__(self, *,
96 updates: int, epochs: IntDynamicHyperParam,
97 n_workers: int, worker_steps: int, batches: int,
98 value_loss_coef: FloatDynamicHyperParam,
99 entropy_bonus_coef: FloatDynamicHyperParam,
100 clip_range: FloatDynamicHyperParam,
101 learning_rate: FloatDynamicHyperParam,
102 ):
number of updates
106 self.updates = updates
number of epochs to train the model with sampled data
108 self.epochs = epochs
number of worker processes
110 self.n_workers = n_workers
number of steps to run on each process for a single update
112 self.worker_steps = worker_steps
number of mini batches
114 self.batches = batches
total number of samples for a single update
116 self.batch_size = self.n_workers * self.worker_steps
size of a mini batch
118 self.mini_batch_size = self.batch_size // self.batches
119 assert (self.batch_size % self.batches == 0)
Value loss coefficient
122 self.value_loss_coef = value_loss_coef
Entropy bonus coefficient
124 self.entropy_bonus_coef = entropy_bonus_coef
Clipping range
127 self.clip_range = clip_range
Learning rate
129 self.learning_rate = learning_rate
create workers
134 self.workers = [Worker(47 + i) for i in range(self.n_workers)]
initialize tensors for observations
137 self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8)
138 for worker in self.workers:
139 worker.child.send(("reset", None))
140 for i, worker in enumerate(self.workers):
141 self.obs[i] = worker.child.recv()
model
144 self.model = Model().to(device)
optimizer
147 self.optimizer = optim.Adam(self.model.parameters(), lr=2.5e-4)
GAE with and
150 self.gae = GAE(self.n_workers, self.worker_steps, 0.99, 0.95)
PPO Loss
153 self.ppo_loss = ClippedPPOLoss()
Value Loss
156 self.value_loss = ClippedValueFunctionLoss()
158 def sample(self) -> Dict[str, torch.Tensor]:
163 rewards = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
164 actions = np.zeros((self.n_workers, self.worker_steps), dtype=np.int32)
165 done = np.zeros((self.n_workers, self.worker_steps), dtype=np.bool)
166 obs = np.zeros((self.n_workers, self.worker_steps, 4, 84, 84), dtype=np.uint8)
167 log_pis = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
168 values = np.zeros((self.n_workers, self.worker_steps + 1), dtype=np.float32)
169
170 with torch.no_grad():
sample worker_steps
from each worker
172 for t in range(self.worker_steps):
self.obs
keeps track of the last observation from each worker, which is the input for the model to sample the next action
175 obs[:, t] = self.obs
sample actions from for each worker; this returns arrays of size n_workers
178 pi, v = self.model(obs_to_torch(self.obs))
179 values[:, t] = v.cpu().numpy()
180 a = pi.sample()
181 actions[:, t] = a.cpu().numpy()
182 log_pis[:, t] = pi.log_prob(a).cpu().numpy()
run sampled actions on each worker
185 for w, worker in enumerate(self.workers):
186 worker.child.send(("step", actions[w, t]))
187
188 for w, worker in enumerate(self.workers):
get results after executing the actions
190 self.obs[w], rewards[w, t], done[w, t], info = worker.child.recv()
collect episode info, which is available if an episode finished; this includes total reward and length of the episode - look at Game
to see how it works.
195 if info:
196 tracker.add('reward', info['reward'])
197 tracker.add('length', info['length'])
Get value of after the final step
200 _, v = self.model(obs_to_torch(self.obs))
201 values[:, self.worker_steps] = v.cpu().numpy()
calculate advantages
204 advantages = self.gae(done, rewards, values)
207 samples = {
208 'obs': obs,
209 'actions': actions,
210 'values': values[:, :-1],
211 'log_pis': log_pis,
212 'advantages': advantages
213 }
samples are currently in [workers, time_step]
table, we should flatten it for training
217 samples_flat = {}
218 for k, v in samples.items():
219 v = v.reshape(v.shape[0] * v.shape[1], *v.shape[2:])
220 if k == 'obs':
221 samples_flat[k] = obs_to_torch(v)
222 else:
223 samples_flat[k] = torch.tensor(v, device=device)
224
225 return samples_flat
227 def train(self, samples: Dict[str, torch.Tensor]):
It learns faster with a higher number of epochs, but becomes a little unstable; that is, the average episode reward does not monotonically increase over time. May be reducing the clipping range might solve it.
237 for _ in range(self.epochs()):
shuffle for each epoch
239 indexes = torch.randperm(self.batch_size)
for each mini batch
242 for start in range(0, self.batch_size, self.mini_batch_size):
get mini batch
244 end = start + self.mini_batch_size
245 mini_batch_indexes = indexes[start: end]
246 mini_batch = {}
247 for k, v in samples.items():
248 mini_batch[k] = v[mini_batch_indexes]
train
251 loss = self._calc_loss(mini_batch)
Set learning rate
254 for pg in self.optimizer.param_groups:
255 pg['lr'] = self.learning_rate()
Zero out the previously calculated gradients
257 self.optimizer.zero_grad()
Calculate gradients
259 loss.backward()
Clip gradients
261 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)
Update parameters based on gradients
263 self.optimizer.step()
265 @staticmethod
266 def _normalize(adv: torch.Tensor):
268 return (adv - adv.mean()) / (adv.std() + 1e-8)
270 def _calc_loss(self, samples: Dict[str, torch.Tensor]) -> torch.Tensor:
returns sampled from
276 sampled_return = samples['values'] + samples['advantages']
, where is advantages sampled from . Refer to sampling function in Main class below for the calculation of .
282 sampled_normalized_advantage = self._normalize(samples['advantages'])
Sampled observations are fed into the model to get and ; we are treating observations as state
286 pi, value = self.model(samples['obs'])
, are actions sampled from
289 log_pi = pi.log_prob(samples['actions'])
Calculate policy loss
292 policy_loss = self.ppo_loss(log_pi, samples['log_pis'], sampled_normalized_advantage, self.clip_range())
298 entropy_bonus = pi.entropy()
299 entropy_bonus = entropy_bonus.mean()
Calculate value function loss
302 value_loss = self.value_loss(value, samples['values'], sampled_return, self.clip_range())
307 loss = (policy_loss
308 + self.value_loss_coef() * value_loss
309 - self.entropy_bonus_coef() * entropy_bonus)
for monitoring
312 approx_kl_divergence = .5 * ((samples['log_pis'] - log_pi) ** 2).mean()
Add to tracker
315 tracker.add({'policy_reward': -policy_loss,
316 'value_loss': value_loss,
317 'entropy_bonus': entropy_bonus,
318 'kl_div': approx_kl_divergence,
319 'clip_fraction': self.ppo_loss.clip_fraction})
320
321 return loss
323 def run_training_loop(self):
last 100 episode information
329 tracker.set_queue('reward', 100, True)
330 tracker.set_queue('length', 100, True)
331
332 for update in monit.loop(self.updates):
sample with current policy
334 samples = self.sample()
train the model
337 self.train(samples)
Save tracked indicators.
340 tracker.save()
Add a new line to the screen periodically
342 if (update + 1) % 1_000 == 0:
343 logger.log()
345 def destroy(self):
350 for worker in self.workers:
351 worker.child.send(("close", None))
354def main():
Create the experiment
356 experiment.create(name='ppo')
Configurations
358 configs = {
Number of updates
360 'updates': 10000,
⚙️ Number of epochs to train the model with sampled data. You can change this while the experiment is running.
363 'epochs': IntDynamicHyperParam(8),
Number of worker processes
365 'n_workers': 8,
Number of steps to run on each process for a single update
367 'worker_steps': 128,
Number of mini batches
369 'batches': 4,
⚙️ Value loss coefficient. You can change this while the experiment is running.
372 'value_loss_coef': FloatDynamicHyperParam(0.5),
⚙️ Entropy bonus coefficient. You can change this while the experiment is running.
375 'entropy_bonus_coef': FloatDynamicHyperParam(0.01),
⚙️ Clip range.
377 'clip_range': FloatDynamicHyperParam(0.1),
You can change this while the experiment is running. ⚙️ Learning rate.
380 'learning_rate': FloatDynamicHyperParam(1e-3, (0, 1e-3)),
381 }
382
383 experiment.configs(configs)
Initialize the trainer
386 m = Trainer(**configs)
Run and monitor the experiment
389 with experiment.start():
390 m.run_training_loop()
Stop the workers
392 m.destroy()
396if __name__ == "__main__":
397 main()