优先体验重播会更频繁地采样重要的过渡。过渡的优先级由时差误差(td 错误)确定。
我们使用概率对过渡进行抽样,其中是确定使用多少优先级的超参数,对应于统一大小写。是当务之急。
我们使用比例优先级,其中是过渡的时间差异。
我们使用损失函数中的重要性采样(IS)权重来纠正优先重播引入的偏差。这完全补偿了.为了稳定起见,我们将权重归一化。对于训练结束时的趋同,公正的本质最为重要。因此,我们在训练快要结束时增加了。
我们使用二叉段树来有效地计算采样所需的累积概率。我们还使用二叉段树来查找,这是必需的。我们也可以为此使用最小堆。二叉段树可以让我们及时计算它们,这比天真的方法效率要高得多。
这就是二叉段树求和的方式;它与最小值相似。让我们来看我们要表示的值的列表。让我们成为二叉树中行的节点。这是节点的两个子节点是和。
行上的叶节点的值将为。每个节点都保留两个子节点的总和。也就是说,根节点保留整个数组值的总和。根节点的左侧和右侧子节点分别保留数组前半部分和后半部分的总和。依此类推...
行中的节点数,这等于以上所有行中的节点总和。因此,我们可以使用单个数组来存储树,其中,
然后是和的子节点。也就是说,
这种维护二叉树的方法很容易编程。请注意,我们从 1 开始索引。
我们使用相同的结构来计算最小值。
20class ReplayBuffer:
90 def __init__(self, capacity, alpha):
我们使用 power f or capacity,因为它简化了代码和调试
95 self.capacity = capacity
97 self.alpha = alpha
维护分段二叉树以求和并找出一定范围内的最小值
100 self.priority_sum = [0 for _ in range(2 * self.capacity)]
101 self.priority_min = [float('inf') for _ in range(2 * self.capacity)]
要分配给新过渡的当前最大优先级
104 self.max_priority = 1.
缓冲区数组
107 self.data = {
108 'obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8),
109 'action': np.zeros(shape=capacity, dtype=np.int32),
110 'reward': np.zeros(shape=capacity, dtype=np.float32),
111 'next_obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8),
112 'done': np.zeros(shape=capacity, dtype=np.bool)
113 }
我们使用循环缓冲区来存储数据,并next_idx
保留下一个空槽的索引
116 self.next_idx = 0
缓冲区的大小
119 self.size = 0
121 def add(self, obs, action, reward, next_obs, done):
获取下一个可用插槽
127 idx = self.next_idx
存储在队列中
130 self.data['obs'][idx] = obs
131 self.data['action'][idx] = action
132 self.data['reward'][idx] = reward
133 self.data['next_obs'][idx] = next_obs
134 self.data['done'][idx] = done
增加下一个可用插槽
137 self.next_idx = (idx + 1) % self.capacity
计算大小
139 self.size = min(self.capacity, self.size + 1)
,新样品得到max_priority
142 priority_alpha = self.max_priority ** self.alpha
更新总和和和最小值的两个区段树
144 self._set_priority_min(idx, priority_alpha)
145 self._set_priority_sum(idx, priority_alpha)
147 def _set_priority_min(self, idx, priority_alpha):
二叉树的叶子
153 idx += self.capacity
154 self.priority_min[idx] = priority_alpha
通过沿着祖先遍历来更新树。继续直到树的根部。
158 while idx >= 2:
获取父节点的索引
160 idx //= 2
父节点的值是其两个子节点中的最小值
162 self.priority_min[idx] = min(self.priority_min[2 * idx], self.priority_min[2 * idx + 1])
164 def _set_priority_sum(self, idx, priority):
二叉树的叶子
170 idx += self.capacity
在叶子上设置优先级
172 self.priority_sum[idx] = priority
通过沿着祖先遍历来更新树。继续直到树的根部。
176 while idx >= 2:
获取父节点的索引
178 idx //= 2
父节点的值是其两个子节点的总和
180 self.priority_sum[idx] = self.priority_sum[2 * idx] + self.priority_sum[2 * idx + 1]
182 def _sum(self):
根节点保留所有值的总和
188 return self.priority_sum[1]
190 def _min(self):
根节点保留所有值中的最小值
196 return self.priority_min[1]
198 def find_prefix_sum_idx(self, prefix_sum):
从根开始
204 idx = 1
205 while idx < self.capacity:
如果左侧分支的总和高于要求的总和
207 if self.priority_sum[idx * 2] > prefix_sum:
转到树的左侧分支
209 idx = 2 * idx
210 else:
否则,转到右分支并从所需的总和中减少左分支的总和
213 prefix_sum -= self.priority_sum[idx * 2]
214 idx = 2 * idx + 1
我们在叶节点。通过树中的索引减去容量,得出实际值的索引
218 return idx - self.capacity
220 def sample(self, batch_size, beta):
初始化样本
226 samples = {
227 'weights': np.zeros(shape=batch_size, dtype=np.float32),
228 'indexes': np.zeros(shape=batch_size, dtype=np.int32)
229 }
获取样本索引
232 for i in range(batch_size):
233 p = random.random() * self._sum()
234 idx = self.find_prefix_sum_idx(p)
235 samples['indexes'][i] = idx
238 prob_min = self._min() / self._sum()
240 max_weight = (prob_min * self.size) ** (-beta)
241
242 for i in range(batch_size):
243 idx = samples['indexes'][i]
245 prob = self.priority_sum[idx + self.capacity] / self._sum()
247 weight = (prob * self.size) ** (-beta)
标准化依据,这也会取消该期限
250 samples['weights'][i] = weight / max_weight
获取样本数据
253 for k, v in self.data.items():
254 samples[k] = v[samples['indexes']]
255
256 return samples
258 def update_priorities(self, indexes, priorities):
263 for idx, priority in zip(indexes, priorities):
设置当前最大优先级
265 self.max_priority = max(self.max_priority, priority)
计算
268 priority_alpha = priority ** self.alpha
更新树木
270 self._set_priority_min(idx, priority_alpha)
271 self._set_priority_sum(idx, priority_alpha)
273 def is_full(self):
277 return self.capacity == self.size