9import multiprocessing
10import multiprocessing.connection
11
12import cv2
13import gym
14import numpy as np
これはOpenAIジムゲーム環境のラッパーです。ここではいくつかのことを行います。
1。4 つのフレームに同じアクションを適用し、最後のフレーム 2 を取得します。観測フレームをグレーに変換し、(84、84) 3 にスケーリングします。最後の4つのアクションを4フレーム重ねる 4.モニタリング用のエピソード情報 (エピソード全体の報酬総額) を追加 5.エピソードを1つのライフに制限します(ゲームにはライフが5つあり、ライフが1つ増えるたびにリセットされます
)観測値はサイズ (4, 84, 84) のテンソルです。最初の軸に積み重ねられた4つのフレーム(ゲーム画面の画像)です。つまり、各チャンネルはフレームです
。17class Game:
38 def __init__(self, seed: int):
環境を作成
40 self.env = gym.make('BreakoutNoFrameskip-v4')
41 self.env.seed(seed)
4フレームのスタックのテンソル
44 self.obs_4 = np.zeros((4, 84, 84))
最後の 2 フレームまで保存するバッファ
47 self.obs_2_max = np.zeros((2, 84, 84))
エピソードの報酬を把握しておけ
50 self.rewards = []
そして残された命の数
52 self.lives = 0
action
4つのタイムステップを実行し、(観測、報酬、完了、エピソード情報) のタプルを返します。
observation
: 4 つのフレームを積み重ねた (このフレームと最後の 3 アクションのフレーム)reward
: アクション実行中の報酬の合計done
: エピソードが終わったかどうか (命が失われた)episode_info
: エピソード情報 (完了した場合)54 def step(self, action):
66 reward = 0.
67 done = None
4 ステップ実行
70 for i in range(4):
OpenAI ジム環境でアクションを実行する
72 obs, r, done, info = self.env.step(action)
73
74 if i >= 2:
75 self.obs_2_max[i % 2] = self._process_obs(obs)
76
77 reward += r
残りライフ数を取得
80 lives = self.env.unwrapped.ale.lives()
命が失われたらリセット
82 if lives < self.lives:
83 done = True
84 break
各ステップの報酬を維持
87 self.rewards.append(reward)
88
89 if done:
終了したら、エピソードが終了したらエピソード情報を設定し、リセットします
91 episode_info = {"reward": sum(self.rewards), "length": len(self.rewards)}
92 self.reset()
93 else:
94 episode_info = None
最後の 2 フレームの最大値を取得
97 obs = self.obs_2_max.max(axis=0)
4フレームのスタックにプッシュ
100 self.obs_4 = np.roll(self.obs_4, shift=-1, axis=0)
101 self.obs_4[-1] = obs
102
103 return self.obs_4, reward, done, episode_info
105 def reset(self):
OpenAI ジム環境をリセット
112 obs = self.env.reset()
キャッシュをリセット
115 obs = self._process_obs(obs)
116 for i in range(4):
117 self.obs_4[i] = obs
118 self.rewards = []
119
120 self.lives = self.env.unwrapped.ale.lives()
121
122 return self.obs_4
124 @staticmethod
125 def _process_obs(obs):
130 obs = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)
131 obs = cv2.resize(obs, (84, 84), interpolation=cv2.INTER_AREA)
132 return obs
135def worker_process(remote: multiprocessing.connection.Connection, seed: int):
ゲーム作成
143 game = Game(seed)
接続からの指示を待って実行する
146 while True:
147 cmd, data = remote.recv()
148 if cmd == "step":
149 remote.send(game.step(data))
150 elif cmd == "reset":
151 remote.send(game.reset())
152 elif cmd == "close":
153 remote.close()
154 break
155 else:
156 raise NotImplementedError
新しいワーカーを作成し、別のプロセスで実行します。
159class Worker:
164 def __init__(self, seed):
165 self.child, parent = multiprocessing.Pipe()
166 self.process = multiprocessing.Process(target=worker_process, args=(parent, seed))
167 self.process.start()