Skip to content

Commit

Permalink
added support to continue training from checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
FelipeMartins96 committed Aug 21, 2020
1 parent cb0d455 commit e248855
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 21 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ pb/*
__pycache__
.idea/
*.0
agents/agentPenaltyEnv/rlAdventure2DDPG/runs/
78 changes: 57 additions & 21 deletions agents/agentPenaltyEnv/rlAdventure2DDPG/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import gym
import gym_ssl
import numpy as np
import os
import sys

from torch.utils.tensorboard import SummaryWriter
import torch.optim as optim
Expand All @@ -15,14 +17,15 @@

class AgentDDPG:

def __init__(self,
def __init__(self, name='DDPG',
maxEpisodes=10000, maxSteps=200, batchSize=256, replayBufferSize=200000, valueLR=1e-3, policyLR=1e-4,
hiddenDim=256):
hiddenDim=256, nEpisodesPerCheckpoint=10):
# Training Parameters
self.batchSize = batchSize
self.maxSteps = maxSteps
self.maxEpisodes = maxEpisodes
self.episode = 0
self.nEpisodesPerCheckpoint = nEpisodesPerCheckpoint
self.nEpisodes = 0

# Check if cuda gpu is available, and select it
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down Expand Up @@ -54,8 +57,8 @@ def __init__(self,
self.replayBuffer = ReplayBuffer(replayBufferSize)

# Tensorboard Init
self.path = './runs/'
# self._load
self.path = './runs/' + name
self._load()
self.writer = SummaryWriter(log_dir=self.path)

def _update(self, batch_size,
Expand Down Expand Up @@ -103,39 +106,72 @@ def _update(self, batch_size,

# Training Loop
def train(self):
while self.episode < self.maxEpisodes:
# TODO pausar treino quando apertar algum botão e salvar estado
while self.nEpisodes < self.maxEpisodes:
state = self.env.reset()
self.ouNoise.reset()
episode_reward = 0
steps_episode = 0
episodeReward = 0
nStepsInEpisode = 0

for step in range(self.maxSteps):
while nStepsInEpisode < self.maxSteps:
action = self.policyNet.get_action(state)
action = self.ouNoise.get_action(action, step)
action = self.ouNoise.get_action(action, nStepsInEpisode)
next_state, reward, done, _ = self.env.step(action)

self.replayBuffer.push(state, action, reward, next_state, done)
if len(self.replayBuffer) > self.batchSize:
self._update(self.batchSize)

state = next_state
episode_reward += reward
episodeReward += reward
nStepsInEpisode += 1

if done:
steps_episode = step
break

self.episode += 1
self.nEpisodes += 1

# rewards.append(episode_reward)
# TODO trocar por lista circular
# rewards.append(episodeReward)

self.writer.add_scalar('Train/Reward', episode_reward, self.episode)
self.writer.add_scalar('Train/Steps', steps_episode, self.episode)
self.writer.add_scalar('Train/Reward', episodeReward, self.nEpisodes)
self.writer.add_scalar('Train/Steps', nStepsInEpisode, self.nEpisodes)

# if (episode % 1000) == 0:
# torch.save({
# 'target_value_net_dict': target_value_net.state_dict(),
# 'target_policy_net_dict': target_policy_net.state_dict(),
# }, './saved_networks')
# TODO arquivo separado a cada x passos
if (self.nEpisodes % self.nEpisodesPerCheckpoint) == 0:
torch.save({
'valueNetDict': self.valueNet.state_dict(),
'policyNetDict': self.targetPolicyNet.state_dict(),
'targetValueNetDict': self.targetValueNet.state_dict(),
'targetPolicyNetDict': self.targetPolicyNet.state_dict(),
'nEpisodes': self.nEpisodes
}, self.path + '/checkpoint')

self.writer.flush()

def _load(self):
# Check if folder exists
if os.path.exists(self.path):
try:
checkpoint = torch.load(self.path + '/checkpoint')
# Load networks parameters checkpoint
self.valueNet.load_state_dict(checkpoint['valueNetDict'])
self.policyNet.load_state_dict(checkpoint['policyNetDict'])
self.targetValueNet.load_state_dict(checkpoint['targetValueNetDict'])
self.targetPolicyNet.load_state_dict(checkpoint['targetPolicyNetDict'])
# Load number of episodes on checkpoint
self.nEpisodes = checkpoint['nEpisodes']
print(self.nEpisodes)
except FileNotFoundError as e:
print(e)
except IOError as e:
print(e)


if __name__ == '__main__':

if len(sys.argv) == 2:
agent = AgentDDPG(name=sys.argv[1])
agent.train()
else:
print("correct usage: python train.py {name}")

0 comments on commit e248855

Please sign in to comment.