PPO Experiment with Atari Breakout

This experiment trains Proximal Policy Optimization (PPO) agent Atari Breakout game on OpenAI Gym. It runs the game environments on multiple processes to sample efficiently.

Open In Colab

15from typing import Dict
16
17import numpy as np
18import torch
19from torch import nn
20from torch import optim
21from torch.distributions import Categorical
22
23from labml import monit, tracker, logger, experiment
24from labml.configs import FloatDynamicHyperParam, IntDynamicHyperParam
25from labml_nn.rl.game import Worker
26from labml_nn.rl.ppo import ClippedPPOLoss, ClippedValueFunctionLoss
27from labml_nn.rl.ppo.gae import GAE

Select device

30if torch.cuda.is_available():
31    device = torch.device("cuda:0")
32else:
33    device = torch.device("cpu")

Model

36class Model(nn.Module):
41    def __init__(self):
42        super().__init__()

The first convolution layer takes a 84x84 frame and produces a 20x20 frame

46        self.conv1 = nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4)

The second convolution layer takes a 20x20 frame and produces a 9x9 frame

50        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2)

The third convolution layer takes a 9x9 frame and produces a 7x7 frame

54        self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)

A fully connected layer takes the flattened frame from third convolution layer, and outputs 512 features

59        self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)

A fully connected layer to get logits for

62        self.pi_logits = nn.Linear(in_features=512, out_features=4)

A fully connected layer to get value function

65        self.value = nn.Linear(in_features=512, out_features=1)

68        self.activation = nn.ReLU()
70    def forward(self, obs: torch.Tensor):
71        h = self.activation(self.conv1(obs))
72        h = self.activation(self.conv2(h))
73        h = self.activation(self.conv3(h))
74        h = h.reshape((-1, 7 * 7 * 64))
75
76        h = self.activation(self.lin(h))
77
78        pi = Categorical(logits=self.pi_logits(h))
79        value = self.value(h).reshape(-1)
80
81        return pi, value

Scale observations from [0, 255] to [0, 1]

84def obs_to_torch(obs: np.ndarray) -> torch.Tensor:
86    return torch.tensor(obs, dtype=torch.float32, device=device) / 255.

Trainer

89class Trainer:
94    def __init__(self, *,
95                 updates: int, epochs: IntDynamicHyperParam,
96                 n_workers: int, worker_steps: int, batches: int,
97                 value_loss_coef: FloatDynamicHyperParam,
98                 entropy_bonus_coef: FloatDynamicHyperParam,
99                 clip_range: FloatDynamicHyperParam,
100                 learning_rate: FloatDynamicHyperParam,
101                 ):

Configurations

number of updates

105        self.updates = updates

number of epochs to train the model with sampled data

107        self.epochs = epochs

number of worker processes

109        self.n_workers = n_workers

number of steps to run on each process for a single update

111        self.worker_steps = worker_steps

number of mini batches

113        self.batches = batches

total number of samples for a single update

115        self.batch_size = self.n_workers * self.worker_steps

size of a mini batch

117        self.mini_batch_size = self.batch_size // self.batches
118        assert (self.batch_size % self.batches == 0)

Value loss coefficient

121        self.value_loss_coef = value_loss_coef

Entropy bonus coefficient

123        self.entropy_bonus_coef = entropy_bonus_coef

Clipping range

126        self.clip_range = clip_range

Learning rate

128        self.learning_rate = learning_rate

Initialize

create workers

133        self.workers = [Worker(47 + i) for i in range(self.n_workers)]

initialize tensors for observations

136        self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8)
137        for worker in self.workers:
138            worker.child.send(("reset", None))
139        for i, worker in enumerate(self.workers):
140            self.obs[i] = worker.child.recv()

model

143        self.model = Model().to(device)

optimizer

146        self.optimizer = optim.Adam(self.model.parameters(), lr=2.5e-4)

GAE with and

149        self.gae = GAE(self.n_workers, self.worker_steps, 0.99, 0.95)

PPO Loss

152        self.ppo_loss = ClippedPPOLoss()

Value Loss

155        self.value_loss = ClippedValueFunctionLoss()

Sample data with current policy

157    def sample(self) -> Dict[str, torch.Tensor]:
162        rewards = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
163        actions = np.zeros((self.n_workers, self.worker_steps), dtype=np.int32)
164        done = np.zeros((self.n_workers, self.worker_steps), dtype=np.bool)
165        obs = np.zeros((self.n_workers, self.worker_steps, 4, 84, 84), dtype=np.uint8)
166        log_pis = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
167        values = np.zeros((self.n_workers, self.worker_steps + 1), dtype=np.float32)
168
169        with torch.no_grad():

sample worker_steps from each worker

171            for t in range(self.worker_steps):

self.obs keeps track of the last observation from each worker, which is the input for the model to sample the next action

174                obs[:, t] = self.obs

sample actions from for each worker; this returns arrays of size n_workers

177                pi, v = self.model(obs_to_torch(self.obs))
178                values[:, t] = v.cpu().numpy()
179                a = pi.sample()
180                actions[:, t] = a.cpu().numpy()
181                log_pis[:, t] = pi.log_prob(a).cpu().numpy()

run sampled actions on each worker

184                for w, worker in enumerate(self.workers):
185                    worker.child.send(("step", actions[w, t]))
186
187                for w, worker in enumerate(self.workers):

get results after executing the actions

189                    self.obs[w], rewards[w, t], done[w, t], info = worker.child.recv()

