This implements paper Prioritized experience replay, using a binary segment tree.

```
16import random
17
18import numpy as np
```

Prioritized experience replay samples important transitions more frequently. The transitions are prioritized by the Temporal Difference error (td error), $δ$.

We sample transition $i$ with probability, $P(i)=∑_{k}p_{k}p_{i}_{α} $ where $α$ is a hyper-parameter that determines how much prioritization is used, with $α=0$ corresponding to uniform case. $p_{i}$ is the priority.

We use proportional prioritization $p_{i}=∣δ_{i}∣+ϵ$ where $δ_{i}$ is the temporal difference for transition $i$.

We correct the bias introduced by prioritized replay using importance-sampling (IS) weights $w_{i}=(N1 P(i)1 )_{β}$ in the loss function. This fully compensates when $β=1$. We normalize weights by $max_{i}w_{i}1 $ for stability. Unbiased nature is most important towards the convergence at end of training. Therefore we increase $β$ towards end of training.

We use a binary segment tree to efficiently calculate $∑_{k}p_{k}$, the cumulative probability, which is needed to sample. We also use a binary segment tree to find $minp_{i}_{α}$, which is needed for $max_{i}w_{i}1 $. We can also use a min-heap for this. Binary Segment Tree lets us calculate these in $O(gn)$ time, which is way more efficient that the naive $O(n)$ approach.

This is how a binary segment tree works for sum; it is similar for minimum. Let $x_{i}$ be the list of $N$ values we want to represent. Let $b_{i,j}$ be the $j_{th}$ node of the $i_{th}$ row in the binary tree. That is two children of node $b_{i,j}$ are $b_{i+1,2j}$ and $b_{i+1,2j+1}$.

The leaf nodes on row $D=⌈1+g_{2}N⌉$ will have values of $x$. Every node keeps the sum of the two child nodes. That is, the root node keeps the sum of the entire array of values. The left and right children of the root node keep the sum of the first half of the array and the sum of the second half of the array, respectively. And so on...

$b_{i,j}=k=(j−1)∗2_{D−i}+1∑j∗2_{D−i} x_{k}$

Number of nodes in row $i$, $N_{i}=⌈D−i+1N ⌉$ This is equal to the sum of nodes in all rows above $i$. So we can use a single array $a$ to store the tree, where, $b_{i,j}→a_{N_{i}+j}$

Then child nodes of $a_{i}$ are $a_{2i}$ and $a_{2i+1}$. That is, $a_{i}=a_{2i}+a_{2i+1}$

This way of maintaining binary trees is very easy to program. *Note that we are indexing starting from 1*.

We use the same structure to compute the minimum.

`21class ReplayBuffer:`

`91 def __init__(self, capacity, alpha):`

We use a power of $2$ for capacity because it simplifies the code and debugging

`96 self.capacity = capacity`

$α$

`98 self.alpha = alpha`

Maintain segment binary trees to take sum and find minimum over a range

```
101 self.priority_sum = [0 for _ in range(2 * self.capacity)]
102 self.priority_min = [float('inf') for _ in range(2 * self.capacity)]
```

Current max priority, $p$, to be assigned to new transitions

`105 self.max_priority = 1.`

Arrays for buffer

```
108 self.data = {
109 'obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8),
110 'action': np.zeros(shape=capacity, dtype=np.int32),
111 'reward': np.zeros(shape=capacity, dtype=np.float32),
112 'next_obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8),
113 'done': np.zeros(shape=capacity, dtype=np.bool)
114 }
```

We use cyclic buffers to store data, and `next_idx`

keeps the index of the next empty slot

`117 self.next_idx = 0`

Size of the buffer

`120 self.size = 0`

`122 def add(self, obs, action, reward, next_obs, done):`

Get next available slot

`128 idx = self.next_idx`

store in the queue

```
131 self.data['obs'][idx] = obs
132 self.data['action'][idx] = action
133 self.data['reward'][idx] = reward
134 self.data['next_obs'][idx] = next_obs
135 self.data['done'][idx] = done
```

Increment next available slot

`138 self.next_idx = (idx + 1) % self.capacity`

