24from typing import Tuple
25
26import torch
27from torch import nn
28
29from labml import tracker
30from labml_helpers.module import Module
31from labml_nn.rl.dqn.replay_buffer import ReplayBuffer
我们想找到最佳的动作值函数。
为了提高稳定性,我们使用经验回放,从以前的经验中随机抽样。我们还使用具有一组单独参数的 Q 网络来计算目标。定期更新。这是根据论文《通过深度强化学习进行人体水平控制》得出的。
所以损失函数是,
上述计算中的最大值运算符使用相同的网络来选择最佳动作和评估值。也就是说,我们使用双重Q-L earning,其中取自值,取自值。
损失函数变成,
34class QFuncLoss(Module):
102 def __init__(self, gamma: float):
103 super().__init__()
104 self.gamma = gamma
105 self.huber_loss = nn.SmoothL1Loss(reduction='none')
107 def forward(self, q: torch.Tensor, action: torch.Tensor, double_q: torch.Tensor,
108 target_q: torch.Tensor, done: torch.Tensor, reward: torch.Tensor,
109 weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
121 q_sampled_action = q.gather(-1, action.to(torch.long).unsqueeze(-1)).squeeze(-1)
122 tracker.add('q_sampled_action', q_sampled_action)
渐变不应传播渐变
130 with torch.no_grad():
在州内采取最佳行动
134 best_next_action = torch.argmax(double_q, -1)
从目标网络获取 q 值,以便在州内采取最佳行动
140 best_next_q_value = target_q.gather(-1, best_next_action.unsqueeze(-1)).squeeze(-1)
151 q_update = reward + self.gamma * best_next_q_value * (1 - done)
152 tracker.add('q_update', q_update)
时差误差用于称量重放缓冲区中的样本
155 td_error = q_sampled_action - q_update
156 tracker.add('td_error', td_error)
获取加权均值
162 loss = torch.mean(weights * losses)
163 tracker.add('loss', loss)
164
165 return td_error, loss