collect episode info, which is available if an episode finished; this includes total reward and length of the episode - look at Game to see how it works.

194                    if info:
195                        tracker.add('reward', info['reward'])
196                        tracker.add('length', info['length'])

Get value of after the final step

199            _, v = self.model(obs_to_torch(self.obs))
200            values[:, self.worker_steps] = v.cpu().numpy()

calculate advantages

203        advantages = self.gae(done, rewards, values)

206        samples = {
207            'obs': obs,
208            'actions': actions,
209            'values': values[:, :-1],
210            'log_pis': log_pis,
211            'advantages': advantages
212        }

samples are currently in [workers, time_step] table, we should flatten it for training

216        samples_flat = {}
217        for k, v in samples.items():
218            v = v.reshape(v.shape[0] * v.shape[1], *v.shape[2:])
219            if k == 'obs':
220                samples_flat[k] = obs_to_torch(v)
221            else:
222                samples_flat[k] = torch.tensor(v, device=device)
223
224        return samples_flat

Train the model based on samples

226    def train(self, samples: Dict[str, torch.Tensor]):

It learns faster with a higher number of epochs, but becomes a little unstable; that is, the average episode reward does not monotonically increase over time. May be reducing the clipping range might solve it.

236        for _ in range(self.epochs()):

shuffle for each epoch

238            indexes = torch.randperm(self.batch_size)

for each mini batch

241            for start in range(0, self.batch_size, self.mini_batch_size):

get mini batch

243                end = start + self.mini_batch_size
244                mini_batch_indexes = indexes[start: end]
245                mini_batch = {}
246                for k, v in samples.items():
247                    mini_batch[k] = v[mini_batch_indexes]

train

250                loss = self._calc_loss(mini_batch)

Set learning rate

253                for pg in self.optimizer.param_groups:
254                    pg['lr'] = self.learning_rate()

Zero out the previously calculated gradients

256                self.optimizer.zero_grad()

Calculate gradients

258                loss.backward()

Clip gradients

260                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)

Update parameters based on gradients

262                self.optimizer.step()

Normalize advantage function

264    @staticmethod
265    def _normalize(adv: torch.Tensor):
267        return (adv - adv.mean()) / (adv.std() + 1e-8)

Calculate total loss

269    def _calc_loss(self, samples: Dict[str, torch.Tensor]) -> torch.Tensor:

returns sampled from

275        sampled_return = samples['values'] + samples['advantages']

, where is advantages sampled from . Refer to sampling function in Main class below for the calculation of .

281        sampled_normalized_advantage = self._normalize(samples['advantages'])

Sampled observations are fed into the model to get and ; we are treating observations as state

285        pi, value = self.model(samples['obs'])

, are actions sampled from

288        log_pi = pi.log_prob(samples['actions'])

Calculate policy loss

291        policy_loss = self.ppo_loss(log_pi, samples['log_pis'], sampled_normalized_advantage, self.clip_range())

Calculate Entropy Bonus

297        entropy_bonus = pi.entropy()
298        entropy_bonus = entropy_bonus.mean()

Calculate value function loss

301        value_loss = self.value_loss(value, samples['values'], sampled_return, self.clip_range())

306        loss = (policy_loss
307                + self.value_loss_coef() * value_loss
308                - self.entropy_bonus_coef() * entropy_bonus)

for monitoring

311        approx_kl_divergence = .5 * ((samples['log_pis'] - log_pi) ** 2).mean()

Add to tracker

314        tracker.add({'policy_reward': -policy_loss,
315                     'value_loss': value_loss,
316                     'entropy_bonus': entropy_bonus,
317                     'kl_div': approx_kl_divergence,
318                     'clip_fraction': self.ppo_loss.clip_fraction})
319
320        return loss

Run training loop

322    def run_training_loop(self):

last 100 episode information

328        tracker.set_queue('reward', 100, True)
329        tracker.set_queue('length', 100, True)
330
331        for update in monit.loop(self.updates):

sample with current policy

333            samples = self.sample()

train the model

336            self.train(samples)

Save tracked indicators.

339            tracker.save()

Add a new line to the screen periodically

341            if (update + 1) % 1_000 == 0:
342                logger.log()

Destroy

Stop the workers

344    def destroy(self):
349        for worker in self.workers:
350            worker.child.send(("close", None))
353def main():

Create the experiment

355    experiment.create(name='ppo')

Configurations

357    configs = {

Number of updates

359        'updates': 10000,

⚙️ Number of epochs to train the model with sampled data. You can change this while the experiment is running.

362        'epochs': IntDynamicHyperParam(8),

Number of worker processes

364        'n_workers': 8,

Number of steps to run on each process for a single update

366        'worker_steps': 128,

Number of mini batches

368        'batches': 4,

⚙️ Value loss coefficient. You can change this while the experiment is running.

371        'value_loss_coef': FloatDynamicHyperParam(0.5),

⚙️ Entropy bonus coefficient. You can change this while the experiment is running.

374        'entropy_bonus_coef': FloatDynamicHyperParam(0.01),

⚙️ Clip range.

376        'clip_range': FloatDynamicHyperParam(0.1),

You can change this while the experiment is running. ⚙️ Learning rate.

379        'learning_rate': FloatDynamicHyperParam(1e-3, (0, 1e-3)),
380    }
381
382    experiment.configs(configs)

Initialize the trainer

385    m = Trainer(**configs)

Run and monitor the experiment

388    with experiment.start():
389        m.run_training_loop()

Stop the workers

391    m.destroy()

Run it

395if __name__ == "__main__":
396    main()