近接ポリシー最適化-PPO

これは近接ポリシー最適化(PPO)のPyTorch実装です

PPOは強化学習のポリシーグラデーション法です。シンプルなポリシーグラデーションメソッドでは、サンプル (またはサンプルセット) ごとに 1 回のグラデーション更新を行います。1つのサンプルに対して複数のグラデーションステップを実行すると、ポリシーの偏差が大きすぎて不適切なポリシーになるため、問題が発生します。PPO では、ポリシーをデータのサンプリングに使用したポリシーに近い状態に保つことで、サンプルごとに複数のグラデーション更新を行うことができます。更新されたポリシーがデータのサンプリングに使用されたポリシーに合わない場合は、グラデーションフローをクリッピングして更新します

これを使った実験はこちらからご覧いただけます。この実験では、一般化アドバンテージ推定を使用しています

Open In Colab

28import torch
29
30from labml_helpers.module import Module
31from labml_nn.rl.ppo.gae import GAE

PPO ロス

PPO 更新ルールは次の方法で導き出されます。

ここで、が報酬、がポリシー、がポリシーからサンプリングされた軌跡、そしてその間の割引係数で、ポリシーの報酬を最大化したいと考えています。

だから、

割引後の将来の状態分布を定義し、

次に、

からの重要度サンプリング

そうすると、似たようなものだと仮定します。この仮定によって生じる誤差は、との間の KL の相違によって決まります。制約付きポリシー最適化はその証拠です。まだ読んでないよ。

34class ClippedPPOLoss(Module):
136    def __init__(self):
137        super().__init__()
139    def forward(self, log_pi: torch.Tensor, sampled_log_pi: torch.Tensor,
140                advantage: torch.Tensor, clip: float) -> torch.Tensor:

比率これは報酬とは異なります

143        ratio = torch.exp(log_pi - sampled_log_pi)

ポリシー比率のクリッピング

比率は 1 に近づくようにクリッピングされます。比率がとの間でない場合にのみ勾配が傾くように最小化しています。これにより、との間の KL の相違が抑えられます。大きな偏差があると、ポリシーのパフォーマンスが低下し、不適切なポリシーからサンプリングしているためにポリシーのパフォーマンスが低下し、回復しない場合があります。

正規化されたアドバンテージを使用すると、ポリシー勾配推定量に偏りが生じますが、分散は大幅に減少します。

172        clipped_ratio = ratio.clamp(min=1.0 - clip,
173                                    max=1.0 + clip)
174        policy_reward = torch.min(ratio * advantage,
175                                  clipped_ratio * advantage)
176
177        self.clip_fraction = (abs((ratio - 1.0)) > clip).to(torch.float).mean()
178
179        return -policy_reward.mean()

クリッピングバリュー関数の損失

同様に、値関数の更新もクリップします。

クリッピングにより、値関数が大きくずれないようにします。

182class ClippedValueFunctionLoss(Module):
204    def forward(self, value: torch.Tensor, sampled_value: torch.Tensor, sampled_return: torch.Tensor, clip: float):
205        clipped_value = sampled_value + (value - sampled_value).clamp(min=-clip, max=clip)
206        vf_loss = torch.max((value - sampled_return) ** 2, (clipped_value - sampled_return) ** 2)
207        return 0.5 * vf_loss.mean()