Skip to content

Commit

Permalink
enable python3 and pytorch0.3
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Olson committed Mar 15, 2018
1 parent fd2b5bd commit ea10868
Showing 1 changed file with 50 additions and 39 deletions.
89 changes: 50 additions & 39 deletions baby-a3c.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Baby Advantage Actor-Critic | Sam Greydanus | October 2017 | MIT License

from __future__ import print_function
import torch, os, gym, time, glob, argparse
import torch, os, gym, time, glob, argparse, sys
import numpy as np
from scipy.signal import lfilter
from scipy.misc import imresize # preserves single-pixel info _unlike_ img = img[::2,::2]
Expand All @@ -12,24 +12,21 @@
import torch.multiprocessing as mp
os.environ['OMP_NUM_THREADS'] = '1'

parser = argparse.ArgumentParser(description=None)
parser.add_argument('--env', default='Breakout-v0', type=str, help='gym environment')
parser.add_argument('--processes', default=20, type=int, help='number of processes to train with')
parser.add_argument('--render', default=False, type=bool, help='renders the atari environment')
parser.add_argument('--test', default=False, type=bool, help='test mode sets lr=0, chooses most likely actions')
parser.add_argument('--lstm_steps', default=20, type=int, help='steps to train LSTM over')
parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')
parser.add_argument('--seed', default=1, type=int, help='seed random # generators (for reproducibility)')
parser.add_argument('--gamma', default=0.99, type=float, help='discount for gamma-discounted rewards')
parser.add_argument('--tau', default=1.0, type=float, help='discount for generalized advantage estimation')
parser.add_argument('--horizon', default=0.99, type=float, help='horizon for running averages')
args = parser.parse_args()

args.save_dir = '{}/'.format(args.env.lower()) # keep the directory structure simple
if args.render: args.processes = 1 ; args.test = True # render mode -> test mode w one process
if args.test: args.lr = 0 # don't train in render mode
args.num_actions = gym.make(args.env).action_space.n # get the action space of this game
os.makedirs(args.save_dir) if not os.path.exists(args.save_dir) else None # make dir to save models etc.
def get_args():
parser = argparse.ArgumentParser(description=None)
parser.add_argument('--env', default='Breakout-v0', type=str, help='gym environment')
parser.add_argument('--processes', default=20, type=int, help='number of processes to train with')
parser.add_argument('--render', default=False, type=bool, help='renders the atari environment')
parser.add_argument('--test', default=False, type=bool, help='test mode sets lr=0, chooses most likely actions')
parser.add_argument('--lstm_steps', default=20, type=int, help='steps to train LSTM over')
parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')
parser.add_argument('--seed', default=1, type=int, help='seed random # generators (for reproducibility)')
parser.add_argument('--gamma', default=0.99, type=float, help='discount for gamma-discounted rewards')
parser.add_argument('--tau', default=1.0, type=float, help='discount for generalized advantage estimation')
parser.add_argument('--horizon', default=0.99, type=float, help='horizon for running averages')
return parser.parse_args()



discount = lambda x, gamma: lfilter([1],[1,-gamma],x[::-1])[::-1] # discounted rewards one liner
prepro = lambda img: imresize(img[35:195].mean(2), (80,80)).astype(np.float32).reshape(1,80,80)/255.
Expand Down Expand Up @@ -81,15 +78,7 @@ def step(self, closure=None):
self.state[p]['step'] = self.state[p]['shared_steps'][0] - 1 # there's a "step += 1" later
super.step(closure)

torch.manual_seed(args.seed)
shared_model = NNPolicy(channels=1, num_actions=args.num_actions).share_memory()
shared_optimizer = SharedAdam(shared_model.parameters(), lr=args.lr)

info = {k : torch.DoubleTensor([0]).share_memory_() for k in ['run_epr', 'run_loss', 'episodes', 'frames']}
info['frames'] += shared_model.try_load(args.save_dir)*1e6
if int(info['frames'][0]) == 0: printlog(args,'', end='', mode='w') # clear log file

