Skip to content

Commit

Permalink
recorder_log_keys arg: enable logging of arbitrary keys from test rol…
Browse files Browse the repository at this point in the history
…louts (pytorch#212)
  • Loading branch information
vmoens authored Jun 21, 2022
1 parent 0a94e45 commit fc31ab6
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 5 deletions.
11 changes: 11 additions & 0 deletions torchrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import abc
import collections
import time
from warnings import warn

Expand Down Expand Up @@ -98,3 +99,13 @@ def seed_generator(seed):
rng = np.random.default_rng(seed)
seed = int.from_bytes(rng.bytes(8), "big")
return seed % max_seed_val


class KeyDependentDefaultDict(collections.defaultdict):
def __init__(self, fun=lambda x: x):
self.fun = fun
super().__init__()

def __missing__(self, key):
value = self.fun(key)
return value
8 changes: 8 additions & 0 deletions torchrl/trainers/helpers/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ def parser_recorder_args(parser: ArgumentParser) -> ArgumentParser:
help="experiment name. Used for logging directory. "
"A date and uuid will be joined to account for multiple experiments with the same name.",
)
parser.add_argument(
"--recorder_log_keys",
"--recorder-log-keys",
nargs="+",
type=str,
default=["reward"],
help="Keys to log in the recorder.",
)
parser.add_argument(
"--record_interval",
"--record-interval",
Expand Down
3 changes: 2 additions & 1 deletion torchrl/trainers/helpers/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def make_trainer(
policy_exploration=policy_exploration,
recorder=recorder,
record_interval=args.record_interval,
log_keys=args.recorder_log_keys,
)
trainer.register_op(
"post_steps_log",
Expand All @@ -228,7 +229,7 @@ def make_trainer(
record_interval=args.record_interval,
exploration_mode="random",
suffix="exploration",
out_key="r_evaluation_exploration",
out_keys={"reward": "r_evaluation_exploration"},
)
trainer.register_op(
"post_steps_log",
Expand Down
26 changes: 22 additions & 4 deletions torchrl/trainers/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import torch.nn
from torch import nn, optim

from torchrl import KeyDependentDefaultDict

try:
from tqdm import tqdm

Expand Down Expand Up @@ -824,7 +826,8 @@ def __init__(
policy_exploration: TensorDictModule,
recorder: _EnvClass,
exploration_mode: str = "mode",
out_key: str = "r_evaluation",
log_keys: Optional[List[str]] = None,
out_keys: Optional[Dict[str, str]] = None,
suffix: Optional[str] = None,
log_pbar: bool = False,
) -> None:
Expand All @@ -836,7 +839,13 @@ def __init__(
self._count = 0
self.record_interval = record_interval
self.exploration_mode = exploration_mode
self.out_key = out_key
if log_keys is None:
log_keys = ["reward"]
if out_keys is None:
out_keys = KeyDependentDefaultDict()
out_keys["reward"] = "r_evaluation"
self.log_keys = log_keys
self.out_keys = out_keys
self.suffix = suffix
self.log_pbar = log_pbar

Expand All @@ -859,10 +868,19 @@ def __call__(self, batch: _TensorDict) -> Dict:
if isinstance(self.policy_exploration, torch.nn.Module):
self.policy_exploration.train()
self.recorder.train()
reward = td_record.get("reward").mean() / self.frame_skip
self.recorder.transform.dump(suffix=self.suffix)
out = {self.out_key: reward, "log_pbar": self.log_pbar}

out = dict()
for key in self.log_keys:
value = td_record.get(key).float().mean()
if key == "reward":
value = value / self.frame_skip
if key == "solved":
value = value.any().float()
out[self.out_keys[key]] = value
out["log_pbar"] = self.log_pbar
self._count += 1
self.recorder.close()
return out


Expand Down

0 comments on commit fc31ab6

Please sign in to comment.