DQN Experiment with Atari Breakout

This experiment trains a Deep Q Network (DQN) to play Atari Breakout game on OpenAI Gym. It runs the game environments on multiple processes to sample efficiently.

Open In Colab View Run

16import numpy as np
17import torch
18
19from labml import tracker, experiment, logger, monit
20from labml.internal.configs.dynamic_hyperparam import FloatDynamicHyperParam
21from labml_helpers.schedule import Piecewise
22from labml_nn.rl.dqn import QFuncLoss
23from labml_nn.rl.dqn.model import Model
24from labml_nn.rl.dqn.replay_buffer import ReplayBuffer
25from labml_nn.rl.game import Worker

Select device

28if torch.cuda.is_available():
29    device = torch.device("cuda:0")
30else:
31    device = torch.device("cpu")

Scale observations from [0, 255] to [0, 1]

34def obs_to_torch(obs: np.ndarray) -> torch.Tensor:
36    return torch.tensor(obs, dtype=torch.float32, device=device) / 255.

Trainer

39class Trainer:
44    def __init__(self, *,
45                 updates: int, epochs: int,
46                 n_workers: int, worker_steps: int, mini_batch_size: int,
47                 update_target_model: int,
48                 learning_rate: FloatDynamicHyperParam,
49                 ):

number of workers

51        self.n_workers = n_workers

steps sampled on each update

53        self.worker_steps = worker_steps

number of training iterations

55        self.train_epochs = epochs

number of updates

58        self.updates = updates

size of mini batch for training

60        self.mini_batch_size = mini_batch_size

update target network every 250 update

63        self.update_target_model = update_target_model

learning rate

66        self.learning_rate = learning_rate

exploration as a function of updates

69        self.exploration_coefficient = Piecewise(
70            [
71                (0, 1.0),
72                (25_000, 0.1),
73                (self.updates / 2, 0.01)
74            ], outside_value=0.01)

for replay buffer as a function of updates

77        self.prioritized_replay_beta = Piecewise(
78            [
79                (0, 0.4),
80                (self.updates, 1)
81            ], outside_value=1)

Replay buffer with . Capacity of the replay buffer must be a power of 2.

84        self.replay_buffer = ReplayBuffer(2 ** 14, 0.6)

Model for sampling and training

87        self.model = Model().to(device)

target model to get

89        self.target_model = Model().to(device)

create workers

92        self.workers = [Worker(47 + i) for i in range(self.n_workers)]

initialize tensors for observations

95        self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8)

reset the workers

98        for worker in self.workers:
99            worker.child.send(("reset", None))

get the initial observations

102        for i, worker in enumerate(self.workers):
103            self.obs[i] = worker.child.recv()

loss function

106        self.loss_func = QFuncLoss(0.99)

optimizer

108        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=2.5e-4)

-greedy Sampling

When sampling actions we use a -greedy strategy, where we take a greedy action with probabiliy and take a random action with probability . We refer to as exploration_coefficient .

110    def _sample_action(self, q_value: torch.Tensor, exploration_coefficient: float):

Sampling doesn't need gradients

120        with torch.no_grad():

Sample the action with highest Q-value. This is the greedy action.

122            greedy_action = torch.argmax(q_value, dim=-1)

Uniformly sample and action

124            random_action = torch.randint(q_value.shape[-1], greedy_action.shape, device=q_value.device)

Whether to chose greedy action or the random action

126            is_choose_rand = torch.rand(greedy_action.shape, device=q_value.device) < exploration_coefficient

Pick the action based on is_choose_rand

128            return torch.where(is_choose_rand, random_action, greedy_action).cpu().numpy()

Sample data

130    def sample(self, exploration_coefficient: float):

This doesn't need gradients

134        with torch.no_grad():

Sample worker_steps

136            for t in range(self.worker_steps):

Get Q_values for the current observation

138                q_value = self.model(obs_to_torch(self.obs))

Sample actions

140                actions = self._sample_action(q_value, exploration_coefficient)

Run sampled actions on each worker

