この実験では、ディープQネットワーク(DQN)にOpenAI Gymでアタリブレイクアウトゲームをプレイするようにトレーニングします。ゲーム環境を複数のプロセスで実行して効率的にサンプリングします。
15import numpy as np
16import torch
17
18from labml import tracker, experiment, logger, monit
19from labml.internal.configs.dynamic_hyperparam import FloatDynamicHyperParam
20from labml_helpers.schedule import Piecewise
21from labml_nn.rl.dqn import QFuncLoss
22from labml_nn.rl.dqn.model import Model
23from labml_nn.rl.dqn.replay_buffer import ReplayBuffer
24from labml_nn.rl.game import Worker
デバイスを選択
27if torch.cuda.is_available():
28 device = torch.device("cuda:0")
29else:
30 device = torch.device("cpu")
[0, 255]
観測値をからにスケーリング [0, 1]
33def obs_to_torch(obs: np.ndarray) -> torch.Tensor:
35 return torch.tensor(obs, dtype=torch.float32, device=device) / 255.
38class Trainer:
43 def __init__(self, *,
44 updates: int, epochs: int,
45 n_workers: int, worker_steps: int, mini_batch_size: int,
46 update_target_model: int,
47 learning_rate: FloatDynamicHyperParam,
48 ):
労働者の数
50 self.n_workers = n_workers
更新のたびにサンプリングされるステップ
52 self.worker_steps = worker_steps
トレーニングの反復回数
54 self.train_epochs = epochs
更新回数
57 self.updates = updates
トレーニング用ミニバッチのサイズ
59 self.mini_batch_size = mini_batch_size
250 回の更新ごとにターゲットネットワークを更新
62 self.update_target_model = update_target_model
学習率
65 self.learning_rate = learning_rate
更新機能としての探索
68 self.exploration_coefficient = Piecewise(
69 [
70 (0, 1.0),
71 (25_000, 0.1),
72 (self.updates / 2, 0.01)
73 ], outside_value=0.01)
更新機能としての再生バッファ用
76 self.prioritized_replay_beta = Piecewise(
77 [
78 (0, 0.4),
79 (self.updates, 1)
80 ], outside_value=1)
リプレイバッファは.再生バッファの容量は 2 の累乗でなければなりません
。83 self.replay_buffer = ReplayBuffer(2 ** 14, 0.6)
サンプリングとトレーニング用のモデル
86 self.model = Model().to(device)
取得する対象モデル
88 self.target_model = Model().to(device)
ワーカーを作成
91 self.workers = [Worker(47 + i) for i in range(self.n_workers)]
観測用のテンソルを初期化
94 self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8)
ワーカーをリセット
97 for worker in self.workers:
98 worker.child.send(("reset", None))
初期観測値を取得
101 for i, worker in enumerate(self.workers):
102 self.obs[i] = worker.child.recv()
損失関数
105 self.loss_func = QFuncLoss(0.99)
オプティマイザー
107 self.optimizer = torch.optim.Adam(self.model.parameters(), lr=2.5e-4)
アクションをサンプリングするときは、-greedy ストラテジーを使用します。つまり、確率のある貪欲なアクションを実行し、確率のあるランダムなアクションを実行します。と呼びますexploration_coefficient
。
109 def _sample_action(self, q_value: torch.Tensor, exploration_coefficient: float):
サンプリングにはグラデーションは必要ありません
119 with torch.no_grad():
Q値が最も高いアクションをサンプリングします。これは貪欲な行動です
。121 greedy_action = torch.argmax(q_value, dim=-1)
サンプルとアクションを均一に
123 random_action = torch.randint(q_value.shape[-1], greedy_action.shape, device=q_value.device)
欲張りアクションとランダムアクションのどちらを選ぶか
125 is_choose_rand = torch.rand(greedy_action.shape, device=q_value.device) < exploration_coefficient
以下に基づいてアクションを選択してください is_choose_rand
127 return torch.where(is_choose_rand, random_action, greedy_action).cpu().numpy()
129 def sample(self, exploration_coefficient: float):
これにはグラデーションは必要ありません
133 with torch.no_grad():
[サンプル] worker_steps
135 for t in range(self.worker_steps):
現在の観測値の Q_value を取得
137 q_value = self.model(obs_to_torch(self.obs))
サンプルアクション
139 actions = self._sample_action(q_value, exploration_coefficient)
各ワーカーでサンプルアクションを実行
142 for w, worker in enumerate(self.workers):
143 worker.child.send(("step", actions[w]))
各作業者から情報を収集する
146 for w, worker in enumerate(self.workers):
アクションを実行した後に結果を取得
148 next_obs, reward, done, info = worker.child.recv()
再生バッファにトランジションを追加
151 self.replay_buffer.add(self.obs[w], actions[w], reward, next_obs, done)
エピソード情報を更新します。エピソードが終了した場合に利用できるエピソード情報を収集します。これには、合計報酬とエピソードの長さが含まれます。仕組みを確認してみてください。Game
157 if info:
158 tracker.add('reward', info['reward'])
159 tracker.add('length', info['length'])
現在の観測値を更新
162 self.obs[w] = next_obs
164 def train(self, beta: float):
168 for _ in range(self.train_epochs):
プライオリティ・リプレイ・バッファからのサンプル
170 samples = self.replay_buffer.sample(self.mini_batch_size, beta)
予測された Q 値の取得
172 q_value = self.model(obs_to_torch(samples['obs']))
二重Q学習の次の状態のQ値を取得します。これらの場合、グラデーションは伝播しないはずです
176 with torch.no_grad():
取得
178 double_q_value = self.model(obs_to_torch(samples['next_obs']))
取得
180 target_q_value = self.target_model(obs_to_torch(samples['next_obs']))
時差 (TD) 誤差、および損失を計算します。
183 td_errors, loss = self.loss_func(q_value,
184 q_value.new_tensor(samples['action']),
185 double_q_value, target_q_value,
186 q_value.new_tensor(samples['done']),
187 q_value.new_tensor(samples['reward']),
188 q_value.new_tensor(samples['weights']))
再生バッファの優先度を計算
191 new_priorities = np.abs(td_errors.cpu().numpy()) + 1e-6
リプレイバッファの優先順位を更新
193 self.replay_buffer.update_priorities(samples['indexes'], new_priorities)
学習率を設定
196 for pg in self.optimizer.param_groups:
197 pg['lr'] = self.learning_rate()
以前に計算したグラデーションをゼロにします
199 self.optimizer.zero_grad()
勾配の計算
201 loss.backward()
クリップグラデーション
203 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)
グラデーションに基づいてパラメータを更新
205 self.optimizer.step()
207 def run_training_loop(self):
最新100話の情報
213 tracker.set_queue('reward', 100, True)
214 tracker.set_queue('length', 100, True)
最初にターゲットネットワークにコピー
217 self.target_model.load_state_dict(self.model.state_dict())
218
219 for update in monit.loop(self.updates):
、探査フラクション
221 exploration = self.exploration_coefficient(update)
222 tracker.add('exploration', exploration)
優先再生用
224 beta = self.prioritized_replay_beta(update)
225 tracker.add('beta', beta)
現在のポリシーを含むサンプル
228 self.sample(exploration)
バッファーがいっぱいになったらトレーニングを開始する
231 if self.replay_buffer.is_full():
モデルのトレーニング
233 self.train(beta)
ターゲットネットワークを定期的に更新
236 if update % self.update_target_model == 0:
237 self.target_model.load_state_dict(self.model.state_dict())
追跡指標を保存します。
240 tracker.save()
画面に定期的に新しい行を追加してください
242 if (update + 1) % 1_000 == 0:
243 logger.log()
245 def destroy(self):
250 for worker in self.workers:
251 worker.child.send(("close", None))
254def main():
実験を作成
256 experiment.create(name='dqn')
コンフィギュレーション
259 configs = {
更新回数
261 'updates': 1_000_000,
サンプルデータを使用してモデルをトレーニングするエポックの数。
263 'epochs': 8,
ワーカープロセスの数
265 'n_workers': 8,
1 回の更新で各プロセスで実行するステップの数
267 'worker_steps': 4,
ミニバッチサイズ
269 'mini_batch_size': 32,
対象モデルの更新間隔
271 'update_target_model': 250,
学習率。
273 'learning_rate': FloatDynamicHyperParam(1e-4, (0, 1e-3)),
274 }
コンフィギュレーション
277 experiment.configs(configs)
トレーナーを初期化
280 m = Trainer(**configs)
実験の実行と監視
282 with experiment.start():
283 m.run_training_loop()
労働者を止めろ
285 m.destroy()
289if __name__ == "__main__":
290 main()