Proximal Policy Optimization - PPO

This is a PyTorch implementation of Proximal Policy Optimization - PPO.

PPO is a policy gradient method for reinforcement learning. Simple policy gradient methods do a single gradient update per sample (or a set of samples). Doing multiple gradient steps for a single sample causes problems because the policy deviates too much, producing a bad policy. PPO lets us do multiple gradient updates per sample by trying to keep the policy close to the policy that was used to sample data. It does so by clipping gradient flow if the updated policy is not close to the policy used to sample the data.

You can find an experiment that uses it here. The experiment uses Generalized Advantage Estimation.

Open In Colab

28import torch
29from labml_nn.rl.ppo.gae import GAE
30from torch import nn

PPO Loss

Here's how the PPO update rule is derived.

We want to maximize policy reward where is the reward, is the policy, is a trajectory sampled from policy, and is the discount factor between .

So,

Define discounted-future state distribution,

Then,

Importance sampling from ,

Then we assume and are similar. The error we introduce to by this assumption is bound by the KL divergence between and . Constrained Policy Optimization shows the proof of this. I haven't read it.

33class ClippedPPOLoss(nn.Module):
135    def __init__(self):
136        super().__init__()
138    def forward(self, log_pi: torch.Tensor, sampled_log_pi: torch.Tensor,
139                advantage: torch.Tensor, clip: float) -> torch.Tensor:

ratio ; this is different from rewards .

142        ratio = torch.exp(log_pi - sampled_log_pi)

Cliping the policy ratio

The ratio is clipped to be close to 1. We take the minimum so that the gradient will only pull towards if the ratio is not between and . This keeps the KL divergence between and constrained. Large deviation can cause performance collapse; where the policy performance drops and doesn't recover because we are sampling from a bad policy.

Using the normalized advantage introduces a bias to the policy gradient estimator, but it reduces variance a lot.

171        clipped_ratio = ratio.clamp(min=1.0 - clip,
172                                    max=1.0 + clip)
173        policy_reward = torch.min(ratio * advantage,
174                                  clipped_ratio * advantage)
175
176        self.clip_fraction = (abs((ratio - 1.0)) > clip).to(torch.float).mean()
177
178        return -policy_reward.mean()

Clipped Value Function Loss

Similarly we clip the value function update also.

Clipping makes sure the value function doesn't deviate significantly from .

181class ClippedValueFunctionLoss(nn.Module):
203    def forward(self, value: torch.Tensor, sampled_value: torch.Tensor, sampled_return: torch.Tensor, clip: float):
204        clipped_value = sampled_value + (value - sampled_value).clamp(min=-clip, max=clip)
205        vf_loss = torch.max((value - sampled_return) ** 2, (clipped_value - sampled_return) ** 2)
206        return 0.5 * vf_loss.mean()