Skip to content

Commit

Permalink
Skip ParallelEnv if only one env is created in helpers (#146)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored May 17, 2022
1 parent b7d91f4 commit ab0b00c
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 11 deletions.
10 changes: 8 additions & 2 deletions examples/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import uuid
from datetime import datetime

from torchrl.envs import ParallelEnv

try:
import configargparse as argparse

Expand Down Expand Up @@ -144,8 +146,12 @@ def main(args):
recorder_rm = TransformedEnv(recorder.env, recorder.transform[1:])
else:
recorder_rm = recorder
recorder_rm.load_state_dict(create_env_fn.state_dict()["worker0"])
create_env_fn.close()
if isinstance(create_env_fn, ParallelEnv):
recorder_rm.load_state_dict(create_env_fn.state_dict()["worker0"])
create_env_fn.close()
else:
recorder_rm.load_state_dict(create_env_fn.state_dict())

# reset reward scaling
for t in recorder.transform:
if isinstance(t, RewardScaling):
Expand Down
9 changes: 7 additions & 2 deletions examples/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import uuid
from datetime import datetime

from torchrl.envs import ParallelEnv

try:
import configargparse as argparse

Expand Down Expand Up @@ -130,8 +132,11 @@ def main(args):
else:
recorder_rm = recorder

recorder_rm.load_state_dict(create_env_fn.state_dict()["worker0"])
create_env_fn.close()
if isinstance(create_env_fn, ParallelEnv):
recorder_rm.load_state_dict(create_env_fn.state_dict()["worker0"])
create_env_fn.close()
else:
recorder_rm.load_state_dict(create_env_fn.state_dict())
# reset reward scaling
for t in recorder.transform:
if isinstance(t, RewardScaling):
Expand Down
9 changes: 7 additions & 2 deletions examples/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import uuid
from datetime import datetime

from torchrl.envs import ParallelEnv

try:
import configargparse as argparse

Expand Down Expand Up @@ -117,8 +119,11 @@ def main(args):
else:
recorder_rm = recorder

recorder_rm.load_state_dict(create_env_fn.state_dict()["worker0"])
create_env_fn.close()
if isinstance(create_env_fn, ParallelEnv):
recorder_rm.load_state_dict(create_env_fn.state_dict()["worker0"])
create_env_fn.close()
else:
recorder_rm.load_state_dict(create_env_fn.state_dict())

# reset reward scaling
for t in recorder.transform:
Expand Down
9 changes: 7 additions & 2 deletions examples/redq/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import uuid
from datetime import datetime

from torchrl.envs import ParallelEnv

try:
import configargparse as argparse

Expand Down Expand Up @@ -146,8 +148,11 @@ def main(args):
else:
recorder_rm = recorder

recorder_rm.load_state_dict(create_env_fn.state_dict()["worker0"])
create_env_fn.close()
if isinstance(create_env_fn, ParallelEnv):
recorder_rm.load_state_dict(create_env_fn.state_dict()["worker0"])
create_env_fn.close()
else:
recorder_rm.load_state_dict(create_env_fn.state_dict())

# reset reward scaling
for t in recorder.transform:
Expand Down
9 changes: 7 additions & 2 deletions examples/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import uuid
from datetime import datetime

from torchrl.envs import ParallelEnv

try:
import configargparse as argparse

Expand Down Expand Up @@ -145,8 +147,11 @@ def main(args):
else:
recorder_rm = recorder

recorder_rm.load_state_dict(create_env_fn.state_dict()["worker0"])
create_env_fn.close()
if isinstance(create_env_fn, ParallelEnv):
recorder_rm.load_state_dict(create_env_fn.state_dict()["worker0"])
create_env_fn.close()
else:
recorder_rm.load_state_dict(create_env_fn.state_dict())

# reset reward scaling, as it was just overwritten by state_dict load
for t in recorder.transform:
Expand Down
8 changes: 7 additions & 1 deletion torchrl/trainers/helpers/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,13 +223,19 @@ def make_transformed_env(**kwargs) -> TransformedEnv:
return make_transformed_env


def parallel_env_constructor(args: Namespace, **kwargs) -> EnvCreator:
def parallel_env_constructor(
args: Namespace, **kwargs
) -> Union[ParallelEnv, EnvCreator]:
"""Returns a parallel environment from an argparse.Namespace built with the appropriate parser constructor.
Args:
args (argparse.Namespace): script arguments originating from the parser built with parser_env_args
kwargs: keyword arguments for the `transformed_env_constructor` method.
"""
if args.env_per_collector == 1:
kwargs.update({"args": args, "use_env_creator": True})
make_transformed_env = transformed_env_constructor(**kwargs)
return make_transformed_env
kwargs.update({"args": args, "use_env_creator": True})
make_transformed_env = transformed_env_constructor(**kwargs)
env = ParallelEnv(
Expand Down

0 comments on commit ab0b00c

Please sign in to comment.