24from typing import Tuple
25
26import torch
27from torch import nn
28
29from labml import tracker
30from labml_helpers.module import Module
31from labml_nn.rl.dqn.replay_buffer import ReplayBuffer

モデルのトレーニング

最適なアクションバリュー関数を見つけたい。

ターゲットネットワーク 🎯

安定性を向上させるために、以前のエクスペリエンスからランダムにサンプリングされるエクスペリエンスのリプレイを使用しています。また、別のパラメータセットを持つQネットワークを使用してターゲットを計算します。定期的に更新されます。これは、深層強化学習によるヒューマンレベル制御の論文によるものです

したがって、損失関数は、

ダブルラーニング

上の計算の max 演算子は、最適なアクションの選択と値の評価の両方に同じネットワークを使用します。つまりの取得元と値の取得元という二重Qラーニングを使用しています

そして、損失関数は次のようになります。

34class QFuncLoss(Module):
102    def __init__(self, gamma: float):
103        super().__init__()
104        self.gamma = gamma
105        self.huber_loss = nn.SmoothL1Loss(reduction='none')
  • q -
  • action -
  • double_q -
  • target_q -
  • done -アクションを実行した後にゲームが終了したかどうか
  • reward -
  • weights -経験豊かなリプレイを優先して抽出したサンプルの重み
107    def forward(self, q: torch.Tensor, action: torch.Tensor, double_q: torch.Tensor,
108                target_q: torch.Tensor, done: torch.Tensor, reward: torch.Tensor,
109                weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

121        q_sampled_action = q.gather(-1, action.to(torch.long).unsqueeze(-1)).squeeze(-1)
122        tracker.add('q_sampled_action', q_sampled_action)

グラデーションはグラデーションを伝播してはいけません

130        with torch.no_grad():

州で最高のアクションを

134            best_next_action = torch.argmax(double_q, -1)

ターゲットネットワークから q 値を取得して、状態での最適なアクションを実現する

140            best_next_q_value = target_q.gather(-1, best_next_action.unsqueeze(-1)).squeeze(-1)

目的の Q 値を計算します。ゲームが終了したら(1 - done) 、を掛けて次のステートのQ値をゼロにします

151            q_update = reward + self.gamma * best_next_q_value * (1 - done)
152            tracker.add('q_update', q_update)

時間差エラーはリプレイバッファ内のサンプルの重み付けに使用されます

155            td_error = q_sampled_action - q_update
156            tracker.add('td_error', td_error)

外れ値の影響を受けにくいので、平均二乗誤差損失の代わりにフーバー損失を使用します

160        losses = self.huber_loss(q_sampled_action, q_update)

加重平均を取得

162        loss = torch.mean(weights * losses)
163        tracker.add('loss', loss)
164
165        return td_error, loss