12import torch
13from torch import nn
We are using a dueling network to calculate Q-values. Intuition behind dueling network architecture is that in most states the action doesn't matter, and in some states the action is significant. Dueling network allows this to be represented very well.
So we create two networks for and and get from them. We share the initial layers of the and networks.
17class Model(nn.Module):
48 def __init__(self):
49 super().__init__()
50 self.conv = nn.Sequential(
The first convolution layer takes a frame and produces a frame
53 nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4),
54 nn.ReLU(),
The second convolution layer takes a frame and produces a frame
58 nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
59 nn.ReLU(),
The third convolution layer takes a frame and produces a frame
63 nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
64 nn.ReLU(),
65 )
A fully connected layer takes the flattened frame from third convolution layer, and outputs features
70 self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)
71 self.activation = nn.ReLU()
This head gives the state value
74 self.state_value = nn.Sequential(
75 nn.Linear(in_features=512, out_features=256),
76 nn.ReLU(),
77 nn.Linear(in_features=256, out_features=1),
78 )
This head gives the action value
80 self.action_value = nn.Sequential(
81 nn.Linear(in_features=512, out_features=256),
82 nn.ReLU(),
83 nn.Linear(in_features=256, out_features=4),
84 )
86 def forward(self, obs: torch.Tensor):
Convolution
88 h = self.conv(obs)
Reshape for linear layers
90 h = h.reshape((-1, 7 * 7 * 64))
Linear layer
93 h = self.activation(self.lin(h))
96 action_value = self.action_value(h)
98 state_value = self.state_value(h)
101 action_score_centered = action_value - action_value.mean(dim=-1, keepdim=True)
103 q = state_value + action_score_centered
104
105 return q