DQN Experiment with Atari Breakout

This experiment trains a Deep Q Network (DQN) to play Atari Breakout game on OpenAI Gym. It runs the game environments on multiple processes to sample efficiently.

Open In Colab

15import numpy as np
16import torch
17
18from labml import tracker, experiment, logger, monit
19from labml.internal.configs.dynamic_hyperparam import FloatDynamicHyperParam
20from labml_helpers.schedule import Piecewise
21from labml_nn.rl.dqn import QFuncLoss
22from labml_nn.rl.dqn.model import Model
23from labml_nn.rl.dqn.replay_buffer import ReplayBuffer
24from labml_nn.rl.game import Worker

Select device

27if torch.cuda.is_available():
28    device = torch.device("cuda:0")
29else:
30    device = torch.device("cpu")

Scale observations from [0, 255] to [0, 1]

33def obs_to_torch(obs: np.ndarray) -> torch.Tensor:
35    return torch.tensor(obs, dtype=torch.float32, device=device) / 255.

Trainer

38class Trainer:
43    def __init__(self, *,
44                 updates: int, epochs: int,
45                 n_workers: int, worker_steps: int, mini_batch_size: int,
46                 update_target_model: int,
47                 learning_rate: FloatDynamicHyperParam,
48                 ):

number of workers

50        self.n_workers = n_workers

steps sampled on each update

52        self.worker_steps = worker_steps

number of training iterations

54        self.train_epochs = epochs

number of updates

57        self.updates = updates

size of mini batch for training

59        self.mini_batch_size = mini_batch_size

update target network every 250 update

62        self.update_target_model = update_target_model

learning rate

65        self.learning_rate = learning_rate

exploration as a function of updates

68        self.exploration_coefficient = Piecewise(
69            [
70                (0, 1.0),
71                (25_000, 0.1),
72                (self.updates / 2, 0.01)
73            ], outside_value=0.01)

for replay buffer as a function of updates

76        self.prioritized_replay_beta = Piecewise(
77            [
78                (0, 0.4),
79                (self.updates, 1)
80            ], outside_value=1)

Replay buffer with . Capacity of the replay buffer must be a power of 2.

83        self.replay_buffer = ReplayBuffer(2 ** 14, 0.6)

Model for sampling and training

86        self.model = Model().to(device)

target model to get

88        self.target_model = Model().to(device)

create workers

91        self.workers = [Worker(47 + i) for i in range(self.n_workers)]

initialize tensors for observations

94        self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8)

reset the workers

97        for worker in self.workers:
98            worker.child.send(("reset", None))

get the initial observations

101        for i, worker in enumerate(self.workers):
102            self.obs[i] = worker.child.recv()

loss function

105        self.loss_func = QFuncLoss(0.99)

optimizer

107        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=2.5e-4)

-greedy Sampling

When sampling actions we use a -greedy strategy, where we take a greedy action with probabiliy and take a random action with probability . We refer to as exploration_coefficient .

109    def _sample_action(self, q_value: torch.Tensor, exploration_coefficient: float):

Sampling doesn't need gradients

119        with torch.no_grad():

Sample the action with highest Q-value. This is the greedy action.

121            greedy_action = torch.argmax(q_value, dim=-1)

Uniformly sample and action

123            random_action = torch.randint(q_value.shape[-1], greedy_action.shape, device=q_value.device)

Whether to chose greedy action or the random action

125            is_choose_rand = torch.rand(greedy_action.shape, device=q_value.device) < exploration_coefficient

Pick the action based on is_choose_rand

127            return torch.where(is_choose_rand, random_action, greedy_action).cpu().numpy()

Sample data

129    def sample(self, exploration_coefficient: float):

This doesn't need gradients

133        with torch.no_grad():

Sample worker_steps

135            for t in range(self.worker_steps):

Get Q_values for the current observation

137                q_value = self.model(obs_to_torch(self.obs))

Sample actions

139                actions = self._sample_action(q_value, exploration_coefficient)

Run sampled actions on each worker

