Deep Q Networks (DQN)

This is a PyTorch implementation of paper Playing Atari with Deep Reinforcement Learning along with Dueling Network, Prioritized Replay and Double Q Network.

Here is the experiment and model implementation.

27from typing import Tuple
29import torch
30from torch import nn
32from labml import tracker
33from labml_helpers.module import Module
34from labml_nn.rl.dqn.replay_buffer import ReplayBuffer

Train the model

We want to find optimal action-value function.

Target network 🎯

In order to improve stability we use experience replay that randomly sample from previous experience $U(D)$. We also use a Q network with a separate set of paramters $\color{orangle}{\theta_i^{-}}$ to calculate the target. $\color{orangle}{\theta_i^{-}}$ is updated periodically. This is according to paper Human Level Control Through Deep Reinforcement Learning.

So the loss function is,

Double $Q$-Learning

The max operator in the above calculation uses same network for both selecting the best action and for evaluating the value. That is, We use double Q-learning, where the $\operatorname{argmax}$ is taken from $\color{cyan}{\theta_i}$ and the value is taken from $\color{orange}{\theta_i^{-}}$.

And the loss function becomes,

37class QFuncLoss(Module):
104    def __init__(self, gamma: float):
105        super().__init__()
106        self.gamma = gamma
107        self.huber_loss = nn.SmoothL1Loss(reduction='none')
  • q - $Q(s;\theta_i)$
  • action - $a$
  • double_q - $\color{cyan}Q(s’;\color{cyan}{\theta_i})$
  • target_q - $\color{orange}Q(s’;\color{orange}{\theta_i^{-}})$
  • done - whether the game ended after taking the action
  • reward - $r$
  • weights - weights of the samples from prioritized experienced replay
109    def __call__(self, q: torch.Tensor, action: torch.Tensor, double_q: torch.Tensor,
110                 target_q: torch.Tensor, done: torch.Tensor, reward: torch.Tensor,
111                 weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:


123        q_sampled_action = q.gather(-1,
124        tracker.add('q_sampled_action', q_sampled_action)

Gradients shouldn’t propagate gradients

132        with torch.no_grad():

Get the best action at state $s’$

136            best_next_action = torch.argmax(double_q, -1)

Get the q value from the target network for the best action at state $s’$

142            best_next_q_value = target_q.gather(-1, best_next_action.unsqueeze(-1)).squeeze(-1)

Calculate the desired Q value. We multiply by (1 - done) to zero out the next state Q values if the game ended.

153            q_update = reward + self.gamma * best_next_q_value * (1 - done)
154            tracker.add('q_update', q_update)

Temporal difference error $\delta$ is used to weigh samples in replay buffer

157            td_error = q_sampled_action - q_update
158            tracker.add('td_error', td_error)

We take Huber loss instead of mean squared error loss because it is less sensitive to outliers

162        losses = self.huber_loss(q_sampled_action, q_update)

Get weighted means

164        loss = torch.mean(weights * losses)
165        tracker.add('loss', loss)
167        return td_error, loss