DQN Experiment with Atari Breakout

This experiment trains a Deep Q Network (DQN) to play Atari Breakout game on OpenAI Gym. It runs the game environments on multiple processes to sample efficiently.

13import numpy as np
14import torch
15
16from labml import tracker, experiment, logger, monit
17from labml_helpers.schedule import Piecewise
18from labml_nn.rl.dqn import QFuncLoss
19from labml_nn.rl.dqn.model import Model
20from labml_nn.rl.dqn.replay_buffer import ReplayBuffer
21from labml_nn.rl.game import Worker

Select device

24if torch.cuda.is_available():
25    device = torch.device("cuda:0")
26else:
27    device = torch.device("cpu")

Scale observations from [0, 255] to [0, 1]

30def obs_to_torch(obs: np.ndarray) -> torch.Tensor:
32    return torch.tensor(obs, dtype=torch.float32, device=device) / 255.

Trainer

35class Trainer:
40    def __init__(self):

Configurations

number of workers

44        self.n_workers = 8

steps sampled on each update

46        self.worker_steps = 4

number of training iterations

48        self.train_epochs = 8

number of updates

51        self.updates = 1_000_000

size of mini batch for training

53        self.mini_batch_size = 32

exploration as a function of updates

56        self.exploration_coefficient = Piecewise(
57            [
58                (0, 1.0),
59                (25_000, 0.1),
60                (self.updates / 2, 0.01)
61            ], outside_value=0.01)

update target network every 250 update

64        self.update_target_model = 250

$\beta$ for replay buffer as a function of updates

67        self.prioritized_replay_beta = Piecewise(
68            [
69                (0, 0.4),
70                (self.updates, 1)
71            ], outside_value=1)

Replay buffer with $\alpha = 0.6$. Capacity of the replay buffer must be a power of 2.

74        self.replay_buffer = ReplayBuffer(2 ** 14, 0.6)

Model for sampling and training

77        self.model = Model().to(device)

target model to get $\color{orange}Q(s’;\color{orange}{\theta_i^{-}})$

79        self.target_model = Model().to(device)

create workers

82        self.workers = [Worker(47 + i) for i in range(self.n_workers)]

initialize tensors for observations

85        self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8)
86        for worker in self.workers:
87            worker.child.send(("reset", None))
88        for i, worker in enumerate(self.workers):
89            self.obs[i] = worker.child.recv()

loss function

92        self.loss_func = QFuncLoss(0.99)

optimizer

94        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=2.5e-4)

$\epsilon$-greedy Sampling

When sampling actions we use a $\epsilon$-greedy strategy, where we take a greedy action with probabiliy $1 - \epsilon$ and take a random action with probability $\epsilon$. We refer to $\epsilon$ as exploration_coefficient.

96    def _sample_action(self, q_value: torch.Tensor, exploration_coefficient: float):

Sampling doesn’t need gradients

106        with torch.no_grad():

Sample the action with highest Q-value. This is the greedy action.

108            greedy_action = torch.argmax(q_value, dim=-1)

Uniformly sample and action

110            random_action = torch.randint(q_value.shape[-1], greedy_action.shape, device=q_value.device)

Whether to chose greedy action or the random action

112            is_choose_rand = torch.rand(greedy_action.shape, device=q_value.device) < exploration_coefficient

Pick the action based on is_choose_rand

114            return torch.where(is_choose_rand, random_action, greedy_action).cpu().numpy()

Sample data

116    def sample(self, exploration_coefficient: float):

This doesn’t need gradients

120        with torch.no_grad():

Sample worker_steps

122            for t in range(self.worker_steps):

Get Q_values for the current observation

124                q_value = self.model(obs_to_torch(self.obs))

Sample actions

126                actions = self._sample_action(q_value, exploration_coefficient)

Run sampled actions on each worker

129                for w, worker in enumerate(self.workers):
130                    worker.child.send(("step", actions[w]))

Collect information from each worker

133                for w, worker in enumerate(self.workers):

Get results after executing the actions

135                    next_obs, reward, done, info = worker.child.recv()

Add transition to replay buffer

138                    self.replay_buffer.add(self.obs[w], actions[w], reward, next_obs, done)

update episode information. collect episode info, which is available if an episode finished; this includes total reward and length of the episode - look at Game to see how it works.

144                    if info:
145                        tracker.add('reward', info['reward'])
146                        tracker.add('length', info['length'])

update current observation

149                    self.obs[w] = next_obs

Train the model

151    def train(self, beta: float):
155        for _ in range(self.train_epochs):

Sample from priority replay buffer

157            samples = self.replay_buffer.sample(self.mini_batch_size, beta)

Get the predicted Q-value

159            q_value = self.model(obs_to_torch(samples['obs']))

Get the Q-values of the next state for Double Q-learning. Gradients shouldn’t propagate for these

163            with torch.no_grad():

Get $\color{cyan}Q(s’;\color{cyan}{\theta_i})$

165                double_q_value = self.model(obs_to_torch(samples['next_obs']))

Get $\color{orange}Q(s’;\color{orange}{\theta_i^{-}})$

167                target_q_value = self.target_model(obs_to_torch(samples['next_obs']))

Compute Temporal Difference (TD) errors, $\delta$, and the loss, $\mathcal{L}(\theta)$.

170            td_errors, loss = self.loss_func(q_value,
171                                             q_value.new_tensor(samples['action']),
172                                             double_q_value, target_q_value,
173                                             q_value.new_tensor(samples['done']),
174                                             q_value.new_tensor(samples['reward']),
175                                             q_value.new_tensor(samples['weights']))

Calculate priorities for replay buffer $p_i = |\delta_i| + \epsilon$

178            new_priorities = np.abs(td_errors.cpu().numpy()) + 1e-6

Update replay buffer priorities

180            self.replay_buffer.update_priorities(samples['indexes'], new_priorities)

Zero out the previously calculated gradients

183            self.optimizer.zero_grad()

Calculate gradients

185            loss.backward()

Clip gradients

187            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)

Update parameters based on gradients

189            self.optimizer.step()

Run training loop

191    def run_training_loop(self):

Last 100 episode information

197        tracker.set_queue('reward', 100, True)
198        tracker.set_queue('length', 100, True)

Copy to target network initially

201        self.target_model.load_state_dict(self.model.state_dict())
202
203        for update in monit.loop(self.updates):

$\epsilon$, exploration fraction

205            exploration = self.exploration_coefficient(update)
206            tracker.add('exploration', exploration)

$\beta$ for prioritized replay

208            beta = self.prioritized_replay_beta(update)
209            tracker.add('beta', beta)

Sample with current policy

212            self.sample(exploration)

Start training after the buffer is full

215            if self.replay_buffer.is_full():

Train the model

217                self.train(beta)

Periodically update target network

220                if update % self.update_target_model == 0:
221                    self.target_model.load_state_dict(self.model.state_dict())

Save tracked indicators.

224            tracker.save()

Add a new line to the screen periodically

226            if (update + 1) % 1_000 == 0:
227                logger.log()

Destroy

Stop the workers

229    def destroy(self):
234        for worker in self.workers:
235            worker.child.send(("close", None))
238def main():

Create the experiment

240    experiment.create(name='dqn')

Initialize the trainer

242    m = Trainer()

Run and monitor the experiment

244    with experiment.start():
245        m.run_training_loop()

Stop the workers

247    m.destroy()

Run it

251if __name__ == "__main__":
252    main()