This experiment trains a Deep Q Network (DQN) to play Atari Breakout game on OpenAI Gym. It runs the game environments on multiple processes to sample efficiently.
15import numpy as np
16import torch
17
18from labml import tracker, experiment, logger, monit
19from labml.internal.configs.dynamic_hyperparam import FloatDynamicHyperParam
20from labml_helpers.schedule import Piecewise
21from labml_nn.rl.dqn import QFuncLoss
22from labml_nn.rl.dqn.model import Model
23from labml_nn.rl.dqn.replay_buffer import ReplayBuffer
24from labml_nn.rl.game import Worker
Select device
27if torch.cuda.is_available():
28 device = torch.device("cuda:0")
29else:
30 device = torch.device("cpu")
Scale observations from [0, 255]
to [0, 1]
33def obs_to_torch(obs: np.ndarray) -> torch.Tensor:
35 return torch.tensor(obs, dtype=torch.float32, device=device) / 255.
38class Trainer:
43 def __init__(self, *,
44 updates: int, epochs: int,
45 n_workers: int, worker_steps: int, mini_batch_size: int,
46 update_target_model: int,
47 learning_rate: FloatDynamicHyperParam,
48 ):
number of workers
50 self.n_workers = n_workers
steps sampled on each update
52 self.worker_steps = worker_steps
number of training iterations
54 self.train_epochs = epochs
number of updates
57 self.updates = updates
size of mini batch for training
59 self.mini_batch_size = mini_batch_size
update target network every 250 update
62 self.update_target_model = update_target_model
learning rate
65 self.learning_rate = learning_rate
exploration as a function of updates
68 self.exploration_coefficient = Piecewise(
69 [
70 (0, 1.0),
71 (25_000, 0.1),
72 (self.updates / 2, 0.01)
73 ], outside_value=0.01)
for replay buffer as a function of updates
76 self.prioritized_replay_beta = Piecewise(
77 [
78 (0, 0.4),
79 (self.updates, 1)
80 ], outside_value=1)
Replay buffer with . Capacity of the replay buffer must be a power of 2.
83 self.replay_buffer = ReplayBuffer(2 ** 14, 0.6)
Model for sampling and training
86 self.model = Model().to(device)
target model to get
88 self.target_model = Model().to(device)
create workers
91 self.workers = [Worker(47 + i) for i in range(self.n_workers)]
initialize tensors for observations
94 self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8)
reset the workers
97 for worker in self.workers:
98 worker.child.send(("reset", None))
get the initial observations
101 for i, worker in enumerate(self.workers):
102 self.obs[i] = worker.child.recv()
loss function
105 self.loss_func = QFuncLoss(0.99)
optimizer
107 self.optimizer = torch.optim.Adam(self.model.parameters(), lr=2.5e-4)
When sampling actions we use a -greedy strategy, where we take a greedy action with probabiliy and take a random action with probability . We refer to as exploration_coefficient
.
109 def _sample_action(self, q_value: torch.Tensor, exploration_coefficient: float):
Sampling doesn't need gradients
119 with torch.no_grad():
Sample the action with highest Q-value. This is the greedy action.
121 greedy_action = torch.argmax(q_value, dim=-1)
Uniformly sample and action
123 random_action = torch.randint(q_value.shape[-1], greedy_action.shape, device=q_value.device)
Whether to chose greedy action or the random action
125 is_choose_rand = torch.rand(greedy_action.shape, device=q_value.device) < exploration_coefficient
Pick the action based on is_choose_rand
127 return torch.where(is_choose_rand, random_action, greedy_action).cpu().numpy()
129 def sample(self, exploration_coefficient: float):
This doesn't need gradients
133 with torch.no_grad():
Sample worker_steps
135 for t in range(self.worker_steps):
Get Q_values for the current observation
137 q_value = self.model(obs_to_torch(self.obs))
Sample actions
139 actions = self._sample_action(q_value, exploration_coefficient)
Run sampled actions on each worker
142 for w, worker in enumerate(self.workers):
143 worker.child.send(("step", actions[w]))
Collect information from each worker
146 for w, worker in enumerate(self.workers):
Get results after executing the actions
148 next_obs, reward, done, info = worker.child.recv()
Add transition to replay buffer
151 self.replay_buffer.add(self.obs[w], actions[w], reward, next_obs, done)
update episode information. collect episode info, which is available if an episode finished; this includes total reward and length of the episode - look at Game
to see how it works.
157 if info:
158 tracker.add('reward', info['reward'])
159 tracker.add('length', info['length'])
update current observation
162 self.obs[w] = next_obs
164 def train(self, beta: float):
168 for _ in range(self.train_epochs):
Sample from priority replay buffer
170 samples = self.replay_buffer.sample(self.mini_batch_size, beta)
Get the predicted Q-value
172 q_value = self.model(obs_to_torch(samples['obs']))
Get the Q-values of the next state for Double Q-learning. Gradients shouldn't propagate for these
176 with torch.no_grad():
Get
178 double_q_value = self.model(obs_to_torch(samples['next_obs']))
Get
180 target_q_value = self.target_model(obs_to_torch(samples['next_obs']))
Compute Temporal Difference (TD) errors, , and the loss, .
183 td_errors, loss = self.loss_func(q_value,
184 q_value.new_tensor(samples['action']),
185 double_q_value, target_q_value,
186 q_value.new_tensor(samples['done']),
187 q_value.new_tensor(samples['reward']),
188 q_value.new_tensor(samples['weights']))
Calculate priorities for replay buffer
191 new_priorities = np.abs(td_errors.cpu().numpy()) + 1e-6
Update replay buffer priorities
193 self.replay_buffer.update_priorities(samples['indexes'], new_priorities)
Set learning rate
196 for pg in self.optimizer.param_groups:
197 pg['lr'] = self.learning_rate()
Zero out the previously calculated gradients
199 self.optimizer.zero_grad()
Calculate gradients
201 loss.backward()
Clip gradients
203 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)
Update parameters based on gradients
205 self.optimizer.step()
207 def run_training_loop(self):
Last 100 episode information
213 tracker.set_queue('reward', 100, True)
214 tracker.set_queue('length', 100, True)
Copy to target network initially
217 self.target_model.load_state_dict(self.model.state_dict())
218
219 for update in monit.loop(self.updates):
, exploration fraction
221 exploration = self.exploration_coefficient(update)
222 tracker.add('exploration', exploration)
for prioritized replay
224 beta = self.prioritized_replay_beta(update)
225 tracker.add('beta', beta)
Sample with current policy
228 self.sample(exploration)
Start training after the buffer is full
231 if self.replay_buffer.is_full():
Train the model
233 self.train(beta)
Periodically update target network
236 if update % self.update_target_model == 0:
237 self.target_model.load_state_dict(self.model.state_dict())
Save tracked indicators.
240 tracker.save()
Add a new line to the screen periodically
242 if (update + 1) % 1_000 == 0:
243 logger.log()
245 def destroy(self):
250 for worker in self.workers:
251 worker.child.send(("close", None))
254def main():
Create the experiment
256 experiment.create(name='dqn')
Configurations
259 configs = {
Number of updates
261 'updates': 1_000_000,
Number of epochs to train the model with sampled data.
263 'epochs': 8,
Number of worker processes
265 'n_workers': 8,
Number of steps to run on each process for a single update
267 'worker_steps': 4,
Mini batch size
269 'mini_batch_size': 32,
Target model updating interval
271 'update_target_model': 250,
Learning rate.
273 'learning_rate': FloatDynamicHyperParam(1e-4, (0, 1e-3)),
274 }
Configurations
277 experiment.configs(configs)
Initialize the trainer
280 m = Trainer(**configs)
Run and monitor the experiment
282 with experiment.start():
283 m.run_training_loop()
Stop the workers
285 m.destroy()
289if __name__ == "__main__":
290 main()