From 58281470c90a25699069814b60b0b306a9dfeb69 Mon Sep 17 00:00:00 2001 From: Matthew Neeley Date: Thu, 25 Apr 2024 18:54:59 +0000 Subject: [PATCH] Fix `__len__` of empty Product sweep to match actual length (#6575) --- cirq-core/cirq/study/sweeps.py | 2 -- cirq-core/cirq/study/sweeps_test.py | 5 +++++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/cirq-core/cirq/study/sweeps.py b/cirq-core/cirq/study/sweeps.py index f075de48058..3f98798f602 100644 --- a/cirq-core/cirq/study/sweeps.py +++ b/cirq-core/cirq/study/sweeps.py @@ -236,8 +236,6 @@ def keys(self) -> List['cirq.TParamKey']: return sum((factor.keys for factor in self.factors), []) def __len__(self) -> int: - if not self.factors: - return 0 length = 1 for factor in self.factors: length *= len(factor) diff --git a/cirq-core/cirq/study/sweeps_test.py b/cirq-core/cirq/study/sweeps_test.py index aba82eacd23..83d0e0cc201 100644 --- a/cirq-core/cirq/study/sweeps_test.py +++ b/cirq-core/cirq/study/sweeps_test.py @@ -142,6 +142,11 @@ def test_product(): assert _values(sweep, 'b') == [4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7] +def test_empty_product(): + sweep = cirq.Product() + assert len(sweep) == len(list(sweep)) == 1 + + def test_slice_access_error(): sweep = cirq.Points('a', [1, 2, 3]) with pytest.raises(TypeError, match=''):