Skip to content

Commit

Permalink
Refactor Agent class (pytorch#101)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored May 7, 2022
1 parent f0d7e1c commit 61601b9
Show file tree
Hide file tree
Showing 51 changed files with 1,586 additions and 764 deletions.
1 change: 1 addition & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,7 @@ jobs:
<<: *binary_common
macos:
xcode: "12.0"

resource_class: large
steps:
- checkout
Expand Down
8 changes: 7 additions & 1 deletion .circleci/unittest/linux/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ set -e
eval "$(./conda/bin/conda shell.bash hook)"
conda activate ./env

if [[ $OSTYPE == 'darwin'* ]]; then
PRIVATE_MUJOCO_GL=egl
else
PRIVATE_MUJOCO_GL=glfw
fi

export PYTORCH_TEST_WITH_SLOW='1'
python -m torch.utils.collect_env
root_dir="$(git rev-parse --show-toplevel)"
Expand All @@ -13,4 +19,4 @@ export MUJOCO_PY_MUJOCO_PATH=$root_dir/.mujoco/mujoco210
export DISPLAY=unix:0.0
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/root/project/.mujoco/mujoco210/bin
#MUJOCO_GL=glfw pytest --cov=torchrl --junitxml=test-results/junit.xml -v --durations 20
MUJOCO_GL=glfw pytest -v --durations 20
MUJOCO_GL=$PRIVATE_MUJOCO_GL pytest -v --durations 20
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ On the low-level end, torchrl comes with a set of highly re-usable functionals f

On the high-level end, torchrl provides:
- multiprocess [data collectors](torchrl/collectors/collectors.py);
- a generic [agent class](torchrl/agents/agents.py);
- a generic [agent class](torchrl/trainers/trainers.py);
- efficient and generic [replay buffers](torchrl/data/replay_buffers/replay_buffers.py);
- [TensorDict](torchrl/data/tensordict/tensordict.py), a convenient data structure to pass data from one object to another without friction;
- An associated [`TDModule` class](torchrl/modules/td_module/common.py) which is [functorch](https://github.com/pytorch/functorch)-compatible!
Expand All @@ -31,7 +31,7 @@ On the high-level end, torchrl provides:
- various tools for distributed learning (e.g. [memory mapped tensors](torchrl/data/tensordict/memmap.py));
- various [architectures](torchrl/modules/models/) and models (e.g. [actor-critic](torchrl/modules/td_module/actors.py));
- [exploration wrappers](torchrl/modules/td_module/exploration.py);
- various [recipes](torchrl/agents/helpers/models.py) to build models that correspond to the environment being deployed.
- various [recipes](torchrl/trainers/helpers/models.py) to build models that correspond to the environment being deployed.

A series of [examples](examples/) are provided with an illustrative purpose:
- [DQN (and add-ons up to Rainbow)](examples/dqn/dqn.py)
Expand Down
24 changes: 12 additions & 12 deletions examples/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,30 @@
_configargparse = False
import torch.cuda
from torch.utils.tensorboard import SummaryWriter
from torchrl.agents.helpers.agents import make_agent, parser_agent_args
from torchrl.agents.helpers.collectors import (
from torchrl.envs.transforms import RewardScaling, TransformedEnv
from torchrl.modules import OrnsteinUhlenbeckProcessWrapper
from torchrl.trainers.helpers.collectors import (
make_collector_offpolicy,
parser_collector_args_offpolicy,
)
from torchrl.agents.helpers.envs import (
from torchrl.trainers.helpers.envs import (
correct_for_frame_skip,
get_stats_random_rollout,
parallel_env_constructor,
parser_env_args,
transformed_env_constructor,
)
from torchrl.agents.helpers.losses import make_ddpg_loss, parser_loss_args
from torchrl.agents.helpers.models import (
from torchrl.trainers.helpers.losses import make_ddpg_loss, parser_loss_args
from torchrl.trainers.helpers.models import (
make_ddpg_actor,
parser_model_args_continuous,
)
from torchrl.agents.helpers.recorder import parser_recorder_args
from torchrl.agents.helpers.replay_buffer import (
from torchrl.trainers.helpers.recorder import parser_recorder_args
from torchrl.trainers.helpers.replay_buffer import (
make_replay_buffer,
parser_replay_args,
)
from torchrl.envs.transforms import RewardScaling, TransformedEnv
from torchrl.modules import OrnsteinUhlenbeckProcessWrapper
from torchrl.trainers.helpers.trainers import make_trainer, parser_trainer_args


def make_args():
Expand All @@ -52,7 +52,7 @@ def make_args():
is_config_file=True,
help="config file path",
)
parser_agent_args(parser)
parser_trainer_args(parser)
parser_collector_args_offpolicy(parser)
parser_env_args(parser)
parser_loss_args(parser, algorithm="DDPG")
Expand Down Expand Up @@ -152,7 +152,7 @@ def make_args():
if isinstance(t, RewardScaling):
t.scale.fill_(1.0)

agent = make_agent(
trainer = make_trainer(
collector,
loss_module,
recorder,
Expand All @@ -163,4 +163,4 @@ def make_args():
args,
)

agent.train()
trainer.train()
2 changes: 1 addition & 1 deletion examples/dqn/configs/pong.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
frames_per_batch=500
frame_skip=4
optim_steps_per_collection=125
optim_steps_per_batch=125
env_library=gym
env_name=ALE/Pong-v5
noops=30
Expand Down
23 changes: 23 additions & 0 deletions examples/dqn/configs/pong_smoketest.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
frames_per_batch=32
frame_skip=4
optim_steps_per_batch=3
env_library=gym
env_name=ALE/Pong-v5
noops=30
max_frames_per_traj=-1
exp_name=pong
record_interval=23
batch_size=32
async_collection
distributional
prb
multi_step
annealing_frames=500
total_frames=500
record_frames=30
normalize_rewards_online
from_pixels
record_video
num_workers=4
env_per_collector=2
init_random_frames=7
24 changes: 12 additions & 12 deletions examples/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,30 @@
_configargparse = False
import torch.cuda
from torch.utils.tensorboard import SummaryWriter
from torchrl.agents.helpers.agents import make_agent, parser_agent_args
from torchrl.agents.helpers.collectors import (
from torchrl.envs.transforms import RewardScaling, TransformedEnv
from torchrl.modules import EGreedyWrapper
from torchrl.trainers.helpers.collectors import (
make_collector_offpolicy,
parser_collector_args_offpolicy,
)
from torchrl.agents.helpers.envs import (
from torchrl.trainers.helpers.envs import (
correct_for_frame_skip,
get_stats_random_rollout,
parallel_env_constructor,
parser_env_args,
transformed_env_constructor,
)
from torchrl.agents.helpers.losses import make_dqn_loss, parser_loss_args
from torchrl.agents.helpers.models import (
from torchrl.trainers.helpers.losses import make_dqn_loss, parser_loss_args
from torchrl.trainers.helpers.models import (
make_dqn_actor,
parser_model_args_discrete,
)
from torchrl.agents.helpers.recorder import parser_recorder_args
from torchrl.agents.helpers.replay_buffer import (
from torchrl.trainers.helpers.recorder import parser_recorder_args
from torchrl.trainers.helpers.replay_buffer import (
make_replay_buffer,
parser_replay_args,
)
from torchrl.envs.transforms import RewardScaling, TransformedEnv
from torchrl.modules import EGreedyWrapper
from torchrl.trainers.helpers.trainers import make_trainer, parser_trainer_args


def make_args():
Expand All @@ -52,7 +52,7 @@ def make_args():
is_config_file=True,
help="config file path",
)
parser_agent_args(parser)
parser_trainer_args(parser)
parser_collector_args_offpolicy(parser)
parser_env_args(parser)
parser_loss_args(parser, algorithm="DQN")
Expand Down Expand Up @@ -138,7 +138,7 @@ def make_args():
if isinstance(t, RewardScaling):
t.scale.fill_(1.0)

agent = make_agent(
trainer = make_trainer(
collector,
loss_module,
recorder,
Expand All @@ -149,4 +149,4 @@ def make_args():
args,
)

agent.train()
trainer.train()
7 changes: 5 additions & 2 deletions examples/ppo/configs/cheetah.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
env_name=cheetah
env_task=run
env_library=dm_control
optim_steps_per_collection=10
optim_steps_per_batch=10
lamda=0.95
normalize_rewards_online
record_video
max_frames_per_traj=1000
record_interval=200
lr=3e-4
tanh_loc
init_with_lag
entropy_factor=0.1
clip_norm=1000.0
frames_per_batch=3200
frame_skip=4
2 changes: 1 addition & 1 deletion examples/ppo/configs/cheetah_pixels.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
env_name=cheetah
env_task=run
env_library=dm_control
optim_steps_per_collection=10
optim_steps_per_batch=10
lamda=0.95
normalize_rewards_online
record_video
Expand Down
15 changes: 15 additions & 0 deletions examples/ppo/configs/cheetah_smoketest.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
env_name=cheetah
env_task=run
env_library=dm_control
optim_steps_per_batch=3
lamda=0.95
normalize_rewards_online
record_video
max_frames_per_traj=100
record_interval=4
lr=3e-4
tanh_loc
init_with_lag
frame_skip=4
num_workers=4
env_per_collector=2
3 changes: 1 addition & 2 deletions examples/ppo/configs/humanoid.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
env_name=humanoid
env_task=walk
env_library=dm_control
optim_steps_per_collection=10
optim_steps_per_batch=10
lamda=0.95
normalize_rewards_online
record_video
Expand All @@ -11,4 +11,3 @@ lr=3e-4
entropy_factor=1e-4
frame_skip=4
tanh_loc
init_with_lag
22 changes: 12 additions & 10 deletions examples/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,25 @@
_configargparse = False
import torch.cuda
from torch.utils.tensorboard import SummaryWriter
from torchrl.agents.helpers.agents import make_agent, parser_agent_args
from torchrl.agents.helpers.collectors import (
from torchrl.envs.transforms import RewardScaling, TransformedEnv
from torchrl.trainers.helpers.collectors import (
make_collector_onpolicy,
parser_collector_args_onpolicy,
)
from torchrl.agents.helpers.envs import (
from torchrl.trainers.helpers.envs import (
correct_for_frame_skip,
get_stats_random_rollout,
parallel_env_constructor,
parser_env_args,
transformed_env_constructor,
)
from torchrl.agents.helpers.losses import make_ppo_loss, parser_loss_args_ppo
from torchrl.agents.helpers.models import (
from torchrl.trainers.helpers.losses import make_ppo_loss, parser_loss_args_ppo
from torchrl.trainers.helpers.models import (
make_ppo_model,
parser_model_args_continuous,
)
from torchrl.agents.helpers.recorder import parser_recorder_args
from torchrl.envs.transforms import RewardScaling, TransformedEnv
from torchrl.trainers.helpers.recorder import parser_recorder_args
from torchrl.trainers.helpers.trainers import make_trainer, parser_trainer_args


def make_args():
Expand All @@ -47,7 +47,7 @@ def make_args():
is_config_file=True,
help="config file path",
)
parser_agent_args(parser)
parser_trainer_args(parser)
parser_collector_args_onpolicy(parser)
parser_env_args(parser)
parser_loss_args_ppo(parser)
Expand Down Expand Up @@ -126,8 +126,10 @@ def make_args():
if isinstance(t, RewardScaling):
t.scale.fill_(1.0)

agent = make_agent(
trainer = make_trainer(
collector, loss_module, recorder, None, actor_model, None, writer, args
)
if args.loss == "kl":
trainer.register_op("pre_optim_steps", loss_module.reset)

agent.train()
trainer.train()
15 changes: 15 additions & 0 deletions examples/redq/configs/cheetah.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
env_name=cheetah
env_task=run
env_library=dm_control
async_collection
record_video
frame_skip=4
frames_per_batch=320
optim_steps_per_batch=3200
prb
multi_step
exp_name=humanoid_stats
tanh_loc
num_workers=8
env_per_collector=8
total_frames=5000000
6 changes: 5 additions & 1 deletion examples/redq/configs/humanoid.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@ env_task=walk
env_library=dm_control
async_collection
record_video
normalize_rewards_online
frame_skip=2
frames_per_batch=320
optim_steps_per_batch=3200
prb
multi_step
exp_name=humanoid_stats
tanh_loc
num_workers=8
env_per_collector=8
total_frames=5000000
Loading

0 comments on commit 61601b9

Please sign in to comment.