Skip to content

Commit

Permalink
[BugFix] Fix flaky rb tests (pytorch#1901)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Feb 12, 2024
1 parent 1bd5ec6 commit 1647fa4
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,8 @@ def test_storage_state_dict(self, storage_in, storage_out, init_out, backend):
def test_storage_dumps_loads(
self, device_data, storage_type, data_type, isinit, tmpdir
):
torch.manual_seed(0)

dir_rb = tmpdir / "rb"
dir_save = tmpdir / "save"
dir_rb.mkdir()
Expand Down Expand Up @@ -716,25 +718,30 @@ class TC:
)
else:
raise NotImplementedError

if storage_type in (LazyMemmapStorage,):
storage = storage_type(max_size=10, scratch_dir=dir_rb)
else:
storage = storage_type(max_size=10)

# We cast the device to CPU as CUDA isn't automatically cast to CPU when using range() index
if data_type == "pytree":
storage.set(range(3), tree_map(lambda x: x.cpu(), data))
else:
storage.set(range(3), data.cpu())

storage.dumps(dir_save)
# check we can dump twice
storage.dumps(dir_save)

storage_recover = storage_type(max_size=10)
if isinit:
if data_type == "pytree":
storage_recover.set(range(3), tree_map(lambda x: x.cpu().zero_(), data))
storage_recover.set(
range(3), tree_map(lambda x: x.cpu().clone().zero_(), data)
)
else:
storage_recover.set(range(3), data.cpu().zero_())
storage_recover.set(range(3), data.cpu().clone().zero_())

if data_type in ("tensor", "pytree") and not isinit:
with pytest.raises(
Expand Down

0 comments on commit 1647fa4

Please sign in to comment.