PPO Experiment with Atari Breakout

This experiment trains Proximal Policy Optimization (PPO) agent Atari Breakout game on OpenAI Gym. It runs the game environments on multiple processes to sample efficiently.

Open In Colab

15from typing import Dict
16
17import numpy as np
18import torch
19from torch import nn
20from torch import optim
21from torch.distributions import Categorical
22
23from labml import monit, tracker, logger, experiment
24from labml.configs import FloatDynamicHyperParam, IntDynamicHyperParam
25from labml_helpers.module import Module
26from labml_nn.rl.game import Worker
27from labml_nn.rl.ppo import ClippedPPOLoss, ClippedValueFunctionLoss
28from labml_nn.rl.ppo.gae import GAE

Select device

31if torch.cuda.is_available():
32    device = torch.device("cuda:0")
33else:
34    device = torch.device("cpu")

Model

37class Model(Module):
42    def __init__(self):
43        super().__init__()

The first convolution layer takes a 84x84 frame and produces a 20x20 frame

47        self.conv1 = nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4)

The second convolution layer takes a 20x20 frame and produces a 9x9 frame

51        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2)

The third convolution layer takes a 9x9 frame and produces a 7x7 frame

55        self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)

A fully connected layer takes the flattened frame from third convolution layer, and outputs 512 features

60        self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)

A fully connected layer to get logits for

63        self.pi_logits = nn.Linear(in_features=512, out_features=4)

A fully connected layer to get value function

66        self.value = nn.Linear(in_features=512, out_features=1)

69        self.activation = nn.ReLU()
71    def forward(self, obs: torch.Tensor):
72        h = self.activation(self.conv1(obs))
73        h = self.activation(self.conv2(h))
74        h = self.activation(self.conv3(h))
75        h = h.reshape((-1, 7 * 7 * 64))
76
77        h = self.activation(self.lin(h))
78
79        pi = Categorical(logits=self.pi_logits(h))
80        value = self.value(h).reshape(-1)
81
82        return pi, value

Scale observations from [0, 255] to [0, 1]

85def obs_to_torch(obs: np.ndarray) -> torch.Tensor:
87    return torch.tensor(obs, dtype=torch.float32, device=device) / 255.

Trainer

90class Trainer:
95    def __init__(self, *,
96                 updates: int, epochs: IntDynamicHyperParam,
97                 n_workers: int, worker_steps: int, batches: int,
98                 value_loss_coef: FloatDynamicHyperParam,
99                 entropy_bonus_coef: FloatDynamicHyperParam,
100                 clip_range: FloatDynamicHyperParam,
101                 learning_rate: FloatDynamicHyperParam,
102                 ):

Configurations

number of updates

106        self.updates = updates

number of epochs to train the model with sampled data

108        self.epochs = epochs

number of worker processes

110        self.n_workers = n_workers

number of steps to run on each process for a single update

112        self.worker_steps = worker_steps

number of mini batches

114        self.batches = batches

total number of samples for a single update

116        self.batch_size = self.n_workers * self.worker_steps

size of a mini batch

118        self.mini_batch_size = self.batch_size // self.batches
119        assert (self.batch_size % self.batches == 0)

Value loss coefficient

122        self.value_loss_coef = value_loss_coef

Entropy bonus coefficient

124        self.entropy_bonus_coef = entropy_bonus_coef

Clipping range

127        self.clip_range = clip_range

Learning rate

129        self.learning_rate = learning_rate

Initialize

create workers

134        self.workers = [Worker(47 + i) for i in range(self.n_workers)]

initialize tensors for observations

137        self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8)
138        for worker in self.workers:
139            worker.child.send(("reset", None))
140        for i, worker in enumerate(self.workers):
141            self.obs[i] = worker.child.recv()

model

144        self.model = Model().to(device)

optimizer

147        self.optimizer = optim.Adam(self.model.parameters(), lr=2.5e-4)

GAE with and

150        self.gae = GAE(self.n_workers, self.worker_steps, 0.99, 0.95)

PPO Loss

153        self.ppo_loss = ClippedPPOLoss()

Value Loss

156        self.value_loss = ClippedValueFunctionLoss()

Sample data with current policy

158    def sample(self) -> Dict[str, torch.Tensor]:
163        rewards = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
164        actions = np.zeros((self.n_workers, self.worker_steps), dtype=np.int32)
165        done = np.zeros((self.n_workers, self.worker_steps), dtype=np.bool)
166        obs = np.zeros((self.n_workers, self.worker_steps, 4, 84, 84), dtype=np.uint8)
167        log_pis = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
168        values = np.zeros((self.n_workers, self.worker_steps + 1), dtype=np.float32)
169
170        with torch.no_grad():

sample worker_steps from each worker

172            for t in range(self.worker_steps):

self.obs keeps track of the last observation from each worker, which is the input for the model to sample the next action

175                obs[:, t] = self.obs

sample actions from for each worker; this returns arrays of size n_workers

178                pi, v = self.model(obs_to_torch(self.obs))
179                values[:, t] = v.cpu().numpy()
180                a = pi.sample()
181                actions[:, t] = a.cpu().numpy()
182                log_pis[:, t] = pi.log_prob(a).cpu().numpy()

run sampled actions on each worker

185                for w, worker in enumerate(self.workers):
186                    worker.child.send(("step", actions[w, t]))
187
188                for w, worker in enumerate(self.workers):

get results after executing the actions

190                    self.obs[w], rewards[w, t], done[w, t], info = worker.child.recv()

