24from typing import Tuple 25 26import torch 27from torch import nn 28 29from labml import tracker 30from labml_helpers.module import Module 31from labml_nn.rl.dqn.replay_buffer import ReplayBuffer
We want to find optimal action-value function.
In order to improve stability we use experience replay that randomly sample from previous experience . We also use a Q network with a separate set of parameters to calculate the target. is updated periodically. This is according to paper Human Level Control Through Deep Reinforcement Learning.
So the loss function is,
The max operator in the above calculation uses same network for both selecting the best action and for evaluating the value. That is, We use double Q-learning, where the is taken from and the value is taken from .
And the loss function becomes,
102 def __init__(self, gamma: float): 103 super().__init__() 104 self.gamma = gamma 105 self.huber_loss = nn.SmoothL1Loss(reduction='none')
done- whether the game ended after taking the action
weights- weights of the samples from prioritized experienced replay
107 def forward(self, q: torch.Tensor, action: torch.Tensor, double_q: torch.Tensor, 108 target_q: torch.Tensor, done: torch.Tensor, reward: torch.Tensor, 109 weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
121 q_sampled_action = q.gather(-1, action.to(torch.long).unsqueeze(-1)).squeeze(-1) 122 tracker.add('q_sampled_action', q_sampled_action)
Gradients shouldn't propagate gradients
130 with torch.no_grad():
Get the best action at state
134 best_next_action = torch.argmax(double_q, -1)
Get the q value from the target network for the best action at state
140 best_next_q_value = target_q.gather(-1, best_next_action.unsqueeze(-1)).squeeze(-1)
Calculate the desired Q value. We multiply by
(1 - done)
to zero out the next state Q values if the game ended.
151 q_update = reward + self.gamma * best_next_q_value * (1 - done) 152 tracker.add('q_update', q_update)
Temporal difference error is used to weigh samples in replay buffer
155 td_error = q_sampled_action - q_update 156 tracker.add('td_error', td_error)
We take Huber loss instead of mean squared error loss because it is less sensitive to outliers
160 losses = self.huber_loss(q_sampled_action, q_update)
Get weighted means
162 loss = torch.mean(weights * losses) 163 tracker.add('loss', loss) 164 165 return td_error, loss