优先体验重播会更频繁地采样重要的过渡。过渡的优先级由时差误差(td 错误)确定。
我们使用概率对过渡进行抽样,其中是确定使用多少优先级的超参数,对应于统一大小写。是当务之急。
我们使用比例优先级,其中是过渡的时间差异。
我们使用损失函数中的重要性采样(IS)权重来纠正优先重播引入的偏差。这完全补偿了.为了稳定起见,我们将权重归一化。对于训练结束时的趋同,公正的本质最为重要。因此,我们在训练快要结束时增加了。
我们使用二叉段树来有效地计算采样所需的累积概率。我们还使用二叉段树来查找,这是必需的。我们也可以为此使用最小堆。二叉段树可以让我们及时计算它们,这比天真的方法效率要高得多。
这就是二叉段树求和的方式;它与最小值相似。让我们来看我们要表示的值的列表。让我们成为二叉树中行的节点。这是节点的两个子节点是和。
行上的叶节点的值将为。每个节点都保留两个子节点的总和。也就是说,根节点保留整个数组值的总和。根节点的左侧和右侧子节点分别保留数组前半部分和后半部分的总和。依此类推...
行中的节点数,这等于以上所有行中的节点总和。因此,我们可以使用单个数组来存储树,其中,
然后是和的子节点。也就是说,
这种维护二叉树的方法很容易编程。请注意,我们从 1 开始索引。
我们使用相同的结构来计算最小值。
20class ReplayBuffer:90    def __init__(self, capacity, alpha):我们使用 power f or capacity,因为它简化了代码和调试
95        self.capacity = capacity97        self.alpha = alpha维护分段二叉树以求和并找出一定范围内的最小值
100        self.priority_sum = [0 for _ in range(2 * self.capacity)]
101        self.priority_min = [float('inf') for _ in range(2 * self.capacity)]要分配给新过渡的当前最大优先级
104        self.max_priority = 1.缓冲区数组
107        self.data = {
108            'obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8),
109            'action': np.zeros(shape=capacity, dtype=np.int32),
110            'reward': np.zeros(shape=capacity, dtype=np.float32),
111            'next_obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8),
112            'done': np.zeros(shape=capacity, dtype=np.bool)
113        }我们使用循环缓冲区来存储数据,并next_idx
保留下一个空槽的索引
116        self.next_idx = 0缓冲区的大小
119        self.size = 0121    def add(self, obs, action, reward, next_obs, done):获取下一个可用插槽
127        idx = self.next_idx存储在队列中
130        self.data['obs'][idx] = obs
131        self.data['action'][idx] = action
132        self.data['reward'][idx] = reward
133        self.data['next_obs'][idx] = next_obs
134        self.data['done'][idx] = done增加下一个可用插槽
137        self.next_idx = (idx + 1) % self.capacity计算大小
139        self.size = min(self.capacity, self.size + 1),新样品得到max_priority
142        priority_alpha = self.max_priority ** self.alpha更新总和和和最小值的两个区段树
144        self._set_priority_min(idx, priority_alpha)
145        self._set_priority_sum(idx, priority_alpha)147    def _set_priority_min(self, idx, priority_alpha):二叉树的叶子
153        idx += self.capacity
154        self.priority_min[idx] = priority_alpha通过沿着祖先遍历来更新树。继续直到树的根部。
158        while idx >= 2:获取父节点的索引
160            idx //= 2父节点的值是其两个子节点中的最小值
162            self.priority_min[idx] = min(self.priority_min[2 * idx], self.priority_min[2 * idx + 1])164    def _set_priority_sum(self, idx, priority):二叉树的叶子
170        idx += self.capacity在叶子上设置优先级
172        self.priority_sum[idx] = priority通过沿着祖先遍历来更新树。继续直到树的根部。
176        while idx >= 2:获取父节点的索引
178            idx //= 2父节点的值是其两个子节点的总和
180            self.priority_sum[idx] = self.priority_sum[2 * idx] + self.priority_sum[2 * idx + 1]182    def _sum(self):根节点保留所有值的总和
188        return self.priority_sum[1]190    def _min(self):根节点保留所有值中的最小值
196        return self.priority_min[1]198    def find_prefix_sum_idx(self, prefix_sum):从根开始
204        idx = 1
205        while idx < self.capacity:如果左侧分支的总和高于要求的总和
207            if self.priority_sum[idx * 2] > prefix_sum:转到树的左侧分支
209                idx = 2 * idx
210            else:否则,转到右分支并从所需的总和中减少左分支的总和
213                prefix_sum -= self.priority_sum[idx * 2]
214                idx = 2 * idx + 1我们在叶节点。通过树中的索引减去容量,得出实际值的索引
218        return idx - self.capacity220    def sample(self, batch_size, beta):初始化样本
226        samples = {
227            'weights': np.zeros(shape=batch_size, dtype=np.float32),
228            'indexes': np.zeros(shape=batch_size, dtype=np.int32)
229        }获取样本索引
232        for i in range(batch_size):
233            p = random.random() * self._sum()
234            idx = self.find_prefix_sum_idx(p)
235            samples['indexes'][i] = idx238        prob_min = self._min() / self._sum()240        max_weight = (prob_min * self.size) ** (-beta)
241
242        for i in range(batch_size):
243            idx = samples['indexes'][i]245            prob = self.priority_sum[idx + self.capacity] / self._sum()247            weight = (prob * self.size) ** (-beta)标准化依据,这也会取消该期限
250            samples['weights'][i] = weight / max_weight获取样本数据
253        for k, v in self.data.items():
254            samples[k] = v[samples['indexes']]
255
256        return samples258    def update_priorities(self, indexes, priorities):263        for idx, priority in zip(indexes, priorities):设置当前最大优先级
265            self.max_priority = max(self.max_priority, priority)计算
268            priority_alpha = priority ** self.alpha更新树木
270            self._set_priority_min(idx, priority_alpha)
271            self._set_priority_sum(idx, priority_alpha)273    def is_full(self):277        return self.capacity == self.size