Skip to content

Commit

Permalink
Add warnings for duplicate usage of action-bounded actor and action s…
Browse files Browse the repository at this point in the history
…caling method (thu-ml#850)

- Fix the current bug discussed in thu-ml#844 in `test_ppo.py`.
- Add warning for `ActorProb ` if both `max_action ` and
`unbounded=True` are used for model initializations.
- Add warning for PGpolicy and DDPGpolicy if they find duplicate usage
of action-bounded actor and action scaling method.
  • Loading branch information
ChenDRAG authored Apr 23, 2023
1 parent e7c2c37 commit 1423eeb
Show file tree
Hide file tree
Showing 21 changed files with 56 additions and 62 deletions.
9 changes: 2 additions & 7 deletions examples/box2d/bipedal_hardcore_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,8 @@ def test_sac_bipedal(args=get_args()):

# model
net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
actor = ActorProb(
net_a,
args.action_shape,
max_action=args.max_action,
device=args.device,
unbounded=True
).to(args.device)
actor = ActorProb(net_a, args.action_shape, device=args.device,
unbounded=True).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)

net_c1 = Net(
Expand Down
9 changes: 2 additions & 7 deletions examples/box2d/mcc_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,8 @@ def test_sac(args=get_args()):
test_envs.seed(args.seed)
# model
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
actor = ActorProb(
net,
args.action_shape,
max_action=args.max_action,
device=args.device,
unbounded=True
).to(args.device)
actor = ActorProb(net, args.action_shape, device=args.device,
unbounded=True).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
net_c1 = Net(
args.state_shape,
Expand Down
9 changes: 2 additions & 7 deletions examples/inverse/irl_gail.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,8 @@ def test_gail(args=get_args()):
activation=nn.Tanh,
device=args.device
)
actor = ActorProb(
net_a,
args.action_shape,
max_action=args.max_action,
unbounded=True,
device=args.device
).to(args.device)
actor = ActorProb(net_a, args.action_shape, unbounded=True,
device=args.device).to(args.device)
net_c = Net(
args.state_shape,
hidden_sizes=args.hidden_sizes,
Expand Down
1 change: 0 additions & 1 deletion examples/mujoco/mujoco_a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def test_a2c(args=get_args()):
actor = ActorProb(
net_a,
args.action_shape,
max_action=args.max_action,
unbounded=True,
device=args.device,
).to(args.device)
Expand Down
1 change: 0 additions & 1 deletion examples/mujoco/mujoco_npg.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def test_npg(args=get_args()):
actor = ActorProb(
net_a,
args.action_shape,
max_action=args.max_action,
unbounded=True,
device=args.device,
).to(args.device)
Expand Down
1 change: 0 additions & 1 deletion examples/mujoco/mujoco_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def test_ppo(args=get_args()):
actor = ActorProb(
net_a,
args.action_shape,
max_action=args.max_action,
unbounded=True,
device=args.device,
).to(args.device)
Expand Down
1 change: 0 additions & 1 deletion examples/mujoco/mujoco_redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ def test_redq(args=get_args()):
actor = ActorProb(
net_a,
args.action_shape,
max_action=args.max_action,
device=args.device,
unbounded=True,
conditioned_sigma=True,
Expand Down
1 change: 0 additions & 1 deletion examples/mujoco/mujoco_reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def test_reinforce(args=get_args()):
actor = ActorProb(
net_a,
args.action_shape,
max_action=args.max_action,
unbounded=True,
device=args.device,
).to(args.device)
Expand Down
1 change: 0 additions & 1 deletion examples/mujoco/mujoco_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ def test_sac(args=get_args()):
actor = ActorProb(
net_a,
args.action_shape,
max_action=args.max_action,
device=args.device,
unbounded=True,
conditioned_sigma=True,
Expand Down
1 change: 0 additions & 1 deletion examples/mujoco/mujoco_trpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def test_trpo(args=get_args()):
actor = ActorProb(
net_a,
args.action_shape,
max_action=args.max_action,
unbounded=True,
device=args.device,
).to(args.device)
Expand Down
1 change: 0 additions & 1 deletion examples/offline/d4rl_cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ def test_cql():
actor = ActorProb(
net_a,
action_shape=args.action_shape,
max_action=args.max_action,
device=args.device,
unbounded=True,
conditioned_sigma=True
Expand Down
9 changes: 2 additions & 7 deletions test/continuous/test_npg.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,8 @@ def test_npg(args=get_args()):
activation=nn.Tanh,
device=args.device
)
actor = ActorProb(
net,
args.action_shape,
max_action=args.max_action,
unbounded=True,
device=args.device
).to(args.device)
actor = ActorProb(net, args.action_shape, unbounded=True,
device=args.device).to(args.device)
critic = Critic(
Net(
args.state_shape,
Expand Down
5 changes: 2 additions & 3 deletions test/continuous/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,8 @@ def test_ppo(args=get_args()):
test_envs.seed(args.seed)
# model
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
actor = ActorProb(
net, args.action_shape, max_action=args.max_action, device=args.device
).to(args.device)
actor = ActorProb(net, args.action_shape, unbounded=True,
device=args.device).to(args.device)
critic = Critic(
Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device),
device=args.device
Expand Down
1 change: 0 additions & 1 deletion test/continuous/test_redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def test_redq(args=get_args()):
actor = ActorProb(
net,
args.action_shape,
max_action=args.max_action,
device=args.device,
unbounded=True,
conditioned_sigma=True
Expand Down
9 changes: 2 additions & 7 deletions test/continuous/test_sac_with_il.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,8 @@ def test_sac_with_il(args=get_args()):
torch.manual_seed(args.seed)
# model
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
actor = ActorProb(
net,
args.action_shape,
max_action=args.max_action,
device=args.device,
unbounded=True
).to(args.device)
actor = ActorProb(net, args.action_shape, device=args.device,
unbounded=True).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
net_c1 = Net(
args.state_shape,
Expand Down
9 changes: 2 additions & 7 deletions test/continuous/test_trpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,8 @@ def test_trpo(args=get_args()):
activation=nn.Tanh,
device=args.device
)
actor = ActorProb(
net,
args.action_shape,
max_action=args.max_action,
unbounded=True,
device=args.device
).to(args.device)
actor = ActorProb(net, args.action_shape, unbounded=True,
device=args.device).to(args.device)
critic = Critic(
Net(
args.state_shape,
Expand Down
1 change: 0 additions & 1 deletion test/offline/gather_pendulum_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def gather_data():
actor = ActorProb(
net,
args.action_shape,
max_action=args.max_action,
device=args.device,
unbounded=True,
).to(args.device)
Expand Down
1 change: 0 additions & 1 deletion test/offline/test_cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ def test_cql(args=get_args()):
actor = ActorProb(
net_a,
action_shape=args.action_shape,
max_action=args.max_action,
device=args.device,
unbounded=True,
conditioned_sigma=True,
Expand Down
13 changes: 13 additions & 0 deletions tianshou/policy/modelfree/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,19 @@ def __init__(
assert action_bound_method != "tanh", "tanh mapping is not supported" \
"in policies where action is used as input of critic , because" \
"raw action in range (-inf, inf) will cause instability in training"
try:
if actor is not None and action_scaling and \
not np.isclose(actor.max_action, 1.): # type: ignore
import warnings
warnings.warn(
"action_scaling and action_bound_method are only intended to deal"
"with unbounded model action space, but find actor model bound"
f"action space with max_action={actor.max_action}."
"Consider using unbounded=True option of the actor model,"
"or set action_scaling to False and action_bound_method to \"\"."
)
except Exception:
pass
if actor is not None and actor_optim is not None:
self.actor: torch.nn.Module = actor
self.actor_old = deepcopy(actor)
Expand Down
12 changes: 12 additions & 0 deletions tianshou/policy/modelfree/pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,18 @@ def __init__(
**kwargs
)
self.actor = model
try:
if action_scaling and not np.isclose(model.max_action, 1.): # type: ignore
import warnings
warnings.warn(
"action_scaling and action_bound_method are only intended"
"to deal with unbounded model action space, but find actor model"
f"bound action space with max_action={model.max_action}."
"Consider using unbounded=True option of the actor model,"
"or set action_scaling to False and action_bound_method to \"\"."
)
except Exception:
pass
self.optim = optim
self.dist_fn = dist_fn
assert 0.0 <= discount_factor <= 1.0, "discount factor should be in [0, 1]"
Expand Down
23 changes: 17 additions & 6 deletions tianshou/utils/net/continuous.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union

import numpy as np
Expand Down Expand Up @@ -54,7 +55,7 @@ def __init__(
hidden_sizes,
device=self.device
)
self._max = max_action
self.max_action = max_action

def forward(
self,
Expand All @@ -64,7 +65,7 @@ def forward(
) -> Tuple[torch.Tensor, Any]:
"""Mapping: obs -> logits -> action."""
logits, hidden = self.preprocess(obs, state)
logits = self._max * torch.tanh(self.last(logits))
logits = self.max_action * torch.tanh(self.last(logits))
return logits, hidden


Expand Down Expand Up @@ -178,6 +179,11 @@ def __init__(
preprocess_net_output_dim: Optional[int] = None,
) -> None:
super().__init__()
if unbounded and not np.isclose(max_action, 1.0):
warnings.warn(
"Note that max_action input will be discarded when unbounded is True."
)
max_action = 1.0
self.preprocess = preprocess_net
self.device = device
self.output_dim = int(np.prod(action_shape))
Expand All @@ -198,7 +204,7 @@ def __init__(
)
else:
self.sigma_param = nn.Parameter(torch.zeros(self.output_dim, 1))
self._max = max_action
self.max_action = max_action
self._unbounded = unbounded

def forward(
Expand All @@ -211,7 +217,7 @@ def forward(
logits, hidden = self.preprocess(obs, state)
mu = self.mu(logits)
if not self._unbounded:
mu = self._max * torch.tanh(mu)
mu = self.max_action * torch.tanh(mu)
if self._c_sigma:
sigma = torch.clamp(self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX).exp()
else:
Expand Down Expand Up @@ -240,6 +246,11 @@ def __init__(
conditioned_sigma: bool = False,
) -> None:
super().__init__()
if unbounded and not np.isclose(max_action, 1.0):
warnings.warn(
"Note that max_action input will be discarded when unbounded is True."
)
max_action = 1.0
self.device = device
self.nn = nn.LSTM(
input_size=int(np.prod(state_shape)),
Expand All @@ -254,7 +265,7 @@ def __init__(
self.sigma = nn.Linear(hidden_layer_size, output_dim)
else:
self.sigma_param = nn.Parameter(torch.zeros(output_dim, 1))
self._max = max_action
self.max_action = max_action
self._unbounded = unbounded

def forward(
Expand Down Expand Up @@ -289,7 +300,7 @@ def forward(
logits = obs[:, -1]
mu = self.mu(logits)
if not self._unbounded:
mu = self._max * torch.tanh(mu)
mu = self.max_action * torch.tanh(mu)
if self._c_sigma:
sigma = torch.clamp(self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX).exp()
else:
Expand Down

0 comments on commit 1423eeb

Please sign in to comment.