# 优先体验重播缓冲区

15import random
16
17import numpy as np

## 优先体验回放的缓冲区

### 二叉段树

20class ReplayBuffer:

### 初始化

90    def __init__(self, capacity, alpha):

95        self.capacity = capacity
97        self.alpha = alpha

100        self.priority_sum = [0 for _ in range(2 * self.capacity)]
101        self.priority_min = [float('inf') for _ in range(2 * self.capacity)]

104        self.max_priority = 1.

107        self.data = {
108            'obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8),
109            'action': np.zeros(shape=capacity, dtype=np.int32),
110            'reward': np.zeros(shape=capacity, dtype=np.float32),
111            'next_obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8),
112            'done': np.zeros(shape=capacity, dtype=np.bool)
113        }

116        self.next_idx = 0

119        self.size = 0

### 将样品添加到队列

121    def add(self, obs, action, reward, next_obs, done):

127        idx = self.next_idx

130        self.data['obs'][idx] = obs
131        self.data['action'][idx] = action
132        self.data['reward'][idx] = reward
133        self.data['next_obs'][idx] = next_obs
134        self.data['done'][idx] = done

137        self.next_idx = (idx + 1) % self.capacity

139        self.size = min(self.capacity, self.size + 1)

，新样品得到max_priority

142        priority_alpha = self.max_priority ** self.alpha

144        self._set_priority_min(idx, priority_alpha)
145        self._set_priority_sum(idx, priority_alpha)

#### 将二叉段树中的优先级设置为最小值

147    def _set_priority_min(self, idx, priority_alpha):

153        idx += self.capacity
154        self.priority_min[idx] = priority_alpha

158        while idx >= 2:

160            idx //= 2

162            self.priority_min[idx] = min(self.priority_min[2 * idx], self.priority_min[2 * idx + 1])

#### 在二叉段树中设置 sum 的优先级

164    def _set_priority_sum(self, idx, priority):

170        idx += self.capacity

172        self.priority_sum[idx] = priority

176        while idx >= 2:

178            idx //= 2

180            self.priority_sum[idx] = self.priority_sum[2 * idx] + self.priority_sum[2 * idx + 1]
182    def _sum(self):

188        return self.priority_sum[1]
190    def _min(self):

196        return self.priority_min[1]

#### 找到这样的最大

198    def find_prefix_sum_idx(self, prefix_sum):

204        idx = 1
205        while idx < self.capacity:

207            if self.priority_sum[idx * 2] > prefix_sum:

209                idx = 2 * idx
210            else:

213                prefix_sum -= self.priority_sum[idx * 2]
214                idx = 2 * idx + 1

218        return idx - self.capacity

### 来自缓冲液的样本

220    def sample(self, batch_size, beta):

226        samples = {
227            'weights': np.zeros(shape=batch_size, dtype=np.float32),
228            'indexes': np.zeros(shape=batch_size, dtype=np.int32)
229        }

232        for i in range(batch_size):
233            p = random.random() * self._sum()
234            idx = self.find_prefix_sum_idx(p)
235            samples['indexes'][i] = idx
238        prob_min = self._min() / self._sum()
240        max_weight = (prob_min * self.size) ** (-beta)
241
242        for i in range(batch_size):
243            idx = samples['indexes'][i]
245            prob = self.priority_sum[idx + self.capacity] / self._sum()
247            weight = (prob * self.size) ** (-beta)

250            samples['weights'][i] = weight / max_weight

253        for k, v in self.data.items():
254            samples[k] = v[samples['indexes']]
255
256        return samples

### 更新优先级

258    def update_priorities(self, indexes, priorities):
263        for idx, priority in zip(indexes, priorities):

265            self.max_priority = max(self.max_priority, priority)

268            priority_alpha = priority ** self.alpha

270            self._set_priority_min(idx, priority_alpha)
271            self._set_priority_sum(idx, priority_alpha)

### 缓冲区是否已满

273    def is_full(self):
277        return self.capacity == self.size