15import random
16
17import numpy as np
エクスペリエンスの優先リプレイでは、重要なトランジションをより頻繁にサンプリングします。遷移は時差異誤差 (td エラー) によって優先順位付けされます
。遷移を確率でサンプリングします。ここで、は、どの程度優先順位を付けるかを決定するハイパーパラメータで、同じケースに対応しています。が優先事項です。
時間的な差を変化させる場合は、比例的な優先順位付けを行います。
優先リプレイによって生じるバイアスは、損失関数の重要度サンプリング(IS)の重みを使用して修正します。これにより、次の場合は完全に補正されます。安定性を考慮して重量を正規化しています。トレーニング終了時のコンバージェンスには、偏りのない性格が最も重要です。したがって、トレーニングの終わりに近づくにつれて増加します。
バイナリセグメントツリーを使用して、サンプリングに必要な累積確率を効率的に計算します。また、必要なバイナリセグメントツリーを使用して検索します。これにはミニヒープを使うこともできます。バイナリセグメントツリーでは、これらを時間内に計算できます。これは、単純なアプローチよりもはるかに効率的です
。これがバイナリセグメントツリーの合計の仕組みで、最小値でも同様です。表現したい値のリストを見てみましょう。をバイナリツリーの行のノードとします。つまり、ノードとの 2 つの子ノードです。
行のリーフノードの値はになります。すべてのノードは 2 つの子ノードの合計を保持します。つまり、ルートノードは値の配列全体の合計を保持します。ルートノードの左と右の子は、それぞれ配列の前半の合計と配列の後半の合計を保持します。などなど...
行のノード数。これは上のすべての行のノードの合計と同じです。つまり、1 つの配列でツリーを格納できます。ここで、
その場合、の子ノードはとです。つまり、
バイナリツリーを管理するこの方法は、プログラムするのがとても簡単です。インデックスは 1 から始まっていることに注意してください
。同じ構造を使用して最小値を計算します。
20class ReplayBuffer:
90 def __init__(self, capacity, alpha):
キャパシティについては、コードやデバッグを簡略化するため、のべき乗を使用しています。
95 self.capacity = capacity
97 self.alpha = alpha
セグメントの二分木を管理して、合計を取って範囲内の最小値を求める
100 self.priority_sum = [0 for _ in range(2 * self.capacity)]
101 self.priority_min = [float('inf') for _ in range(2 * self.capacity)]
現在の最大優先度 (新しいトランジションに割り当てられる)
104 self.max_priority = 1.
バッファ用の配列
107 self.data = {
108 'obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8),
109 'action': np.zeros(shape=capacity, dtype=np.int32),
110 'reward': np.zeros(shape=capacity, dtype=np.float32),
111 'next_obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8),
112 'done': np.zeros(shape=capacity, dtype=np.bool)
113 }
循環バッファを使用してデータを保存し、next_idx
次の空きスロットのインデックスを保持します
116 self.next_idx = 0
バッファのサイズ
119 self.size = 0
121 def add(self, obs, action, reward, next_obs, done):
次の空き枠を取得
127 idx = self.next_idx
キューに保存
130 self.data['obs'][idx] = obs
131 self.data['action'][idx] = action
132 self.data['reward'][idx] = reward
133 self.data['next_obs'][idx] = next_obs
134 self.data['done'][idx] = done
次に空いているスロットを増やす
137 self.next_idx = (idx + 1) % self.capacity
サイズを計算
139 self.size = min(self.capacity, self.size + 1)
、新しいサンプルを入手 max_priority
142 priority_alpha = self.max_priority ** self.alpha
2 つのセグメントツリーの合計と最小値を更新
144 self._set_priority_min(idx, priority_alpha)
145 self._set_priority_sum(idx, priority_alpha)
147 def _set_priority_min(self, idx, priority_alpha):
バイナリーツリーの葉
153 idx += self.capacity
154 self.priority_min[idx] = priority_alpha
先祖をたどってツリーを更新します。木の根元まで続けてください。
158 while idx >= 2:
親ノードのインデックスを取得
160 idx //= 2
親ノードの値は、2 つの子ノードのうちの最小値です
162 self.priority_min[idx] = min(self.priority_min[2 * idx], self.priority_min[2 * idx + 1])
164 def _set_priority_sum(self, idx, priority):
バイナリーツリーの葉
170 idx += self.capacity
リーフに優先順位を設定
172 self.priority_sum[idx] = priority
先祖をたどってツリーを更新します。木の根元まで続けてください。
176 while idx >= 2:
親ノードのインデックスを取得
178 idx //= 2
親ノードの値は、2 つの子の合計です。
180 self.priority_sum[idx] = self.priority_sum[2 * idx] + self.priority_sum[2 * idx + 1]
182 def _sum(self):
ルートノードはすべての値の合計を保持します。
188 return self.priority_sum[1]
190 def _min(self):
ルートノードはすべての値の最小値を保持します
196 return self.priority_min[1]
198 def find_prefix_sum_idx(self, prefix_sum):
ルートから始める
204 idx = 1
205 while idx < self.capacity:
左ブランチの合計が必要な合計よりも大きい場合
207 if self.priority_sum[idx * 2] > prefix_sum:
木の左の枝に移動
209 idx = 2 * idx
210 else:
それ以外の場合は、右の分岐に移動して、左の分岐の合計を必要な合計から減らします。
213 prefix_sum -= self.priority_sum[idx * 2]
214 idx = 2 * idx + 1
私たちはリーフノードにいます。ツリー内のインデックスで容量を引くと、実際の値のインデックスを得ることができます
218 return idx - self.capacity
220 def sample(self, batch_size, beta):
サンプルを初期化
226 samples = {
227 'weights': np.zeros(shape=batch_size, dtype=np.float32),
228 'indexes': np.zeros(shape=batch_size, dtype=np.int32)
229 }
サンプルインデックスを取得
232 for i in range(batch_size):
233 p = random.random() * self._sum()
234 idx = self.find_prefix_sum_idx(p)
235 samples['indexes'][i] = idx
238 prob_min = self._min() / self._sum()
240 max_weight = (prob_min * self.size) ** (-beta)
241
242 for i in range(batch_size):
243 idx = samples['indexes'][i]
245 prob = self.priority_sum[idx + self.capacity] / self._sum()
247 weight = (prob * self.size) ** (-beta)
次の条件で正規化すると、項も相殺されます
250 samples['weights'][i] = weight / max_weight
サンプルデータを取得
253 for k, v in self.data.items():
254 samples[k] = v[samples['indexes']]
255
256 return samples
258 def update_priorities(self, indexes, priorities):
263 for idx, priority in zip(indexes, priorities):
現在の最大優先度を設定
265 self.max_priority = max(self.max_priority, priority)
計算
268 priority_alpha = priority ** self.alpha
ツリーを更新
270 self._set_priority_min(idx, priority_alpha)
271 self._set_priority_sum(idx, priority_alpha)
273 def is_full(self):
277 return self.capacity == self.size