142                for w, worker in enumerate(self.workers):
143                    worker.child.send(("step", actions[w]))

Collect information from each worker

146                for w, worker in enumerate(self.workers):

Get results after executing the actions

148                    next_obs, reward, done, info = worker.child.recv()

Add transition to replay buffer

151                    self.replay_buffer.add(self.obs[w], actions[w], reward, next_obs, done)

update episode information. collect episode info, which is available if an episode finished; this includes total reward and length of the episode - look at Game to see how it works.

157                    if info:
158                        tracker.add('reward', info['reward'])
159                        tracker.add('length', info['length'])

update current observation

162                    self.obs[w] = next_obs

Train the model

164    def train(self, beta: float):
168        for _ in range(self.train_epochs):

Sample from priority replay buffer

170            samples = self.replay_buffer.sample(self.mini_batch_size, beta)

Get the predicted Q-value

172            q_value = self.model(obs_to_torch(samples['obs']))

Get the Q-values of the next state for Double Q-learning. Gradients shouldn't propagate for these

176            with torch.no_grad():

Get

178                double_q_value = self.model(obs_to_torch(samples['next_obs']))

Get

180                target_q_value = self.target_model(obs_to_torch(samples['next_obs']))

Compute Temporal Difference (TD) errors, , and the loss, .

183            td_errors, loss = self.loss_func(q_value,
184                                             q_value.new_tensor(samples['action']),
185                                             double_q_value, target_q_value,
186                                             q_value.new_tensor(samples['done']),
187                                             q_value.new_tensor(samples['reward']),
188                                             q_value.new_tensor(samples['weights']))

Calculate priorities for replay buffer

191            new_priorities = np.abs(td_errors.cpu().numpy()) + 1e-6

Update replay buffer priorities

193            self.replay_buffer.update_priorities(samples['indexes'], new_priorities)

Set learning rate

196            for pg in self.optimizer.param_groups:
197                pg['lr'] = self.learning_rate()

Zero out the previously calculated gradients

199            self.optimizer.zero_grad()

Calculate gradients

201            loss.backward()

Clip gradients

203            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)

Update parameters based on gradients

205            self.optimizer.step()

Run training loop

207    def run_training_loop(self):

Last 100 episode information

213        tracker.set_queue('reward', 100, True)
214        tracker.set_queue('length', 100, True)

Copy to target network initially

217        self.target_model.load_state_dict(self.model.state_dict())
218
219        for update in monit.loop(self.updates):

, exploration fraction

221            exploration = self.exploration_coefficient(update)
222            tracker.add('exploration', exploration)

for prioritized replay

224            beta = self.prioritized_replay_beta(update)
225            tracker.add('beta', beta)

Sample with current policy

228            self.sample(exploration)

Start training after the buffer is full

231            if self.replay_buffer.is_full():

Train the model

233                self.train(beta)

Periodically update target network

236                if update % self.update_target_model == 0:
237                    self.target_model.load_state_dict(self.model.state_dict())

Save tracked indicators.

240            tracker.save()

Add a new line to the screen periodically

242            if (update + 1) % 1_000 == 0:
243                logger.log()

Destroy

Stop the workers

245    def destroy(self):
250        for worker in self.workers:
251            worker.child.send(("close", None))
254def main():

Create the experiment

256    experiment.create(name='dqn')

Configurations

259    configs = {

Number of updates

261        'updates': 1_000_000,

Number of epochs to train the model with sampled data.

263        'epochs': 8,

Number of worker processes

265        'n_workers': 8,

Number of steps to run on each process for a single update

267        'worker_steps': 4,

Mini batch size

269        'mini_batch_size': 32,

Target model updating interval

271        'update_target_model': 250,

Learning rate.

273        'learning_rate': FloatDynamicHyperParam(1e-4, (0, 1e-3)),
274    }

Configurations

277    experiment.configs(configs)

Initialize the trainer

280    m = Trainer(**configs)

Run and monitor the experiment

282    with experiment.start():
283        m.run_training_loop()

Stop the workers

285    m.destroy()

Run it

289if __name__ == "__main__":
290    main()