Skip to content

Commit

Permalink
fixed bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmedhajyahmed committed Mar 19, 2023
2 parents a32aa5c + 6b475d1 commit 8d12885
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 29 deletions.
13 changes: 7 additions & 6 deletions debugger/checkers/rl_checkers/reward_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class RewardsCheck(DebuggerInterface):
This class performs checks on the accumulated reward.
For more details on the specific checks performed, refer to the `run()` function.
"""

def __init__(self):
"""
Initializes the following parameters:
Expand Down Expand Up @@ -99,7 +100,7 @@ def run(self, reward, max_total_steps, max_reward) -> None:
return
n_rewards = len(self._episodes_rewards)
if self.check_period() and (
n_rewards >= self.config.window_size * self.config.start
n_rewards >= self.config.window_size * self.config.start
):
stds = []
stds_nor = []
Expand All @@ -116,13 +117,13 @@ def run(self, reward, max_total_steps, max_reward) -> None:
stds += [reward_std]
stds_nor += [reward_std_nor]

self.wandb_metrics = {"reward_stds": stds}

stds = torch.tensor(stds).float()
stds_nor = torch.tensor(stds_nor).float()

self.wandb_metrics = {"reward_stds": stds}

if (self.step_num < max_total_steps * self.config.exploration_perc) and (
not self.config.fluctuation.disabled
not self.config.fluctuation.disabled
):
cof = get_data_slope(stds)
fluctuations = estimate_fluctuation_rmse(cof, stds)
Expand All @@ -134,7 +135,7 @@ def run(self, reward, max_total_steps, max_reward) -> None:
)

if self.step_num > max_total_steps * (self.config.exploitation_perc) and (
not self.config.monotonicity.disabled
not self.config.monotonicity.disabled
):
cof = get_data_slope(stds_nor)
self.check_reward_monotonicity(cof, max_reward)
Expand All @@ -159,7 +160,7 @@ def check_reward_monotonicity(self, cof, max_reward):
else:
stagnated_reward = torch.mean(self._episodes_rewards)
if stagnated_reward < max_reward * (
1 - self.config.monotonicity.reward_stagnation_tolerance
1 - self.config.monotonicity.reward_stagnation_tolerance
):
self.error_msg.append(
self.main_msgs["stagnated_reward"].format(
Expand Down
36 changes: 22 additions & 14 deletions debugger/debugger_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class DebuggerFactory:
def __init__(self):
self.logger = settings.set_logger()
self.wandb_logger = settings.set_wandb_logger()
self.wandb_logger = None
self.debuggers = dict()
self.observed_params = {}
self.observed_params_update_nums = dict()
Expand Down Expand Up @@ -95,8 +95,8 @@ def is_final_step_of_ep(self):
Returns (bool): returns True if the step is the last one in an episode, and False otherwise.
"""
if self.observed_params["done"] or (
(self.step_num > 0)
and ((self.step_num % self.observed_params["max_steps_per_episode"]) == 0)
(self.step_num > 0)
and ((self.step_num % self.observed_params["max_steps_per_episode"]) == 0)
):
return True
return False
Expand Down Expand Up @@ -199,6 +199,9 @@ def set_config(self, config_path=None):
Args:
config_path (str): The path to the configuration dict
"""
if not (config_path is None):
self.wandb_logger = settings.set_wandb_logger(config_path)

if config_path is None:
config_path = load_default_config()

Expand Down Expand Up @@ -260,21 +263,26 @@ def turn_on(self):
self.training = True

def set_custom_wandb_logger(
self,
project,
name,
dir=None,
mode=None,
id=None,
resume=None,
start_method=None,
**kwargs,
self,
project,
name,
dir=None,
mode=None,
id=None,
resume=None,
start_method=None,
**kwargs,
):
self.wandb_logger.custom_wandb_logger(
project, name, dir, mode, id, resume, start_method, **kwargs
)

def plot_wandb(self, debugger):
if debugger.wandb_metrics:
self.wandb_logger.plot(debugger.wandb_metrics)
if debugger.wandb_metrics and (not (self.wandb_logger is None)):
for (key, values) in debugger.wandb_metrics.items():
if values.ndim == 0:
self.wandb_logger.plot({key: values})
else:
for value in values:
self.wandb_logger.plot({key: value})
debugger.wandb_metrics = {}
3 changes: 2 additions & 1 deletion debugger/utils/config/default_debugger.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ debugger:
steps_done
]
check_type:
wand_logger:
wandb_logger:
disable: False
project: DRLDebugger
name: default_name

Expand Down
19 changes: 11 additions & 8 deletions debugger/utils/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,19 @@ def set_logger():
return file_logger(log_fpath, "logger")


def set_wandb_logger(config_path=None):
if config_path is None:
config_path = load_default_config()

def set_wandb_logger(config_path):
with open(config_path) as f:
config = yaml.safe_load(f)
return WandbLogger(
project=config["debugger"]["wand_logger"]["project"],
name=config["debugger"]["wand_logger"]["name"],
)
if "wandb_logger" in config["debugger"].keys():
if bool(config["debugger"]["wandb_logger"]["disable"]):
return None
return WandbLogger(
project=config["debugger"]["wandb_logger"]["project"],
name=config["debugger"]["wandb_logger"]["name"],
)
return None




def react(logger, messages, fail_on=False):
Expand Down
5 changes: 5 additions & 0 deletions examples/dqn/debugger.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,8 @@ debugger:
# period: 100
# - name: ValueFunction
# period: 10

wandb_logger:
disable: False
project: DRLDebugger
name: default_name
Empty file.

0 comments on commit 8d12885

Please sign in to comment.