forked from Limmen/gym-idsgame
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_qhd_maximal_attack.py
executable file
·79 lines (70 loc) · 3.41 KB
/
run_qhd_maximal_attack.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os
import gym
import sys
import time
from gym_idsgame.agents.training_agents.q_learning.abstract_qhd_agent_config import AbstractQHDAgentConfig
from gym_idsgame.agents.training_agents.q_learning.qhd.qhd import QHDAgent
from gym_idsgame.agents.training_agents.q_learning.qhd.qhd_config import QHDConfig
from experiments.util import util
# def get_script_path():
# """
# :return: the script path
# """
# return os.path.dirname(os.path.realpath(sys.argv[0]))
# def default_output_dir() -> str:
# """
# :return: the default output dir
# """
# script_dir = get_script_path()
# return script_dir
# Program entrypoint
if __name__ == '__main__':
random_seed = 0
util.create_artefact_dirs('./', random_seed)
for lr in [0.00001]: #, 0.0001, 0.001, 0.01]:
qhd_config = QHDConfig(input_dim=88,
defender_output_dim=88, # attacker would need 80: 10 attacks (+1 for defender), 8 nodes
attacker_output_dim=80,
replay_memory_size=10000,
batch_size=32,
target_network_update_freq=1000, # TODO: Hyperparameter for fine-tuning
gpu=False,
tensorboard=False,
tensorboard_dir="./results/tensorboard/",
lr_exp_decay=False,
lr_decay_rate=0.9999)
qhd_agent_config = AbstractQHDAgentConfig(gamma=0.999,
lr=lr, # TODO: Hyper-parameter for fine-tuning
num_episodes=20001,
epsilon=1,
min_epsilon=0.01,
epsilon_decay=0.95, # TODO: Hyperparameter for fine-tuning
eval_sleep=0.9,
eval_frequency=1000,
eval_episodes=100,
train_log_frequency=100,
eval_log_frequency=1,
render=False,
eval_render=False,
video=False,
video_fps=5,
video_frequency=101,
video_dir="./results/videos/",
gifs=False,
gif_dir="./results/gifs/",
save_dir="./results/data/maximal_attack/",
attacker=False,
defender=True,
qhd_config=qhd_config,
checkpoint_freq=300000)
# Set up environment
env_name = "idsgame-maximal_attack-v3" # "idsgame-maximal_defense-v3"
env = gym.make(env_name, save_dir="./results/data/maximal_attack/")
# Set up agent
agent = QHDAgent(env, qhd_agent_config, "FINAL_")
start = time.time()
agent.train()
print("*********Time to train*********: ", time.time() - start)
# TODO: I need to implement these functions
train_result = agent.train_result
eval_result = agent.eval_result