From 6b475d17e3701f528d1945919626af43e247a2ac Mon Sep 17 00:00:00 2001 From: rached1997 Date: Sun, 19 Mar 2023 18:04:46 -0400 Subject: [PATCH] wandb bugs --- debugger/checkers/rl_checkers/reward_check.py | 13 +++---- debugger/debugger_factory.py | 36 +++++++++++-------- debugger/utils/config/default_debugger.yml | 3 +- debugger/utils/settings.py | 19 +++++----- examples/dqn/debugger.yml | 17 +++++---- .../run-3ovoht77.wandb | 0 6 files changed, 53 insertions(+), 35 deletions(-) create mode 100644 examples/dqn/wandb/run-20230318_114142-3ovoht77/run-3ovoht77.wandb diff --git a/debugger/checkers/rl_checkers/reward_check.py b/debugger/checkers/rl_checkers/reward_check.py index 07d0479..2e642a9 100644 --- a/debugger/checkers/rl_checkers/reward_check.py +++ b/debugger/checkers/rl_checkers/reward_check.py @@ -11,6 +11,7 @@ class RewardsCheck(DebuggerInterface): This class performs checks on the accumulated reward. For more details on the specific checks performed, refer to the `run()` function. """ + def __init__(self): """ Initializes the following parameters: @@ -99,7 +100,7 @@ def run(self, reward, max_total_steps, max_reward) -> None: return n_rewards = len(self._episodes_rewards) if self.check_period() and ( - n_rewards >= self.config.window_size * self.config.start + n_rewards >= self.config.window_size * self.config.start ): stds = [] stds_nor = [] @@ -116,13 +117,13 @@ def run(self, reward, max_total_steps, max_reward) -> None: stds += [reward_std] stds_nor += [reward_std_nor] - self.wandb_metrics = {"reward_stds": stds} - stds = torch.tensor(stds).float() stds_nor = torch.tensor(stds_nor).float() + self.wandb_metrics = {"reward_stds": stds} + if (self.step_num < max_total_steps * self.config.exploration_perc) and ( - not self.config.fluctuation.disabled + not self.config.fluctuation.disabled ): cof = get_data_slope(stds) fluctuations = estimate_fluctuation_rmse(cof, stds) @@ -134,7 +135,7 @@ def run(self, reward, max_total_steps, max_reward) -> None: ) if self.step_num > max_total_steps * (self.config.exploitation_perc) and ( - not self.config.monotonicity.disabled + not self.config.monotonicity.disabled ): cof = get_data_slope(stds_nor) self.check_reward_monotonicity(cof, max_reward) @@ -159,7 +160,7 @@ def check_reward_monotonicity(self, cof, max_reward): else: stagnated_reward = torch.mean(self._episodes_rewards) if stagnated_reward < max_reward * ( - 1 - self.config.monotonicity.reward_stagnation_tolerance + 1 - self.config.monotonicity.reward_stagnation_tolerance ): self.error_msg.append( self.main_msgs["stagnated_reward"].format( diff --git a/debugger/debugger_factory.py b/debugger/debugger_factory.py index e405749..ab8f17e 100644 --- a/debugger/debugger_factory.py +++ b/debugger/debugger_factory.py @@ -12,7 +12,7 @@ class DebuggerFactory: def __init__(self): self.logger = settings.set_logger() - self.wandb_logger = settings.set_wandb_logger() + self.wandb_logger = None self.debuggers = dict() self.observed_params = {} self.observed_params_update_nums = dict() @@ -95,8 +95,8 @@ def is_final_step_of_ep(self): Returns (bool): returns True if the step is the last one in an episode, and False otherwise. """ if self.observed_params["done"] or ( - (self.step_num > 0) - and ((self.step_num % self.observed_params["max_steps_per_episode"]) == 0) + (self.step_num > 0) + and ((self.step_num % self.observed_params["max_steps_per_episode"]) == 0) ): return True return False @@ -199,6 +199,9 @@ def set_config(self, config_path=None): Args: config_path (str): The path to the configuration dict """ + if not (config_path is None): + self.wandb_logger = settings.set_wandb_logger(config_path) + if config_path is None: config_path = load_default_config() @@ -258,21 +261,26 @@ def turn_on(self): self.training = True def set_custom_wandb_logger( - self, - project, - name, - dir=None, - mode=None, - id=None, - resume=None, - start_method=None, - **kwargs, + self, + project, + name, + dir=None, + mode=None, + id=None, + resume=None, + start_method=None, + **kwargs, ): self.wandb_logger.custom_wandb_logger( project, name, dir, mode, id, resume, start_method, **kwargs ) def plot_wandb(self, debugger): - if debugger.wandb_metrics: - self.wandb_logger.plot(debugger.wandb_metrics) + if debugger.wandb_metrics and (not (self.wandb_logger is None)): + for (key, values) in debugger.wandb_metrics.items(): + if values.ndim == 0: + self.wandb_logger.plot({key: values}) + else: + for value in values: + self.wandb_logger.plot({key: value}) debugger.wandb_metrics = {} diff --git a/debugger/utils/config/default_debugger.yml b/debugger/utils/config/default_debugger.yml index 587ebb3..ba594ec 100644 --- a/debugger/utils/config/default_debugger.yml +++ b/debugger/utils/config/default_debugger.yml @@ -31,7 +31,8 @@ debugger: steps_done ] check_type: - wand_logger: + wandb_logger: + disable: False project: DRLDebugger name: default_name diff --git a/debugger/utils/settings.py b/debugger/utils/settings.py index 5cef2e4..a0cc262 100644 --- a/debugger/utils/settings.py +++ b/debugger/utils/settings.py @@ -83,16 +83,19 @@ def set_logger(): return file_logger(log_fpath, "logger") -def set_wandb_logger(config_path=None): - if config_path is None: - config_path = load_default_config() - +def set_wandb_logger(config_path): with open(config_path) as f: config = yaml.safe_load(f) - return WandbLogger( - project=config["debugger"]["wand_logger"]["project"], - name=config["debugger"]["wand_logger"]["name"], - ) + if "wandb_logger" in config["debugger"].keys(): + if bool(config["debugger"]["wandb_logger"]["disable"]): + return None + return WandbLogger( + project=config["debugger"]["wandb_logger"]["project"], + name=config["debugger"]["wandb_logger"]["name"], + ) + return None + + def react(logger, messages, fail_on=False): diff --git a/examples/dqn/debugger.yml b/examples/dqn/debugger.yml index 3a54052..32ebea8 100644 --- a/examples/dqn/debugger.yml +++ b/examples/dqn/debugger.yml @@ -7,20 +7,20 @@ debugger: check_type: # - name: Activation # period: 10 - - name: Bias - period: 10 +# - name: Bias +# period: 10 # - name: Loss # period: 10 # - name: Weight # period: 10 # - name: ProperFitting # - name: Gradient -# - name: Action -# period: 10 + - name: Action + period: 10 # - name: Agent # - name: ExplorationParameter -# - name: Reward -# period: 10 + - name: Reward + period: 10 # - name: State # period: 10 # - name: Step @@ -29,3 +29,8 @@ debugger: # period: 100 # - name: ValueFunction # period: 10 + + wandb_logger: + disable: False + project: DRLDebugger + name: default_name diff --git a/examples/dqn/wandb/run-20230318_114142-3ovoht77/run-3ovoht77.wandb b/examples/dqn/wandb/run-20230318_114142-3ovoht77/run-3ovoht77.wandb new file mode 100644 index 0000000..e69de29