Skip to content

Commit

Permalink
Add support for PPO2 training
Browse files Browse the repository at this point in the history
  • Loading branch information
jakvah committed Jul 30, 2020
1 parent 6e65296 commit 2a23257
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 7 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ where the arguments are the same as the ones in the [running the program locally
The following algorithms to train agents are available

- **DQN**. For more information see [the stable-baselines documentation](https://stable-baselines.readthedocs.io/en/master/modules/dqn.html). To specify the DQN algorithm when running the program pass "DQN" or "dqn" as the ``<algorithm>`` argument.
- **PPO2**. For more information see [the stable-baselines documentation](https://stable-baselines.readthedocs.io/en/master/modules/ppo2.html). To specify the PPO2 algorithm when running the program pass "PPO2" or "ppo2" as the ``<algorithm>`` argument.
- **PPO2**. For more information see [the stable-baselines documentation](https://stable-baselines.readthedocs.io/en/master/modules/ppo2.html). To specify the PPO2 algorithm when running the program pass "PPO2" or "ppo2" as the ``<algorithm>`` argument.

**NB:** *Since the results with the PPO2 algorithm during development have been weak, the support for PPO2 is limited. Functionality for dispaying trained PPO2 agents have not been developed.*

More algorithms comming soon!

Expand Down
8 changes: 4 additions & 4 deletions gym-drill/agent_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def get_trained_DQN_model(model_to_load,*,exploration_initial_eps=0.02,learning_
return model

def train_new_PPO2(total_timesteps,save_name):
model = PPO2(LnMlpPolicy, ENV, verbose=1, tensorboard_log=TENSORBOARD_FOLDER_PPO2)
model = PPO2(MlpPolicy, ENV, verbose=1, tensorboard_log=TENSORBOARD_FOLDER_PPO2)
model.learn(total_timesteps=total_timesteps, tb_log_name = "PPO2")
print("Done training with PPO2 algorithm.")
save_model(model,save_name)
Expand All @@ -88,10 +88,10 @@ def train_existing_PPO2(model_to_load,total_timesteps,save_name):
def get_trained_PPO2_model(model_to_load):
load_location = TRAINED_MODEL_FOLDER_DOCKER + model_to_load
try:
model = PPO2.load(load_location, ENV, exploration_initial_eps=exploration_initial_eps, learning_rate= learning_rate, tensorboard_log=TENSORBOARD_FOLDER_DQN)
except FileNotFoundError:
model = PPO2.load(load_location, ENV, tensorboard_log=TENSORBOARD_FOLDER_DQN)
except Exception:
load_location = TRAINED_MODEL_FOLDER_LOCAL + model_to_load
model = PPO2.load(load_location, ENV, exploration_initial_eps=exploration_initial_eps, learning_rate= learning_rate, tensorboard_log=TENSORBOARD_FOLDER_DQN)
model = PPO2.load(load_location, ENV, tensorboard_log=TENSORBOARD_FOLDER_DQN)

return model

Expand Down
2 changes: 1 addition & 1 deletion gym-drill/gym_drill/envs/drill_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class DrillEnv(gym.Env):
'video.frames_per_second': 50
}

def __init__(self,startLocation,bitInitialization,*,activate_hazards=False,monte_carlo=True,activate_log=False,load=True):
def __init__(self,startLocation,bitInitialization,*,activate_hazards=True,monte_carlo=False,activate_log=False,load=False):
self.activate_log = activate_log
self.activate_hazards = activate_hazards
self.monte_carlo = monte_carlo
Expand Down
2 changes: 1 addition & 1 deletion gym-drill/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,5 +102,5 @@ def verify_algorithm(algorithm):
train.display_agent(name)
# show model
elif algorithm == "PPO2":
model = train.get_trained_PPO2_model()
model = train.get_trained_PPO2_model(name)
train.display_agent(model)

0 comments on commit 2a23257

Please sign in to comment.