Skip to content

Commit

Permalink
Single player game PPO algorithm and exmaple. Also adds Atar game
Browse files Browse the repository at this point in the history
  • Loading branch information
AWS ParallelCluster user committed Jul 23, 2022
1 parent b5e0bf6 commit a2b8f73
Show file tree
Hide file tree
Showing 5 changed files with 771 additions and 0 deletions.
220 changes: 220 additions & 0 deletions open_spiel/python/examples/ppo_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
# Note: code adapted (with permission) from https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo.py and https://github.com/vwxyzjn/ppo-implementation-details/blob/main/ppo_atari.py

import argparse
import collections
import logging
import os
import random
import sys
import time
from datetime import datetime
from distutils.util import strtobool

import numpy as np
import pandas as pd
import pyspiel
import torch
from open_spiel.python.pytorch.ppo import PPO, PPOAtariAgent, PPOAgent
from open_spiel.python.rl_agent import StepOutput
from open_spiel.python.rl_environment import Environment, ChanceEventSampler
from open_spiel.python.vector_env import SyncVectorEnv
from torch.utils.tensorboard import SummaryWriter

def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
help="the name of this experiment")
parser.add_argument("--game-name", type=str, default="atari",
help="the id of the OpenSpiel game")
parser.add_argument("--learning-rate", type=float, default=2.5e-4,
help="the learning rate of the optimizer")
parser.add_argument("--seed", type=int, default=1,
help="seed of the experiment")
parser.add_argument("--total-timesteps", type=int, default=10_000_000,
help="total timesteps of the experiments")
parser.add_argument("--eval-every", type=int, default=10,
help="evaluate the policy every N updates")
parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="if toggled, `torch.backends.cudnn.deterministic=False`")
parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="if toggled, cuda will be enabled by default")

# Atari specific arguments
parser.add_argument("--gym-id", type=str, default="BreakoutNoFrameskip-v4",
help="the id of the environment")
parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to capture videos of the agent performances (check out `videos` folder)")

