diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index ca587014653..35dd4c780fd 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -482,7 +482,7 @@ def sample(self, storage: Storage, batch_size: int) -> torch.Tensor: index = torch.where(zero_weight, index - 1, index) if (index < 0).any(): raise RuntimeError("Failed to find a suitable index") - zero_weight = torch.as_tensor(self._sum_tree[index]) + weight = torch.as_tensor(self._sum_tree[index]) zero_weight = weight == 0 # Importance sampling weight formula: