15from typing import Dict
16
17import numpy as np
18import torch
19from torch import nn
20from torch import optim
21from torch.distributions import Categorical
22
23from labml import monit, tracker, logger, experiment
24from labml.configs import FloatDynamicHyperParam, IntDynamicHyperParam
25from labml_helpers.module import Module
26from labml_nn.rl.game import Worker
27from labml_nn.rl.ppo import ClippedPPOLoss, ClippedValueFunctionLoss
28from labml_nn.rl.ppo.gae import GAE
选择设备
31if torch.cuda.is_available():
32 device = torch.device("cuda:0")
33else:
34 device = torch.device("cpu")
37class Model(Module):
42 def __init__(self):
43 super().__init__()
第一个卷积层采用 84x84 帧并生成 20x20 帧
47 self.conv1 = nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4)
第二个卷积层采用 20x20 帧并生成 9x9 的帧
51 self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2)
第三个卷积层采用 9x9 帧并生成 7x7 帧
55 self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)
完全连接的图层从第三个卷积图层获取平坦的帧,并输出 512 个要素
60 self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)
一个完全连接的层,用于获取日志
63 self.pi_logits = nn.Linear(in_features=512, out_features=4)
一个完全连接的层来获取价值函数
66 self.value = nn.Linear(in_features=512, out_features=1)
69 self.activation = nn.ReLU()
71 def forward(self, obs: torch.Tensor):
72 h = self.activation(self.conv1(obs))
73 h = self.activation(self.conv2(h))
74 h = self.activation(self.conv3(h))
75 h = h.reshape((-1, 7 * 7 * 64))
76
77 h = self.activation(self.lin(h))
78
79 pi = Categorical(logits=self.pi_logits(h))
80 value = self.value(h).reshape(-1)
81
82 return pi, value
将观测值从缩放[0, 255]
到[0, 1]
85def obs_to_torch(obs: np.ndarray) -> torch.Tensor:
87 return torch.tensor(obs, dtype=torch.float32, device=device) / 255.
90class Trainer:
95 def __init__(self, *,
96 updates: int, epochs: IntDynamicHyperParam,
97 n_workers: int, worker_steps: int, batches: int,
98 value_loss_coef: FloatDynamicHyperParam,
99 entropy_bonus_coef: FloatDynamicHyperParam,
100 clip_range: FloatDynamicHyperParam,
101 learning_rate: FloatDynamicHyperParam,
102 ):
更新次数
106 self.updates = updates
使用采样数据训练模型的周期数
108 self.epochs = epochs
工作进程的数量
110 self.n_workers = n_workers
单次更新的每个进程要运行的步骤数
112 self.worker_steps = worker_steps
微型批次数
114 self.batches = batches
单次更新的样本总数
116 self.batch_size = self.n_workers * self.worker_steps
小批量的大小
118 self.mini_batch_size = self.batch_size // self.batches
119 assert (self.batch_size % self.batches == 0)
价值损失系数
122 self.value_loss_coef = value_loss_coef
熵加成系数
124 self.entropy_bonus_coef = entropy_bonus_coef
裁剪范围
127 self.clip_range = clip_range
学习率
129 self.learning_rate = learning_rate
创建工作人员
134 self.workers = [Worker(47 + i) for i in range(self.n_workers)]
初始化观测值的张量
137 self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8)
138 for worker in self.workers:
139 worker.child.send(("reset", None))
140 for i, worker in enumerate(self.workers):
141 self.obs[i] = worker.child.recv()
模型
144 self.model = Model().to(device)
优化者
147 self.optimizer = optim.Adam(self.model.parameters(), lr=2.5e-4)
使用和的 GAE
150 self.gae = GAE(self.n_workers, self.worker_steps, 0.99, 0.95)
PPO 损失
153 self.ppo_loss = ClippedPPOLoss()
价值损失
156 self.value_loss = ClippedValueFunctionLoss()
158 def sample(self) -> Dict[str, torch.Tensor]:
163 rewards = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
164 actions = np.zeros((self.n_workers, self.worker_steps), dtype=np.int32)
165 done = np.zeros((self.n_workers, self.worker_steps), dtype=np.bool)
166 obs = np.zeros((self.n_workers, self.worker_steps, 4, 84, 84), dtype=np.uint8)
167 log_pis = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
168 values = np.zeros((self.n_workers, self.worker_steps + 1), dtype=np.float32)
169
170 with torch.no_grad():
每位工作人员worker_steps
的样本
172 for t in range(self.worker_steps):
self.obs
跟踪来自每个 worker 的最后一个观测值,这是模型对下一个操作进行采样的输入
175 obs[:, t] = self.obs
每个 worker 的示例操作;这会返回大小数组n_workers
178 pi, v = self.model(obs_to_torch(self.obs))
179 values[:, t] = v.cpu().numpy()
180 a = pi.sample()
181 actions[:, t] = a.cpu().numpy()
182 log_pis[:, t] = pi.log_prob(a).cpu().numpy()
对每个 worker 运行采样操作
185 for w, worker in enumerate(self.workers):
186 worker.child.send(("step", actions[w, t]))
187
188 for w, worker in enumerate(self.workers):
执行操作后获得结果
190 self.obs[w], rewards[w, t], done[w, t], info = worker.child.recv()
收集剧集信息,在剧集结束后可用;这包括总奖励和剧集长度——看看Game
它是如何运作的。
195 if info:
196 tracker.add('reward', info['reward'])
197 tracker.add('length', info['length'])
在最后一步之后获取的值
200 _, v = self.model(obs_to_torch(self.obs))
201 values[:, self.worker_steps] = v.cpu().numpy()
计算优势
204 advantages = self.gae(done, rewards, values)
207 samples = {
208 'obs': obs,
209 'actions': actions,
210 'values': values[:, :-1],
211 'log_pis': log_pis,
212 'advantages': advantages
213 }
样本目前在[workers, time_step]
表中,我们应该将其压平以进行训练
217 samples_flat = {}
218 for k, v in samples.items():
219 v = v.reshape(v.shape[0] * v.shape[1], *v.shape[2:])
220 if k == 'obs':
221 samples_flat[k] = obs_to_torch(v)
222 else:
223 samples_flat[k] = torch.tensor(v, device=device)
224
225 return samples_flat
227 def train(self, samples: Dict[str, torch.Tensor]):
随着时代数量的增加,它学习得更快,但会变得有点不稳定;也就是说,平均剧集奖励不会随着时间的推移而单调增加。可能会缩小剪切范围可能会解决这个问题。
237 for _ in range(self.epochs()):
随机播放每个时代
239 indexes = torch.randperm(self.batch_size)
每小批次
242 for start in range(0, self.batch_size, self.mini_batch_size):
获得小批量
244 end = start + self.mini_batch_size
245 mini_batch_indexes = indexes[start: end]
246 mini_batch = {}
247 for k, v in samples.items():
248 mini_batch[k] = v[mini_batch_indexes]
火车
251 loss = self._calc_loss(mini_batch)
设置学习速率
254 for pg in self.optimizer.param_groups:
255 pg['lr'] = self.learning_rate()
将先前计算的梯度归零
257 self.optimizer.zero_grad()
计算梯度
259 loss.backward()
剪辑渐变
261 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)
根据渐变更新参数
263 self.optimizer.step()
265 @staticmethod
266 def _normalize(adv: torch.Tensor):
268 return (adv - adv.mean()) / (adv.std() + 1e-8)
270 def _calc_loss(self, samples: Dict[str, torch.Tensor]) -> torch.Tensor:
从中抽样的返回
276 sampled_return = samples['values'] + samples['advantages']
282 sampled_normalized_advantage = self._normalize(samples['advantages'])
采样观测值被输入到模型中以获取和;我们将观测值视为状态
286 pi, value = self.model(samples['obs'])
,是从中采样的动作
289 log_pi = pi.log_prob(samples['actions'])
计算保单损失
292 policy_loss = self.ppo_loss(log_pi, samples['log_pis'], sampled_normalized_advantage, self.clip_range())
298 entropy_bonus = pi.entropy()
299 entropy_bonus = entropy_bonus.mean()
计算值函数损失
302 value_loss = self.value_loss(value, samples['values'], sampled_return, self.clip_range())
307 loss = (policy_loss
308 + self.value_loss_coef() * value_loss
309 - self.entropy_bonus_coef() * entropy_bonus)
用于监控
312 approx_kl_divergence = .5 * ((samples['log_pis'] - log_pi) ** 2).mean()
添加到追踪器
315 tracker.add({'policy_reward': -policy_loss,
316 'value_loss': value_loss,
317 'entropy_bonus': entropy_bonus,
318 'kl_div': approx_kl_divergence,
319 'clip_fraction': self.ppo_loss.clip_fraction})
320
321 return loss
323 def run_training_loop(self):
最近 100 集信息
329 tracker.set_queue('reward', 100, True)
330 tracker.set_queue('length', 100, True)
331
332 for update in monit.loop(self.updates):
当前政策的样本
334 samples = self.sample()
训练模型
337 self.train(samples)
保存跟踪的指标。
340 tracker.save()
定期在屏幕上添加新行
342 if (update + 1) % 1_000 == 0:
343 logger.log()
345 def destroy(self):
350 for worker in self.workers:
351 worker.child.send(("close", None))
354def main():
创建实验
356 experiment.create(name='ppo')
配置
358 configs = {
更新次数
360 'updates': 10000,
⚙️ 使用采样数据训练模型的时代数。你可以在实验运行时更改此设置。
363 'epochs': IntDynamicHyperParam(8),
工作进程数
365 'n_workers': 8,
单次更新的每个进程要运行的步骤数
367 'worker_steps': 128,
微型批次数
369 'batches': 4,
⚙️ 价值损失系数。你可以在实验运行时更改此设置。
372 'value_loss_coef': FloatDynamicHyperParam(0.5),
⚙️ 熵加成系数。你可以在实验运行时更改此设置。
375 'entropy_bonus_coef': FloatDynamicHyperParam(0.01),
⚙️ 剪辑范围。
377 'clip_range': FloatDynamicHyperParam(0.1),
你可以在实验运行时更改此设置。⚙️ 学习率。
380 'learning_rate': FloatDynamicHyperParam(1e-3, (0, 1e-3)),
381 }
382
383 experiment.configs(configs)
初始化训练器
386 m = Trainer(**configs)
运行并监控实验
389 with experiment.start():
390 m.run_training_loop()
阻止工人
392 m.destroy()
396if __name__ == "__main__":
397 main()