diff --git a/cirq-core/cirq/study/resolver.py b/cirq-core/cirq/study/resolver.py index e603caeda02..08e543eae03 100644 --- a/cirq-core/cirq/study/resolver.py +++ b/cirq-core/cirq/study/resolver.py @@ -116,6 +116,7 @@ def value_of( """ # Input is a pass through type, no resolution needed: return early + # print(value, type(value), recursive) v = _resolve_value(value) if v is not NotImplemented: return v diff --git a/cirq-core/cirq/value/duration.py b/cirq-core/cirq/value/duration.py index e9e6a1d84b5..fe9b1229019 100644 --- a/cirq-core/cirq/value/duration.py +++ b/cirq-core/cirq/value/duration.py @@ -13,14 +13,14 @@ # limitations under the License. """A typed time delta that supports picosecond accuracy.""" -from typing import AbstractSet, Any, Dict, Optional, Tuple, TYPE_CHECKING, Union +from typing import AbstractSet, Any, Dict, Optional, Tuple, TYPE_CHECKING, Union, List import datetime import sympy import numpy as np from cirq import protocols -from cirq._compat import proper_repr +from cirq._compat import proper_repr, cached_method from cirq._doc import document if TYPE_CHECKING: @@ -79,48 +79,53 @@ def __init__( >>> print(cirq.Duration(micros=1.5 * sympy.Symbol('t'))) (1500.0*t) ns """ + self._time_vals = [0, 0, 0, 0] + self._multipliers = [1, 1000, 1000_000, 1000_000_000] if value is not None and value != 0: if isinstance(value, datetime.timedelta): # timedelta has microsecond resolution. - micros += int(value / datetime.timedelta(microseconds=1)) + self._time_vals[2] = int(value / datetime.timedelta(microseconds=1)) elif isinstance(value, Duration): - picos += value._picos + self._time_vals = value._time_vals else: raise TypeError(f'Not a `cirq.DURATION_LIKE`: {repr(value)}.') - - val = picos + nanos * 1000 + micros * 1000_000 + millis * 1000_000_000 - self._picos: _NUMERIC_OUTPUT_TYPE = float(val) if isinstance(val, np.number) else val + input_vals = [picos, nanos, micros, millis] + self._time_vals = _add_time_vals(self._time_vals, input_vals) def _is_parameterized_(self) -> bool: - return protocols.is_parameterized(self._picos) + return protocols.is_parameterized(self._time_vals) def _parameter_names_(self) -> AbstractSet[str]: - return protocols.parameter_names(self._picos) + return protocols.parameter_names(self._time_vals) def _resolve_parameters_(self, resolver: 'cirq.ParamResolver', recursive: bool) -> 'Duration': - return Duration(picos=protocols.resolve_parameters(self._picos, resolver, recursive)) + return _duration_from_time_vals( + protocols.resolve_parameters(self._time_vals, resolver, recursive) + ) + @cached_method def total_picos(self) -> _NUMERIC_OUTPUT_TYPE: """Returns the number of picoseconds that the duration spans.""" - return self._picos + val = sum(a * b for a, b in zip(self._time_vals, self._multipliers)) + return float(val) if isinstance(val, np.number) else val def total_nanos(self) -> _NUMERIC_OUTPUT_TYPE: """Returns the number of nanoseconds that the duration spans.""" - return self._picos / 1000 + return self.total_picos() / 1000 def total_micros(self) -> _NUMERIC_OUTPUT_TYPE: """Returns the number of microseconds that the duration spans.""" - return self._picos / 1000_000 + return self.total_picos() / 1000_000 def total_millis(self) -> _NUMERIC_OUTPUT_TYPE: """Returns the number of milliseconds that the duration spans.""" - return self._picos / 1000_000_000 + return self.total_picos() / 1000_000_000 def __add__(self, other) -> 'Duration': other = _attempt_duration_like_to_duration(other) if other is None: return NotImplemented - return Duration(picos=self._picos + other._picos) + return _duration_from_time_vals(_add_time_vals(self._time_vals, other._time_vals)) def __radd__(self, other) -> 'Duration': return self.__add__(other) @@ -129,29 +134,36 @@ def __sub__(self, other) -> 'Duration': other = _attempt_duration_like_to_duration(other) if other is None: return NotImplemented - return Duration(picos=self._picos - other._picos) + return _duration_from_time_vals( + _add_time_vals(self._time_vals, [-x for x in other._time_vals]) + ) def __rsub__(self, other) -> 'Duration': other = _attempt_duration_like_to_duration(other) if other is None: return NotImplemented - return Duration(picos=other._picos - self._picos) + return _duration_from_time_vals( + _add_time_vals(other._time_vals, [-x for x in self._time_vals]) + ) def __mul__(self, other) -> 'Duration': if not isinstance(other, (int, float, sympy.Expr)): return NotImplemented - return Duration(picos=self._picos * other) + if other == 0: + return _duration_from_time_vals([0] * 4) + return _duration_from_time_vals([x * other for x in self._time_vals]) def __rmul__(self, other) -> 'Duration': return self.__mul__(other) def __truediv__(self, other) -> Union['Duration', float]: if isinstance(other, (int, float, sympy.Expr)): - return Duration(picos=self._picos / other) + new_time_vals = [x / other for x in self._time_vals] + return _duration_from_time_vals(new_time_vals) other_duration = _attempt_duration_like_to_duration(other) if other_duration is not None: - return self._picos / other_duration._picos + return self.total_picos() / other_duration.total_picos() return NotImplemented @@ -159,56 +171,57 @@ def __eq__(self, other): other = _attempt_duration_like_to_duration(other) if other is None: return NotImplemented - return self._picos == other._picos + return self.total_picos() == other.total_picos() def __ne__(self, other): other = _attempt_duration_like_to_duration(other) if other is None: return NotImplemented - return self._picos != other._picos + return self.total_picos() != other.total_picos() def __gt__(self, other): other = _attempt_duration_like_to_duration(other) if other is None: return NotImplemented - return self._picos > other._picos + return self.total_picos() > other.total_picos() def __lt__(self, other): other = _attempt_duration_like_to_duration(other) if other is None: return NotImplemented - return self._picos < other._picos + return self.total_picos() < other.total_picos() def __ge__(self, other): other = _attempt_duration_like_to_duration(other) if other is None: return NotImplemented - return self._picos >= other._picos + return self.total_picos() >= other.total_picos() def __le__(self, other): other = _attempt_duration_like_to_duration(other) if other is None: return NotImplemented - return self._picos <= other._picos + return self.total_picos() <= other.total_picos() def __bool__(self): - return bool(self._picos) + return bool(self.total_picos()) def __hash__(self): - if isinstance(self._picos, (int, float)) and self._picos % 1000000 == 0: - return hash(datetime.timedelta(microseconds=self._picos / 1000000)) - return hash((Duration, self._picos)) + if isinstance(self.total_picos(), (int, float)) and self.total_picos() % 1000000 == 0: + return hash(datetime.timedelta(microseconds=self.total_picos() / 1000000)) + return hash((Duration, self.total_picos())) def _decompose_into_amount_unit_suffix(self) -> Tuple[int, str, str]: + picos = self.total_picos() if ( - isinstance(self._picos, sympy.Mul) - and len(self._picos.args) == 2 - and isinstance(self._picos.args[0], (sympy.Integer, sympy.Float)) + isinstance(picos, sympy.Mul) + and len(picos.args) == 2 + and isinstance(picos.args[0], (sympy.Integer, sympy.Float)) ): - scale = self._picos.args[0] - rest = self._picos.args[1] + scale = picos.args[0] + rest = picos.args[1] else: - scale = self._picos + scale = picos rest = 1 if scale % 1000_000_000 == 0: @@ -234,7 +247,7 @@ def _decompose_into_amount_unit_suffix(self) -> Tuple[int, str, str]: return amount * rest, unit, suffix def __str__(self) -> str: - if self._picos == 0: + if self.total_picos() == 0: return 'Duration(0)' amount, _, suffix = self._decompose_into_amount_unit_suffix() if not isinstance(amount, (int, float, sympy.Symbol)): @@ -257,3 +270,19 @@ def _attempt_duration_like_to_duration(value: Any) -> Optional[Duration]: if isinstance(value, (int, float)) and value == 0: return Duration() return None + + +def _add_time_vals(val1: List[_NUMERIC_INPUT_TYPE], val2: List[_NUMERIC_INPUT_TYPE]): + ret = [] + for i in range(4): + if val1[i] and val2[i]: + ret.append(val1[i] + val2[i]) + else: + ret.append(val1[i] or val2[i]) + return ret + + +def _duration_from_time_vals(time_vals: List[_NUMERIC_INPUT_TYPE]): + ret = Duration() + ret._time_vals = time_vals + return ret diff --git a/cirq-core/cirq/value/duration_test.py b/cirq-core/cirq/value/duration_test.py index 52dd80c4c86..2e99f28ba80 100644 --- a/cirq-core/cirq/value/duration_test.py +++ b/cirq-core/cirq/value/duration_test.py @@ -168,9 +168,11 @@ def test_sub(): def test_mul(): assert Duration(picos=2) * 3 == Duration(picos=6) assert 4 * Duration(picos=3) == Duration(picos=12) + assert 0 * Duration(picos=10) == Duration() t = sympy.Symbol('t') assert t * Duration(picos=3) == Duration(picos=3 * t) + assert 0 * Duration(picos=t) == Duration(picos=0) with pytest.raises(TypeError): _ = Duration() * Duration()