この実験では、OpenAI Gymでプロキシマルポリシー最適化(PPO)エージェントのAtariブレイクアウトゲームをトレーニングします。ゲーム環境を複数のプロセスで実行して効率的にサンプリングします。
15from typing import Dict
16
17import numpy as np
18import torch
19from torch import nn
20from torch import optim
21from torch.distributions import Categorical
22
23from labml import monit, tracker, logger, experiment
24from labml.configs import FloatDynamicHyperParam, IntDynamicHyperParam
25from labml_helpers.module import Module
26from labml_nn.rl.game import Worker
27from labml_nn.rl.ppo import ClippedPPOLoss, ClippedValueFunctionLoss
28from labml_nn.rl.ppo.gae import GAE
デバイスを選択
31if torch.cuda.is_available():
32 device = torch.device("cuda:0")
33else:
34 device = torch.device("cpu")
37class Model(Module):
42 def __init__(self):
43 super().__init__()
最初の畳み込み層は 84 x 84 フレームで、20 x 20 フレームを生成します。
47 self.conv1 = nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4)
2 番目の畳み込み層は 20x20 フレームで、9x9 フレームを生成します。
51 self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2)
3 番目の畳み込み層は 9x9 フレームで 7x7 フレームを生成します。
55 self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)
完全結合層は、3 番目の畳み込み層から平坦化されたフレームを取り出し、512 個の特徴を出力します。
60 self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)
ロジットを取得するための完全接続レイヤー
63 self.pi_logits = nn.Linear(in_features=512, out_features=4)
バリュー関数を得るための完全連結レイヤー
66 self.value = nn.Linear(in_features=512, out_features=1)
69 self.activation = nn.ReLU()
71 def forward(self, obs: torch.Tensor):
72 h = self.activation(self.conv1(obs))
73 h = self.activation(self.conv2(h))
74 h = self.activation(self.conv3(h))
75 h = h.reshape((-1, 7 * 7 * 64))
76
77 h = self.activation(self.lin(h))
78
79 pi = Categorical(logits=self.pi_logits(h))
80 value = self.value(h).reshape(-1)
81
82 return pi, value
[0, 255]
観測値をからにスケーリング [0, 1]
85def obs_to_torch(obs: np.ndarray) -> torch.Tensor:
87 return torch.tensor(obs, dtype=torch.float32, device=device) / 255.
90class Trainer:
95 def __init__(self, *,
96 updates: int, epochs: IntDynamicHyperParam,
97 n_workers: int, worker_steps: int, batches: int,
98 value_loss_coef: FloatDynamicHyperParam,
99 entropy_bonus_coef: FloatDynamicHyperParam,
100 clip_range: FloatDynamicHyperParam,
101 learning_rate: FloatDynamicHyperParam,
102 ):
更新回数
106 self.updates = updates
サンプルデータを使用してモデルをトレーニングするエポックの数
108 self.epochs = epochs
ワーカープロセスの数
110 self.n_workers = n_workers
1 回の更新で各プロセスで実行するステップの数
112 self.worker_steps = worker_steps
ミニバッチ数
114 self.batches = batches
1 回の更新でのサンプルの総数
116 self.batch_size = self.n_workers * self.worker_steps
ミニバッチのサイズ
118 self.mini_batch_size = self.batch_size // self.batches
119 assert (self.batch_size % self.batches == 0)
価値損失係数
122 self.value_loss_coef = value_loss_coef
エントロピーボーナス係数
124 self.entropy_bonus_coef = entropy_bonus_coef
クリッピング範囲
127 self.clip_range = clip_range
学習率
129 self.learning_rate = learning_rate
ワーカーを作成
134 self.workers = [Worker(47 + i) for i in range(self.n_workers)]
観測用のテンソルを初期化
137 self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8)
138 for worker in self.workers:
139 worker.child.send(("reset", None))
140 for i, worker in enumerate(self.workers):
141 self.obs[i] = worker.child.recv()
モデル
144 self.model = Model().to(device)
オプティマイザー
147 self.optimizer = optim.Adam(self.model.parameters(), lr=2.5e-4)
GATE (および付き)
150 self.gae = GAE(self.n_workers, self.worker_steps, 0.99, 0.95)
PPO ロス
153 self.ppo_loss = ClippedPPOLoss()
価値損失
156 self.value_loss = ClippedValueFunctionLoss()
158 def sample(self) -> Dict[str, torch.Tensor]:
163 rewards = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
164 actions = np.zeros((self.n_workers, self.worker_steps), dtype=np.int32)
165 done = np.zeros((self.n_workers, self.worker_steps), dtype=np.bool)
166 obs = np.zeros((self.n_workers, self.worker_steps, 4, 84, 84), dtype=np.uint8)
167 log_pis = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
168 values = np.zeros((self.n_workers, self.worker_steps + 1), dtype=np.float32)
169
170 with torch.no_grad():
worker_steps
各労働者からのサンプル
172 for t in range(self.worker_steps):
self.obs
各ワーカーからの最後の観測値を追跡します。これは、モデルが次のアクションをサンプリングするための入力です
175 obs[:, t] = self.obs
各ワーカーのサンプルアクション。これはサイズの配列を返します n_workers
178 pi, v = self.model(obs_to_torch(self.obs))
179 values[:, t] = v.cpu().numpy()
180 a = pi.sample()
181 actions[:, t] = a.cpu().numpy()
182 log_pis[:, t] = pi.log_prob(a).cpu().numpy()
各ワーカーでサンプルアクションを実行
185 for w, worker in enumerate(self.workers):
186 worker.child.send(("step", actions[w, t]))
187
188 for w, worker in enumerate(self.workers):
アクションを実行した後に結果を取得
190 self.obs[w], rewards[w, t], done[w, t], info = worker.child.recv()
エピソードの情報を集めましょう。Game
エピソードが終了したときに入手できます。これには報酬総額やエピソードの長さが含まれます。仕組みを確認してみましょう。
195 if info:
196 tracker.add('reward', info['reward'])
197 tracker.add('length', info['length'])
最後のステップの後に値を取得
200 _, v = self.model(obs_to_torch(self.obs))
201 values[:, self.worker_steps] = v.cpu().numpy()
利点を計算
204 advantages = self.gae(done, rewards, values)
207 samples = {
208 'obs': obs,
209 'actions': actions,
210 'values': values[:, :-1],
211 'log_pis': log_pis,
212 'advantages': advantages
213 }
[workers, time_step]
サンプルは現在テーブルにあるので、トレーニング用に平らにする必要があります
217 samples_flat = {}
218 for k, v in samples.items():
219 v = v.reshape(v.shape[0] * v.shape[1], *v.shape[2:])
220 if k == 'obs':
221 samples_flat[k] = obs_to_torch(v)
222 else:
223 samples_flat[k] = torch.tensor(v, device=device)
224
225 return samples_flat
227 def train(self, samples: Dict[str, torch.Tensor]):
エポック数が多いほど学習は速くなりますが、少し不安定になります。つまり、エピソードの平均報酬は時間の経過とともに単調に増加しません。クリッピング範囲を狭くすることで解決する可能性があります。
237 for _ in range(self.epochs()):
各エポックのシャッフル
239 indexes = torch.randperm(self.batch_size)
各ミニバッチ用
242 for start in range(0, self.batch_size, self.mini_batch_size):
ミニバッチを入手
244 end = start + self.mini_batch_size
245 mini_batch_indexes = indexes[start: end]
246 mini_batch = {}
247 for k, v in samples.items():
248 mini_batch[k] = v[mini_batch_indexes]
列車
251 loss = self._calc_loss(mini_batch)
学習率を設定
254 for pg in self.optimizer.param_groups:
255 pg['lr'] = self.learning_rate()
以前に計算したグラデーションをゼロにします
257 self.optimizer.zero_grad()
勾配の計算
259 loss.backward()
クリップグラデーション
261 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)
グラデーションに基づいてパラメータを更新
263 self.optimizer.step()
265 @staticmethod
266 def _normalize(adv: torch.Tensor):
268 return (adv - adv.mean()) / (adv.std() + 1e-8)
270 def _calc_loss(self, samples: Dict[str, torch.Tensor]) -> torch.Tensor:
からサンプリングされたリターン
276 sampled_return = samples['values'] + samples['advantages']
、利点はどこからサンプリングされているのか。の計算については、下記のメインクラスのサンプリング関数を参照してください。
282 sampled_normalized_advantage = self._normalize(samples['advantages'])
サンプリングされた観測値はモデルに入力され、取得されます。観測値は状態として扱います
286 pi, value = self.model(samples['obs'])
アクションは以下からサンプリングされます
289 log_pi = pi.log_prob(samples['actions'])
保険契約損失の計算
292 policy_loss = self.ppo_loss(log_pi, samples['log_pis'], sampled_normalized_advantage, self.clip_range())
298 entropy_bonus = pi.entropy()
299 entropy_bonus = entropy_bonus.mean()
値関数損失の計算
302 value_loss = self.value_loss(value, samples['values'], sampled_return, self.clip_range())
307 loss = (policy_loss
308 + self.value_loss_coef() * value_loss
309 - self.entropy_bonus_coef() * entropy_bonus)
監視用
312 approx_kl_divergence = .5 * ((samples['log_pis'] - log_pi) ** 2).mean()
トラッカーに追加
315 tracker.add({'policy_reward': -policy_loss,
316 'value_loss': value_loss,
317 'entropy_bonus': entropy_bonus,
318 'kl_div': approx_kl_divergence,
319 'clip_fraction': self.ppo_loss.clip_fraction})
320
321 return loss
323 def run_training_loop(self):
最後の 100 話の情報
329 tracker.set_queue('reward', 100, True)
330 tracker.set_queue('length', 100, True)
331
332 for update in monit.loop(self.updates):
現行ポリシーのサンプル
334 samples = self.sample()
モデルのトレーニング
337 self.train(samples)
追跡指標を保存します。
340 tracker.save()
画面に定期的に新しい行を追加してください
342 if (update + 1) % 1_000 == 0:
343 logger.log()
345 def destroy(self):
350 for worker in self.workers:
351 worker.child.send(("close", None))
354def main():
実験を作成
356 experiment.create(name='ppo')
コンフィギュレーション
358 configs = {
更新回数
360 'updates': 10000,
⚙️ サンプルデータを使用してモデルをトレーニングするエポックの数。これは実験の実行中に変更できます。
363 'epochs': IntDynamicHyperParam(8),
ワーカープロセスの数
365 'n_workers': 8,
1 回の更新で各プロセスで実行するステップの数
367 'worker_steps': 128,
ミニバッチ数
369 'batches': 4,
⚙️ 価値損失係数。これは実験の実行中に変更できます。
372 'value_loss_coef': FloatDynamicHyperParam(0.5),
⚙️ エントロピーボーナス係数。これは実験の実行中に変更できます。
375 'entropy_bonus_coef': FloatDynamicHyperParam(0.01),
⚙️ クリップレンジ。
377 'clip_range': FloatDynamicHyperParam(0.1),
テストの実行中にこれを変更できます。⚙️ 学習率。
380 'learning_rate': FloatDynamicHyperParam(1e-3, (0, 1e-3)),
381 }
382
383 experiment.configs(configs)
トレーナーを初期化
386 m = Trainer(**configs)
実験の実行と監視
389 with experiment.start():
390 m.run_training_loop()
労働者を止めろ
392 m.destroy()
396if __name__ == "__main__":
397 main()