Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CI] Resolve DMC and mujoco pinned versions #2396

Merged
merged 2 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
amend
  • Loading branch information
vmoens committed Aug 13, 2024
commit 6bc10470e9587a61303b25f7a8177b0c57de6f35
9 changes: 3 additions & 6 deletions sota-implementations/cql/cql_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,14 @@ def main(cfg: "DictConfig"): # noqa: F821
device = "cpu"
device = torch.device(device)

# Create replay buffer
replay_buffer = make_offline_replay_buffer(cfg.replay_buffer)

# Create env
train_env, eval_env = make_environment(
cfg, train_num_envs=1, eval_num_envs=cfg.logger.eval_envs, logger=logger
)

# Create replay buffer
replay_buffer = make_offline_replay_buffer(cfg.replay_buffer)

# Create agent
model = make_cql_model(cfg, train_env, eval_env, device)
del train_env
Expand Down Expand Up @@ -107,9 +107,6 @@ def main(cfg: "DictConfig"): # noqa: F821

q_loss = q_loss + cql_loss

alpha_loss = loss_vals["loss_alpha"]
alpha_prime_loss = loss_vals["loss_alpha_prime"]

# update model
alpha_loss = loss_vals["loss_alpha"]
alpha_prime_loss = loss_vals["loss_alpha_prime"]
Expand Down
13 changes: 7 additions & 6 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10423,17 +10423,18 @@ def test_transform_no_env(self, batch):
reason="EndOfLifeTransform can only be tested when Gym is present.",
)
class TestEndOfLife(TransformBase):
pytest.mark.filterwarnings("ignore:The base_env is not a gym env")

def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv):
def make():
with set_gym_backend("gymnasium"):
return GymEnv(BREAKOUT_VERSIONED())

with pytest.warns(UserWarning, match="The base_env is not a gym env"):
with pytest.raises(AttributeError):
env = TransformedEnv(
maybe_fork_ParallelEnv(2, make), transform=EndOfLifeTransform()
)
check_env_specs(env)
with pytest.raises(AttributeError):
env = TransformedEnv(
maybe_fork_ParallelEnv(2, make), transform=EndOfLifeTransform()
)
check_env_specs(env)

def test_trans_serial_env_check(self):
def make():
Expand Down
Loading