Skip to content

Commit

Permalink
[Feature]: Refactored logging to be able to support other loggers eas…
Browse files Browse the repository at this point in the history
…ily (pytorch#270)
  • Loading branch information
nicolas-dufour authored Jul 15, 2022
1 parent e5bea04 commit 806733f
Show file tree
Hide file tree
Showing 12 changed files with 196 additions and 72 deletions.
10 changes: 5 additions & 5 deletions examples/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@

@hydra.main(version_base=None, config_path=None, config_name="config")
def main(cfg: "DictConfig"):
from torch.utils.tensorboard import SummaryWriter
from torchrl.trainers.loggers import TensorboardLogger

cfg = correct_for_frame_skip(cfg)

Expand All @@ -89,7 +89,7 @@ def main(cfg: "DictConfig"):
datetime.now().strftime("%y_%m_%d-%H_%M_%S"),
]
)
writer = SummaryWriter(f"ddpg_logging/{exp_name}")
logger = TensorboardLogger(f"ddpg_logging/{exp_name}")
video_tag = exp_name if cfg.record_video else ""

stats = None
Expand Down Expand Up @@ -161,7 +161,7 @@ def main(cfg: "DictConfig"):
video_tag=video_tag,
norm_obs_only=True,
stats=stats,
writer=writer,
logger=logger,
use_env_creator=False,
)()

Expand Down Expand Up @@ -195,7 +195,7 @@ def main(cfg: "DictConfig"):
target_net_updater,
actor_model_explore,
replay_buffer,
writer,
logger,
cfg,
)

Expand All @@ -217,7 +217,7 @@ def select_keys(batch):
print(f"init seed: {cfg.seed}, final seed: {final_seed}")

trainer.train()
return (writer.log_dir, trainer._log_dict, trainer.state_dict())
return (logger.log_dir, trainer._log_dict, trainer.state_dict())


if __name__ == "__main__":
Expand Down
10 changes: 5 additions & 5 deletions examples/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
@hydra.main(version_base=None, config_path=None, config_name="config")
def main(cfg: "DictConfig"):

from torch.utils.tensorboard import SummaryWriter
from torchrl.trainers.loggers import TensorboardLogger

cfg = correct_for_frame_skip(cfg)

Expand All @@ -80,7 +80,7 @@ def main(cfg: "DictConfig"):
datetime.now().strftime("%y_%m_%d-%H_%M_%S"),
]
)
writer = SummaryWriter(f"dqn_logging/{exp_name}")
logger = TensorboardLogger(f"dqn_logging/{exp_name}")
video_tag = exp_name if cfg.record_video else ""

stats = None
Expand Down Expand Up @@ -133,7 +133,7 @@ def main(cfg: "DictConfig"):
video_tag=video_tag,
norm_obs_only=True,
stats=stats,
writer=writer,
logger=logger,
)()

# remove video recorder from recorder to have matching state_dict keys
Expand Down Expand Up @@ -165,7 +165,7 @@ def main(cfg: "DictConfig"):
target_net_updater,
model,
replay_buffer,
writer,
logger,
cfg,
)

Expand All @@ -187,7 +187,7 @@ def select_keys(batch):
print(f"init seed: {cfg.seed}, final seed: {final_seed}")

trainer.train()
return (writer.log_dir, trainer._log_dict, trainer.state_dict())
return (logger.log_dir, trainer._log_dict, trainer.state_dict())


if __name__ == "__main__":
Expand Down
10 changes: 5 additions & 5 deletions examples/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@

@hydra.main(version_base=None, config_path=None, config_name="config")
def main(cfg: "DictConfig"):
from torch.utils.tensorboard import SummaryWriter
from torchrl.trainers.loggers import TensorboardLogger

cfg = correct_for_frame_skip(cfg)

Expand All @@ -74,7 +74,7 @@ def main(cfg: "DictConfig"):
datetime.now().strftime("%y_%m_%d-%H_%M_%S"),
]
)
writer = SummaryWriter(f"ppo_logging/{exp_name}")
logger = TensorboardLogger(f"ppo_logging/{exp_name}")
video_tag = exp_name if cfg.record_video else ""

