Skip to content

Commit

Permalink
Feature/2003 lint (#32)
Browse files Browse the repository at this point in the history
* fix lint
  • Loading branch information
chris-chris authored Mar 4, 2020
1 parent 5b70343 commit c8074a1
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 110 deletions.
5 changes: 3 additions & 2 deletions train_defeat_zerglings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from baselines import deepq
from pysc2.env import sc2_env
from pysc2.lib import actions
from baselines.logger import Logger, TensorBoardOutputFormat, HumanOutputFormat

from defeat_zerglings import dqfd
from baselines.logger import Logger, TensorBoardOutputFormat, HumanOutputFormat

_MOVE_SCREEN = actions.FUNCTIONS.Move_screen.id
_SELECT_ARMY = actions.FUNCTIONS.select_army.id
Expand Down Expand Up @@ -101,8 +101,9 @@ def main():
)
act.save("defeat_zerglings.pkl")


def deepq_callback(locals, globals):
#pprint.pprint(locals)

global max_mean_reward, last_filename
if('done' in locals and locals['done'] == True):
if('mean_100ep_reward' in locals
Expand Down
201 changes: 93 additions & 108 deletions train_mineral_shards.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,19 @@
import sys
import os

import datetime
import random
from absl import flags
from baselines import deepq
from baselines_legacy import mlp, cnn_to_mlp

from pysc2.env import sc2_env
from pysc2.lib import actions
import os

import deepq_mineral_shards
import datetime
from baselines_legacy import cnn_to_mlp
from baselines.logger import Logger, TensorBoardOutputFormat, HumanOutputFormat

from common.vec_env.subproc_vec_env import SubprocVecEnv
from a2c.policies import CnnPolicy
from a2c import a2c
from baselines.logger import Logger, TensorBoardOutputFormat, HumanOutputFormat

import random

import deepq_mineral_4way

import threading
import time
import deepq_mineral_shards

_MOVE_SCREEN = actions.FUNCTIONS.Move_screen.id
_SELECT_ARMY = actions.FUNCTIONS.select_army.id
Expand Down Expand Up @@ -50,8 +42,6 @@
max_mean_reward = 0
last_filename = ""

start_time = datetime.datetime.now().strftime("%m%d%H%M")


def main():
FLAGS(sys.argv)
Expand All @@ -64,7 +54,7 @@ def main():
print("num_agents : %s" % FLAGS.num_agents)
print("lr : %s" % FLAGS.lr)

if (FLAGS.lr == 0):
if FLAGS.lr == 0:
FLAGS.lr = random.uniform(0.00001, 0.001)

print("random lr : %s" % FLAGS.lr)
Expand All @@ -73,35 +63,36 @@ def main():

logdir = "tensorboard"

if (FLAGS.algorithm == "deepq-4way"):
if FLAGS.algorithm == "deepq-4way":
logdir = "tensorboard/mineral/%s/%s_%s_prio%s_duel%s_lr%s/%s" % (
FLAGS.algorithm, FLAGS.timesteps, FLAGS.exploration_fraction,
FLAGS.prioritized, FLAGS.dueling, lr_round, start_time)
elif (FLAGS.algorithm == "deepq"):
FLAGS.algorithm, FLAGS.timesteps, FLAGS.exploration_fraction,
FLAGS.prioritized, FLAGS.dueling, lr_round, start_time)
elif FLAGS.algorithm == "deepq":
logdir = "tensorboard/mineral/%s/%s_%s_prio%s_duel%s_lr%s/%s" % (
FLAGS.algorithm, FLAGS.timesteps, FLAGS.exploration_fraction,
FLAGS.prioritized, FLAGS.dueling, lr_round, start_time)
elif (FLAGS.algorithm == "a2c"):
FLAGS.algorithm, FLAGS.timesteps, FLAGS.exploration_fraction,
FLAGS.prioritized, FLAGS.dueling, lr_round, start_time)
elif FLAGS.algorithm == "a2c":
logdir = "tensorboard/mineral/%s/%s_n%s_s%s_nsteps%s/lr%s/%s" % (
FLAGS.algorithm, FLAGS.timesteps,
FLAGS.num_agents + FLAGS.num_scripts, FLAGS.num_scripts,
FLAGS.nsteps, lr_round, start_time)
FLAGS.algorithm, FLAGS.timesteps,
FLAGS.num_agents + FLAGS.num_scripts, FLAGS.num_scripts,
FLAGS.nsteps, lr_round, start_time)

if (FLAGS.log == "tensorboard"):
if FLAGS.log == "tensorboard":
Logger.DEFAULT \
= Logger.CURRENT \
= Logger(dir=None,
output_formats=[TensorBoardOutputFormat(logdir)])

elif (FLAGS.log == "stdout"):
elif FLAGS.log == "stdout":
Logger.DEFAULT \
= Logger.CURRENT \
= Logger(dir=None,
output_formats=[HumanOutputFormat(sys.stdout)])

if (FLAGS.algorithm == "deepq"):
if FLAGS.algorithm == "deepq":

