Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Logger #1858

Merged
merged 49 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
012f9c0
init
vmoens Jan 19, 2024
3485d6d
Merge remote-tracking branch 'origin/main' into remove-deprecs
vmoens Jan 31, 2024
b1344e6
Merge remote-tracking branch 'origin/main' into remove-deprecs
vmoens Jan 31, 2024
4839c09
amend
vmoens Jan 31, 2024
95ac71d
amend
vmoens Jan 31, 2024
067f4da
Merge remote-tracking branch 'origin/main' into remove-deprecs
vmoens Jan 31, 2024
e31e41b
amend
vmoens Jan 31, 2024
d7d2621
amend
vmoens Jan 31, 2024
f5b195d
amend
vmoens Jan 31, 2024
96d3a18
amend
vmoens Jan 31, 2024
4b41b5c
amend
vmoens Jan 31, 2024
755b7f4
amend
vmoens Jan 31, 2024
4b8c89b
amend
vmoens Jan 31, 2024
2fc17c2
amend
vmoens Jan 31, 2024
ba8dada
amend
vmoens Jan 31, 2024
3120f22
amend
vmoens Jan 31, 2024
ff27094
amend
vmoens Jan 31, 2024
6bdb2c4
amend
vmoens Jan 31, 2024
5dbc588
init
vmoens Jan 31, 2024
f30e02a
amend
vmoens Jan 31, 2024
b1c69b1
amend
vmoens Jan 31, 2024
ab07abe
amend
vmoens Jan 31, 2024
c7e8278
amend
vmoens Jan 31, 2024
f984105
amend
vmoens Jan 31, 2024
deb8b2e
amend
vmoens Jan 31, 2024
d0efa38
Merge remote-tracking branch 'origin/remove-deprecs' into logger
vmoens Jan 31, 2024
62b1dc8
amend
vmoens Jan 31, 2024
bd498ab
amend
vmoens Jan 31, 2024
b35c26a
Merge branch 'remove-deprecs' into logger
vmoens Jan 31, 2024
bf4a0d9
amend
vmoens Jan 31, 2024
1903d10
Merge branch 'remove-deprecs' into logger
vmoens Jan 31, 2024
e4bdde2
amend
vmoens Jan 31, 2024
ba63298
Merge remote-tracking branch 'origin/main' into logger
vmoens Jan 31, 2024
9c36712
amend
vmoens Jan 31, 2024
fdc4557
amend
vmoens Jan 31, 2024
c8d6441
amend
vmoens Jan 31, 2024
f32ce83
amend
vmoens Jan 31, 2024
f68dd4f
amend
vmoens Jan 31, 2024
2c2c9fb
amend
vmoens Jan 31, 2024
31b866a
amend
vmoens Jan 31, 2024
656e75b
amend
vmoens Jan 31, 2024
22cd51b
empty
vmoens Jan 31, 2024
23171f7
amend
vmoens Jan 31, 2024
707747e
amend
vmoens Jan 31, 2024
f44fe53
amend
vmoens Jan 31, 2024
2c737d6
amend
vmoens Jan 31, 2024
2797de8
amend
vmoens Jan 31, 2024
03c201c
amend
vmoens Jan 31, 2024
4b746f6
amend
vmoens Jan 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/unittest/helpers/coverage_run_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
nevertheless. It writes temporary coverage config files on the fly and
invokes coverage with proper arguments
"""
import logging
import os
import shlex
import subprocess
Expand Down Expand Up @@ -45,7 +44,7 @@ def write_config(config_path: Path, argv: List[str]) -> None:

def main(argv: List[str]) -> int:
if len(argv) < 1:
logging.info(
print( # noqa
"Usage: 'python coverage_run_parallel.py <command> [command arguments]'"
)
sys.exit(1)
Expand Down
13 changes: 5 additions & 8 deletions benchmarks/benchmark_batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,8 @@

"""

import logging

logging.basicConfig(level=logging.ERROR)
logging.captureWarnings(True)
import pandas as pd
from torchrl._utils import logger as torchrl_logger

pd.set_option("display.max_columns", 100)
pd.set_option("display.width", 1000)
Expand Down Expand Up @@ -68,8 +65,8 @@ def run_env(env):
devices.append("cuda")
for device in devices:
for num_workers in [1, 4, 16]:
logging.info(f"With num_workers={num_workers}, {device}")
logging.info("Multithreaded...")
torchrl_logger.info(f"With num_workers={num_workers}, {device}")
torchrl_logger.info("Multithreaded...")
env_multithreaded = create_multithreaded(num_workers, device)
res_multithreaded = Timer(
stmt="run_env(env)",
Expand All @@ -78,7 +75,7 @@ def run_env(env):
)
time_multithreaded = res_multithreaded.blocked_autorange().mean

logging.info("Serial...")
torchrl_logger.info("Serial...")
env_serial = create_serial(num_workers, device)
res_serial = Timer(
stmt="run_env(env)",
Expand All @@ -87,7 +84,7 @@ def run_env(env):
)
time_serial = res_serial.blocked_autorange().mean

