-
Notifications
You must be signed in to change notification settings - Fork 948
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Single player game PPO algorithm and exmaple. Also adds Atar game
- Loading branch information
AWS ParallelCluster user
committed
Jul 23, 2022
1 parent
b5e0bf6
commit a2b8f73
Showing
5 changed files
with
771 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,220 @@ | ||
# Note: code adapted (with permission) from https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo.py and https://github.com/vwxyzjn/ppo-implementation-details/blob/main/ppo_atari.py | ||
|
||
import argparse | ||
import collections | ||
import logging | ||
import os | ||
import random | ||
import sys | ||
import time | ||
from datetime import datetime | ||
from distutils.util import strtobool | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import pyspiel | ||
import torch | ||
from open_spiel.python.pytorch.ppo import PPO, PPOAtariAgent, PPOAgent | ||
from open_spiel.python.rl_agent import StepOutput | ||
from open_spiel.python.rl_environment import Environment, ChanceEventSampler | ||
from open_spiel.python.vector_env import SyncVectorEnv | ||
from torch.utils.tensorboard import SummaryWriter | ||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), | ||
help="the name of this experiment") | ||
parser.add_argument("--game-name", type=str, default="atari", | ||
help="the id of the OpenSpiel game") | ||
parser.add_argument("--learning-rate", type=float, default=2.5e-4, | ||
help="the learning rate of the optimizer") | ||
parser.add_argument("--seed", type=int, default=1, | ||
help="seed of the experiment") | ||
parser.add_argument("--total-timesteps", type=int, default=10_000_000, | ||
help="total timesteps of the experiments") | ||
parser.add_argument("--eval-every", type=int, default=10, | ||
help="evaluate the policy every N updates") | ||
parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, | ||
help="if toggled, `torch.backends.cudnn.deterministic=False`") | ||
parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, | ||
help="if toggled, cuda will be enabled by default") | ||
|
||
# Atari specific arguments | ||
parser.add_argument("--gym-id", type=str, default="BreakoutNoFrameskip-v4", | ||
help="the id of the environment") | ||
parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, | ||
help="whether to capture videos of the agent performances (check out `videos` folder)") | ||
|
||
# Algorithm specific arguments | ||
parser.add_argument("--num-envs", type=int, default=8, | ||
help="the number of parallel game environments") | ||
parser.add_argument("--num-steps", type=int, default=128, | ||
help="the number of steps to run in each environment per policy rollout") | ||
parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, | ||
help="Toggle learning rate annealing for policy and value networks") | ||
parser.add_argument("--gae", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, | ||
help="Use GAE for advantage computation") | ||
parser.add_argument("--gamma", type=float, default=0.99, | ||
help="the discount factor gamma") | ||
parser.add_argument("--gae-lambda", type=float, default=0.95, | ||
help="the lambda for the general advantage estimation") | ||
parser.add_argument("--num-minibatches", type=int, default=4, | ||
help="the number of mini-batches") | ||
parser.add_argument("--update-epochs", type=int, default=4, | ||
help="the K epochs to update the policy") | ||
parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, | ||
help="Toggles advantages normalization") | ||
parser.add_argument("--clip-coef", type=float, default=0.1, | ||
help="the surrogate clipping coefficient") | ||
parser.add_argument("--clip-vloss", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, | ||
help="Toggles whether or not to use a clipped loss for the value function, as per the paper.") | ||
parser.add_argument("--ent-coef", type=float, default=0.01, | ||
help="coefficient of the entropy") | ||
parser.add_argument("--vf-coef", type=float, default=0.5, | ||
help="coefficient of the value function") | ||
parser.add_argument("--max-grad-norm", type=float, default=0.5, | ||
help="the maximum norm for the gradient clipping") | ||
parser.add_argument("--target-kl", type=float, default=None, | ||
help="the target KL divergence threshold") | ||
args = parser.parse_args() | ||
args.batch_size = int(args.num_envs * args.num_steps) | ||
args.minibatch_size = int(args.batch_size // args.num_minibatches) | ||
return args | ||
|
||
def setUpLogging(): | ||
root = logging.getLogger() | ||
root.setLevel(logging.DEBUG) | ||
|
||
handler = logging.StreamHandler(sys.stdout) | ||
handler.setLevel(logging.DEBUG) | ||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') | ||
handler.setFormatter(formatter) | ||
root.addHandler(handler) | ||
|
||
def make_single_atari_env(gym_id, seed, idx, capture_video, run_name, use_episodic_life_env=True): | ||
def gen_env(): | ||
game = pyspiel.load_game('atari', { | ||
'gym_id': gym_id, | ||
'seed': seed, | ||
'idx': idx, | ||
'capture_video': capture_video, | ||
'run_name': run_name, | ||
'use_episodic_life_env': use_episodic_life_env | ||
}) | ||
return Environment(game, chance_event_sampler=ChanceEventSampler(seed=seed)) | ||
return gen_env | ||
|
||
def make_single_env(game_name, seed): | ||
def gen_env(): | ||
game = pyspiel.load_game(game_name) | ||
return Environment(game, chance_event_sampler=ChanceEventSampler(seed=seed)) | ||
return gen_env | ||
|
||
|
||
def main(): | ||
setUpLogging() | ||
args = parse_args() | ||
|
||
if args.game_name == 'atari': | ||
import open_spiel.python.games.atari | ||
|
||
current_day = datetime.now().strftime('%d') | ||
current_month_text = datetime.now().strftime('%h') | ||
run_name = f"{args.game_name}__{args.gym_id}__" | ||
if args.game_name == 'atari': | ||
run_name += f'{args.exp_name}__' | ||
run_name += f"{args.seed}__{current_month_text}__{current_day}__{int(time.time())}" | ||
|
||
writer = SummaryWriter(f"runs/{run_name}") | ||
writer.add_text( | ||
"hyperparameters", | ||
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), | ||
) | ||
|
||
random.seed(args.seed) | ||
np.random.seed(args.seed) | ||
torch.manual_seed(args.seed) | ||
torch.backends.cudnn.deterministic = args.torch_deterministic | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") | ||
logging.info(f"Using device: {device}") | ||
|
||
if args.game_name == 'atari': | ||
envs = SyncVectorEnv( | ||
[make_single_atari_env(args.gym_id, args.seed + i, i, False, run_name)() for i in range(args.num_envs)] | ||
) | ||
agent_fn = PPOAtariAgent | ||
else: | ||
envs = SyncVectorEnv( | ||
[make_single_env(args.game_name, args.seed + i)() for i in range(args.num_envs)] | ||
) | ||
agent_fn = PPOAgent | ||
|
||
|
||
game = envs.envs[0]._game | ||
info_state_shape = tuple(np.array(envs.observation_spec()["info_state"]).flatten()) | ||
num_updates = args.total_timesteps // args.batch_size | ||
agent = PPO( | ||
input_shape=info_state_shape, | ||
num_actions=game.num_distinct_actions(), | ||
num_players=game.num_players(), | ||
player_id=0, | ||
num_envs=args.num_envs, | ||
steps_per_batch=args.num_steps, | ||
num_minibatches=args.num_minibatches, | ||
update_epochs=args.update_epochs, | ||
learning_rate=args.learning_rate, | ||
num_annealing_updates=num_updates, | ||
gae=args.gae, | ||
gamma=args.gamma, | ||
gae_lambda=args.gae_lambda, | ||
normalize_advantages=args.norm_adv, | ||
clip_coef=args.clip_coef, | ||
clip_vloss=args.clip_vloss, | ||
entropy_coef=args.ent_coef, | ||
value_coef=args.vf_coef, | ||
max_grad_norm=args.max_grad_norm, | ||
target_kl=args.target_kl, | ||
device=device, | ||
writer=writer, | ||
agent_fn=agent_fn, | ||
) | ||
|
||
N_REWARD_WINDOW = 50 | ||
recent_rewards = collections.deque(maxlen=N_REWARD_WINDOW) | ||
time_step = envs.reset() | ||
for update in range(1, num_updates + 1): | ||
for step in range(0, args.num_steps): | ||
agent_output = agent.step(time_step) | ||
time_step, reward, done, unreset_time_steps = envs.step(agent_output, reset_if_done=True) | ||
|
||
if args.game_name == 'atari': | ||
# Get around the fact that the stable_baselines3.common.atari_wrappers.EpisodicLifeEnv will modify rewards at the LIFE and not GAME level by only counting rewards of finished episodes | ||
for ts in unreset_time_steps: | ||
info = ts.observations.get('info') | ||
if info and 'episode' in info: | ||
real_reward = info['episode']['r'] | ||
writer.add_scalar('charts/player_0_training_returns', real_reward, agent.total_steps_done) | ||
recent_rewards.append(real_reward) | ||
else: | ||
for ts in unreset_time_steps: | ||
if ts.last(): | ||
real_reward = ts.rewards[0] | ||
writer.add_scalar('charts/player_0_training_returns', real_reward, agent.total_steps_done) | ||
recent_rewards.append(real_reward) | ||
|
||
agent.post_step(reward, done) | ||
|
||
agent.learn(time_step) | ||
|
||
if update % args.eval_every == 0: | ||
logging.info("-" * 80) | ||
logging.info("Step %s", agent.total_steps_done) | ||
logging.info(f"Summary of past {N_REWARD_WINDOW} rewards\n %s", pd.Series(recent_rewards).describe()) | ||
|
||
writer.close() | ||
logging.info("All done. Have a pleasant day :)") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
import gym | ||
import numpy as np | ||
import pyspiel | ||
|
||
from stable_baselines3.common.atari_wrappers import ( | ||
ClipRewardEnv, | ||
EpisodicLifeEnv, | ||
FireResetEnv, | ||
MaxAndSkipEnv, | ||
NoopResetEnv | ||
) | ||
|
||
### NOTE: We include this wrapper by hand because the default wrapper threw errors (see modified lines). | ||
class NoopResetEnv(gym.Wrapper): | ||
""" | ||
Sample initial states by taking random number of no-ops on reset. | ||
No-op is assumed to be action 0. | ||
:param env: the environment to wrap | ||
:param noop_max: the maximum value of no-ops to run | ||
""" | ||
|
||
def __init__(self, env: gym.Env, noop_max: int = 30): | ||
gym.Wrapper.__init__(self, env) | ||
self.noop_max = noop_max | ||
self.override_num_noops = None | ||
self.noop_action = 0 | ||
assert env.unwrapped.get_action_meanings()[0] == "NOOP" | ||
|
||
def reset(self, **kwargs) -> np.ndarray: | ||
self.env.reset(**kwargs) | ||
if self.override_num_noops is not None: | ||
noops = self.override_num_noops | ||
else: | ||
#### MODIFIED LINES ### | ||
noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) | ||
### END MODIFIED LIENS ### | ||
assert noops > 0 | ||
obs = np.zeros(0) | ||
for _ in range(noops): | ||
obs, _, done, _ = self.env.step(self.noop_action) | ||
if done: | ||
obs = self.env.reset(**kwargs) | ||
return obs | ||
|
||
_NUM_PLAYERS = 1 | ||
_GAME_TYPE = pyspiel.GameType( | ||
short_name="atari", | ||
long_name="atari", | ||
dynamics=pyspiel.GameType.Dynamics.SEQUENTIAL, | ||
chance_mode=pyspiel.GameType.ChanceMode.SAMPLED_STOCHASTIC, | ||
information=pyspiel.GameType.Information.PERFECT_INFORMATION, | ||
utility=pyspiel.GameType.Utility.ZERO_SUM, | ||
reward_model=pyspiel.GameType.RewardModel.REWARDS, | ||
max_num_players=_NUM_PLAYERS, | ||
min_num_players=_NUM_PLAYERS, | ||
provides_information_state_string=False, | ||
provides_information_state_tensor=True, | ||
provides_observation_string=False, | ||
provides_observation_tensor=False, | ||
parameter_specification={"gym_id": 'ALE/Breakout-v5', "seed": 1, "idx": 0, "capture_video": False, 'run_name': 'default', 'use_episodic_life_env': True}) | ||
_GAME_INFO = pyspiel.GameInfo( | ||
num_distinct_actions=4, | ||
max_chance_outcomes=0, | ||
num_players=_NUM_PLAYERS, | ||
min_utility=-1.0, | ||
max_utility=1.0, | ||
utility_sum=0.0, | ||
max_game_length=2000) | ||
|
||
class AtariGame(pyspiel.Game): | ||
|
||
def __init__(self, params=None): | ||
super().__init__(_GAME_TYPE, _GAME_INFO, params or dict()) | ||
self.gym_id = params.get('gym_id', 'BreakoutNoFrameskip-v4') | ||
self.seed = params.get('seed', 1) | ||
self.idx = params.get('idx', 0) | ||
self.capture_video = params.get('capture_video', False) | ||
self.run_name = params.get('run_name', 'default') | ||
self.use_episodic_life_env = params.get('use_episodic_life_env', True) | ||
|
||
env = gym.make(self.gym_id) | ||
env = gym.wrappers.RecordEpisodeStatistics(env) | ||
if self.capture_video and self.idx == 0: | ||
env = gym.wrappers.RecordVideo(env, f"videos/{self.run_name}") | ||
|
||
# Wrappers are a bit specialized right nwo to Breakout - different games may want different wrappers. | ||
env = NoopResetEnv(env, noop_max=30) | ||
env = MaxAndSkipEnv(env, skip=4) | ||
if self.use_episodic_life_env: | ||
env = EpisodicLifeEnv(env) | ||
if "FIRE" in env.unwrapped.get_action_meanings(): | ||
env = FireResetEnv(env) | ||
env = ClipRewardEnv(env) | ||
env = gym.wrappers.ResizeObservation(env, (84, 84)) | ||
env = gym.wrappers.GrayScaleObservation(env) | ||
env = gym.wrappers.FrameStack(env, 4) | ||
env.seed(self.seed) | ||
env.action_space.seed(self.seed) | ||
env.observation_space.seed(self.seed) | ||
self.env = env | ||
|
||
def new_initial_state(self): | ||
"""Returns a state corresponding to the start of a game.""" | ||
return AtariState(self,) | ||
|
||
def information_state_tensor_size(self): | ||
return AtariState(self).information_state_tensor(0).shape | ||
|
||
class AtariState(pyspiel.State): | ||
"""A python version of the Atari Game state.""" | ||
|
||
def __init__(self, game): | ||
"""Constructor; should only be called by Game.new_initial_state.""" | ||
super().__init__(game) | ||
self._is_terminal = False | ||
self.tracked_rewards = 0 | ||
self.env = game.env | ||
self.observation = self.env.reset() | ||
self.last_reward = None | ||
self.last_info = dict() | ||
|
||
def current_player(self): | ||
"""Returns id of the next player to move, or TERMINAL if game is over.""" | ||
return pyspiel.PlayerId.TERMINAL if self._is_terminal else 0 | ||
|
||
def _legal_actions(self, player): | ||
"""Returns a list of legal actions, sorted in ascending order.""" | ||
return list(range(self.env.action_space.n)) | ||
|
||
def _apply_action(self, action): | ||
"""Applies the specified action to the state.""" | ||
observation, reward, done, info = self.env.step(action) | ||
self.last_info = info | ||
self.last_reward = reward | ||
self.tracked_rewards += reward | ||
if done: | ||
self._is_terminal = True | ||
self.observation = observation # Store this for later | ||
|
||
def information_state_tensor(self, player_id): | ||
return self.observation | ||
|
||
def _action_to_string(self, player, action): | ||
return self.env.get_action_meanings()[action] | ||
|
||
def is_terminal(self): | ||
"""Returns True if the game is over.""" | ||
return self._is_terminal | ||
|
||
def rewards(self): | ||
return [self.last_reward] | ||
|
||
def returns(self): | ||
"""Total reward for each player over the course of the game so far.""" | ||
return [self.tracked_rewards] | ||
|
||
def __str__(self): | ||
"""String for debug purposes. No particular semantics are required.""" | ||
return "DEBUG" | ||
|
||
# Register the game with the OpenSpiel library | ||
pyspiel.register_game(_GAME_TYPE, AtariGame) |
Oops, something went wrong.