This is a PyTorch implementation of Proximal Policy Optimization - PPO.

PPO is a policy gradient method for reinforcement learning. Simple policy gradient methods do a single gradient update per sample (or a set of samples). Doing multiple gradient steps for a single sample causes problems because the policy deviates too much, producing a bad policy. PPO lets us do multiple gradient updates per sample by trying to keep the policy close to the policy that was used to sample data. It does so by clipping gradient flow if the updated policy is not close to the policy used to sample the data.

You can find an experiment that uses it here. The experiment uses Generalized Advantage Estimation.

```
28import torch
29
30from labml_helpers.module import Module
31from labml_nn.rl.ppo.gae import GAE
```

Here's how the PPO update rule is derived.

We want to maximize policy reward $θmax J(π_{θ})=E_{τ∼π_{θ}}[t=0∑∞ γ_{t}r_{t}]$ where $r$ is the reward, $π$ is the policy, $τ$ is a trajectory sampled from policy, and $γ$ is the discount factor between $[0,1]$.

$E_{τ∼π_{θ}}[t=0∑∞ γ_{t}A_{π_{OLD}}(s_{t},a_{t})]E_{τ∼π_{θ}}[t=0∑∞ γ_{t}(Q_{π_{OLD}}(s_{t},a_{t})−V_{π_{OLD}}(s_{t}))]E_{τ∼π_{θ}}[t=0∑∞ γ_{t}(r_{t}+V_{π_{OLD}}(s_{t+1})−V_{π_{OLD}}(s_{t}))]E_{τ∼π_{θ}}[t=0∑∞ γ_{t}(r_{t})]−E_{τ∼π_{θ}}[V_{π_{OLD}}(s_{0})] ====J(π_{θ})−J(π_{θ_{OLD}}) $So, $θmax J(π_{θ})=θmax E_{τ∼π_{θ}}[t=0∑∞ γ_{t}A_{π_{OLD}}(s_{t},a_{t})]$

Define discounted-future state distribution, $d_{π}(s)=(1−γ)t=0∑∞ γ_{t}P(s_{t}=s∣π)$

Then,

$J(π_{θ})−J(π_{θ_{OLD}}) =E_{τ∼π_{θ}}[t=0∑∞ γ_{t}A_{π_{OLD}}(s_{t},a_{t})]=1−γ1 E_{s∼d_{π},a∼π_{θ}}[A_{π_{OLD}}(s,a)] $Importance sampling $a$ from $π_{θ_{OLD}}$,

$J(π_{θ})−J(π_{θ_{OLD}}) =1−γ1 E_{s∼d_{π},a∼π_{θ}}[A_{π_{OLD}}(s,a)]=1−γ1 E_{s∼d_{π},a∼π_{θ}}[π_{θ_{OLD}}(a∣s)π_{θ}(a∣s) A_{π_{OLD}}(s,a)] $Then we assume $d_{π_{θ}}(s)$ and $d_{π_{θ}}(s)$ are similar. The error we introduce to $J(π_{θ})−J(π_{θ_{OLD}})$ by this assumption is bound by the KL divergence between $π_{θ}$ and $π_{θ_{OLD}}$. Constrained Policy Optimization shows the proof of this. I haven't read it.

$J(π_{θ})−J(π_{θ_{OLD}}) =1−γ1 E_{a∼πs∼d}[π_{θ_{OLD}}(a∣s)π_{θ}(a∣s) A_{π_{OLD}}(s,a)]≈1−γ1 E_{a∼πs∼d}[π_{θ_{OLD}}(a∣s)π_{θ}(a∣s) A_{π_{OLD}}(s,a)]=1−γ1 L_{CPI} $`34class ClippedPPOLoss(Module):`

```
136 def __init__(self):
137 super().__init__()
```

```
139 def forward(self, log_pi: torch.Tensor, sampled_log_pi: torch.Tensor,
140 advantage: torch.Tensor, clip: float) -> torch.Tensor:
```

ratio $r_{t}(θ)=π_{θ}(a_{t}∣s_{t})π_{θ}(a_{t}∣s_{t}) $; *this is different from rewards* $r_{t}$.

`143 ratio = torch.exp(log_pi - sampled_log_pi)`

The ratio is clipped to be close to 1. We take the minimum so that the gradient will only pull $π_{θ}$ towards $π_{θ_{OLD}}$ if the ratio is not between $1−ϵ$ and $1+ϵ$. This keeps the KL divergence between $π_{θ}$ and $π_{θ_{OLD}}$ constrained. Large deviation can cause performance collapse; where the policy performance drops and doesn't recover because we are sampling from a bad policy.

Using the normalized advantage $A_{t}ˉ =σ(A_{t}^ )A_{t}^ −μ(A_{t}^ ) $ introduces a bias to the policy gradient estimator, but it reduces variance a lot.

```
172 clipped_ratio = ratio.clamp(min=1.0 - clip,
173 max=1.0 + clip)
174 policy_reward = torch.min(ratio * advantage,
175 clipped_ratio * advantage)
176
177 self.clip_fraction = (abs((ratio - 1.0)) > clip).to(torch.float).mean()
178
179 return -policy_reward.mean()
```

Similarly we clip the value function update also.

$V_{CLIP}(s_{t})L_{VF}(θ) =clip(V_{π_{θ}}(s_{t})−V_{t}^ ,−ϵ,+ϵ)=21 E[max((V_{π_{θ}}(s_{t})−R_{t})_{2},(V_{CLIP}(s_{t})−R_{t})_{2})] $Clipping makes sure the value function $V_{θ}$ doesn't deviate significantly from $V_{θ_{OLD}}$.

`182class ClippedValueFunctionLoss(Module):`

```
204 def forward(self, value: torch.Tensor, sampled_value: torch.Tensor, sampled_return: torch.Tensor, clip: float):
205 clipped_value = sampled_value + (value - sampled_value).clamp(min=-clip, max=clip)
206 vf_loss = torch.max((value - sampled_return) ** 2, (clipped_value - sampled_return) ** 2)
207 return 0.5 * vf_loss.mean()
```