stats = None
Expand Down Expand Up @@ -131,7 +131,7 @@ def main(cfg: "DictConfig"):
video_tag=video_tag,
norm_obs_only=True,
stats=stats,
writer=writer,
logger=logger,
use_env_creator=False,
)()

Expand Down Expand Up @@ -165,7 +165,7 @@ def main(cfg: "DictConfig"):
None,
actor_model,
None,
writer,
logger,
cfg,
)
if cfg.loss == "kl":
Expand All @@ -175,7 +175,7 @@ def main(cfg: "DictConfig"):
print(f"init seed: {cfg.seed}, final seed: {final_seed}")

trainer.train()
return (writer.log_dir, trainer._log_dict, trainer.state_dict())
return (logger.log_dir, trainer._log_dict, trainer.state_dict())


if __name__ == "__main__":
Expand Down
12 changes: 7 additions & 5 deletions examples/redq/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@

@hydra.main(version_base=None, config_path=None, config_name="config")
def main(cfg: "DictConfig"):
from torch.utils.tensorboard import SummaryWriter # avoid loading on each process
from torchrl.trainers.loggers import (
TensorboardLogger,
) # avoid loading on each process

cfg = correct_for_frame_skip(cfg)

Expand All @@ -90,7 +92,7 @@ def main(cfg: "DictConfig"):
datetime.now().strftime("%y_%m_%d-%H_%M_%S"),
]
)
writer = SummaryWriter(f"redq_logging/{exp_name}")
logger = TensorboardLogger(f"redq_logging/{exp_name}")
video_tag = exp_name if cfg.record_video else ""

stats = None
Expand Down Expand Up @@ -161,7 +163,7 @@ def main(cfg: "DictConfig"):
video_tag=video_tag,
norm_obs_only=True,
stats=stats,
writer=writer,
logger=logger,
use_env_creator=False,
)()

Expand Down Expand Up @@ -195,7 +197,7 @@ def main(cfg: "DictConfig"):
target_net_updater,
actor_model_explore,
replay_buffer,
writer,
logger,
cfg,
)

Expand All @@ -217,7 +219,7 @@ def select_keys(batch):
print(f"init seed: {cfg.seed}, final seed: {final_seed}")

trainer.train()
return (writer.log_dir, trainer._log_dict, trainer.state_dict())
return (logger.log_dir, trainer._log_dict, trainer.state_dict())


if __name__ == "__main__":
Expand Down
10 changes: 5 additions & 5 deletions examples/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@

@hydra.main(version_base=None, config_path=None, config_name="config")
def main(cfg: "DictConfig"):
from torch.utils.tensorboard import SummaryWriter
from torchrl.trainers.loggers import TensorboardLogger

cfg = correct_for_frame_skip(cfg)

Expand All @@ -90,7 +90,7 @@ def main(cfg: "DictConfig"):
datetime.now().strftime("%y_%m_%d-%H_%M_%S"),
]
)
writer = SummaryWriter(f"sac_logging/{exp_name}")
logger = TensorboardLogger(f"sac_logging/{exp_name}")
video_tag = exp_name if cfg.record_video else ""

stats = None
Expand Down Expand Up @@ -158,7 +158,7 @@ def main(cfg: "DictConfig"):
video_tag=video_tag,
norm_obs_only=True,
stats=stats,
writer=writer,
logger=logger,
)()

# remove video recorder from recorder to have matching state_dict keys
Expand Down Expand Up @@ -191,7 +191,7 @@ def main(cfg: "DictConfig"):
target_net_updater,
actor_model_explore,
replay_buffer,
writer,
logger,
cfg,
)

Expand All @@ -213,7 +213,7 @@ def select_keys(batch):
print(f"init seed: {cfg.seed}, final seed: {final_seed}")

trainer.train()
return (writer.log_dir, trainer._log_dict, trainer.state_dict())
return (logger.log_dir, trainer._log_dict, trainer.state_dict())


