ディープQネットワーク (DQN) モデル

Open In Colab

12import torch
13from torch import nn
14
15from labml_helpers.module import Module

デュエルネットワーク ⚔️ 価値モデル

Q値の計算にはデュエルネットワークを使用しています。デュエルネットワークアーキテクチャの背後にある直感は、ほとんどの州ではアクションは重要ではなく、一部の州ではアクションが重要であるということです。デュエルネットワークでは、これを非常によく表現できます

そこで、とからの 2 つのネットワークを作成して、その 2 つのネットワークから取得します。とネットワークの初期レイヤーを共有します。

18class Model(Module):
49    def __init__(self):
50        super().__init__()
51        self.conv = nn.Sequential(

最初の畳み込み層はフレームを取り、フレームを生成します。

54            nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4),
55            nn.ReLU(),

2 番目の畳み込み層は、フレームを取得してフレームを生成します。

59            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
60            nn.ReLU(),

3 番目の畳み込み層は、フレームを取得してフレームを生成します。

64            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
65            nn.ReLU(),
66        )

完全に接続されたレイヤーは、3 番目のコンボリューションレイヤーからフラット化されたフレームを取り出し、フィーチャを出力します。

71        self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)
72        self.activation = nn.ReLU()

このヘッドは状態値を与えます

75        self.state_value = nn.Sequential(
76            nn.Linear(in_features=512, out_features=256),
77            nn.ReLU(),
78            nn.Linear(in_features=256, out_features=1),
79        )

このヘッドはアクション値を与えます

81        self.action_value = nn.Sequential(
82            nn.Linear(in_features=512, out_features=256),
83            nn.ReLU(),
84            nn.Linear(in_features=256, out_features=4),
85        )
87    def forward(self, obs: torch.Tensor):

コンボリューション

89        h = self.conv(obs)

線形レイヤーの形状を変更

91        h = h.reshape((-1, 7 * 7 * 64))

リニアレイヤー

94        h = self.activation(self.lin(h))

97        action_value = self.action_value(h)

99        state_value = self.state_value(h)

102        action_score_centered = action_value - action_value.mean(dim=-1, keepdim=True)

104        q = state_value + action_score_centered
105
106        return q