Prioritized Experience Replay Buffer

This implements paper Prioritized experience replay, using a binary segment tree.

Open In Colab

15import random
17import numpy as np

Buffer for Prioritized Experience Replay

Prioritized experience replay samples important transitions more frequently. The transitions are prioritized by the Temporal Difference error (td error), .

We sample transition with probability, where is a hyper-parameter that determines how much prioritization is used, with corresponding to uniform case. is the priority.

We use proportional prioritization where is the temporal difference for transition .

We correct the bias introduced by prioritized replay using importance-sampling (IS) weights in the loss function. This fully compensates when . We normalize weights by for stability. Unbiased nature is most important towards the convergence at end of training. Therefore we increase towards end of training.

Binary Segment Tree

We use a binary segment tree to efficiently calculate , the cumulative probability, which is needed to sample. We also use a binary segment tree to find , which is needed for . We can also use a min-heap for this. Binary Segment Tree lets us calculate these in time, which is way more efficient that the naive approach.

This is how a binary segment tree works for sum; it is similar for minimum. Let be the list of values we want to represent. Let be the node of the row in the binary tree. That is two children of node are and .

The leaf nodes on row will have values of . Every node keeps the sum of the two child nodes. That is, the root node keeps the sum of the entire array of values. The left and right children of the root node keep the sum of the first half of the array and the sum of the second half of the array, respectively. And so on...

Number of nodes in row , This is equal to the sum of nodes in all rows above . So we can use a single array to store the tree, where,

Then child nodes of are and . That is,

This way of maintaining binary trees is very easy to program. Note that we are indexing starting from 1.

We use the same structure to compute the minimum.

20class ReplayBuffer:


90    def __init__(self, capacity, alpha):

We use a power of for capacity because it simplifies the code and debugging

95        self.capacity = capacity

97        self.alpha = alpha

Maintain segment binary trees to take sum and find minimum over a range

100        self.priority_sum = [0 for _ in range(2 * self.capacity)]
101        self.priority_min = [float('inf') for _ in range(2 * self.capacity)]

Current max priority, , to be assigned to new transitions

104        self.max_priority = 1.

Arrays for buffer

107 = {
108            'obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8),
109            'action': np.zeros(shape=capacity, dtype=np.int32),
110            'reward': np.zeros(shape=capacity, dtype=np.float32),
111            'next_obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8),
112            'done': np.zeros(shape=capacity, dtype=np.bool)
113        }

We use cyclic buffers to store data, and next_idx keeps the index of the next empty slot

116        self.next_idx = 0

Size of the buffer

119        self.size = 0

Add sample to queue

121    def add(self, obs, action, reward, next_obs, done):

Get next available slot

127        idx = self.next_idx

store in the queue

130['obs'][idx] = obs
131['action'][idx] = action
132['reward'][idx] = reward
133['next_obs'][idx] = next_obs
134['done'][idx] = done

Increment next available slot

137        self.next_idx = (idx + 1) % self.capacity

Calculate the size

139        self.size = min(self.capacity, self.size + 1)

, new samples get max_priority

142        priority_alpha = self.max_priority ** self.alpha

Update the two segment trees for sum and minimum

144        self._set_priority_min(idx, priority_alpha)
145        self._set_priority_sum(idx, priority_alpha)

Set priority in binary segment tree for minimum

147    def _set_priority_min(self, idx, priority_alpha):

Leaf of the binary tree

153        idx += self.capacity
154        self.priority_min[idx] = priority_alpha

Update tree, by traversing along ancestors. Continue until the root of the tree.

158        while idx >= 2:

Get the index of the parent node

160            idx //= 2

Value of the parent node is the minimum of it's two children

162            self.priority_min[idx] = min(self.priority_min[2 * idx], self.priority_min[2 * idx + 1])

Set priority in binary segment tree for sum

164    def _set_priority_sum(self, idx, priority):

Leaf of the binary tree

170        idx += self.capacity

Set the priority at the leaf

172        self.priority_sum[idx] = priority

Update tree, by traversing along ancestors. Continue until the root of the tree.

176        while idx >= 2:

Get the index of the parent node

178            idx //= 2

Value of the parent node is the sum of it's two children

180            self.priority_sum[idx] = self.priority_sum[2 * idx] + self.priority_sum[2 * idx + 1]

182    def _sum(self):

The root node keeps the sum of all values

188        return self.priority_sum[1]

190    def _min(self):

The root node keeps the minimum of all values

196        return self.priority_min[1]

Find largest such that

198    def find_prefix_sum_idx(self, prefix_sum):

Start from the root

204        idx = 1
205        while idx < self.capacity:

If the sum of the left branch is higher than required sum

207            if self.priority_sum[idx * 2] > prefix_sum:

Go to left branch of the tree

209                idx = 2 * idx
210            else:

Otherwise go to right branch and reduce the sum of left branch from required sum

213                prefix_sum -= self.priority_sum[idx * 2]
214                idx = 2 * idx + 1

We are at the leaf node. Subtract the capacity by the index in the tree to get the index of actual value

218        return idx - self.capacity

Sample from buffer

220    def sample(self, batch_size, beta):

Initialize samples

226        samples = {
227            'weights': np.zeros(shape=batch_size, dtype=np.float32),
228            'indexes': np.zeros(shape=batch_size, dtype=np.int32)
229        }

Get sample indexes

232        for i in range(batch_size):
233            p = random.random() * self._sum()
234            idx = self.find_prefix_sum_idx(p)
235            samples['indexes'][i] = idx

238        prob_min = self._min() / self._sum()

240        max_weight = (prob_min * self.size) ** (-beta)
242        for i in range(batch_size):
243            idx = samples['indexes'][i]

245            prob = self.priority_sum[idx + self.capacity] / self._sum()

247            weight = (prob * self.size) ** (-beta)

Normalize by , which also cancels off the term

250            samples['weights'][i] = weight / max_weight

Get samples data

253        for k, v in
254            samples[k] = v[samples['indexes']]
256        return samples

Update priorities

258    def update_priorities(self, indexes, priorities):
263        for idx, priority in zip(indexes, priorities):

Set current max priority

265            self.max_priority = max(self.max_priority, priority)


268            priority_alpha = priority ** self.alpha

Update the trees

270            self._set_priority_min(idx, priority_alpha)
271            self._set_priority_sum(idx, priority_alpha)

Whether the buffer is full

273    def is_full(self):
277        return self.capacity == self.size