AGENT_INTERFACE_FORMAT = sc2_env.AgentInterfaceFormat(feature_dimensions=sc2_env.Dimensions(screen=16, minimap=16))
AGENT_INTERFACE_FORMAT = sc2_env.AgentInterfaceFormat(
feature_dimensions=sc2_env.Dimensions(screen=16, minimap=16))
# temp solution - sc2_env.Agent(sc2_env.Race.terran) might be too restricting
# We need this change because sc2 now requires specifying players.
with sc2_env.SC2Env(
Expand All @@ -112,94 +103,92 @@ def main():
agent_interface_format=AGENT_INTERFACE_FORMAT) as env:

model = cnn_to_mlp(
convs=[(16, 8, 4), (32, 4, 2)], hiddens=[256], dueling=True)
convs=[(16, 8, 4), (32, 4, 2)], hiddens=[256], dueling=True)

acts = deepq_mineral_shards.learn(
env,
q_func=model,
num_actions=16,
lr=FLAGS.lr,
max_timesteps=FLAGS.timesteps,
buffer_size=10000,
exploration_fraction=FLAGS.exploration_fraction,
exploration_final_eps=0.01,
train_freq=4,
learning_starts=10000,
target_network_update_freq=1000,
gamma=0.99,
prioritized_replay=True,
callback=deepq_callback)
env,
q_func=model,
num_actions=16,
lr=FLAGS.lr,
max_timesteps=FLAGS.timesteps,
buffer_size=10000,
exploration_fraction=FLAGS.exploration_fraction,
exploration_final_eps=0.01,
train_freq=4,
learning_starts=10000,
target_network_update_freq=1000,
gamma=0.99,
prioritized_replay=True,
callback=deepq_callback)
acts[0].save("mineral_shards_x.pkl")
acts[1].save("mineral_shards_y.pkl")

elif (FLAGS.algorithm == "deepq-4way"):
elif FLAGS.algorithm == "deepq-4way":

AGENT_INTERFACE_FORMAT = sc2_env.AgentInterfaceFormat(feature_dimensions=sc2_env.Dimensions(screen=32, minimap=32))
AGENT_INTERFACE_FORMAT = sc2_env.AgentInterfaceFormat(
feature_dimensions=sc2_env.Dimensions(screen=32, minimap=32))
with sc2_env.SC2Env(
map_name="CollectMineralShards",
step_mul=step_mul,
agent_interface_format=AGENT_INTERFACE_FORMAT,
visualize=True) as env:

model = cnn_to_mlp(
convs=[(16, 8, 4), (32, 4, 2)], hiddens=[256], dueling=True)
convs=[(16, 8, 4), (32, 4, 2)], hiddens=[256], dueling=True)

act = deepq_mineral_4way.learn(
env,
q_func=model,
num_actions=4,
lr=FLAGS.lr,
max_timesteps=FLAGS.timesteps,
buffer_size=10000,
exploration_fraction=FLAGS.exploration_fraction,
exploration_final_eps=0.01,
train_freq=4,
learning_starts=10000,
target_network_update_freq=1000,
gamma=0.99,
prioritized_replay=True,
callback=deepq_4way_callback)
env,
q_func=model,
num_actions=4,
lr=FLAGS.lr,
max_timesteps=FLAGS.timesteps,
buffer_size=10000,
exploration_fraction=FLAGS.exploration_fraction,
exploration_final_eps=0.01,
train_freq=4,
learning_starts=10000,
target_network_update_freq=1000,
gamma=0.99,
prioritized_replay=True,
callback=deepq_4way_callback)

act.save("mineral_shards.pkl")

elif (FLAGS.algorithm == "a2c"):
elif FLAGS.algorithm == "a2c":

num_timesteps = int(40e6)

num_timesteps //= 4

seed = 0

env = SubprocVecEnv(FLAGS.num_agents + FLAGS.num_scripts, FLAGS.num_scripts, FLAGS.map)
env = SubprocVecEnv(FLAGS.num_agents + FLAGS.num_scripts, FLAGS.num_scripts,
FLAGS.map)

policy_fn = CnnPolicy
a2c.learn(
policy_fn,
env,
seed,
total_timesteps=num_timesteps,
nprocs=FLAGS.num_agents + FLAGS.num_scripts,
nscripts=FLAGS.num_scripts,
ent_coef=0.5,
nsteps=FLAGS.nsteps,
max_grad_norm=0.01,
callback=a2c_callback)


from pysc2.env import environment
import numpy as np
policy_fn,
env,
seed,
total_timesteps=num_timesteps,
nprocs=FLAGS.num_agents + FLAGS.num_scripts,
nscripts=FLAGS.num_scripts,
ent_coef=0.5,
nsteps=FLAGS.nsteps,
max_grad_norm=0.01,
callback=a2c_callback)


def deepq_callback(locals, globals):
#pprint.pprint(locals)