Calculate the size

`140 self.size = min(self.capacity, self.size + 1)`

$p_{i}_{α}$, new samples get `max_priority`

`143 priority_alpha = self.max_priority ** self.alpha`

Update the two segment trees for sum and minimum

```
145 self._set_priority_min(idx, priority_alpha)
146 self._set_priority_sum(idx, priority_alpha)
```

`148 def _set_priority_min(self, idx, priority_alpha):`

Leaf of the binary tree

```
154 idx += self.capacity
155 self.priority_min[idx] = priority_alpha
```

Update tree, by traversing along ancestors. Continue until the root of the tree.

`159 while idx >= 2:`

Get the index of the parent node

`161 idx //= 2`

Value of the parent node is the minimum of it's two children

`163 self.priority_min[idx] = min(self.priority_min[2 * idx], self.priority_min[2 * idx + 1])`

`165 def _set_priority_sum(self, idx, priority):`

Leaf of the binary tree

`171 idx += self.capacity`

Set the priority at the leaf

`173 self.priority_sum[idx] = priority`

Update tree, by traversing along ancestors. Continue until the root of the tree.

`177 while idx >= 2:`

Get the index of the parent node

`179 idx //= 2`

Value of the parent node is the sum of it's two children

`181 self.priority_sum[idx] = self.priority_sum[2 * idx] + self.priority_sum[2 * idx + 1]`

`183 def _sum(self):`

The root node keeps the sum of all values

`189 return self.priority_sum[1]`

`191 def _min(self):`

The root node keeps the minimum of all values

`197 return self.priority_min[1]`

`199 def find_prefix_sum_idx(self, prefix_sum):`

Start from the root

```
205 idx = 1
206 while idx < self.capacity:
```

If the sum of the left branch is higher than required sum

`208 if self.priority_sum[idx * 2] > prefix_sum:`

Go to left branch of the tree

```
210 idx = 2 * idx
211 else:
```

Otherwise go to right branch and reduce the sum of left branch from required sum

```
214 prefix_sum -= self.priority_sum[idx * 2]
215 idx = 2 * idx + 1
```

We are at the leaf node. Subtract the capacity by the index in the tree to get the index of actual value

`219 return idx - self.capacity`

`221 def sample(self, batch_size, beta):`

Initialize samples

```
227 samples = {
228 'weights': np.zeros(shape=batch_size, dtype=np.float32),
229 'indexes': np.zeros(shape=batch_size, dtype=np.int32)
230 }
```

Get sample indexes

```
233 for i in range(batch_size):
234 p = random.random() * self._sum()
235 idx = self.find_prefix_sum_idx(p)
236 samples['indexes'][i] = idx
```

$min_{i}P(i)=∑_{k}p_{k}min_{i}p_{i}_{α} $

`239 prob_min = self._min() / self._sum()`

$max_{i}w_{i}=(N1 min_{i}P(i)1 )_{β}$

```
241 max_weight = (prob_min * self.size) ** (-beta)
242
243 for i in range(batch_size):
244 idx = samples['indexes'][i]
```

$P(i)=∑_{k}p_{k}p_{i}_{α} $

`246 prob = self.priority_sum[idx + self.capacity] / self._sum()`

$w_{i}=(N1 P(i)1 )_{β}$

`248 weight = (prob * self.size) ** (-beta)`

Normalize by $max_{i}w_{i}1 $, which also cancels off the $N1 $ term

`251 samples['weights'][i] = weight / max_weight`

Get samples data

```
254 for k, v in self.data.items():
255 samples[k] = v[samples['indexes']]
256
257 return samples
```

`259 def update_priorities(self, indexes, priorities):`

`264 for idx, priority in zip(indexes, priorities):`

Set current max priority

`266 self.max_priority = max(self.max_priority, priority)`

Calculate $p_{i}_{α}$

`269 priority_alpha = priority ** self.alpha`

Update the trees

```
271 self._set_priority_min(idx, priority_alpha)
272 self._set_priority_sum(idx, priority_alpha)
```

`274 def is_full(self):`

`278 return self.capacity == self.size`