广义优势估计 (GAE)

这是论文广义优势估计PyTorch 实现。

你可以在这里找到一个使用它的实验。

15import numpy as np
18class GAE:
19    def __init__(self, n_workers: int, worker_steps: int, gamma: float, lambda_: float):
20        self.lambda_ = lambda_
21        self.gamma = gamma
22        self.worker_steps = worker_steps
23        self.n_workers = n_workers

计算优势

是高偏差,低方差,而无偏差,高方差。

我们采用加权平均值来平衡偏差和方差。这称为广义优势估计。我们设置,这给出了干净的计算

25    def __call__(self, done: np.ndarray, rewards: np.ndarray, values: np.ndarray) -> np.ndarray:

优势表

59        advantages = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
60        last_advantage = 0

63        last_value = values[:, -1]
64
65        for t in reversed(range(self.worker_steps)):

如果剧集在步骤之后完成,请掩盖

67            mask = 1.0 - done[:, t]
68            last_value = last_value * mask
69            last_advantage = last_advantage * mask

71            delta = rewards[:, t] + self.gamma * last_value - values[:, t]

74            last_advantage = delta + self.gamma * self.lambda_ * last_advantage

请注意,我们正在按相反的顺序收集。我最初的代码被追加到一个列表中,后来我忘记反转它了。我花了大约 4 到 5 个小时才发现 bug。在初始运行期间,该模型的性能略有改善,这可能是因为样本相似。

83            advantages[:, t] = last_advantage
84
85            last_value = values[:, t]
86
87        return advantages