12import torch
13from torch import nn
14
15from labml_helpers.module import Module
我们正在使用决斗网络来计算 Q 值。决斗网络架构背后的直觉是,在大多数州,行动无关紧要,而在某些州,行动意义重大。决斗网络可以很好地体现这一点。
因此,我们为和创建了两个网络,然后从中获取。我们共享和网络的初始层。
18class Model(Module):
49 def __init__(self):
50 super().__init__()
51 self.conv = nn.Sequential(
第一个卷积层需要一个帧并生成一个帧
54 nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4),
55 nn.ReLU(),
第二个卷积层获取一个帧并生成一个帧
59 nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
60 nn.ReLU(),
第三个卷积层获取一个帧并生成一个帧
64 nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
65 nn.ReLU(),
66 )
完全连接的图层从第三个卷积图层获取展平的帧,并输出要素
71 self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)
72 self.activation = nn.ReLU()
这个头给出了状态值
75 self.state_value = nn.Sequential(
76 nn.Linear(in_features=512, out_features=256),
77 nn.ReLU(),
78 nn.Linear(in_features=256, out_features=1),
79 )
这个头给出了动作值
81 self.action_value = nn.Sequential(
82 nn.Linear(in_features=512, out_features=256),
83 nn.ReLU(),
84 nn.Linear(in_features=256, out_features=4),
85 )
87 def forward(self, obs: torch.Tensor):
卷积
89 h = self.conv(obs)
线性图层的整形
91 h = h.reshape((-1, 7 * 7 * 64))
线性层
94 h = self.activation(self.lin(h))
97 action_value = self.action_value(h)
99 state_value = self.state_value(h)
102 action_score_centered = action_value - action_value.mean(dim=-1, keepdim=True)
104 q = state_value + action_score_centered
105
106 return q