深度 Q 网络 (DQN) 模型

Open In Colab

12import torch
13from torch import nn
14
15from labml_helpers.module import Module

决斗网络 ⚔️ 价值观模型

我们正在使用决斗网络来计算 Q 值。决斗网络架构背后的直觉是,在大多数州,行动无关紧要,而在某些州,行动意义重大。决斗网络可以很好地体现这一点。

因此,我们为和创建了两个网络,然后从中获取。我们共享网络的初始层。

18class Model(Module):
49    def __init__(self):
50        super().__init__()
51        self.conv = nn.Sequential(

第一个卷积层需要一个帧并生成一个

54            nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4),
55            nn.ReLU(),

第二个卷积层获取一个帧并生成一个

59            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
60            nn.ReLU(),

第三个卷积层获取一个帧并生成一个

64            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
65            nn.ReLU(),
66        )

完全连接的图层从第三个卷积图层获取展平的帧,并输出要素

71        self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)
72        self.activation = nn.ReLU()

这个头给出了状态值

75        self.state_value = nn.Sequential(
76            nn.Linear(in_features=512, out_features=256),
77            nn.ReLU(),
78            nn.Linear(in_features=256, out_features=1),
79        )

这个头给出了动作值

81        self.action_value = nn.Sequential(
82            nn.Linear(in_features=512, out_features=256),
83            nn.ReLU(),
84            nn.Linear(in_features=256, out_features=4),
85        )
87    def forward(self, obs: torch.Tensor):

卷积

89        h = self.conv(obs)

线性图层的整形

91        h = h.reshape((-1, 7 * 7 * 64))

线性层

94        h = self.activation(self.lin(h))

97        action_value = self.action_value(h)

99        state_value = self.state_value(h)

102        action_score_centered = action_value - action_value.mean(dim=-1, keepdim=True)

104        q = state_value + action_score_centered
105
106        return q