collect episode info, which is available if an episode finished; this includes total reward and length of the episode - look at Game to see how it works.

195                    if info:
196                        tracker.add('reward', info['reward'])
197                        tracker.add('length', info['length'])

Get value of after the final step

200            _, v = self.model(obs_to_torch(self.obs))
201            values[:, self.worker_steps] = v.cpu().numpy()

calculate advantages

204        advantages = self.gae(done, rewards, values)

207        samples = {
208            'obs': obs,
209            'actions': actions,
210            'values': values[:, :-1],
211            'log_pis': log_pis,
212            'advantages': advantages
213        }

samples are currently in [workers, time_step] table, we should flatten it for training

217        samples_flat = {}
218        for k, v in samples.items():
219            v = v.reshape(v.shape[0] * v.shape[1], *v.shape[2:])
220            if k == 'obs':
221                samples_flat[k] = obs_to_torch(v)
222            else:
223                samples_flat[k] = torch.tensor(v, device=device)
224
225        return samples_flat

Train the model based on samples

227    def train(self, samples: Dict[str, torch.Tensor]):

It learns faster with a higher number of epochs, but becomes a little unstable; that is, the average episode reward does not monotonically increase over time. May be reducing the clipping range might solve it.

237        for _ in range(self.epochs()):

shuffle for each epoch

239            indexes = torch.randperm(self.batch_size)

for each mini batch

242            for start in range(0, self.batch_size, self.mini_batch_size):

get mini batch

244                end = start + self.mini_batch_size
245                mini_batch_indexes = indexes[start: end]
246                mini_batch = {}
247                for k, v in samples.items():
248                    mini_batch[k] = v[mini_batch_indexes]

train

251                loss = self._calc_loss(mini_batch)

Set learning rate

254                for pg in self.optimizer.param_groups:
255                    pg['lr'] = self.learning_rate()

Zero out the previously calculated gradients

257                self.optimizer.zero_grad()

Calculate gradients

259                loss.backward()

Clip gradients

261                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)

Update parameters based on gradients

263                self.optimizer.step()

Normalize advantage function

265    @staticmethod
266    def _normalize(adv: torch.Tensor):
268        return (adv - adv.mean()) / (adv.std() + 1e-8)

Calculate total loss

270    def _calc_loss(self, samples: Dict[str, torch.Tensor]) -> torch.Tensor:

returns sampled from

276        sampled_return = samples['values'] + samples['advantages']

, where is advantages sampled from . Refer to sampling function in Main class below for the calculation of .

282        sampled_normalized_advantage = self._normalize(samples['advantages'])

Sampled observations are fed into the model to get and ; we are treating observations as state

286        pi, value = self.model(samples['obs'])

, are actions sampled from

289        log_pi = pi.log_prob(samples['actions'])

Calculate policy loss

292        policy_loss = self.ppo_loss(log_pi, samples['log_pis'], sampled_normalized_advantage, self.clip_range())

Calculate Entropy Bonus

298        entropy_bonus = pi.entropy()
299        entropy_bonus = entropy_bonus.mean()

Calculate value function loss

302        value_loss = self.value_loss(value, samples['values'], sampled_return, self.clip_range())

307        loss = (policy_loss
308                + self.value_loss_coef() * value_loss
309                - self.entropy_bonus_coef() * entropy_bonus)

for monitoring

312        approx_kl_divergence = .5 * ((samples['log_pis'] - log_pi) ** 2).mean()

Add to tracker

315        tracker.add({'policy_reward': -policy_loss,
316                     'value_loss': value_loss,
317                     'entropy_bonus': entropy_bonus,
318                     'kl_div': approx_kl_divergence,
319                     'clip_fraction': self.ppo_loss.clip_fraction})
320
321        return loss

Run training loop

323    def run_training_loop(self):

last 100 episode information

329        tracker.set_queue('reward', 100, True)
330        tracker.set_queue('length', 100, True)
331
332        for update in monit.loop(self.updates):

sample with current policy

334            samples = self.sample()

train the model

337            self.train(samples)

Save tracked indicators.

340            tracker.save()

Add a new line to the screen periodically

342            if (update + 1) % 1_000 == 0:
343                logger.log()

Destroy

Stop the workers

345    def destroy(self):
350        for worker in self.workers:
351            worker.child.send(("close", None))
354def main():

Create the experiment

356    experiment.create(name='ppo')

Configurations

358    configs = {

Number of updates

360        'updates': 10000,

⚙️ Number of epochs to train the model with sampled data. You can change this while the experiment is running.

363        'epochs': IntDynamicHyperParam(8),

Number of worker processes

365        'n_workers': 8,

Number of steps to run on each process for a single update

367        'worker_steps': 128,

Number of mini batches

369        'batches': 4,

⚙️ Value loss coefficient. You can change this while the experiment is running.

372        'value_loss_coef': FloatDynamicHyperParam(0.5),

⚙️ Entropy bonus coefficient. You can change this while the experiment is running.

375        'entropy_bonus_coef': FloatDynamicHyperParam(0.01),

⚙️ Clip range.

377        'clip_range': FloatDynamicHyperParam(0.1),

You can change this while the experiment is running. ⚙️ Learning rate.

380        'learning_rate': FloatDynamicHyperParam(1e-3, (0, 1e-3)),
381    }
382
383    experiment.configs(configs)

Initialize the trainer

386    m = Trainer(**configs)

Run and monitor the experiment

389    with experiment.start():
390        m.run_training_loop()

Stop the workers

392    m.destroy()

Run it

396if __name__ == "__main__":
397    main()