This is a PyTorch implementation of paper Playing Atari with Deep Reinforcement Learning along with Dueling Network, Prioritized Replay and Double Q Network.
Here is the experiment and model implementation.
24from typing import Tuple
25
26import torch
27from torch import nn
28
29from labml import tracker
30from labml_nn.rl.dqn.replay_buffer import ReplayBuffer
We want to find optimal action-value function.
In order to improve stability we use experience replay that randomly sample from previous experience . We also use a Q network with a separate set of parameters to calculate the target. is updated periodically. This is according to paper Human Level Control Through Deep Reinforcement Learning.
So the loss function is,
The max operator in the above calculation uses same network for both selecting the best action and for evaluating the value. That is, We use double Q-learning, where the is taken from and the value is taken from .
And the loss function becomes,
33class QFuncLoss(nn.Module):
101 def __init__(self, gamma: float):
102 super().__init__()
103 self.gamma = gamma
104 self.huber_loss = nn.SmoothL1Loss(reduction='none')
q
- action
- double_q
- target_q
- done
- whether the game ended after taking the action reward
- weights
- weights of the samples from prioritized experienced replay106 def forward(self, q: torch.Tensor, action: torch.Tensor, double_q: torch.Tensor,
107 target_q: torch.Tensor, done: torch.Tensor, reward: torch.Tensor,
108 weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
120 q_sampled_action = q.gather(-1, action.to(torch.long).unsqueeze(-1)).squeeze(-1)
121 tracker.add('q_sampled_action', q_sampled_action)
Gradients shouldn't propagate gradients
129 with torch.no_grad():
Get the best action at state
133 best_next_action = torch.argmax(double_q, -1)
Get the q value from the target network for the best action at state
139 best_next_q_value = target_q.gather(-1, best_next_action.unsqueeze(-1)).squeeze(-1)
Calculate the desired Q value. We multiply by (1 - done)
to zero out the next state Q values if the game ended.
150 q_update = reward + self.gamma * best_next_q_value * (1 - done)
151 tracker.add('q_update', q_update)
Temporal difference error is used to weigh samples in replay buffer
154 td_error = q_sampled_action - q_update
155 tracker.add('td_error', td_error)
We take Huber loss instead of mean squared error loss because it is less sensitive to outliers
159 losses = self.huber_loss(q_sampled_action, q_update)
Get weighted means
161 loss = torch.mean(weights * losses)
162 tracker.add('loss', loss)
163
164 return td_error, loss