def train(rank, args, info):
def train(shared_model, shared_optimizer, rank, args, info):
env = gym.make(args.env) # make a local (unshared) environment
env.seed(args.seed + rank) ; torch.manual_seed(args.seed + rank) # seed everything
model = NNPolicy(channels=1, num_actions=args.num_actions) # init a local (unshared) model
Expand All @@ -108,7 +97,7 @@ def train(rank, args, info):
for step in range(args.lstm_steps):
episode_length += 1
value, logit, (hx, cx) = model((Variable(state.view(1,1,80,80)), (hx, cx)))
logp = F.log_softmax(logit)
logp = F.log_softmax(logit, dim=1)

action = logp.max(1)[1].data if args.test else torch.exp(logp).multinomial().data[0]
state, reward, done, _ = env.step(action.numpy()[0])
Expand Down Expand Up @@ -143,7 +132,7 @@ def train(rank, args, info):
next_value = Variable(torch.zeros(1,1)) if done else model((Variable(state.unsqueeze(0)), (hx, cx)))[0]
values.append(Variable(next_value.data))

loss = cost_func(torch.cat(values), torch.cat(logps), torch.cat(actions), np.asarray(rewards))
loss = cost_func(args, torch.cat(values), torch.cat(logps), torch.cat(actions), np.asarray(rewards))
eploss += loss.data[0]
shared_optimizer.zero_grad() ; loss.backward()
torch.nn.utils.clip_grad_norm(model.parameters(), 40)
Expand All @@ -152,27 +141,49 @@ def train(rank, args, info):
if shared_param.grad is None: shared_param._grad = param.grad # sync gradients with shared model
shared_optimizer.step()

def cost_func(values, logps, actions, rewards):
def cost_func(args, values, logps, actions, rewards):
np_values = values.view(-1).data.numpy()

# generalized advantage estimation (a policy gradient method)
delta_t = np.asarray(rewards) + args.gamma * np_values[1:] - np_values[:-1]
gae = discount(delta_t, args.gamma * args.tau)
logpys = logps.gather(1, Variable(actions).view(-1,1))
policy_loss = -(logpys.view(-1) * Variable(torch.Tensor(gae))).sum()
policy_loss = -(logpys.view(-1) * Variable(torch.Tensor(gae.copy()))).sum()

# l2 loss over value estimator
rewards[-1] += args.gamma * np_values[-1]
discounted_r = discount(np.asarray(rewards), args.gamma)
discounted_r = Variable(torch.Tensor(discounted_r))
discounted_r = Variable(torch.Tensor(discounted_r.copy()))
value_loss = .5 * (discounted_r - values[:-1,0]).pow(2).sum()

entropy_loss = -(-logps * torch.exp(logps)).sum() # encourage lower entropy
return policy_loss + 0.5 * value_loss + 0.01 * entropy_loss

processes = []
for rank in range(args.processes):
p = mp.Process(target=train, args=(rank, args, info))
p.start() ; processes.append(p)
for p in processes:
p.join()
if __name__ == "__main__":
if sys.version_info[0] > 2:
mp.set_start_method("spawn") #this must not be in global scope
elif sys.platform == 'linux' or sys.platform == 'linux2':
raise "Must be using Python 3 with linux!" #or else you get a deadlock in conv2d

args = get_args()

args.save_dir = '{}/'.format(args.env.lower()) # keep the directory structure simple
if args.render: args.processes = 1 ; args.test = True # render mode -> test mode w one process
if args.test: args.lr = 0 # don't train in render mode
args.num_actions = gym.make(args.env).action_space.n # get the action space of this game
os.makedirs(args.save_dir) if not os.path.exists(args.save_dir) else None # make dir to save models etc.

torch.manual_seed(args.seed)
shared_model = NNPolicy(channels=1, num_actions=args.num_actions).share_memory()
shared_optimizer = SharedAdam(shared_model.parameters(), lr=args.lr)

info = {k : torch.DoubleTensor([0]).share_memory_() for k in ['run_epr', 'run_loss', 'episodes', 'frames']}
info['frames'] += shared_model.try_load(args.save_dir)*1e6
if int(info['frames'][0]) == 0: printlog(args,'', end='', mode='w') # clear log file

processes = []
for rank in range(args.processes):
p = mp.Process(target=train, args=(shared_model, shared_optimizer, rank, args, info))
p.start() ; processes.append(p)
for p in processes:
p.join()

0 comments on commit ea10868

Please sign in to comment.