Atari wrapper with multi-processing

9import multiprocessing
10import multiprocessing.connection
12import cv2
13import gym
14import numpy as np

Game environment

This is a wrapper for OpenAI gym game environment. We do a few things here:

1. Apply the same action on four frames and get the last frame 2. Convert observation frames to gray and scale it to (84, 84) 3. Stack four frames of the last four actions 4. Add episode information (total reward for the entire episode) for monitoring 5. Restrict an episode to a single life (game has 5 lives, we reset after every single life)

Observation format

Observation is tensor of size (4, 84, 84). It is four frames (images of the game screen) stacked on first axis. i.e, each channel is a frame.

17class Game:
38    def __init__(self, seed: int):

create environment

40        self.env = gym.make('BreakoutNoFrameskip-v4')
41        self.env.seed(seed)

tensor for a stack of 4 frames

44        self.obs_4 = np.zeros((4, 84, 84))

buffer to keep the maximum of last 2 frames

47        self.obs_2_max = np.zeros((2, 84, 84))

keep track of the episode rewards

50        self.rewards = []

and number of lives left

52        self.lives = 0


Executes action for 4 time steps and returns a tuple of (observation, reward, done, episode_info).

  • observation : stacked 4 frames (this frame and frames for last 3 actions)
  • reward : total reward while the action was executed
  • done : whether the episode finished (a life lost)
  • episode_info : episode information if completed
54    def step(self, action):
66        reward = 0.
67        done = None

run for 4 steps

70        for i in range(4):

execute the action in the OpenAI Gym environment

72            obs, r, done, info = self.env.step(action)
74            if i >= 2:
75                self.obs_2_max[i % 2] = self._process_obs(obs)
77            reward += r

get number of lives left

80            lives = self.env.unwrapped.ale.lives()

reset if a life is lost

82            if lives < self.lives:
83                done = True
84                break

maintain rewards for each step

87        self.rewards.append(reward)
89        if done:

if finished, set episode information if episode is over, and reset

91            episode_info = {"reward": sum(self.rewards), "length": len(self.rewards)}
92            self.reset()
93        else:
94            episode_info = None

get the max of last two frames

97            obs = self.obs_2_max.max(axis=0)

push it to the stack of 4 frames

100            self.obs_4 = np.roll(self.obs_4, shift=-1, axis=0)
101            self.obs_4[-1] = obs
103        return self.obs_4, reward, done, episode_info

Reset environment

Clean up episode info and 4 frame stack

105    def reset(self):

reset OpenAI Gym environment

112        obs = self.env.reset()

reset caches

115        obs = self._process_obs(obs)
116        for i in range(4):
117            self.obs_4[i] = obs
118        self.rewards = []
120        self.lives = self.env.unwrapped.ale.lives()
122        return self.obs_4

Process game frames

Convert game frames to gray and rescale to 84x84

124    @staticmethod
125    def _process_obs(obs):
130        obs = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)
131        obs = cv2.resize(obs, (84, 84), interpolation=cv2.INTER_AREA)
132        return obs

Worker Process

Each worker process runs this method

135def worker_process(remote: multiprocessing.connection.Connection, seed: int):

create game

143    game = Game(seed)

wait for instructions from the connection and execute them

146    while True:
147        cmd, data = remote.recv()
148        if cmd == "step":
149            remote.send(game.step(data))
150        elif cmd == "reset":
151            remote.send(game.reset())
152        elif cmd == "close":
153            remote.close()
154            break
155        else:
156            raise NotImplementedError

Creates a new worker and runs it in a separate process.

159class Worker:
164    def __init__(self, seed):
165        self.child, parent = multiprocessing.Pipe()
166        self.process = multiprocessing.Process(target=worker_process, args=(parent, seed))
167        self.process.start()