143                for w, worker in enumerate(self.workers):
144                    worker.child.send(("step", actions[w]))

Collect information from each worker

147                for w, worker in enumerate(self.workers):

Get results after executing the actions

149                    next_obs, reward, done, info = worker.child.recv()

Add transition to replay buffer

152                    self.replay_buffer.add(self.obs[w], actions[w], reward, next_obs, done)

update episode information. collect episode info, which is available if an episode finished; this includes total reward and length of the episode - look at Game to see how it works.

158                    if info:
159                        tracker.add('reward', info['reward'])
160                        tracker.add('length', info['length'])

update current observation

163                    self.obs[w] = next_obs

Train the model

165    def train(self, beta: float):
169        for _ in range(self.train_epochs):

Sample from priority replay buffer

171            samples = self.replay_buffer.sample(self.mini_batch_size, beta)

Get the predicted Q-value

173            q_value = self.model(obs_to_torch(samples['obs']))

Get the Q-values of the next state for Double Q-learning. Gradients shouldn't propagate for these

177            with torch.no_grad():

Get

179                double_q_value = self.model(obs_to_torch(samples['next_obs']))

Get

181                target_q_value = self.target_model(obs_to_torch(samples['next_obs']))

Compute Temporal Difference (TD) errors, , and the loss, .

184            td_errors, loss = self.loss_func(q_value,
185                                             q_value.new_tensor(samples['action']),
186                                             double_q_value, target_q_value,
187                                             q_value.new_tensor(samples['done']),
188                                             q_value.new_tensor(samples['reward']),
189                                             q_value.new_tensor(samples['weights']))

Calculate priorities for replay buffer

192            new_priorities = np.abs(td_errors.cpu().numpy()) + 1e-6

Update replay buffer priorities

194            self.replay_buffer.update_priorities(samples['indexes'], new_priorities)

Set learning rate

197            for pg in self.optimizer.param_groups:
198                pg['lr'] = self.learning_rate()

Zero out the previously calculated gradients

200            self.optimizer.zero_grad()

Calculate gradients

202            loss.backward()

Clip gradients

204            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)

Update parameters based on gradients

206            self.optimizer.step()

Run training loop

208    def run_training_loop(self):

Last 100 episode information

214        tracker.set_queue('reward', 100, True)
215        tracker.set_queue('length', 100, True)

Copy to target network initially

218        self.target_model.load_state_dict(self.model.state_dict())
219
220        for update in monit.loop(self.updates):

, exploration fraction

222            exploration = self.exploration_coefficient(update)
223            tracker.add('exploration', exploration)

for prioritized replay

225            beta = self.prioritized_replay_beta(update)
226            tracker.add('beta', beta)

Sample with current policy

229            self.sample(exploration)

Start training after the buffer is full

232            if self.replay_buffer.is_full():

Train the model

234                self.train(beta)

Periodically update target network

237                if update % self.update_target_model == 0:
238                    self.target_model.load_state_dict(self.model.state_dict())

Save tracked indicators.

241            tracker.save()

Add a new line to the screen periodically

243            if (update + 1) % 1_000 == 0:
244                logger.log()

Destroy

Stop the workers

246    def destroy(self):
251        for worker in self.workers:
252            worker.child.send(("close", None))
255def main():

Create the experiment

257    experiment.create(name='dqn')

Configurations

260    configs = {

Number of updates

262        'updates': 1_000_000,

Number of epochs to train the model with sampled data.

264        'epochs': 8,

Number of worker processes

266        'n_workers': 8,

Number of steps to run on each process for a single update

268        'worker_steps': 4,

Mini batch size

270        'mini_batch_size': 32,

Target model updating interval

272        'update_target_model': 250,

Learning rate.

274        'learning_rate': FloatDynamicHyperParam(1e-4, (0, 1e-3)),
275    }

Configurations

278    experiment.configs(configs)

Initialize the trainer

281    m = Trainer(**configs)

Run and monitor the experiment

283    with experiment.start():
284        m.run_training_loop()

Stop the workers

286    m.destroy()

Run it

290if __name__ == "__main__":
291    main()