Skip to content

Commit

Permalink
[BugFix] Fix OOB sampling in PrioritizedSliceSampler (pytorch#2239)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 20, 2024
1 parent c44a521 commit eb35793
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 3 deletions.
3 changes: 2 additions & 1 deletion sota-implementations/cql/cql_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ def main(cfg: "DictConfig"): # noqa: F821

pbar.close()
torchrl_logger.info(f"Training time: {time.time() - start_time}")
eval_env.close()
if not eval_env.is_closed:
eval_env.close()


if __name__ == "__main__":
Expand Down
6 changes: 4 additions & 2 deletions sota-implementations/cql/cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,10 @@ def main(cfg: "DictConfig"): # noqa: F821
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")

collector.shutdown()
eval_env.close()
train_env.close()
if not eval_env.is_closed:
eval_env.close()
if not train_env.is_closed:
train_env.close()


if __name__ == "__main__":
Expand Down
9 changes: 9 additions & 0 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,15 @@ def sample(self, storage: Storage, batch_size: int) -> torch.Tensor:
index = index.unsqueeze(0)
index.clamp_max_(len(storage) - 1)
weight = torch.as_tensor(self._sum_tree[index])
# get indices where weight is 0
zero_weight = weight == 0
index = index
while zero_weight.any():
index = torch.where(zero_weight, index - 1, index)
if (index < 0).any():
raise RuntimeError("Failed to find a suitable index")
zero_weight = torch.as_tensor(self._sum_tree[index])
zero_weight = weight == 0

# Importance sampling weight formula:
# w_i = (p_i / sum(p) * N) ^ (-beta)
Expand Down

0 comments on commit eb35793

Please sign in to comment.