PPO Experiment with Atari Breakout

This experiment trains Proximal Policy Optimization (PPO) agent Atari Breakout game on OpenAI Gym. It runs the game environments on multiple processes to sample efficiently.

Open In Colab View Run

16from typing import Dict
17
18import numpy as np
19import torch
20from torch import nn
21from torch import optim
22from torch.distributions import Categorical
23
24from labml import monit, tracker, logger, experiment
25from labml.configs import FloatDynamicHyperParam, IntDynamicHyperParam
26from labml_helpers.module import Module
27from labml_nn.rl.game import Worker
28from labml_nn.rl.ppo import ClippedPPOLoss, ClippedValueFunctionLoss
29from labml_nn.rl.ppo.gae import GAE

Select device

32if torch.cuda.is_available():
33    device = torch.device("cuda:0")
34else:
35    device = torch.device("cpu")

Model

38class Model(Module):
43    def __init__(self):
44        super().__init__()

The first convolution layer takes a 84x84 frame and produces a 20x20 frame

48        self.conv1 = nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4)

The second convolution layer takes a 20x20 frame and produces a 9x9 frame

52        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2)

The third convolution layer takes a 9x9 frame and produces a 7x7 frame

56        self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)

A fully connected layer takes the flattened frame from third convolution layer, and outputs 512 features

61        self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)

A fully connected layer to get logits for

64        self.pi_logits = nn.Linear(in_features=512, out_features=4)

A fully connected layer to get value function

67        self.value = nn.Linear(in_features=512, out_features=1)

70        self.activation = nn.ReLU()
72    def forward(self, obs: torch.Tensor):
73        h = self.activation(self.conv1(obs))
74        h = self.activation(self.conv2(h))
75        h = self.activation(self.conv3(h))
76        h = h.reshape((-1, 7 * 7 * 64))
77
78        h = self.activation(self.lin(h))
79
80        pi = Categorical(logits=self.pi_logits(h))
81        value = self.value(h).reshape(-1)
82
83        return pi, value

Scale observations from [0, 255] to [0, 1]

86def obs_to_torch(obs: np.ndarray) -> torch.Tensor:
88    return torch.tensor(obs, dtype=torch.float32, device=device) / 255.

Trainer

91class Trainer:
96    def __init__(self, *,
97                 updates: int, epochs: IntDynamicHyperParam,
98                 n_workers: int, worker_steps: int, batches: int,
99                 value_loss_coef: FloatDynamicHyperParam,
100                 entropy_bonus_coef: FloatDynamicHyperParam,
101                 clip_range: FloatDynamicHyperParam,
102                 learning_rate: FloatDynamicHyperParam,
103                 ):

Configurations

number of updates

107        self.updates = updates

number of epochs to train the model with sampled data

109        self.epochs = epochs

number of worker processes

111        self.n_workers = n_workers

number of steps to run on each process for a single update

113        self.worker_steps = worker_steps

number of mini batches

115        self.batches = batches

total number of samples for a single update

117        self.batch_size = self.n_workers * self.worker_steps

size of a mini batch

119        self.mini_batch_size = self.batch_size // self.batches
120        assert (self.batch_size % self.batches == 0)

Value loss coefficient

123        self.value_loss_coef = value_loss_coef

Entropy bonus coefficient

125        self.entropy_bonus_coef = entropy_bonus_coef

Clipping range

128        self.clip_range = clip_range

Learning rate

130        self.learning_rate = learning_rate

Initialize

create workers

135        self.workers = [Worker(47 + i) for i in range(self.n_workers)]

initialize tensors for observations

138        self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8)
139        for worker in self.workers:
140            worker.child.send(("reset", None))
141        for i, worker in enumerate(self.workers):
142            self.obs[i] = worker.child.recv()

model

145        self.model = Model().to(device)

optimizer

148        self.optimizer = optim.Adam(self.model.parameters(), lr=2.5e-4)

GAE with and

151        self.gae = GAE(self.n_workers, self.worker_steps, 0.99, 0.95)

PPO Loss

154        self.ppo_loss = ClippedPPOLoss()

Value Loss

157        self.value_loss = ClippedValueFunctionLoss()

Sample data with current policy

159    def sample(self) -> Dict[str, torch.Tensor]:
164        rewards = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
165        actions = np.zeros((self.n_workers, self.worker_steps), dtype=np.int32)
166        done = np.zeros((self.n_workers, self.worker_steps), dtype=np.bool)
167        obs = np.zeros((self.n_workers, self.worker_steps, 4, 84, 84), dtype=np.uint8)
168        log_pis = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
169        values = np.zeros((self.n_workers, self.worker_steps + 1), dtype=np.float32)
170
171        with torch.no_grad():

sample worker_steps from each worker

173            for t in range(self.worker_steps):

self.obs keeps track of the last observation from each worker, which is the input for the model to sample the next action

176                obs[:, t] = self.obs

sample actions from for each worker; this returns arrays of size n_workers

179                pi, v = self.model(obs_to_torch(self.obs))
180                values[:, t] = v.cpu().numpy()
181                a = pi.sample()
182                actions[:, t] = a.cpu().numpy()
183                log_pis[:, t] = pi.log_prob(a).cpu().numpy()

run sampled actions on each worker

186                for w, worker in enumerate(self.workers):
187                    worker.child.send(("step", actions[w, t]))
188
189                for w, worker in enumerate(self.workers):

get results after executing the actions

191                    self.obs[w], rewards[w, t], done[w, t], info = worker.child.recv()

