Skip to content

Commit

Permalink
tested openai agents
Browse files Browse the repository at this point in the history
  • Loading branch information
Cheng-Xue committed Nov 28, 2021
1 parent f9130c4 commit 468fd57
Showing 13 changed files with 15 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -15,4 +15,4 @@ tasks/generated_tasks/

sciencebirdsgames/Linux/

**/__pycache__/
*.pyc
2 changes: 1 addition & 1 deletion sciencebirdsagents/HeuristicAgents/CollectionAgent.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
import numpy as np

from SBAgent import SBAgent
from SBEnviornment.SBEnvironmentWrapper import SBEnvironmentWrapper
from SBEnvironment.SBEnvironmentWrapper import SBEnvironmentWrapper
from Utils.LevelSelection import LevelSelectionSchema
from StateReader.SymbolicStateDevReader import SymbolicStateDevReader
from Utils.point2D import Point2D
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
from typing import List

from HeuristicAgents.CollectionAgent import CollectionAgent
from SBEnviornment.SBEnvironmentWrapper import SBEnvironmentWrapper
from SBEnvironment.SBEnvironmentWrapper import SBEnvironmentWrapper
import os
import csv
import cv2
2 changes: 1 addition & 1 deletion sciencebirdsagents/LearningAgents/RLDiscreteAgent.py
Original file line number Diff line number Diff line change
@@ -10,7 +10,7 @@
from LearningAgents.LearningAgent import LearningAgent
from Utils.LevelSelection import LevelSelectionSchema
from LearningAgents.Memory import ReplayMemory
from SBEnviornment.SBEnvironmentWrapper import SBEnvironmentWrapper
from SBEnvironment.SBEnvironmentWrapper import SBEnvironmentWrapper
from torch.utils.tensorboard import SummaryWriter
from LearningAgents.RLNetwork.MultiHeadRelationalModule import MultiHeadRelationalModuleImage
from einops import rearrange, reduce
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
from LearningAgents.LearningAgent import LearningAgent
from Utils.LevelSelection import LevelSelectionSchema
from LearningAgents.Memory import ReplayMemory
from SBEnviornment.SBEnvironmentWrapper import SBEnvironmentWrapper
from SBEnvironment.SBEnvironmentWrapper import SBEnvironmentWrapper
from torch.utils.tensorboard import SummaryWriter
from einops import rearrange, reduce

2 changes: 1 addition & 1 deletion sciencebirdsagents/LearningAgents/SACAgent.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@

from LearningAgents.LearningAgent import LearningAgent
from LearningAgents.Memory import ReplayMemory
from SBEnviornment.SBEnvironmentWrapper import SBEnvironmentWrapper
from SBEnvironment.SBEnvironmentWrapper import SBEnvironmentWrapper
from Utils.LevelSelection import LevelSelectionSchema


6 changes: 3 additions & 3 deletions sciencebirdsagents/OpenAI_StableBaseline_Train.py
Original file line number Diff line number Diff line change
@@ -13,8 +13,8 @@
from torch.utils.tensorboard import SummaryWriter

from LearningAgents.RLNetwork.OpenAICustomCNN import OpenAICustomCNN
from SBEnviornment.SBEnvironmentWrapperOpenAI import SBEnvironmentWrapperOpenAI
from SBEnviornment.Server import Server
from SBEnvironment.SBEnvironmentWrapperOpenAI import SBEnvironmentWrapperOpenAI
from SBEnvironment.Server import Server
from Utils.Config import config
from Utils.Parameters import Parameters
from Utils.utils import make_env, sample_levels_with_at_least_num_agents
@@ -94,7 +94,7 @@ def _on_step(self) -> bool:
parser.add_argument('--mode',
type=str,
default='within_template') # propose three modes, 'train1testrest', 'trainhalftesthalf', 'trainresttestone'
parser.add_argument('--level_path', type=str, default='fourth generation')
parser.add_argument('--level_path', type=str, default='fifth_generation')
parser.add_argument('--game_version', type=str, default='Linux')

args = parser.parse_args()
Original file line number Diff line number Diff line change
@@ -327,7 +327,7 @@ def __degToShot(self, deg):


if __name__ == '__main__':
from SBEnviornment.Server import Server
from SBEnvironment.Server import Server
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import SubprocVecEnv
from LearningAgents.RLNetwork.OpenAICustomCNN import OpenAICustomCNN
2 changes: 1 addition & 1 deletion sciencebirdsagents/TestAgentOfflineWithinCapability.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@
from torch.utils.tensorboard import SummaryWriter

from LearningAgents.LearningAgentThread import MultiThreadTrajCollection
from SBEnviornment.SBEnvironmentWrapper import SBEnvironmentWrapper
from SBEnvironment.SBEnvironmentWrapper import SBEnvironmentWrapper
from Utils.Config import config
from Utils.LevelSelection import LevelSelectionSchema
from Utils.Parameters import Parameters
2 changes: 1 addition & 1 deletion sciencebirdsagents/TestAgentOfflineWithinTemplate.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@
from Utils.LevelSelection import LevelSelectionSchema

from LearningAgents.LearningAgentThread import MultiThreadTrajCollection
from SBEnviornment.SBEnvironmentWrapper import SBEnvironmentWrapper
from SBEnvironment.SBEnvironmentWrapper import SBEnvironmentWrapper
from Utils.Config import config
from Utils.Parameters import Parameters

2 changes: 1 addition & 1 deletion sciencebirdsagents/TrainAndTestOpenAIStableBaselines.sh
Original file line number Diff line number Diff line change
@@ -8,4 +8,4 @@ fi
for val in $templates; do
echo running $val
python OpenAI_StableBaseline_Train.py --template $val --mode $mode --game_version Linux --level_path "fifth_generation"
done
done
2 changes: 1 addition & 1 deletion sciencebirdsagents/TrainLearningAgent.py
Original file line number Diff line number Diff line change
@@ -113,7 +113,7 @@ def sample_levels(training_level_set, num_agents, agent_idx, **kwargs):
parser.add_argument('--template', metavar='N', type=str)
parser.add_argument('--mode',
type=str) # propose three modes, 'train1testrest', 'trainhalftesthalf', 'trainresttestone', 'benchmark'
parser.add_argument('--level_path', type=str, default='fourth generation')
parser.add_argument('--level_path', type=str, default='fifth_generation')
parser.add_argument('--game_version', type=str, default='Linux')
parser.add_argument('--if_save_local', type=str2bool, default=True)
parser.add_argument('--resume', type=str2bool, default=False)
2 changes: 1 addition & 1 deletion sciencebirdsagents/Utils/utils.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@
import math
import random

from SBEnviornment.SBEnvironmentWrapperOpenAI import SBEnvironmentWrapperOpenAI
from SBEnvironment.SBEnvironmentWrapperOpenAI import SBEnvironmentWrapperOpenAI


def str2bool(v):

0 comments on commit 468fd57

Please sign in to comment.