logging.info("Parallel...")
torchrl_logger.info("Parallel...")
env_parallel = create_parallel(num_workers, device)
res_parallel = Timer(
stmt="run_env(env)",
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import time
import warnings
from collections import defaultdict

import pytest
from torchrl._utils import logger as torchrl_logger

CALL_TIMES = defaultdict(lambda: 0.0)

Expand All @@ -32,7 +32,7 @@ def pytest_sessionfinish(maxprint=50):
out_str += f"\t{key}{spaces}{item: 4.4f}s\n"
if i == maxprint - 1:
break
logging.info(out_str)
torchrl_logger.info(out_str)


@pytest.fixture(autouse=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree.


import logging
import os
import pickle

Expand All @@ -23,6 +22,7 @@
from ray.rllib.agents.ppo import PPOTrainer
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.tune import register_env
from torchrl._utils import logger as torchrl_logger
from torchrl.collectors import SyncDataCollector
from torchrl.envs.libs.vmas import VmasEnv
from vmas import Wrapper
Expand Down Expand Up @@ -165,11 +165,11 @@ def run_comparison_torchrl_rllib(
evaluation = {}
for framework in ["TorchRL", "RLlib"]:
if framework not in evaluation.keys():
logging.info(f"\nFramework {framework}")
torchrl_logger.info(f"\nFramework {framework}")
vmas_times = []
for n_envs in list_n_envs:
n_envs = int(n_envs)
logging.info(f"Running {n_envs} environments")
torchrl_logger.info(f"Running {n_envs} environments")
if framework == "TorchRL":
vmas_times.append(
(n_envs * n_steps)
Expand All @@ -190,7 +190,7 @@ def run_comparison_torchrl_rllib(
device=device,
)
)
logging.info(f"fps {vmas_times[-1]}s")
torchrl_logger.info(f"fps {vmas_times[-1]}s")
evaluation[framework] = vmas_times

store_pickled_evaluation(name=figure_name_pkl, evaluation=evaluation)
Expand Down
12 changes: 6 additions & 6 deletions benchmarks/storage/benchmark_sample_latency_over_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
This code is based on examples/distributed/distributed_replay_buffer.py.
"""
import argparse
import logging
import os
import pickle
import sys
Expand All @@ -25,6 +24,7 @@
import torch
import torch.distributed.rpc as rpc
from tensordict import TensorDict
from torchrl._utils import logger as torchrl_logger
from torchrl.data.replay_buffers import RemoteTensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import RandomSampler
from torchrl.data.replay_buffers.storages import (
Expand Down Expand Up @@ -106,10 +106,10 @@ def _create_replay_buffer(self) -> rpc.RRef:
buffer_rref = rpc.remote(
replay_buffer_info, ReplayBufferNode, args=(1000000,)
)
logging.info(f"Connected to replay buffer {replay_buffer_info}")
torchrl_logger.info(f"Connected to replay buffer {replay_buffer_info}")
return buffer_rref
except Exception:
logging.info("Failed to connect to replay buffer")
torchrl_logger.info("Failed to connect to replay buffer")
time.sleep(RETRY_DELAY_SECS)


Expand Down Expand Up @@ -144,7 +144,7 @@ def __init__(self, capacity: int):
rank = args.rank
storage_type = args.storage

logging.info(f"Rank: {rank}; Storage: {storage_type}")
torchrl_logger.info(f"Rank: {rank}; Storage: {storage_type}")

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
Expand All @@ -167,7 +167,7 @@ def __init__(self, capacity: int):
if i == 0:
continue
results.append(result)
logging.info(i, results[-1])
torchrl_logger.info(f"{i}, {results[-1]}")

with open(
f'./benchmark_{datetime.now().strftime("%d-%m-%Y%H:%M:%S")};batch_size={BATCH_SIZE};tensor_size={TENSOR_SIZE};repeat={REPEATS};storage={storage_type}.pkl',
Expand All @@ -176,7 +176,7 @@ def __init__(self, capacity: int):
pickle.dump(results, f)

tensor_results = torch.tensor(results)
logging.info(f"Mean: {torch.mean(tensor_results)}")
torchrl_logger.info(f"Mean: {torch.mean(tensor_results)}")
breakpoint()
elif rank == 1:
# rank 1 is the replay buffer
Expand Down
5 changes: 2 additions & 3 deletions examples/a2c/a2c_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging

import hydra
from torchrl._utils import logger as torchrl_logger


@hydra.main(config_path=".", config_name="config_atari", version_base="1.1")
Expand Down Expand Up @@ -220,7 +219,7 @@ def main(cfg: "DictConfig"): # noqa: F821

end_time = time.time()
execution_time = end_time - start_time
logging.info(f"Training took {execution_time:.2f} seconds to finish")
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")


if __name__ == "__main__":
Expand Down
5 changes: 2 additions & 3 deletions examples/a2c/a2c_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging

import hydra
from torchrl._utils import logger as torchrl_logger


@hydra.main(config_path=".", config_name="config_mujoco", version_base="1.1")
Expand Down Expand Up @@ -205,7 +204,7 @@ def main(cfg: "DictConfig"): # noqa: F821

end_time = time.time()
execution_time = end_time - start_time
logging.info(f"Training took {execution_time:.2f} seconds to finish")
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions examples/a2c/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def make_ppo_modules_pixels(proof_environment):
num_outputs = proof_environment.action_spec.shape
distribution_class = TanhNormal
distribution_kwargs = {
"min": proof_environment.action_spec.space.minimum,
"max": proof_environment.action_spec.space.maximum,
"min": proof_environment.action_spec.space.low,
"max": proof_environment.action_spec.space.high,
}

# Define input keys
Expand Down
4 changes: 2 additions & 2 deletions examples/a2c/utils_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def make_ppo_models_state(proof_environment):
num_outputs = proof_environment.action_spec.shape[-1]
distribution_class = TanhNormal
distribution_kwargs = {
"min": proof_environment.action_spec.space.minimum,
"max": proof_environment.action_spec.space.maximum,
"min": proof_environment.action_spec.space.low,
"max": proof_environment.action_spec.space.high,
"tanh_loc": False,
}

Expand Down
2 changes: 1 addition & 1 deletion examples/bandits/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,4 @@
f"training reward {data['next', 'reward'].sum() / env.numel() : 4.4f}, "
f"loss {loss_val: 4.4f} (init: {init_loss: 4.4f})"
)
policy.step()
policy[1].step()
4 changes: 2 additions & 2 deletions examples/cql/cql_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
The helper functions are coded in the utils.py associated with this script.

"""
import logging
import time

import hydra
import numpy as np
import torch
import tqdm
from torchrl._utils import logger as torchrl_logger
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.record.loggers import generate_exp_name, get_logger

Expand Down Expand Up @@ -150,7 +150,7 @@ def main(cfg: "DictConfig"): # noqa: F821
log_metrics(logger, to_log, i)

pbar.close()
logging.info(f"Training time: {time.time() - start_time}")
torchrl_logger.info(f"Training time: {time.time() - start_time}")


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions examples/cql/cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
The helper functions are coded in the utils.py associated with this script.

"""
import logging
import time

import hydra
import numpy as np
import torch
import tqdm
from tensordict import TensorDict
from torchrl._utils import logger as torchrl_logger
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.record.loggers import generate_exp_name, get_logger

Expand Down Expand Up @@ -211,7 +211,7 @@ def main(cfg: "DictConfig"): # noqa: F821
collector.shutdown()
end_time = time.time()
execution_time = end_time - start_time
logging.info(f"Training took {execution_time:.2f} seconds to finish")
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")

collector.shutdown()

Expand Down
4 changes: 2 additions & 2 deletions examples/cql/discrete_cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@

The helper functions are coded in the utils.py associated with this script.
"""
import logging
import time

import hydra
import numpy as np
import torch
import torch.cuda
import tqdm
from torchrl._utils import logger as torchrl_logger

from torchrl.envs.utils import ExplorationType, set_exploration_type

Expand Down Expand Up @@ -196,7 +196,7 @@ def main(cfg: "DictConfig"): # noqa: F821
collector.shutdown()
end_time = time.time()
execution_time = end_time - start_time
logging.info(f"Training took {execution_time:.2f} seconds to finish")
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions examples/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

The helper functions are coded in the utils.py associated with this script.
"""
import logging
import time

import hydra
Expand All @@ -19,6 +18,7 @@
import torch
import torch.cuda
import tqdm
from torchrl._utils import logger as torchrl_logger

from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.record.loggers import generate_exp_name, get_logger
Expand Down Expand Up @@ -197,7 +197,7 @@ def main(cfg: "DictConfig"): # noqa: F821
collector.shutdown()
end_time = time.time()
execution_time = end_time - start_time
logging.info(f"Training took {execution_time:.2f} seconds to finish")
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")


if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions examples/decision_transformer/dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
This is a self-contained example of an offline Decision Transformer training script.
The helper functions are coded in the utils.py associated with this script.
"""
import logging
import time

import hydra
import numpy as np
import torch
import tqdm
from torchrl._utils import logger as torchrl_logger
from torchrl.envs.libs.gym import set_gym_backend

from torchrl.envs.utils import ExplorationType, set_exploration_type
Expand Down Expand Up @@ -79,7 +79,7 @@ def main(cfg: "DictConfig"): # noqa: F821
pretrain_log_interval = cfg.logger.pretrain_log_interval
reward_scaling = cfg.env.reward_scaling

logging.info(" ***Pretraining*** ")
torchrl_logger.info(" ***Pretraining*** ")
# Pretraining
start_time = time.time()
for i in range(pretrain_gradient_steps):
Expand Down Expand Up @@ -116,7 +116,7 @@ def main(cfg: "DictConfig"): # noqa: F821
log_metrics(logger, to_log, i)

pbar.close()
logging.info(f"Training time: {time.time() - start_time}")
torchrl_logger.info(f"Training time: {time.time() - start_time}")


if __name__ == "__main__":
Expand Down
Loading
Loading