Skip to content

Commit

Permalink
[Logging]: implement MLFlow logging integration (pytorch#432)
Browse files Browse the repository at this point in the history
  • Loading branch information
rayanht authored Sep 21, 2022
1 parent 24e5e0e commit 2b9fbe1
Show file tree
Hide file tree
Showing 9 changed files with 246 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .circleci/unittest/linux/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,5 @@ dependencies:
- tensorboard
- wandb
- dm_control
- mlflow
- av
2 changes: 2 additions & 0 deletions .circleci/unittest/linux_stable/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,5 @@ dependencies:
- tensorboard
- wandb
- dm_control
- mlflow
- av
9 changes: 9 additions & 0 deletions examples/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# LICENSE file in the root directory of this source tree.

import dataclasses
import os
import pathlib
import uuid
from datetime import datetime

Expand Down Expand Up @@ -100,6 +102,13 @@ def main(cfg: "DictConfig"): # noqa: F821
from torchrl.trainers.loggers.wandb import WandbLogger

logger = WandbLogger(log_dir="ddpg_logging", exp_name=exp_name)
elif cfg.logger == "mlflow":
from torchrl.trainers.loggers.mlflow import MLFlowLogger

logger = MLFlowLogger(
tracking_uri=pathlib.Path(os.path.abspath("ddpg_logging")).as_uri(),
exp_name=exp_name,
)
video_tag = exp_name if cfg.record_video else ""

stats = None
Expand Down
8 changes: 8 additions & 0 deletions examples/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# LICENSE file in the root directory of this source tree.

import dataclasses
import os
import pathlib
import uuid
from datetime import datetime

Expand Down Expand Up @@ -90,7 +92,13 @@ def main(cfg: "DictConfig"): # noqa: F821
from torchrl.trainers.loggers.wandb import WandbLogger

logger = WandbLogger(log_dir="dqn_logging", exp_name=exp_name)
elif cfg.logger == "mlflow":
from torchrl.trainers.loggers.mlflow import MLFlowLogger

logger = MLFlowLogger(
tracking_uri=pathlib.Path(os.path.abspath("dqn_logging")).as_uri(),
exp_name=exp_name,
)
video_tag = exp_name if cfg.record_video else ""

stats = None
Expand Down
9 changes: 9 additions & 0 deletions examples/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# LICENSE file in the root directory of this source tree.

import dataclasses
import os
import pathlib
import uuid
from datetime import datetime

Expand Down Expand Up @@ -86,6 +88,13 @@ def main(cfg: "DictConfig"): # noqa: F821
from torchrl.trainers.loggers.wandb import WandbLogger

logger = WandbLogger(log_dir="ppo_logging", exp_name=exp_name)
elif cfg.logger == "mlflow":
from torchrl.trainers.loggers.mlflow import MLFlowLogger

logger = MLFlowLogger(
tracking_uri=pathlib.Path(os.path.abspath("ppo_logging")).as_uri(),
exp_name=exp_name,
)
video_tag = exp_name if cfg.record_video else ""

stats = None
Expand Down
9 changes: 9 additions & 0 deletions examples/redq/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# LICENSE file in the root directory of this source tree.

import dataclasses
import os
import pathlib
import uuid
from datetime import datetime

Expand Down Expand Up @@ -101,6 +103,13 @@ def main(cfg: "DictConfig"): # noqa: F821
from torchrl.trainers.loggers.wandb import WandbLogger

logger = WandbLogger(log_dir="redq_logging", exp_name=exp_name)
elif cfg.logger == "mlflow":
from torchrl.trainers.loggers.mlflow import MLFlowLogger

logger = MLFlowLogger(
tracking_uri=pathlib.Path(os.path.abspath("redq_logging")).as_uri(),
exp_name=exp_name,
)
video_tag = exp_name if cfg.record_video else ""

stats = None
Expand Down
9 changes: 9 additions & 0 deletions examples/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# LICENSE file in the root directory of this source tree.

import dataclasses
import os
import pathlib
import uuid
from datetime import datetime

Expand Down Expand Up @@ -101,6 +103,13 @@ def main(cfg: "DictConfig"): # noqa: F821
from torchrl.trainers.loggers.wandb import WandbLogger

logger = WandbLogger(log_dir="sac_logging", exp_name=exp_name)
elif cfg.logger == "mlflow":
from torchrl.trainers.loggers.mlflow import MLFlowLogger

logger = MLFlowLogger(
tracking_uri=pathlib.Path(os.path.abspath("sac_logging")).as_uri(),
exp_name=exp_name,
)
video_tag = exp_name if cfg.record_video else ""

stats = None
Expand Down
73 changes: 73 additions & 0 deletions test/test_loggers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import argparse
import os
import os.path
import pathlib
import tempfile
from time import sleep

import pytest
import torch
import torchvision
from torchrl.trainers.loggers.csv import CSVLogger
from torchrl.trainers.loggers.mlflow import MLFlowLogger, _has_mlflow
from torchrl.trainers.loggers.tensorboard import TensorboardLogger, _has_tb
from torchrl.trainers.loggers.wandb import WandbLogger, _has_wandb

Expand Down Expand Up @@ -218,6 +222,75 @@ def test_log_video(self):
del logger


@pytest.fixture
def mlflow_fixture():
torch.manual_seed(0)
import mlflow

with tempfile.TemporaryDirectory() as log_dir:
exp_name = "ramala"
log_dir_uri = pathlib.Path(log_dir).as_uri()
logger = MLFlowLogger(exp_name=exp_name, tracking_uri=log_dir_uri)
client = mlflow.MlflowClient()
yield logger, client
mlflow.end_run()


@pytest.mark.skipif(not _has_mlflow, reason="MLFlow not installed")
class TestMLFlowLogger:
@pytest.mark.parametrize("steps", [None, [1, 10, 11]])
def test_log_scalar(self, steps, mlflow_fixture):
import mlflow

logger, client = mlflow_fixture
values = torch.rand(3)
for i in range(3):
scalar_name = "foo"
scalar_value = values[i].item()
logger.log_scalar(
value=scalar_value,
name=scalar_name,
step=steps[i] if steps else None,
)
run_id = mlflow.active_run().info.run_id
for i, metric in enumerate(client.get_metric_history(run_id, "foo")):
assert metric.key == "foo"
assert metric.step == (steps[i] if steps else 0)
assert metric.value == values[i].item()

@pytest.mark.parametrize("steps", [None, [1, 10, 11]])
def test_log_video(self, steps, mlflow_fixture):
import mlflow

logger, client = mlflow_fixture
videos = torch.cat(
(torch.full((3, 64, 3, 32, 32), 255), torch.zeros(3, 64, 3, 32, 32)),
dim=1,
)
fps = 6
for i in range(3):
logger.log_video(
name="test_video",
video=videos[i],
fps=fps,
step=steps[i] if steps else None,
)
run_id = mlflow.active_run().info.run_id
with tempfile.TemporaryDirectory() as artifacts_dir:
videos_dir = client.download_artifacts(run_id, "videos", artifacts_dir)
for i, video_name in enumerate(os.listdir(videos_dir)):
video_path = os.path.join(videos_dir, video_name)
loaded_video, _, _ = torchvision.io.read_video(
video_path, pts_unit="sec", output_format="TCHW"
)
if steps:
assert torch.allclose(loaded_video.int(), videos[i].int(), rtol=0.1)
else:
assert torch.allclose(
loaded_video.int(), videos[-1].int(), rtol=0.1
)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
125 changes: 125 additions & 0 deletions torchrl/trainers/loggers/mlflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import warnings
from tempfile import TemporaryDirectory
from typing import Any, Dict, Optional

import torchvision
from torch import Tensor

from .common import Logger

_has_mlflow = False
try:
import mlflow

_has_mlflow = True
except ImportError:
warnings.warn("mlflow could not be imported")
_has_omgaconf = False
try:
from omegaconf import OmegaConf

_has_omgaconf = True
except ImportError:
warnings.warn(
"OmegaConf could not be imported. Cannot log hydra configs without OmegaConf"
)


class MLFlowLogger(Logger):
"""
Wrapper for the mlflow logger.
Args:
exp_name (str): The name of the experiment.
tracking_uri (str): A tracking URI to a datastore that supports MLFlow or a local directory.
"""

def __init__(
self,
exp_name: str,
tracking_uri: str,
tags: Optional[Dict[str, Any]] = None,
**kwargs,
) -> None:
self._mlflow_kwargs = {
"name": exp_name,
"artifact_location": tracking_uri,
"tags": tags,
}
mlflow.set_tracking_uri(tracking_uri)
super().__init__(exp_name=exp_name, log_dir=tracking_uri)
self.video_log_counter = 0

def _create_experiment(self) -> "mlflow.ActiveRun":
"""
Creates an mlflow experiment.
Returns:
mlflow.ActiveRun: The mlflow experiment object.
"""
if not _has_mlflow:
raise ImportError("MLFlow is not installed")
self.id = mlflow.create_experiment(**self._mlflow_kwargs)
return mlflow.start_run(experiment_id=self.id)

def log_scalar(self, name: str, value: float, step: Optional[int] = None) -> None:
"""
Logs a scalar value to mlflow.
Args:
name (str): The name of the scalar.
value (float): The value of the scalar.
step (int, optional): The step at which the scalar is logged.
Defaults to None.
"""
mlflow.set_experiment(experiment_id=self.id)
mlflow.log_metric(key=name, value=value, step=step)

def log_video(self, name: str, video: Tensor, **kwargs) -> None:
"""
Log video inputs to mlflow.
Args:
name (str): The name of the video.
video (Tensor): The video to be logged, expected to be in (T, C, H, W) format
for consistency with other loggers.
**kwargs: Other keyword arguments. By construction, log_video
supports 'step' (integer indicating the step index) and 'fps' (default: 6).
"""
mlflow.set_experiment(experiment_id=self.id)
if video.ndim == 5:
video = video[-1] # N T C H W -> T C H W
video = video.permute(0, 2, 3, 1) # T C H W -> T H W C
if video.size(dim=-1) != 3:
raise ValueError(
"The MLFlow logger only supports videos with 3 color channels."
)
self.video_log_counter += 1
fps = kwargs.pop("fps", 6)
step = kwargs.pop("step", None)
with TemporaryDirectory() as temp_dir:
video_name = f"{name}_step_{step:04}.mp4" if step else f"{name}.mp4"
with open(os.path.join(temp_dir, video_name), "wb") as f:
torchvision.io.write_video(filename=f.name, video_array=video, fps=fps)
mlflow.log_artifact(f.name, "videos")

def log_hparams(self, cfg: "DictConfig") -> None: # noqa: F821
"""
Logs the hyperparameters of the experiment.
Args:
cfg (DictConfig): The configuration of the experiment.
"""
mlflow.set_experiment(experiment_id=self.id)
if type(cfg) is not dict and _has_omgaconf:
cfg = OmegaConf.to_container(cfg, resolve=True)
mlflow.log_params(cfg)

def __repr__(self) -> str:
return f"MLFlowLogger(experiment={self.experiment.__repr__()})"

0 comments on commit 2b9fbe1

Please sign in to comment.