global max_mean_reward, last_filename
if ('done' in locals and locals['done'] == True):
if ('mean_100ep_reward' in locals and locals['num_episodes'] >= 10
and locals['mean_100ep_reward'] > max_mean_reward):
if 'done' in locals and locals['done'] == True:
if 'mean_100ep_reward' in locals and locals['num_episodes'] >= 10\
and locals['mean_100ep_reward'] > max_mean_reward:
print("mean_100ep_reward : %s max_mean_reward : %s" %
(locals['mean_100ep_reward'], max_mean_reward))

if (not os.path.exists(os.path.join(PROJ_DIR, 'models/deepq/'))):
if not os.path.exists(os.path.join(PROJ_DIR, 'models/deepq/')):
try:
os.mkdir(os.path.join(PROJ_DIR, 'models/'))
except Exception as e:
Expand All @@ -209,7 +198,7 @@ def deepq_callback(locals, globals):
except Exception as e:
print(str(e))

if (last_filename != ""):
if last_filename != "":
os.remove(last_filename)
print("delete last model file : %s" % last_filename)

Expand All @@ -218,28 +207,28 @@ def deepq_callback(locals, globals):
act_y = deepq_mineral_shards.ActWrapper(locals['act_y'])

filename = os.path.join(
PROJ_DIR,
'models/deepq/mineral_x_%s.pkl' % locals['mean_100ep_reward'])
PROJ_DIR,
'models/deepq/mineral_x_%s.pkl' % locals['mean_100ep_reward'])
act_x.save(filename)
filename = os.path.join(
PROJ_DIR,
'models/deepq/mineral_y_%s.pkl' % locals['mean_100ep_reward'])
PROJ_DIR,
'models/deepq/mineral_y_%s.pkl' % locals['mean_100ep_reward'])
act_y.save(filename)
print("save best mean_100ep_reward model to %s" % filename)
last_filename = filename


def deepq_4way_callback(locals, globals):
#pprint.pprint(locals)

global max_mean_reward, last_filename
if ('done' in locals and locals['done'] == True):
if ('mean_100ep_reward' in locals and locals['num_episodes'] >= 10
and locals['mean_100ep_reward'] > max_mean_reward):
if 'done' in locals and locals['done'] == True:
if 'mean_100ep_reward' in locals and locals['num_episodes'] >= 10\
and locals['mean_100ep_reward'] > max_mean_reward:
print("mean_100ep_reward : %s max_mean_reward : %s" %
(locals['mean_100ep_reward'], max_mean_reward))

if (not os.path.exists(
os.path.join(PROJ_DIR, 'models/deepq-4way/'))):
if not os.path.exists(
os.path.join(PROJ_DIR, 'models/deepq-4way/')):
try:
os.mkdir(os.path.join(PROJ_DIR, 'models/'))
except Exception as e:
Expand All @@ -249,36 +238,32 @@ def deepq_4way_callback(locals, globals):
except Exception as e:
print(str(e))

if (last_filename != ""):
if last_filename != "":
os.remove(last_filename)
print("delete last model file : %s" % last_filename)

max_mean_reward = locals['mean_100ep_reward']
act = deepq_mineral_4way.ActWrapper(locals['act'])
#act_y = deepq_mineral_shards.ActWrapper(locals['act_y'])
# act_y = deepq_mineral_shards.ActWrapper(locals['act_y'])

filename = os.path.join(PROJ_DIR,
'models/deepq-4way/mineral_%s.pkl' %
locals['mean_100ep_reward'])
act.save(filename)
# filename = os.path.join(
# PROJ_DIR,
# 'models/deepq/mineral_y_%s.pkl' % locals['mean_100ep_reward'])
# act_y.save(filename)

print("save best mean_100ep_reward model to %s" % filename)
last_filename = filename


def a2c_callback(locals, globals):
global max_mean_reward, last_filename
#pprint.pprint(locals)

if ('mean_100ep_reward' in locals and locals['num_episodes'] >= 10
and locals['mean_100ep_reward'] > max_mean_reward):
if 'mean_100ep_reward' in locals and locals['num_episodes'] >= 10\
and locals['mean_100ep_reward'] > max_mean_reward:
print("mean_100ep_reward : %s max_mean_reward : %s" %
(locals['mean_100ep_reward'], max_mean_reward))

if (not os.path.exists(os.path.join(PROJ_DIR, 'models/a2c/'))):
if not os.path.exists(os.path.join(PROJ_DIR, 'models/a2c/')):
try:
os.mkdir(os.path.join(PROJ_DIR, 'models/'))
except Exception as e:
Expand All @@ -288,16 +273,16 @@ def a2c_callback(locals, globals):
except Exception as e:
print(str(e))

if (last_filename != ""):
if last_filename != "":
os.remove(last_filename)
print("delete last model file : %s" % last_filename)

max_mean_reward = locals['mean_100ep_reward']
model = locals['model']

filename = os.path.join(
PROJ_DIR,
'models/a2c/mineral_%s.pkl' % locals['mean_100ep_reward'])
PROJ_DIR,
'models/a2c/mineral_%s.pkl' % locals['mean_100ep_reward'])
model.save(filename)
print("save best mean_100ep_reward model to %s" % filename)
last_filename = filename
Expand Down

0 comments on commit c8074a1

Please sign in to comment.