# 深度 Q 网络 (DQN)

24from typing import Tuple
25
26import torch
27from torch import nn
28
29from labml import tracker
30from labml_helpers.module import Module
31from labml_nn.rl.dqn.replay_buffer import ReplayBuffer

## 训练模型

### 双重学习

34class QFuncLoss(Module):
102    def __init__(self, gamma: float):
103        super().__init__()
104        self.gamma = gamma
105        self.huber_loss = nn.SmoothL1Loss(reduction='none')
• q -
• action -
• double_q -
• target_q -
• done -游戏在采取行动后是否结束
• reward -
• weights -来自有经验的优先重播的样本的权重
• 107    def forward(self, q: torch.Tensor, action: torch.Tensor, double_q: torch.Tensor,
108                target_q: torch.Tensor, done: torch.Tensor, reward: torch.Tensor,
109                weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
121        q_sampled_action = q.gather(-1, action.to(torch.long).unsqueeze(-1)).squeeze(-1)
122        tracker.add('q_sampled_action', q_sampled_action)

渐变不应传播渐变

130        with torch.no_grad():

在州内采取最佳行动

134            best_next_action = torch.argmax(double_q, -1)

从目标网络获取 q 值，以便在州内采取最佳行动

140            best_next_q_value = target_q.gather(-1, best_next_action.unsqueeze(-1)).squeeze(-1)

计算所需的 Q 值。如果游戏结束，我们将乘(1 - done) 以将下一个状态 Q 值归零。

151            q_update = reward + self.gamma * best_next_q_value * (1 - done)
152            tracker.add('q_update', q_update)

时差误差用于称量重放缓冲区中的样本

155            td_error = q_sampled_action - q_update
156            tracker.add('td_error', td_error)

我们采用 Huber 损失而不是均方误差损失，因为它对异常值不太敏感

160        losses = self.huber_loss(q_sampled_action, q_update)

获取加权均值

162        loss = torch.mean(weights * losses)
165        return td_error, loss