# Algorithm specific arguments
parser.add_argument("--num-envs", type=int, default=8,
help="the number of parallel game environments")
parser.add_argument("--num-steps", type=int, default=128,
help="the number of steps to run in each environment per policy rollout")
parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="Toggle learning rate annealing for policy and value networks")
parser.add_argument("--gae", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="Use GAE for advantage computation")
parser.add_argument("--gamma", type=float, default=0.99,
help="the discount factor gamma")
parser.add_argument("--gae-lambda", type=float, default=0.95,
help="the lambda for the general advantage estimation")
parser.add_argument("--num-minibatches", type=int, default=4,
help="the number of mini-batches")
parser.add_argument("--update-epochs", type=int, default=4,
help="the K epochs to update the policy")
parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="Toggles advantages normalization")
parser.add_argument("--clip-coef", type=float, default=0.1,
help="the surrogate clipping coefficient")
parser.add_argument("--clip-vloss", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="Toggles whether or not to use a clipped loss for the value function, as per the paper.")
parser.add_argument("--ent-coef", type=float, default=0.01,
help="coefficient of the entropy")
parser.add_argument("--vf-coef", type=float, default=0.5,
help="coefficient of the value function")
parser.add_argument("--max-grad-norm", type=float, default=0.5,
help="the maximum norm for the gradient clipping")
parser.add_argument("--target-kl", type=float, default=None,
help="the target KL divergence threshold")
args = parser.parse_args()
args.batch_size = int(args.num_envs * args.num_steps)
args.minibatch_size = int(args.batch_size // args.num_minibatches)
return args

def setUpLogging():
root = logging.getLogger()
root.setLevel(logging.DEBUG)

handler = logging.StreamHandler(sys.stdout)
handler.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
root.addHandler(handler)

def make_single_atari_env(gym_id, seed, idx, capture_video, run_name, use_episodic_life_env=True):
def gen_env():
game = pyspiel.load_game('atari', {
'gym_id': gym_id,
'seed': seed,
'idx': idx,
'capture_video': capture_video,
'run_name': run_name,
'use_episodic_life_env': use_episodic_life_env
})
return Environment(game, chance_event_sampler=ChanceEventSampler(seed=seed))
return gen_env

def make_single_env(game_name, seed):
def gen_env():
game = pyspiel.load_game(game_name)
return Environment(game, chance_event_sampler=ChanceEventSampler(seed=seed))
return gen_env


def main():
setUpLogging()
args = parse_args()

if args.game_name == 'atari':
import open_spiel.python.games.atari

current_day = datetime.now().strftime('%d')
current_month_text = datetime.now().strftime('%h')
run_name = f"{args.game_name}__{args.gym_id}__"
if args.game_name == 'atari':
run_name += f'{args.exp_name}__'
run_name += f"{args.seed}__{current_month_text}__{current_day}__{int(time.time())}"

writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic

device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
logging.info(f"Using device: {device}")

if args.game_name == 'atari':
envs = SyncVectorEnv(
[make_single_atari_env(args.gym_id, args.seed + i, i, False, run_name)() for i in range(args.num_envs)]
)
agent_fn = PPOAtariAgent
else:
envs = SyncVectorEnv(
[make_single_env(args.game_name, args.seed + i)() for i in range(args.num_envs)]
)
agent_fn = PPOAgent


game = envs.envs[0]._game
info_state_shape = tuple(np.array(envs.observation_spec()["info_state"]).flatten())
num_updates = args.total_timesteps // args.batch_size
agent = PPO(
input_shape=info_state_shape,
num_actions=game.num_distinct_actions(),
num_players=game.num_players(),
player_id=0,
num_envs=args.num_envs,
steps_per_batch=args.num_steps,
num_minibatches=args.num_minibatches,
update_epochs=args.update_epochs,
learning_rate=args.learning_rate,
num_annealing_updates=num_updates,
gae=args.gae,
gamma=args.gamma,
gae_lambda=args.gae_lambda,
normalize_advantages=args.norm_adv,
clip_coef=args.clip_coef,
clip_vloss=args.clip_vloss,
entropy_coef=args.ent_coef,
value_coef=args.vf_coef,
max_grad_norm=args.max_grad_norm,
target_kl=args.target_kl,
device=device,
writer=writer,
agent_fn=agent_fn,
)

N_REWARD_WINDOW = 50
recent_rewards = collections.deque(maxlen=N_REWARD_WINDOW)
time_step = envs.reset()
for update in range(1, num_updates + 1):
for step in range(0, args.num_steps):
agent_output = agent.step(time_step)
time_step, reward, done, unreset_time_steps = envs.step(agent_output, reset_if_done=True)

if args.game_name == 'atari':
# Get around the fact that the stable_baselines3.common.atari_wrappers.EpisodicLifeEnv will modify rewards at the LIFE and not GAME level by only counting rewards of finished episodes
for ts in unreset_time_steps:
info = ts.observations.get('info')
if info and 'episode' in info:
real_reward = info['episode']['r']
writer.add_scalar('charts/player_0_training_returns', real_reward, agent.total_steps_done)
recent_rewards.append(real_reward)
else:
for ts in unreset_time_steps:
if ts.last():
real_reward = ts.rewards[0]
writer.add_scalar('charts/player_0_training_returns', real_reward, agent.total_steps_done)
recent_rewards.append(real_reward)

agent.post_step(reward, done)

agent.learn(time_step)

if update % args.eval_every == 0:
logging.info("-" * 80)
logging.info("Step %s", agent.total_steps_done)
logging.info(f"Summary of past {N_REWARD_WINDOW} rewards\n %s", pd.Series(recent_rewards).describe())

writer.close()
logging.info("All done. Have a pleasant day :)")


if __name__ == "__main__":
main()
162 changes: 162 additions & 0 deletions open_spiel/python/games/atari.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import gym
import numpy as np
import pyspiel

from stable_baselines3.common.atari_wrappers import (
ClipRewardEnv,
EpisodicLifeEnv,
FireResetEnv,
MaxAndSkipEnv,
NoopResetEnv
)

### NOTE: We include this wrapper by hand because the default wrapper threw errors (see modified lines).
class NoopResetEnv(gym.Wrapper):
"""
Sample initial states by taking random number of no-ops on reset.
No-op is assumed to be action 0.
:param env: the environment to wrap
:param noop_max: the maximum value of no-ops to run
"""

def __init__(self, env: gym.Env, noop_max: int = 30):
gym.Wrapper.__init__(self, env)
self.noop_max = noop_max
self.override_num_noops = None
self.noop_action = 0
assert env.unwrapped.get_action_meanings()[0] == "NOOP"

def reset(self, **kwargs) -> np.ndarray:
self.env.reset(**kwargs)
if self.override_num_noops is not None:
noops = self.override_num_noops
else:
#### MODIFIED LINES ###
noops = self.unwrapped.np_random.integers(1, self.noop_max + 1)
### END MODIFIED LIENS ###
assert noops > 0
obs = np.zeros(0)
for _ in range(noops):
obs, _, done, _ = self.env.step(self.noop_action)
if done:
obs = self.env.reset(**kwargs)
return obs

_NUM_PLAYERS = 1
_GAME_TYPE = pyspiel.GameType(
short_name="atari",
long_name="atari",
dynamics=pyspiel.GameType.Dynamics.SEQUENTIAL,
chance_mode=pyspiel.GameType.ChanceMode.SAMPLED_STOCHASTIC,
information=pyspiel.GameType.Information.PERFECT_INFORMATION,
utility=pyspiel.GameType.Utility.ZERO_SUM,
reward_model=pyspiel.GameType.RewardModel.REWARDS,
max_num_players=_NUM_PLAYERS,
min_num_players=_NUM_PLAYERS,
provides_information_state_string=False,
provides_information_state_tensor=True,
provides_observation_string=False,
provides_observation_tensor=False,
parameter_specification={"gym_id": 'ALE/Breakout-v5', "seed": 1, "idx": 0, "capture_video": False, 'run_name': 'default', 'use_episodic_life_env': True})
_GAME_INFO = pyspiel.GameInfo(
num_distinct_actions=4,
max_chance_outcomes=0,
num_players=_NUM_PLAYERS,
min_utility=-1.0,
max_utility=1.0,
utility_sum=0.0,
max_game_length=2000)

class AtariGame(pyspiel.Game):

def __init__(self, params=None):
super().__init__(_GAME_TYPE, _GAME_INFO, params or dict())
self.gym_id = params.get('gym_id', 'BreakoutNoFrameskip-v4')
self.seed = params.get('seed', 1)
self.idx = params.get('idx', 0)
self.capture_video = params.get('capture_video', False)
self.run_name = params.get('run_name', 'default')
self.use_episodic_life_env = params.get('use_episodic_life_env', True)

env = gym.make(self.gym_id)
env = gym.wrappers.RecordEpisodeStatistics(env)
if self.capture_video and self.idx == 0:
env = gym.wrappers.RecordVideo(env, f"videos/{self.run_name}")

# Wrappers are a bit specialized right nwo to Breakout - different games may want different wrappers.
env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
if self.use_episodic_life_env:
env = EpisodicLifeEnv(env)
if "FIRE" in env.unwrapped.get_action_meanings():
env = FireResetEnv(env)
env = ClipRewardEnv(env)
env = gym.wrappers.ResizeObservation(env, (84, 84))
env = gym.wrappers.GrayScaleObservation(env)
env = gym.wrappers.FrameStack(env, 4)
env.seed(self.seed)
env.action_space.seed(self.seed)
env.observation_space.seed(self.seed)
self.env = env

def new_initial_state(self):
"""Returns a state corresponding to the start of a game."""
return AtariState(self,)

def information_state_tensor_size(self):
return AtariState(self).information_state_tensor(0).shape

class AtariState(pyspiel.State):
"""A python version of the Atari Game state."""

def __init__(self, game):
"""Constructor; should only be called by Game.new_initial_state."""
super().__init__(game)
self._is_terminal = False
self.tracked_rewards = 0
self.env = game.env
self.observation = self.env.reset()
self.last_reward = None
self.last_info = dict()

def current_player(self):
"""Returns id of the next player to move, or TERMINAL if game is over."""
return pyspiel.PlayerId.TERMINAL if self._is_terminal else 0

def _legal_actions(self, player):
"""Returns a list of legal actions, sorted in ascending order."""
return list(range(self.env.action_space.n))

def _apply_action(self, action):
"""Applies the specified action to the state."""
observation, reward, done, info = self.env.step(action)
self.last_info = info
self.last_reward = reward
self.tracked_rewards += reward
if done:
self._is_terminal = True
self.observation = observation # Store this for later

def information_state_tensor(self, player_id):
return self.observation

def _action_to_string(self, player, action):
return self.env.get_action_meanings()[action]

def is_terminal(self):
"""Returns True if the game is over."""
return self._is_terminal

def rewards(self):
return [self.last_reward]

def returns(self):
"""Total reward for each player over the course of the game so far."""
return [self.tracked_rewards]

def __str__(self):
"""String for debug purposes. No particular semantics are required."""
return "DEBUG"

# Register the game with the OpenSpiel library
pyspiel.register_game(_GAME_TYPE, AtariGame)
Loading

0 comments on commit a2b8f73

Please sign in to comment.