diff --git a/test/test_libs.py b/test/test_libs.py index 42138b4ad9b..6ccbf2788a9 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -2143,7 +2143,6 @@ def test_vmas_seeding(self, scenario_name): @pytest.mark.parametrize("scenario_name", VmasWrapper.available_envs) def test_vmas_batch_size_error(self, scenario_name, batch_size): num_envs = 12 - n_agents = 2 if len(batch_size) > 1: with pytest.raises( TypeError, @@ -2152,7 +2151,6 @@ def test_vmas_batch_size_error(self, scenario_name, batch_size): _ = VmasEnv( scenario=scenario_name, num_envs=num_envs, - n_agents=n_agents, batch_size=batch_size, ) elif len(batch_size) == 1 and batch_size != (num_envs,): @@ -2163,14 +2161,12 @@ def test_vmas_batch_size_error(self, scenario_name, batch_size): _ = VmasEnv( scenario=scenario_name, num_envs=num_envs, - n_agents=n_agents, batch_size=batch_size, ) else: env = VmasEnv( scenario=scenario_name, num_envs=num_envs, - n_agents=n_agents, batch_size=batch_size, ) env.close() @@ -2252,19 +2248,11 @@ def test_vmas_spec_rollout( env.close() @pytest.mark.parametrize("num_envs", [1, 20]) - @pytest.mark.parametrize("n_agents", [1, 5]) @pytest.mark.parametrize("scenario_name", VmasWrapper.available_envs) - def test_vmas_repr(self, scenario_name, num_envs, n_agents): - if ( - n_agents == 1 - and scenario_name == "balance" - or scenario_name == "simple_adversary" - ): - return + def test_vmas_repr(self, scenario_name, num_envs): env = VmasEnv( scenario=scenario_name, num_envs=num_envs, - n_agents=n_agents, ) assert str(env) == ( f"{VmasEnv.__name__}(num_envs={num_envs}, n_agents={env.n_agents},"