collect episode info, which is available if an episode finished; this includes total reward and length of the episode - look at Game to see how it works.

196                    if info:
197                        tracker.add('reward', info['reward'])
198                        tracker.add('length', info['length'])

Get value of after the final step

201            _, v = self.model(obs_to_torch(self.obs))
202            values[:, self.worker_steps] = v.cpu().numpy()

calculate advantages

205        advantages = self.gae(done, rewards, values)

208        samples = {
209            'obs': obs,
210            'actions': actions,
211            'values': values[:, :-1],
212            'log_pis': log_pis,
213            'advantages': advantages
214        }

samples are currently in [workers, time_step] table, we should flatten it for training

218        samples_flat = {}
219        for k, v in samples.items():
220            v = v.reshape(v.shape[0] * v.shape[1], *v.shape[2:])
221            if k == 'obs':
222                samples_flat[k] = obs_to_torch(v)
223            else:
224                samples_flat[k] = torch.tensor(v, device=device)
225
226        return samples_flat

Train the model based on samples

228    def train(self, samples: Dict[str, torch.Tensor]):

It learns faster with a higher number of epochs, but becomes a little unstable; that is, the average episode reward does not monotonically increase over time. May be reducing the clipping range might solve it.

238        for _ in range(self.epochs()):

shuffle for each epoch

240            indexes = torch.randperm(self.batch_size)

for each mini batch

243            for start in range(0, self.batch_size, self.mini_batch_size):

get mini batch

245                end = start + self.mini_batch_size
246                mini_batch_indexes = indexes[start: end]
247                mini_batch = {}
248                for k, v in samples.items():
249                    mini_batch[k] = v[mini_batch_indexes]

train

252                loss = self._calc_loss(mini_batch)

Set learning rate

255                for pg in self.optimizer.param_groups:
256                    pg['lr'] = self.learning_rate()

Zero out the previously calculated gradients

258                self.optimizer.zero_grad()

Calculate gradients

260                loss.backward()

Clip gradients

262                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)

Update parameters based on gradients

264                self.optimizer.step()

Normalize advantage function

266    @staticmethod
267    def _normalize(adv: torch.Tensor):
269        return (adv - adv.mean()) / (adv.std() + 1e-8)

Calculate total loss

271    def _calc_loss(self, samples: Dict[str, torch.Tensor]) -> torch.Tensor:

returns sampled from

277        sampled_return = samples['values'] + samples['advantages']

, where is advantages sampled from . Refer to sampling function in Main class below for the calculation of .

283        sampled_normalized_advantage = self._normalize(samples['advantages'])

Sampled observations are fed into the model to get and ; we are treating observations as state

287        pi, value = self.model(samples['obs'])

, are actions sampled from

290        log_pi = pi.log_prob(samples['actions'])

Calculate policy loss

293        policy_loss = self.ppo_loss(log_pi, samples['log_pis'], sampled_normalized_advantage, self.clip_range())

Calculate Entropy Bonus

299        entropy_bonus = pi.entropy()
300        entropy_bonus = entropy_bonus.mean()

Calculate value function loss

303        value_loss = self.value_loss(value, samples['values'], sampled_return, self.clip_range())

308        loss = (policy_loss
309                + self.value_loss_coef() * value_loss
310                - self.entropy_bonus_coef() * entropy_bonus)

for monitoring

313        approx_kl_divergence = .5 * ((samples['log_pis'] - log_pi) ** 2).mean()

Add to tracker

316        tracker.add({'policy_reward': -policy_loss,
317                     'value_loss': value_loss,
318                     'entropy_bonus': entropy_bonus,
319                     'kl_div': approx_kl_divergence,
320                     'clip_fraction': self.ppo_loss.clip_fraction})
321
322        return loss

Run training loop

324    def run_training_loop(self):

last 100 episode information

330        tracker.set_queue('reward', 100, True)
331        tracker.set_queue('length', 100, True)
332
333        for update in monit.loop(self.updates):

sample with current policy

335            samples = self.sample()

train the model

338            self.train(samples)

Save tracked indicators.

341            tracker.save()

Add a new line to the screen periodically

343            if (update + 1) % 1_000 == 0:
344                logger.log()

Destroy

Stop the workers

346    def destroy(self):
351        for worker in self.workers:
352            worker.child.send(("close", None))
355def main():

Create the experiment

357    experiment.create(name='ppo')

Configurations

359    configs = {

Number of updates

361        'updates': 10000,

⚙️ Number of epochs to train the model with sampled data. You can change this while the experiment is running. Example

365        'epochs': IntDynamicHyperParam(8),

Number of worker processes

367        'n_workers': 8,

Number of steps to run on each process for a single update

369        'worker_steps': 128,

Number of mini batches

371        'batches': 4,

⚙️ Value loss coefficient. You can change this while the experiment is running. Example

375        'value_loss_coef': FloatDynamicHyperParam(0.5),

⚙️ Entropy bonus coefficient. You can change this while the experiment is running. Example

379        'entropy_bonus_coef': FloatDynamicHyperParam(0.01),

⚙️ Clip range.

381        'clip_range': FloatDynamicHyperParam(0.1),

You can change this while the experiment is running. Example ⚙️ Learning rate.

385        'learning_rate': FloatDynamicHyperParam(1e-3, (0, 1e-3)),
386    }
387
388    experiment.configs(configs)

Initialize the trainer

391    m = Trainer(**configs)

Run and monitor the experiment

394    with experiment.start():
395        m.run_training_loop()

Stop the workers

397    m.destroy()

Run it

401if __name__ == "__main__":
402    main()