Skip to content

Commit

Permalink
[Tests] Fix VMAS tests (pytorch#2287)
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini authored Jul 10, 2024
1 parent d0fa836 commit dcd332d
Showing 1 changed file with 1 addition and 13 deletions.
14 changes: 1 addition & 13 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2143,7 +2143,6 @@ def test_vmas_seeding(self, scenario_name):
@pytest.mark.parametrize("scenario_name", VmasWrapper.available_envs)
def test_vmas_batch_size_error(self, scenario_name, batch_size):
num_envs = 12
n_agents = 2
if len(batch_size) > 1:
with pytest.raises(
TypeError,
Expand All @@ -2152,7 +2151,6 @@ def test_vmas_batch_size_error(self, scenario_name, batch_size):
_ = VmasEnv(
scenario=scenario_name,
num_envs=num_envs,
n_agents=n_agents,
batch_size=batch_size,
)
elif len(batch_size) == 1 and batch_size != (num_envs,):
Expand All @@ -2163,14 +2161,12 @@ def test_vmas_batch_size_error(self, scenario_name, batch_size):
_ = VmasEnv(
scenario=scenario_name,
num_envs=num_envs,
n_agents=n_agents,
batch_size=batch_size,
)
else:
env = VmasEnv(
scenario=scenario_name,
num_envs=num_envs,
n_agents=n_agents,
batch_size=batch_size,
)
env.close()
Expand Down Expand Up @@ -2252,19 +2248,11 @@ def test_vmas_spec_rollout(
env.close()

@pytest.mark.parametrize("num_envs", [1, 20])
@pytest.mark.parametrize("n_agents", [1, 5])
@pytest.mark.parametrize("scenario_name", VmasWrapper.available_envs)
def test_vmas_repr(self, scenario_name, num_envs, n_agents):
if (
n_agents == 1
and scenario_name == "balance"
or scenario_name == "simple_adversary"
):
return
def test_vmas_repr(self, scenario_name, num_envs):
env = VmasEnv(
scenario=scenario_name,
num_envs=num_envs,
n_agents=n_agents,
)
assert str(env) == (
f"{VmasEnv.__name__}(num_envs={num_envs}, n_agents={env.n_agents},"
Expand Down

0 comments on commit dcd332d

Please sign in to comment.