diff --git a/test/test_specs.py b/test/test_specs.py index 83a688c7ed5..6b779811f1d 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -3729,6 +3729,7 @@ def test_device_ordinal(): moneh=moneh, mdisc=mdisc, shape=(-1, 1, 2), + device=device, ) assert spec.device == torch.device("cuda:0")