Skip to content

Commit

Permalink
added vectorized environments
Browse files Browse the repository at this point in the history
  • Loading branch information
Gurvan committed Oct 27, 2019
1 parent 5a55615 commit c00c3cc
Show file tree
Hide file tree
Showing 4 changed files with 215 additions and 5 deletions.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
# Versions should comply with PEP440. For a discussion on single-sourcing
# the version across setup.py and the project code, see
# https://packaging.python.org/en/latest/single_source_version.html
version='0.1.0',
version='0.1.1',

description='Super Smash Bros. Melee Gym Environment',
long_description=long_description,
Expand Down Expand Up @@ -50,7 +50,7 @@
# your project is installed. For an analysis of "install_requires" vs pip's
# requirements files see:
# https://packaging.python.org/en/latest/requirements.html
install_requires=['attrs', 'pyzmq'],
install_requires=['attrs', 'pyzmq', 'cloudpickle'],

# If there are data files included in your packages that need to be
# installed, specify them here. If using Python 2.6 or less, then these
Expand Down
2 changes: 1 addition & 1 deletion ssbm_gym/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
import os
path = os.path.dirname(__file__)
from ssbm_gym.ssbm_env import SSBMEnv
from ssbm_gym.ssbm_env import SSBMEnv, EnvVec
175 changes: 173 additions & 2 deletions ssbm_gym/ssbm_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,5 +106,176 @@ def isDying(player):
return player.action_state <= 0xA


# Custom reward
#r += self.obs.players[self.pid].y / 100.0 / 60.0
### Vectorizing


import multiprocessing
import cloudpickle
import pickle


def make_env(env_class, frame_limit, options):
def _init():
env = env_class(frame_limit=frame_limit, options=options)
return env
return _init


def EnvVec(env_class, num_envs, frame_limit=1e12, options={}):
if type(options) == dict:
return SubprocVecEnv([make_env(env_class=env_class, frame_limit=frame_limit, options=options) for _ in range(num_envs)])
if type(options) == list:
assert(len(options) != 0)
assert(len(options) <= num_envs)
if len(options) < num_envs:
k = round(num_envs / len(options)) + 1
options = (k * options)[:num_envs]

return SubprocVecEnv([make_env(env_class=env_class, frame_limit=frame_limit, options=options[i]) for i in range(num_envs)])



class CloudpickleWrapper(object):
def __init__(self, var):
self.var = var

def __getstate__(self):
return cloudpickle.dumps(self.var)

def __setstate__(self, obs):
self.var = pickle.loads(obs)


def _worker(remote, parent_remote, env_fn_wrapper):
parent_remote.close()
env = env_fn_wrapper.var()
while True:
try:
cmd, data = remote.recv()
if cmd == 'step':
observation, reward, done, info = env.step(data)
if done:
# save final observation where user can get it, then reset
info['terminal_observation'] = observation
remote.send((observation, reward, done, info))
elif cmd == 'reset':
observation = env.reset()
remote.send(observation)
elif cmd == 'close':
remote.close()
break
elif cmd == 'get_spaces':
remote.send((env.observation_space, env.action_space))
elif cmd == 'env_method':
method = getattr(env, data[0])
remote.send(method(*data[1], **data[2]))
elif cmd == 'get_attr':
remote.send(getattr(env, data))
elif cmd == 'set_attr':
remote.send(setattr(env, data[0], data[1]))
else:
raise NotImplementedError
except EOFError:
break


class SubprocVecEnv():
def __init__(self, env_fns, start_method=None):
self.num_envs = len(env_fns)
self.waiting = False
self.closed = False

if start_method is None:
# Fork is not a thread safe method (see issue #217)
# but is more user friendly (does not require to wrap the code in
# a `if __name__ == "__main__":`)
forkserver_available = 'forkserver' in multiprocessing.get_all_start_methods()
start_method = 'forkserver' if forkserver_available else 'spawn'
ctx = multiprocessing.get_context(start_method)

self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(self.num_envs)])
self.processes = []
for work_remote, remote, env_fn in zip(self.work_remotes, self.remotes, env_fns):
args = (work_remote, remote, CloudpickleWrapper(env_fn))
# daemon=True: if the main process crashes, we should not cause things to hang
process = ctx.Process(target=_worker, args=args, daemon=True)
process.start()
self.processes.append(process)
work_remote.close()

self.remotes[0].send(('get_spaces', None))
self.observation_space, self.action_space = self.remotes[0].recv()

def step(self, actions):
self.step_async(actions)
return self.step_wait()

def step_async(self, actions):
for remote, action in zip(self.remotes, actions):
remote.send(('step', action))
self.waiting = True


def step_wait(self):
results = [remote.recv() for remote in self.remotes]
self.waiting = False
obs, rews, dones, infos = zip(*results)
return obs, rews, dones, infos


def reset(self):
for remote in self.remotes:
remote.send(('reset', None))
obs = [remote.recv() for remote in self.remotes]
return obs


def close(self):
if self.closed:
return
if self.waiting:
for remote in self.remotes:
remote.recv()
for remote in self.remotes:
remote.send(('close', None))
for process in self.processes:
process.join()
self.closed = True


def get_attr(self, attr_name, indices=None):
"""Return attribute from vectorized environment (see base class)."""
target_remotes = self._get_target_remotes(indices)
for remote in target_remotes:
remote.send(('get_attr', attr_name))
return [remote.recv() for remote in target_remotes]


def set_attr(self, attr_name, value, indices=None):
"""Set attribute inside vectorized environments (see base class)."""
target_remotes = self._get_target_remotes(indices)
for remote in target_remotes:
remote.send(('set_attr', (attr_name, value)))
for remote in target_remotes:
remote.recv()


def env_method(self, method_name, *method_args, indices=None, **method_kwargs):
"""Call instance methods of vectorized environments."""
target_remotes = self._get_target_remotes(indices)
for remote in target_remotes:
remote.send(('env_method', (method_name, method_args, method_kwargs)))
return [remote.recv() for remote in target_remotes]


def _get_target_remotes(self, indices):
"""
Get the connection object needed to communicate with the wanted
envs that are in subprocesses.
:param indices: (None,int,Iterable) refers to indices of envs.
:return: ([multiprocessing.Connection]) Connection object to communicate between processes.
"""
indices = self._get_indices(indices)
return [self.remotes[i] for i in indices]

39 changes: 39 additions & 0 deletions test_env_vectorized.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from ssbm_gym import SSBMEnv, EnvVec
import atexit
import platform
import random
import time

options = dict(
render=True,
player1='ai',
player2='cpu',
char1='fox',
char2='falco',
cpu2=7,
stage='battlefield',
)

# Vectorized envs not supported on Windows
# if platform.system() == 'Windows':
# options['windows'] = True

num_workers = 4


# Required for vectorized envs
if __name__ == "__main__":
env = EnvVec(SSBMEnv, num_workers, options=options)
obs = env.reset()
atexit.register(env.close)

t = time.time()
for i in range(1000):
action = [random.randint(0, env.action_space.n - 1) for _ in range(num_workers)]
obs, reward, done, infos = env.step(action)
try:
print("FPS:", round(1/(time.time() - t)))
print([o.players[0].x for o in obs])
except:
pass
t = time.time()

0 comments on commit c00c3cc

Please sign in to comment.