Skip to content

Commit

Permalink
[Feature] Non-functional objectives (PPO, A2C, Reinforce) (pytorch#1804)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jan 23, 2024
1 parent 6769fee commit 5b67dd3
Show file tree
Hide file tree
Showing 16 changed files with 433 additions and 157 deletions.
6 changes: 3 additions & 3 deletions benchmarks/test_objectives_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ def test_a2c_speed(
actor(td.clone())
critic(td.clone())

loss = A2CLoss(actor=actor, critic=critic)
loss = A2CLoss(actor_network=actor, critic_network=critic)
advantage = GAE(value_network=critic, gamma=0.99, lmbda=0.95, shifted=True)
advantage(td)
loss(td)
Expand Down Expand Up @@ -605,7 +605,7 @@ def test_ppo_speed(
actor(td.clone())
critic(td.clone())

loss = ClipPPOLoss(actor=actor, critic=critic)
loss = ClipPPOLoss(actor_network=actor, critic_network=critic)
advantage = GAE(value_network=critic, gamma=0.99, lmbda=0.95, shifted=True)
advantage(td)
loss(td)
Expand Down Expand Up @@ -662,7 +662,7 @@ def test_reinforce_speed(
actor(td.clone())
critic(td.clone())

loss = ReinforceLoss(actor=actor, critic=critic)
loss = ReinforceLoss(actor_network=actor, critic_network=critic)
advantage = GAE(value_network=critic, gamma=0.99, lmbda=0.95, shifted=True)
advantage(td)
loss(td)
Expand Down
4 changes: 2 additions & 2 deletions examples/a2c/a2c_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def main(cfg: "DictConfig"): # noqa: F821
average_gae=True,
)
loss_module = A2CLoss(
actor=actor,
critic=critic,
actor_network=actor,
critic_network=critic,
loss_critic_type=cfg.loss.loss_critic_type,
entropy_coef=cfg.loss.entropy_coef,
critic_coef=cfg.loss.critic_coef,
Expand Down
4 changes: 2 additions & 2 deletions examples/a2c/a2c_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def main(cfg: "DictConfig"): # noqa: F821
average_gae=False,
)
loss_module = A2CLoss(
actor=actor,
critic=critic,
actor_network=actor,
critic_network=critic,
loss_critic_type=cfg.loss.loss_critic_type,
entropy_coef=cfg.loss.entropy_coef,
critic_coef=cfg.loss.critic_coef,
Expand Down
2 changes: 1 addition & 1 deletion examples/distributed/collectors/multi_nodes/ray_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@
)
loss_module = ClipPPOLoss(
actor=policy_module,
critic=value_module,
critic_network=value_module,
advantage_key="advantage",
clip_epsilon=clip_epsilon,
entropy_bonus=bool(entropy_eps),
Expand Down
4 changes: 2 additions & 2 deletions examples/impala/impala_multi_node_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ def main(cfg: "DictConfig"): # noqa: F821
average_adv=False,
)
loss_module = A2CLoss(
actor=actor,
critic=critic,
actor_network=actor,
critic_network=critic,
loss_critic_type=cfg.loss.loss_critic_type,
entropy_coef=cfg.loss.entropy_coef,
critic_coef=cfg.loss.critic_coef,
Expand Down
4 changes: 2 additions & 2 deletions examples/impala/impala_multi_node_submitit.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ def main(cfg: "DictConfig"): # noqa: F821
average_adv=False,
)
loss_module = A2CLoss(
actor=actor,
critic=critic,
actor_network=actor,
critic_network=critic,
loss_critic_type=cfg.loss.loss_critic_type,
entropy_coef=cfg.loss.entropy_coef,
critic_coef=cfg.loss.critic_coef,
Expand Down
4 changes: 2 additions & 2 deletions examples/impala/impala_single_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def main(cfg: "DictConfig"): # noqa: F821
average_adv=False,
)
loss_module = A2CLoss(
actor=actor,
critic=critic,
actor_network=actor,
critic_network=critic,
loss_critic_type=cfg.loss.loss_critic_type,
entropy_coef=cfg.loss.entropy_coef,
critic_coef=cfg.loss.critic_coef,
Expand Down
6 changes: 3 additions & 3 deletions examples/multiagent/mappo_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ def train(cfg: "DictConfig"): # noqa: F821

# Loss
loss_module = ClipPPOLoss(
actor=policy,
critic=value_module,
actor_network=policy,
critic_network=value_module,
clip_epsilon=cfg.loss.clip_epsilon,
entropy_coef=cfg.loss.entropy_eps,
normalize_advantage=False,
Expand Down Expand Up @@ -174,7 +174,7 @@ def train(cfg: "DictConfig"): # noqa: F821
with torch.no_grad():
loss_module.value_estimator(
tensordict_data,
params=loss_module.critic_params,
params=loss_module.critic_network_params,
target_params=loss_module.target_critic_params,
)
current_frames = tensordict_data.numel()
Expand Down
4 changes: 2 additions & 2 deletions examples/ppo/ppo_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def main(cfg: "DictConfig"): # noqa: F821
average_gae=False,
)
loss_module = ClipPPOLoss(
actor=actor,
critic=critic,
actor_network=actor,
critic_network=critic,
clip_epsilon=cfg.loss.clip_epsilon,
loss_critic_type=cfg.loss.loss_critic_type,
entropy_coef=cfg.loss.entropy_coef,
Expand Down
4 changes: 2 additions & 2 deletions examples/ppo/ppo_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def main(cfg: "DictConfig"): # noqa: F821
)

loss_module = ClipPPOLoss(
actor=actor,
critic=critic,
actor_network=actor,
critic_network=critic,
clip_epsilon=cfg.loss.clip_epsilon,
loss_critic_type=cfg.loss.loss_critic_type,
entropy_coef=cfg.loss.entropy_coef,
Expand Down
49 changes: 30 additions & 19 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -5820,7 +5820,10 @@ def _create_seq_mock_data_ppo(
@pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
def test_ppo(self, loss_class, device, gradient_mode, advantage, td_est):
@pytest.mark.parametrize("functional", [True, False])
def test_ppo(
self, loss_class, device, gradient_mode, advantage, td_est, functional
):
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_ppo(device=device)

Expand Down Expand Up @@ -5850,7 +5853,7 @@ def test_ppo(self, loss_class, device, gradient_mode, advantage, td_est):
else:
raise NotImplementedError

loss_fn = loss_class(actor, value, loss_critic_type="l2")
loss_fn = loss_class(actor, value, loss_critic_type="l2", functional=functional)
if advantage is not None:
advantage(td)
else:
Expand Down Expand Up @@ -6328,7 +6331,7 @@ def test_ppo_notensordict(
)
value = self._create_mock_value(observation_key=observation_key)

loss = loss_class(actor=actor, critic=value)
loss = loss_class(actor_network=actor, critic_network=value)
loss.set_keys(
action=action_key,
reward=reward_key,
Expand Down Expand Up @@ -6537,7 +6540,8 @@ def _create_seq_mock_data_a2c(
@pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
def test_a2c(self, device, gradient_mode, advantage, td_est):
@pytest.mark.parametrize("functional", (True, False))
def test_a2c(self, device, gradient_mode, advantage, td_est, functional):
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_a2c(device=device)

Expand Down Expand Up @@ -6567,7 +6571,7 @@ def test_a2c(self, device, gradient_mode, advantage, td_est):
else:
raise NotImplementedError

loss_fn = A2CLoss(actor, value, loss_critic_type="l2")
loss_fn = A2CLoss(actor, value, loss_critic_type="l2", functional=functional)

# Check error is raised when actions require grads
td["action"].requires_grad = True
Expand Down Expand Up @@ -6629,7 +6633,9 @@ def test_a2c_state_dict(self, device, gradient_mode):
def test_a2c_separate_losses(self, separate_losses):
torch.manual_seed(self.seed)
actor, critic, common, td = self._create_mock_common_layer_setup()
loss_fn = A2CLoss(actor=actor, critic=critic, separate_losses=separate_losses)
loss_fn = A2CLoss(
actor_network=actor, critic_network=critic, separate_losses=separate_losses
)

# Check error is raised when actions require grads
td["action"].requires_grad = True
Expand Down Expand Up @@ -6966,7 +6972,6 @@ def test_a2c_notensordict(
class TestReinforce(LossModuleTestBase):
seed = 0

@pytest.mark.parametrize("delay_value", [True, False])
@pytest.mark.parametrize("gradient_mode", [True, False])
@pytest.mark.parametrize("advantage", ["gae", "td", "td_lambda", None])
@pytest.mark.parametrize(
Expand All @@ -6979,7 +6984,12 @@ class TestReinforce(LossModuleTestBase):
None,
],
)
def test_reinforce_value_net(self, advantage, gradient_mode, delay_value, td_est):
@pytest.mark.parametrize(
"delay_value,functional", [[False, True], [False, False], [True, True]]
)
def test_reinforce_value_net(
self, advantage, gradient_mode, delay_value, td_est, functional
):
n_obs = 3
n_act = 5
batch = 4
Expand Down Expand Up @@ -7023,8 +7033,9 @@ def test_reinforce_value_net(self, advantage, gradient_mode, delay_value, td_est

loss_fn = ReinforceLoss(
actor_net,
critic=value_net,
critic_network=value_net,
delay_value=delay_value,
functional=functional,
)

td = TensorDict(
Expand All @@ -7049,7 +7060,7 @@ def test_reinforce_value_net(self, advantage, gradient_mode, delay_value, td_est
if advantage is not None:
params = TensorDict.from_module(value_net)
if delay_value:
target_params = loss_fn.target_critic_params
target_params = loss_fn.target_critic_network_params
else:
target_params = None
advantage(td, params=params, target_params=target_params)
Expand Down Expand Up @@ -7108,7 +7119,7 @@ def test_reinforce_tensordict_keys(self, td_est):

loss_fn = ReinforceLoss(
actor_net,
critic=value_net,
critic_network=value_net,
)

default_keys = {
Expand All @@ -7133,7 +7144,7 @@ def test_reinforce_tensordict_keys(self, td_est):

loss_fn = ReinforceLoss(
actor_net,
critic=value_net,
critic_network=value_net,
)

key_mapping = {
Expand Down Expand Up @@ -7207,14 +7218,14 @@ def test_reinforce_tensordict_separate_losses(self, separate_losses):
torch.manual_seed(self.seed)
actor, critic, common, td = self._create_mock_common_layer_setup()
loss_fn = ReinforceLoss(
actor=actor, critic=critic, separate_losses=separate_losses
actor_network=actor, critic_network=critic, separate_losses=separate_losses
)

loss = loss_fn(td)

assert all(
(p.grad is None) or (p.grad == 0).all()
for p in loss_fn.critic_params.values(True, True)
for p in loss_fn.critic_network_params.values(True, True)
)
assert all(
(p.grad is None) or (p.grad == 0).all()
Expand All @@ -7234,14 +7245,14 @@ def test_reinforce_tensordict_separate_losses(self, separate_losses):
for p in loss_fn.actor_network_params.values(True, True)
)
common_layers = itertools.islice(
loss_fn.critic_params.values(True, True),
loss_fn.critic_network_params.values(True, True),
common_layers_no,
)
assert all(
(p.grad is None) or (p.grad == 0).all() for p in common_layers
)
critic_layers = itertools.islice(
loss_fn.critic_params.values(True, True),
loss_fn.critic_network_params.values(True, True),
common_layers_no,
None,
)
Expand All @@ -7250,7 +7261,7 @@ def test_reinforce_tensordict_separate_losses(self, separate_losses):
)
else:
common_layers = itertools.islice(
loss_fn.critic_params.values(True, True),
loss_fn.critic_network_params.values(True, True),
common_layers_no,
)
assert not any(
Expand All @@ -7266,7 +7277,7 @@ def test_reinforce_tensordict_separate_losses(self, separate_losses):
)
assert not any(
(p.grad is None) or (p.grad == 0).all()
for p in loss_fn.critic_params.values(True, True)
for p in loss_fn.critic_network_params.values(True, True)
)

else:
Expand Down Expand Up @@ -7297,7 +7308,7 @@ def test_reinforce_notensordict(
in_keys=["loc", "scale"],
spec=UnboundedContinuousTensorSpec(n_act),
)
loss = ReinforceLoss(actor=actor_net, critic=value_net)
loss = ReinforceLoss(actor_network=actor_net, critic_network=value_net)
loss.set_keys(
reward=reward_key,
done=done_key,
Expand Down
Loading

0 comments on commit 5b67dd3

Please sign in to comment.