Deep Q Network (DQN) Model

Open In Colab

12import torch
13from torch import nn
14
15from labml_helpers.module import Module

Dueling Network ⚔️ Model for Values

We are using a dueling network to calculate Q-values. Intuition behind dueling network architecture is that in most states the action doesn't matter, and in some states the action is significant. Dueling network allows this to be represented very well.

So we create two networks for and and get from them. We share the initial layers of the and networks.

18class Model(Module):
49    def __init__(self):
50        super().__init__()
51        self.conv = nn.Sequential(

The first convolution layer takes a frame and produces a frame

54            nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4),
55            nn.ReLU(),

The second convolution layer takes a frame and produces a frame

59            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
60            nn.ReLU(),

The third convolution layer takes a frame and produces a frame

64            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
65            nn.ReLU(),
66        )

A fully connected layer takes the flattened frame from third convolution layer, and outputs features

71        self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)
72        self.activation = nn.ReLU()

This head gives the state value

75        self.state_value = nn.Sequential(
76            nn.Linear(in_features=512, out_features=256),
77            nn.ReLU(),
78            nn.Linear(in_features=256, out_features=1),
79        )

This head gives the action value

81        self.action_value = nn.Sequential(
82            nn.Linear(in_features=512, out_features=256),
83            nn.ReLU(),
84            nn.Linear(in_features=256, out_features=4),
85        )
87    def forward(self, obs: torch.Tensor):

Convolution

89        h = self.conv(obs)

Reshape for linear layers

91        h = h.reshape((-1, 7 * 7 * 64))

Linear layer

94        h = self.activation(self.lin(h))

97        action_value = self.action_value(h)

99        state_value = self.state_value(h)

102        action_score_centered = action_value - action_value.mean(dim=-1, keepdim=True)

104        q = state_value + action_score_centered
105
106        return q