優先体験リプレイバッファ

これは、バイナリのセグメントツリーを使用して、紙の優先順位付けされたエクスペリエンスのリプレイを実装します

Open In Colab

15import random
16
17import numpy as np

エクスペリエンスの優先リプレイ用バッファ

エクスペリエンスの優先リプレイでは、重要なトランジションをより頻繁にサンプリングします。遷移は時差異誤差 (td エラー) によって優先順位付けされます

遷移を確率でサンプリングします。ここで、は、どの程度優先順位を付けるかを決定するハイパーパラメータで、同じケースに対応しています。が優先事項です。

時間的な差を変化させる場合は、比例的な優先順位付けを行います

優先リプレイによって生じるバイアスは、損失関数の重要度サンプリング(IS)の重みを使用して修正します。これにより、次の場合は完全に補正されます。安定性を考慮して重量を正規化しています。トレーニング終了時のコンバージェンスには、偏りのない性格が最も重要です。したがって、トレーニングの終わりに近づくにつれて増加します。

バイナリセグメントツリー

バイナリセグメントツリーを使用して、サンプリングに必要な累積確率を効率的に計算します。また、必要なバイナリセグメントツリーを使用して検索します。これにはミニヒープを使うこともできます。バイナリセグメントツリーでは、これらを時間内に計算できます。これは、単純なアプローチよりもはるかに効率的です

これがバイナリセグメントツリーの合計の仕組みで、最小値でも同様です。表現したい値のリストを見てみましょう。をバイナリツリーの行のノードとします。つまり、ノードとの 2 つの子ノードです。

行のリーフノードの値はになります。すべてのノードは 2 つの子ノードの合計を保持します。つまり、ルートノードは値の配列全体の合計を保持します。ルートノードの左と右の子は、それぞれ配列の前半の合計と配列の後半の合計を保持します。などなど...

行のノード数。これは上のすべての行のノードの合計と同じです。つまり、1 つの配列でツリーを格納できます。ここで、

その場合、の子ノードはとです。つまり、

バイナリツリーを管理するこの方法は、プログラムするのがとても簡単です。インデックスは 1 から始まっていることに注意してください

同じ構造を使用して最小値を計算します。

20class ReplayBuffer:

[初期化]

90    def __init__(self, capacity, alpha):

キャパシティについては、コードやデバッグを簡略化するため、のべき乗を使用しています。

95        self.capacity = capacity

97        self.alpha = alpha

セグメントの二分木を管理して、合計を取って範囲内の最小値を求める

100        self.priority_sum = [0 for _ in range(2 * self.capacity)]
101        self.priority_min = [float('inf') for _ in range(2 * self.capacity)]

現在の最大優先度 (新しいトランジションに割り当てられる)

104        self.max_priority = 1.

バッファ用の配列

107        self.data = {
108            'obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8),
109            'action': np.zeros(shape=capacity, dtype=np.int32),
110            'reward': np.zeros(shape=capacity, dtype=np.float32),
111            'next_obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8),
112            'done': np.zeros(shape=capacity, dtype=np.bool)
113        }

循環バッファを使用してデータを保存し、next_idx 次の空きスロットのインデックスを保持します

116        self.next_idx = 0

バッファのサイズ

119        self.size = 0

サンプルをキューに追加

121    def add(self, obs, action, reward, next_obs, done):

次の空き枠を取得

127        idx = self.next_idx

キューに保存

130        self.data['obs'][idx] = obs
131        self.data['action'][idx] = action
132        self.data['reward'][idx] = reward
133        self.data['next_obs'][idx] = next_obs
134        self.data['done'][idx] = done

次に空いているスロットを増やす

137        self.next_idx = (idx + 1) % self.capacity

サイズを計算

139        self.size = min(self.capacity, self.size + 1)

、新しいサンプルを入手 max_priority

142        priority_alpha = self.max_priority ** self.alpha

2 つのセグメントツリーの合計と最小値を更新

144        self._set_priority_min(idx, priority_alpha)
145        self._set_priority_sum(idx, priority_alpha)

バイナリセグメントツリーの優先度を最小に設定

147    def _set_priority_min(self, idx, priority_alpha):

バイナリーツリーの葉

153        idx += self.capacity
154        self.priority_min[idx] = priority_alpha

先祖をたどってツリーを更新します。木の根元まで続けてください。

158        while idx >= 2:

親ノードのインデックスを取得

160            idx //= 2

親ノードの値は、2 つの子ノードのうちの最小値です

162            self.priority_min[idx] = min(self.priority_min[2 * idx], self.priority_min[2 * idx + 1])

合計のバイナリセグメントツリーの優先順位を設定

164    def _set_priority_sum(self, idx, priority):

バイナリーツリーの葉

170        idx += self.capacity

リーフに優先順位を設定

172        self.priority_sum[idx] = priority

先祖をたどってツリーを更新します。木の根元まで続けてください。

176        while idx >= 2:

親ノードのインデックスを取得

178            idx //= 2

親ノードの値は、2 つの子の合計です。

180            self.priority_sum[idx] = self.priority_sum[2 * idx] + self.priority_sum[2 * idx + 1]

182    def _sum(self):

ルートノードはすべての値の合計を保持します。

188        return self.priority_sum[1]

190    def _min(self):

ルートノードはすべての値の最小値を保持します

196        return self.priority_min[1]

その中で一番大きいものを探してください

198    def find_prefix_sum_idx(self, prefix_sum):

ルートから始める

204        idx = 1
205        while idx < self.capacity:

左ブランチの合計が必要な合計よりも大きい場合

207            if self.priority_sum[idx * 2] > prefix_sum:

木の左の枝に移動

209                idx = 2 * idx
210            else:

それ以外の場合は、右の分岐に移動して、左の分岐の合計を必要な合計から減らします。

213                prefix_sum -= self.priority_sum[idx * 2]
214                idx = 2 * idx + 1

私たちはリーフノードにいます。ツリー内のインデックスで容量を引くと、実際の値のインデックスを得ることができます

218        return idx - self.capacity

バッファからのサンプル

220    def sample(self, batch_size, beta):

サンプルを初期化

226        samples = {
227            'weights': np.zeros(shape=batch_size, dtype=np.float32),
228            'indexes': np.zeros(shape=batch_size, dtype=np.int32)
229        }

サンプルインデックスを取得

232        for i in range(batch_size):
233            p = random.random() * self._sum()
234            idx = self.find_prefix_sum_idx(p)
235            samples['indexes'][i] = idx

238        prob_min = self._min() / self._sum()

240        max_weight = (prob_min * self.size) ** (-beta)
241
242        for i in range(batch_size):
243            idx = samples['indexes'][i]

245            prob = self.priority_sum[idx + self.capacity] / self._sum()

247            weight = (prob * self.size) ** (-beta)

次の条件で正規化すると、項も相殺されます

250            samples['weights'][i] = weight / max_weight

サンプルデータを取得

253        for k, v in self.data.items():
254            samples[k] = v[samples['indexes']]
255
256        return samples

優先度を更新

258    def update_priorities(self, indexes, priorities):
263        for idx, priority in zip(indexes, priorities):

現在の最大優先度を設定

265            self.max_priority = max(self.max_priority, priority)

計算

268            priority_alpha = priority ** self.alpha

ツリーを更新

270            self._set_priority_min(idx, priority_alpha)
271            self._set_priority_sum(idx, priority_alpha)

バッファがいっぱいかどうか

273    def is_full(self):
277        return self.capacity == self.size