12import torch
13from torch import nn
14
15from labml_helpers.module import Module
Q値の計算にはデュエルネットワークを使用しています。デュエルネットワークアーキテクチャの背後にある直感は、ほとんどの州ではアクションは重要ではなく、一部の州ではアクションが重要であるということです。デュエルネットワークでは、これを非常によく表現できます
。そこで、とからの 2 つのネットワークを作成して、その 2 つのネットワークから取得します。とネットワークの初期レイヤーを共有します。
18class Model(Module):
49 def __init__(self):
50 super().__init__()
51 self.conv = nn.Sequential(
最初の畳み込み層はフレームを取り、フレームを生成します。
54 nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4),
55 nn.ReLU(),
2 番目の畳み込み層は、フレームを取得してフレームを生成します。
59 nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
60 nn.ReLU(),
3 番目の畳み込み層は、フレームを取得してフレームを生成します。
64 nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
65 nn.ReLU(),
66 )
完全に接続されたレイヤーは、3 番目のコンボリューションレイヤーからフラット化されたフレームを取り出し、フィーチャを出力します。
71 self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)
72 self.activation = nn.ReLU()
このヘッドは状態値を与えます
75 self.state_value = nn.Sequential(
76 nn.Linear(in_features=512, out_features=256),
77 nn.ReLU(),
78 nn.Linear(in_features=256, out_features=1),
79 )
このヘッドはアクション値を与えます
81 self.action_value = nn.Sequential(
82 nn.Linear(in_features=512, out_features=256),
83 nn.ReLU(),
84 nn.Linear(in_features=256, out_features=4),
85 )
87 def forward(self, obs: torch.Tensor):
コンボリューション
89 h = self.conv(obs)
線形レイヤーの形状を変更
91 h = h.reshape((-1, 7 * 7 * 64))
リニアレイヤー
94 h = self.activation(self.lin(h))
97 action_value = self.action_value(h)
99 state_value = self.state_value(h)
102 action_score_centered = action_value - action_value.mean(dim=-1, keepdim=True)
104 q = state_value + action_score_centered
105
106 return q