This implements paper Prioritized experience replay, using a binary segment tree.
15import random
16
17import numpy as np
Prioritized experience replay samples important transitions more frequently. The transitions are prioritized by the Temporal Difference error (td error), .
We sample transition with probability, where is a hyper-parameter that determines how much prioritization is used, with corresponding to uniform case. is the priority.
We use proportional prioritization where is the temporal difference for transition .
We correct the bias introduced by prioritized replay using importance-sampling (IS) weights in the loss function. This fully compensates when . We normalize weights by for stability. Unbiased nature is most important towards the convergence at end of training. Therefore we increase towards end of training.
We use a binary segment tree to efficiently calculate , the cumulative probability, which is needed to sample. We also use a binary segment tree to find , which is needed for . We can also use a min-heap for this. Binary Segment Tree lets us calculate these in time, which is way more efficient that the naive approach.
This is how a binary segment tree works for sum; it is similar for minimum. Let be the list of values we want to represent. Let be the node of the row in the binary tree. That is two children of node are and .
The leaf nodes on row will have values of . Every node keeps the sum of the two child nodes. That is, the root node keeps the sum of the entire array of values. The left and right children of the root node keep the sum of the first half of the array and the sum of the second half of the array, respectively. And so on...
Number of nodes in row , This is equal to the sum of nodes in all rows above . So we can use a single array to store the tree, where,
Then child nodes of are and . That is,
This way of maintaining binary trees is very easy to program. Note that we are indexing starting from 1.
We use the same structure to compute the minimum.
20class ReplayBuffer:
90 def __init__(self, capacity, alpha):
We use a power of for capacity because it simplifies the code and debugging
95 self.capacity = capacity
97 self.alpha = alpha
Maintain segment binary trees to take sum and find minimum over a range
100 self.priority_sum = [0 for _ in range(2 * self.capacity)]
101 self.priority_min = [float('inf') for _ in range(2 * self.capacity)]
Current max priority, , to be assigned to new transitions
104 self.max_priority = 1.
Arrays for buffer
107 self.data = {
108 'obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8),
109 'action': np.zeros(shape=capacity, dtype=np.int32),
110 'reward': np.zeros(shape=capacity, dtype=np.float32),
111 'next_obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8),
112 'done': np.zeros(shape=capacity, dtype=np.bool)
113 }
We use cyclic buffers to store data, and next_idx
keeps the index of the next empty slot
116 self.next_idx = 0
Size of the buffer
119 self.size = 0
121 def add(self, obs, action, reward, next_obs, done):
Get next available slot
127 idx = self.next_idx
store in the queue
130 self.data['obs'][idx] = obs
131 self.data['action'][idx] = action
132 self.data['reward'][idx] = reward
133 self.data['next_obs'][idx] = next_obs
134 self.data['done'][idx] = done
Increment next available slot
137 self.next_idx = (idx + 1) % self.capacity
Calculate the size
139 self.size = min(self.capacity, self.size + 1)
, new samples get max_priority
142 priority_alpha = self.max_priority ** self.alpha
Update the two segment trees for sum and minimum
144 self._set_priority_min(idx, priority_alpha)
145 self._set_priority_sum(idx, priority_alpha)
147 def _set_priority_min(self, idx, priority_alpha):
Leaf of the binary tree
153 idx += self.capacity
154 self.priority_min[idx] = priority_alpha
Update tree, by traversing along ancestors. Continue until the root of the tree.
158 while idx >= 2:
Get the index of the parent node
160 idx //= 2
Value of the parent node is the minimum of it's two children
162 self.priority_min[idx] = min(self.priority_min[2 * idx], self.priority_min[2 * idx + 1])
164 def _set_priority_sum(self, idx, priority):
Leaf of the binary tree
170 idx += self.capacity
Set the priority at the leaf
172 self.priority_sum[idx] = priority
Update tree, by traversing along ancestors. Continue until the root of the tree.
176 while idx >= 2:
Get the index of the parent node
178 idx //= 2
Value of the parent node is the sum of it's two children
180 self.priority_sum[idx] = self.priority_sum[2 * idx] + self.priority_sum[2 * idx + 1]
182 def _sum(self):
The root node keeps the sum of all values
188 return self.priority_sum[1]
190 def _min(self):
The root node keeps the minimum of all values
196 return self.priority_min[1]
198 def find_prefix_sum_idx(self, prefix_sum):
Start from the root
204 idx = 1
205 while idx < self.capacity:
If the sum of the left branch is higher than required sum
207 if self.priority_sum[idx * 2] > prefix_sum:
Go to left branch of the tree
209 idx = 2 * idx
210 else:
Otherwise go to right branch and reduce the sum of left branch from required sum
213 prefix_sum -= self.priority_sum[idx * 2]
214 idx = 2 * idx + 1
We are at the leaf node. Subtract the capacity by the index in the tree to get the index of actual value
218 return idx - self.capacity
220 def sample(self, batch_size, beta):
Initialize samples
226 samples = {
227 'weights': np.zeros(shape=batch_size, dtype=np.float32),
228 'indexes': np.zeros(shape=batch_size, dtype=np.int32)
229 }
Get sample indexes
232 for i in range(batch_size):
233 p = random.random() * self._sum()
234 idx = self.find_prefix_sum_idx(p)
235 samples['indexes'][i] = idx
238 prob_min = self._min() / self._sum()
240 max_weight = (prob_min * self.size) ** (-beta)
241
242 for i in range(batch_size):
243 idx = samples['indexes'][i]
245 prob = self.priority_sum[idx + self.capacity] / self._sum()
247 weight = (prob * self.size) ** (-beta)
Normalize by , which also cancels off the term
250 samples['weights'][i] = weight / max_weight
Get samples data
253 for k, v in self.data.items():
254 samples[k] = v[samples['indexes']]
255
256 return samples
258 def update_priorities(self, indexes, priorities):
263 for idx, priority in zip(indexes, priorities):
Set current max priority
265 self.max_priority = max(self.max_priority, priority)
Calculate
268 priority_alpha = priority ** self.alpha
Update the trees
270 self._set_priority_min(idx, priority_alpha)
271 self._set_priority_sum(idx, priority_alpha)
273 def is_full(self):
277 return self.capacity == self.size