if __name__ == "__main__":
Expand Down
9 changes: 5 additions & 4 deletions test/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

try:
from tensorboard.backend.event_processing import event_accumulator
from torch.utils.tensorboard import SummaryWriter
from torchrl.trainers.loggers import TensorboardLogger

_has_tb = True
except ImportError:
Expand Down Expand Up @@ -224,7 +224,8 @@ def test_subsampler():
@pytest.mark.skipif(not _has_tb, reason="No tensorboard library")
def test_recorder():
with tempfile.TemporaryDirectory() as folder:
writer = SummaryWriter(log_dir=folder)
print(folder)
logger = TensorboardLogger(exp_name=folder)
args = Namespace()
args.env_name = "ALE/Pong-v5"
args.env_task = ""
Expand All @@ -249,7 +250,7 @@ def test_recorder():
video_tag="tmp",
norm_obs_only=True,
stats={"loc": 0, "scale": 1},
writer=writer,
logger=logger,
)()

recorder = Recorder(
Expand All @@ -261,7 +262,7 @@ def test_recorder():
)

for _ in range(N):
recorder(None)
out = recorder(None)

for (dirpath, dirnames, filenames) in walk(folder):
break
Expand Down
25 changes: 11 additions & 14 deletions torchrl/record/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from torchrl.data.tensordict.tensordict import _TensorDict
from torchrl.envs.transforms import ObservationTransform, Transform
from torchrl.trainers.loggers import Logger

__all__ = ["VideoRecorder", "TensorDictRecorder"]

Expand All @@ -23,12 +24,12 @@ class VideoRecorder(ObservationTransform):
"""
Video Recorder transform.
Will record a series of observations from an environment and write them
to a TensorBoard SummaryWriter object when needed.
to a Logger object when needed.
Args:
writer (SummaryWriter): a tb.SummaryWriter instance where the video
logger (Logger): a Logger instance where the video
should be written.
tag (str): the video tag in the writer.
tag (str): the video tag in the logger.
keys_in (Sequence[str], optional): keys to be read to produce the video.
Default is `"next_pixels"`.
skip (int): frame interval in the output video.
Expand All @@ -41,7 +42,7 @@ class VideoRecorder(ObservationTransform):

def __init__(
self,
writer: "SummaryWriter",
logger: Logger,
tag: str,
keys_in: Optional[Sequence[str]] = None,
skip: int = 2,
Expand All @@ -58,7 +59,7 @@ def __init__(
self.video_kwargs = video_kwargs
self.iter = 0
self.skip = skip
self.writer = writer
self.logger = logger
self.tag = tag
self.count = 0
self.center_crop = center_crop
Expand All @@ -68,10 +69,6 @@ def __init__(
"Could not load center_crop from torchvision. Make sure torchvision is installed."
)
self.obs = []
try:
import moviepy # noqa
except ImportError:
raise Exception("moviepy not found, VideoRecorder cannot be created")

def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor:
if not (observation.shape[-1] == 3 or observation.ndimension() == 2):
Expand Down Expand Up @@ -109,7 +106,7 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor:
return observation

def dump(self, suffix: Optional[str] = None) -> None:
"""Writes the video to the self.writer attribute.
"""Writes the video to the self.logger attribute.
Args:
suffix (str, optional): a suffix for the video to be recorded
Expand All @@ -120,10 +117,10 @@ def dump(self, suffix: Optional[str] = None) -> None:
tag = "_".join([self.tag, suffix])
obs = torch.stack(self.obs, 0).unsqueeze(0).cpu()
del self.obs
self.writer.add_video(
tag=tag,
vid_tensor=obs,
global_step=self.iter,
self.logger.log_video(
name=tag,
video=obs,
step=self.iter,
**self.video_kwargs,
)
del obs
Expand Down
1 change: 1 addition & 0 deletions torchrl/trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
# LICENSE file in the root directory of this source tree.

from .trainers import *
from .loggers import *
Loading

0 comments on commit 806733f

Please sign in to comment.