diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index 3d3aade6008..206a94611ef 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -507,6 +507,9 @@ canonicalize_half_turns, chosen_angle_to_canonical_half_turns, chosen_angle_to_half_turns, + ClassicalDataDictionaryStore, + ClassicalDataStore, + ClassicalDataStoreReader, Condition, Duration, DURATION_LIKE, @@ -515,6 +518,7 @@ LinearDict, MEASUREMENT_KEY_SEPARATOR, MeasurementKey, + MeasurementType, PeriodicValue, RANDOM_STATE_OR_SEED_LIKE, state_vector_to_probabilities, diff --git a/cirq-core/cirq/contrib/quimb/mps_simulator.py b/cirq-core/cirq/contrib/quimb/mps_simulator.py index 8ad99b0b002..34daf9e10e3 100644 --- a/cirq-core/cirq/contrib/quimb/mps_simulator.py +++ b/cirq-core/cirq/contrib/quimb/mps_simulator.py @@ -91,7 +91,7 @@ def _create_partial_act_on_args( self, initial_state: Union[int, 'MPSState'], qubits: Sequence['cirq.Qid'], - logs: Dict[str, Any], + classical_data: 'cirq.ClassicalDataStore', ) -> 'MPSState': """Creates MPSState args for simulating the Circuit. @@ -101,7 +101,8 @@ def _create_partial_act_on_args( qubits: Determines the canonical ordering of the qubits. This is often used in specifying the initial state, i.e. the ordering of the computational basis states. - logs: A mutable object that measurements are recorded into. + classical_data: The shared classical data container for this + simulation. Returns: MPSState args for simulating the Circuit. @@ -115,7 +116,7 @@ def _create_partial_act_on_args( simulation_options=self.simulation_options, grouping=self.grouping, initial_state=initial_state, - log_of_measurement_results=logs, + classical_data=classical_data, ) def _create_step_result( @@ -229,6 +230,7 @@ def __init__( grouping: Optional[Dict['cirq.Qid', int]] = None, initial_state: int = 0, log_of_measurement_results: Dict[str, Any] = None, + classical_data: 'cirq.ClassicalDataStore' = None, ): """Creates and MPSState @@ -242,11 +244,18 @@ def __init__( initial_state: An integer representing the initial state. log_of_measurement_results: A mutable object that measurements are being recorded into. + classical_data: The shared classical data container for this + simulation. Raises: ValueError: If the grouping does not cover the qubits. """ - super().__init__(prng, qubits, log_of_measurement_results) + super().__init__( + prng=prng, + qubits=qubits, + log_of_measurement_results=log_of_measurement_results, + classical_data=classical_data, + ) qubit_map = self.qubit_map self.grouping = qubit_map if grouping is None else grouping if self.grouping.keys() != self.qubit_map.keys(): diff --git a/cirq-core/cirq/contrib/quimb/mps_simulator_test.py b/cirq-core/cirq/contrib/quimb/mps_simulator_test.py index 1025f05e944..53b778d9fb3 100644 --- a/cirq-core/cirq/contrib/quimb/mps_simulator_test.py +++ b/cirq-core/cirq/contrib/quimb/mps_simulator_test.py @@ -550,10 +550,10 @@ def test_state_act_on_args_initializer(): s = ccq.mps_simulator.MPSState( qubits=(cirq.LineQubit(0),), prng=np.random.RandomState(0), - log_of_measurement_results={'test': 4}, + log_of_measurement_results={'test': [4]}, ) assert s.qubits == (cirq.LineQubit(0),) - assert s.log_of_measurement_results == {'test': 4} + assert s.log_of_measurement_results == {'test': [4]} def test_act_on_gate(): diff --git a/cirq-core/cirq/json_resolver_cache.py b/cirq-core/cirq/json_resolver_cache.py index 19fc564df2a..8218e19e0e9 100644 --- a/cirq-core/cirq/json_resolver_cache.py +++ b/cirq-core/cirq/json_resolver_cache.py @@ -65,6 +65,7 @@ def _parallel_gate_op(gate, qubits): 'Circuit': cirq.Circuit, 'CircuitOperation': cirq.CircuitOperation, 'ClassicallyControlledOperation': cirq.ClassicallyControlledOperation, + 'ClassicalDataDictionaryStore': cirq.ClassicalDataDictionaryStore, 'CliffordState': cirq.CliffordState, 'CliffordTableau': cirq.CliffordTableau, 'CNotPowGate': cirq.CNotPowGate, @@ -107,6 +108,7 @@ def _parallel_gate_op(gate, qubits): 'MixedUnitaryChannel': cirq.MixedUnitaryChannel, 'MeasurementKey': cirq.MeasurementKey, 'MeasurementGate': cirq.MeasurementGate, + 'MeasurementType': cirq.MeasurementType, '_MeasurementSpec': cirq.work._MeasurementSpec, 'Moment': cirq.Moment, 'MutableDensePauliString': cirq.MutableDensePauliString, diff --git a/cirq-core/cirq/ops/classically_controlled_operation.py b/cirq-core/cirq/ops/classically_controlled_operation.py index 3b8d051fb54..10fa65977f2 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation.py +++ b/cirq-core/cirq/ops/classically_controlled_operation.py @@ -148,7 +148,6 @@ def _circuit_diagram_info_( sub_info = protocols.circuit_diagram_info(self._sub_operation, sub_args, None) if sub_info is None: return NotImplemented # coverage: ignore - control_count = len({k for c in self._conditions for k in c.keys}) wire_symbols = sub_info.wire_symbols + ('^',) * control_count if any(not isinstance(c, value.KeyCondition) for c in self._conditions): @@ -176,7 +175,7 @@ def _json_dict_(self) -> Dict[str, Any]: } def _act_on_(self, args: 'cirq.OperationTarget') -> bool: - if all(c.resolve(args.log_of_measurement_results) for c in self._conditions): + if all(c.resolve(args.classical_data) for c in self._conditions): protocols.act_on(self._sub_operation, args) return True diff --git a/cirq-core/cirq/ops/classically_controlled_operation_test.py b/cirq-core/cirq/ops/classically_controlled_operation_test.py index 40b8e98f836..e3cf0f7b1c0 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation_test.py +++ b/cirq-core/cirq/ops/classically_controlled_operation_test.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import numpy as np import pytest import sympy from sympy.parsing import sympy_parser @@ -702,6 +704,40 @@ def test_sympy(): assert result.measurements['m_result'][0][0] == (j > i) +def test_sympy_qudits(): + q0 = cirq.LineQid(0, 3) + q1 = cirq.LineQid(1, 5) + q_result = cirq.LineQubit(2) + + class PlusGate(cirq.Gate): + def __init__(self, dimension, increment=1): + self.dimension = dimension + self.increment = increment % dimension + + def _qid_shape_(self): + return (self.dimension,) + + def _unitary_(self): + inc = (self.increment - 1) % self.dimension + 1 + u = np.empty((self.dimension, self.dimension)) + u[inc:] = np.eye(self.dimension)[:-inc] + u[:inc] = np.eye(self.dimension)[-inc:] + return u + + for i in range(15): + digits = cirq.big_endian_int_to_digits(i, digit_count=2, base=(3, 5)) + circuit = cirq.Circuit( + PlusGate(3, digits[0]).on(q0), + PlusGate(5, digits[1]).on(q1), + cirq.measure(q0, q1, key='m'), + cirq.X(q_result).with_classical_controls(sympy_parser.parse_expr('m % 4 <= 1')), + cirq.measure(q_result, key='m_result'), + ) + + result = cirq.Simulator().run(circuit) + assert result.measurements['m_result'][0][0] == (i % 4 <= 1) + + def test_sympy_path_prefix(): q = cirq.LineQubit(0) op = cirq.X(q).with_classical_controls(sympy.Symbol('b')) diff --git a/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.json b/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.json new file mode 100644 index 00000000000..d5c51d5839c --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.json @@ -0,0 +1,60 @@ +{ + "cirq_type": "ClassicalDataDictionaryStore", + "measurements": [ + [ + { + "cirq_type": "MeasurementKey", + "name": "m", + "path": [] + }, + [0, 1] + ] + ], + "measured_qubits": [ + [ + { + "cirq_type": "MeasurementKey", + "name": "m", + "path": [] + }, + [ + { + "cirq_type": "LineQubit", + "x": 0 + }, + { + "cirq_type": "LineQubit", + "x": 1 + } + ] + ] + ], + "channel_measurements": [ + [ + { + "cirq_type": "MeasurementKey", + "name": "c", + "path": [] + }, + 3 + ] + ], + "measurement_types": [ + [ + { + "cirq_type": "MeasurementKey", + "name": "m", + "path": [] + }, + 1 + ], + [ + { + "cirq_type": "MeasurementKey", + "name": "c", + "path": [] + }, + 2 + ] + ] +} \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.repr b/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.repr new file mode 100644 index 00000000000..c19b8190bfb --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/ClassicalDataDictionaryStore.repr @@ -0,0 +1 @@ +cirq.ClassicalDataDictionaryStore(_measurements={cirq.MeasurementKey('m'): [0, 1]}, _measured_qubits={cirq.MeasurementKey('m'): [cirq.LineQubit(0), cirq.LineQubit(1)]}, _channel_measurements={cirq.MeasurementKey('c'): 3}, _measurement_types={cirq.MeasurementKey('m'): cirq.MeasurementType.MEASUREMENT, cirq.MeasurementKey('c'): cirq.MeasurementType.CHANNEL}) \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/MeasurementType.json b/cirq-core/cirq/protocols/json_test_data/MeasurementType.json new file mode 100644 index 00000000000..fd8ef095787 --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/MeasurementType.json @@ -0,0 +1 @@ +[1, 2] \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/MeasurementType.repr b/cirq-core/cirq/protocols/json_test_data/MeasurementType.repr new file mode 100644 index 00000000000..edeebfddc51 --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/MeasurementType.repr @@ -0,0 +1 @@ +[cirq.MeasurementType.MEASUREMENT, cirq.MeasurementType.CHANNEL] \ No newline at end of file diff --git a/cirq-core/cirq/protocols/measurement_key_protocol.py b/cirq-core/cirq/protocols/measurement_key_protocol.py index 333540b62ca..20b6fa29695 100644 --- a/cirq-core/cirq/protocols/measurement_key_protocol.py +++ b/cirq-core/cirq/protocols/measurement_key_protocol.py @@ -20,9 +20,6 @@ from cirq import value from cirq._doc import doc_private -if TYPE_CHECKING: - import cirq - if TYPE_CHECKING: import cirq diff --git a/cirq-core/cirq/sim/act_on_args.py b/cirq-core/cirq/sim/act_on_args.py index f8d1d955b8d..8694067131c 100644 --- a/cirq-core/cirq/sim/act_on_args.py +++ b/cirq-core/cirq/sim/act_on_args.py @@ -31,7 +31,7 @@ import numpy as np -from cirq import protocols, ops +from cirq import ops, protocols, value from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits from cirq.sim.operation_target import OperationTarget @@ -50,6 +50,7 @@ def __init__( qubits: Optional[Sequence['cirq.Qid']] = None, log_of_measurement_results: Optional[Dict[str, List[int]]] = None, ignore_measurement_results: bool = False, + classical_data: Optional['cirq.ClassicalDataStore'] = None, ): """Inits ActOnArgs. @@ -65,16 +66,21 @@ def __init__( will treat measurement as dephasing instead of collapsing process, and not log the result. This is only applicable to simulators that can represent mixed states. + classical_data: The shared classical data container for this + simulation. """ if prng is None: prng = cast(np.random.RandomState, np.random) if qubits is None: qubits = () - if log_of_measurement_results is None: - log_of_measurement_results = {} self._set_qubits(qubits) self.prng = prng - self._log_of_measurement_results = log_of_measurement_results + self._classical_data = classical_data or value.ClassicalDataDictionaryStore( + _measurements={ + value.MeasurementKey.parse_serialized(k): tuple(v) + for k, v in (log_of_measurement_results or {}).items() + } + ) self._ignore_measurement_results = ignore_measurement_results def _set_qubits(self, qubits: Sequence['cirq.Qid']): @@ -103,9 +109,9 @@ def measure(self, qubits: Sequence['cirq.Qid'], key: str, invert_mask: Sequence[ return bits = self._perform_measurement(qubits) corrected = [bit ^ (bit < 2 and mask) for bit, mask in zip(bits, invert_mask)] - if key in self._log_of_measurement_results: - raise ValueError(f"Measurement already logged to key {key!r}") - self._log_of_measurement_results[key] = corrected + self._classical_data.record_measurement( + value.MeasurementKey.parse_serialized(key), corrected, qubits + ) def get_axes(self, qubits: Sequence['cirq.Qid']) -> List[int]: return [self.qubit_map[q] for q in qubits] @@ -138,7 +144,7 @@ def copy(self: TSelf, deep_copy_buffers: bool = True) -> TSelf: DeprecationWarning, ) self._on_copy(args) - args._log_of_measurement_results = self.log_of_measurement_results.copy() + args._classical_data = self._classical_data.copy() return args def _on_copy(self: TSelf, args: TSelf, deep_copy_buffers: bool = True): @@ -236,8 +242,8 @@ def _on_transpose_to_qubit_order(self: TSelf, qubits: Sequence['cirq.Qid'], targ functionality, if supported.""" @property - def log_of_measurement_results(self) -> Dict[str, List[int]]: - return self._log_of_measurement_results + def classical_data(self) -> 'cirq.ClassicalDataStoreReader': + return self._classical_data @property def ignore_measurement_results(self) -> bool: diff --git a/cirq-core/cirq/sim/act_on_args_container.py b/cirq-core/cirq/sim/act_on_args_container.py index d5ee3d6dbc4..b00960b2b49 100644 --- a/cirq-core/cirq/sim/act_on_args_container.py +++ b/cirq-core/cirq/sim/act_on_args_container.py @@ -12,25 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import abc import inspect +import warnings +from collections import abc from typing import ( Dict, - TYPE_CHECKING, Generic, - Sequence, - Optional, Iterator, - Any, - Tuple, List, + Optional, + Sequence, + Tuple, + TYPE_CHECKING, Union, ) -import warnings import numpy as np -from cirq import ops, protocols +from cirq import ops, protocols, value from cirq.sim.operation_target import OperationTarget from cirq.sim.simulator import ( TActOnArgs, @@ -52,7 +51,8 @@ def __init__( args: Dict[Optional['cirq.Qid'], TActOnArgs], qubits: Sequence['cirq.Qid'], split_untangled_states: bool, - log_of_measurement_results: Dict[str, Any], + log_of_measurement_results: Optional[Dict[str, List[int]]] = None, + classical_data: Optional['cirq.ClassicalDataStore'] = None, ): """Initializes the class. @@ -65,11 +65,18 @@ def __init__( at the end. log_of_measurement_results: A mutable object that measurements are being recorded into. + classical_data: The shared classical data container for this + simulation. """ self.args = args self._qubits = tuple(qubits) self.split_untangled_states = split_untangled_states - self._log_of_measurement_results = log_of_measurement_results + self._classical_data = classical_data or value.ClassicalDataDictionaryStore( + _measurements={ + value.MeasurementKey.parse_serialized(k): tuple(v) + for k, v in (log_of_measurement_results or {}).items() + } + ) def create_merged_state(self) -> TActOnArgs: if not self.split_untangled_states: @@ -135,7 +142,7 @@ def _act_on_fallback_( return True def copy(self, deep_copy_buffers: bool = True) -> 'cirq.ActOnArgsContainer[TActOnArgs]': - logs = self.log_of_measurement_results.copy() + classical_data = self._classical_data.copy() copies = {} for act_on_args in set(self.args.values()): if 'deep_copy_buffers' in inspect.signature(act_on_args.copy).parameters: @@ -150,17 +157,19 @@ def copy(self, deep_copy_buffers: bool = True) -> 'cirq.ActOnArgsContainer[TActO ) copies[act_on_args] = act_on_args.copy() for copy in copies.values(): - copy._log_of_measurement_results = logs + copy._classical_data = classical_data args = {q: copies[a] for q, a in self.args.items()} - return ActOnArgsContainer(args, self.qubits, self.split_untangled_states, logs) + return ActOnArgsContainer( + args, self.qubits, self.split_untangled_states, classical_data=classical_data + ) @property def qubits(self) -> Tuple['cirq.Qid', ...]: return self._qubits @property - def log_of_measurement_results(self) -> Dict[str, Any]: - return self._log_of_measurement_results + def classical_data(self) -> 'cirq.ClassicalDataStoreReader': + return self._classical_data def sample( self, diff --git a/cirq-core/cirq/sim/act_on_density_matrix_args.py b/cirq-core/cirq/sim/act_on_density_matrix_args.py index 1660b2ca652..35ed79de68d 100644 --- a/cirq-core/cirq/sim/act_on_density_matrix_args.py +++ b/cirq-core/cirq/sim/act_on_density_matrix_args.py @@ -45,11 +45,12 @@ def __init__( available_buffer: Optional[List[np.ndarray]] = None, qid_shape: Optional[Tuple[int, ...]] = None, prng: Optional[np.random.RandomState] = None, - log_of_measurement_results: Optional[Dict[str, Any]] = None, + log_of_measurement_results: Optional[Dict[str, List[int]]] = None, qubits: Optional[Sequence['cirq.Qid']] = None, ignore_measurement_results: bool = False, initial_state: Union[np.ndarray, 'cirq.STATE_VECTOR_LIKE'] = 0, dtype: Type[np.number] = np.complex64, + classical_data: Optional['cirq.ClassicalDataStore'] = None, ): """Inits ActOnDensityMatrixArgs. @@ -78,12 +79,20 @@ def __init__( dtype: The `numpy.dtype` of the inferred state vector. One of `numpy.complex64` or `numpy.complex128`. Only used when `target_tenson` is None. + classical_data: The shared classical data container for this + simulation. Raises: ValueError: The dimension of `target_tensor` is not divisible by 2 and `qid_shape` is not provided. """ - super().__init__(prng, qubits, log_of_measurement_results, ignore_measurement_results) + super().__init__( + prng=prng, + qubits=qubits, + log_of_measurement_results=log_of_measurement_results, + ignore_measurement_results=ignore_measurement_results, + classical_data=classical_data, + ) if target_tensor is None: qubits_qid_shape = protocols.qid_shape(self.qubits) initial_matrix = qis.to_valid_density_matrix( diff --git a/cirq-core/cirq/sim/act_on_state_vector_args.py b/cirq-core/cirq/sim/act_on_state_vector_args.py index 06f07ed6e37..a1c2618e66f 100644 --- a/cirq-core/cirq/sim/act_on_state_vector_args.py +++ b/cirq-core/cirq/sim/act_on_state_vector_args.py @@ -48,10 +48,11 @@ def __init__( target_tensor: Optional[np.ndarray] = None, available_buffer: Optional[np.ndarray] = None, prng: Optional[np.random.RandomState] = None, - log_of_measurement_results: Optional[Dict[str, Any]] = None, + log_of_measurement_results: Optional[Dict[str, List[int]]] = None, qubits: Optional[Sequence['cirq.Qid']] = None, initial_state: Union[np.ndarray, 'cirq.STATE_VECTOR_LIKE'] = 0, dtype: Type[np.number] = np.complex64, + classical_data: Optional['cirq.ClassicalDataStore'] = None, ): """Inits ActOnStateVectorArgs. @@ -76,8 +77,15 @@ def __init__( dtype: The `numpy.dtype` of the inferred state vector. One of `numpy.complex64` or `numpy.complex128`. Only used when `target_tenson` is None. + classical_data: The shared classical data container for this + simulation. """ - super().__init__(prng, qubits, log_of_measurement_results) + super().__init__( + prng=prng, + qubits=qubits, + log_of_measurement_results=log_of_measurement_results, + classical_data=classical_data, + ) if target_tensor is None: qid_shape = protocols.qid_shape(self.qubits) state = qis.to_valid_state_vector( @@ -304,7 +312,7 @@ def _strat_act_on_state_vector_from_mixture( args.swap_target_tensor_for(args.available_buffer) if protocols.is_measurement(action): key = protocols.measurement_key_name(action) - args.log_of_measurement_results[key] = [index] + args._classical_data.record_channel_measurement(key, index) return True @@ -353,5 +361,5 @@ def prepare_into_buffer(k: int): args.swap_target_tensor_for(args.available_buffer) if protocols.is_measurement(action): key = protocols.measurement_key_name(action) - args.log_of_measurement_results[key] = [index] + args._classical_data.record_channel_measurement(key, index) return True diff --git a/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py b/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py index 9629e7a7a9b..f2be23883ad 100644 --- a/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py +++ b/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py @@ -14,7 +14,7 @@ """A protocol for implementing high performance clifford tableau evolutions for Clifford Simulator.""" -from typing import Any, Dict, TYPE_CHECKING, List, Sequence +from typing import Dict, List, Optional, Sequence, TYPE_CHECKING import numpy as np @@ -32,9 +32,10 @@ class ActOnCliffordTableauArgs(ActOnStabilizerArgs): def __init__( self, tableau: 'cirq.CliffordTableau', - prng: np.random.RandomState, - log_of_measurement_results: Dict[str, Any], - qubits: Sequence['cirq.Qid'] = None, + prng: Optional[np.random.RandomState] = None, + log_of_measurement_results: Optional[Dict[str, List[int]]] = None, + qubits: Optional[Sequence['cirq.Qid']] = None, + classical_data: Optional['cirq.ClassicalDataStore'] = None, ): """Inits ActOnCliffordTableauArgs. @@ -48,8 +49,15 @@ def __init__( effects. log_of_measurement_results: A mutable object that measurements are being recorded into. + classical_data: The shared classical data container for this + simulation. """ - super().__init__(prng, qubits, log_of_measurement_results) + super().__init__( + prng=prng, + qubits=qubits, + log_of_measurement_results=log_of_measurement_results, + classical_data=classical_data, + ) self.tableau = tableau def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]: diff --git a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py index 1c9e2c25b26..c84f0020406 100644 --- a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py +++ b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, TYPE_CHECKING, List, Optional, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence, TYPE_CHECKING, Union import numpy as np @@ -42,6 +42,7 @@ def __init__( log_of_measurement_results: Optional[Dict[str, Any]] = None, qubits: Optional[Sequence['cirq.Qid']] = None, initial_state: Union[int, 'cirq.StabilizerStateChForm'] = 0, + classical_data: Optional['cirq.ClassicalDataStore'] = None, ): """Initializes with the given state and the axes for the operation. @@ -58,8 +59,15 @@ def __init__( initial_state: The initial state for the simulation. This can be a full CH form passed by reference which will be modified inplace, or a big-endian int in the computational basis. + classical_data: The shared classical data container for this + simulation. """ - super().__init__(prng, qubits, log_of_measurement_results) + super().__init__( + prng=prng, + qubits=qubits, + log_of_measurement_results=log_of_measurement_results, + classical_data=classical_data, + ) initial_state = state or initial_state if isinstance(initial_state, int): qubit_map = {q: i for i, q in enumerate(self.qubits)} @@ -92,19 +100,19 @@ def sample( repetitions: int = 1, seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, ) -> np.ndarray: - measurements: Dict[str, List[np.ndarray]] = {} + measurements = value.ClassicalDataDictionaryStore() prng = value.parse_random_state(seed) for i in range(repetitions): op = ops.measure(*qubits, key=str(i)) state = self.state.copy() ch_form_args = ActOnStabilizerCHFormArgs( + classical_data=measurements, prng=prng, - log_of_measurement_results=measurements, qubits=self.qubits, initial_state=state, ) protocols.act_on(op, ch_form_args) - return np.array(list(measurements.values()), dtype=bool) + return np.array(list(measurements.measurements.values()), dtype=bool) def _x(self, g: common_gates.XPowGate, axis: int): exponent = g.exponent diff --git a/cirq-core/cirq/sim/clifford/clifford_simulator.py b/cirq-core/cirq/sim/clifford/clifford_simulator.py index 746758cfb95..e04559272d7 100644 --- a/cirq-core/cirq/sim/clifford/clifford_simulator.py +++ b/cirq-core/cirq/sim/clifford/clifford_simulator.py @@ -77,7 +77,7 @@ def _create_partial_act_on_args( self, initial_state: Union[int, 'cirq.ActOnStabilizerCHFormArgs'], qubits: Sequence['cirq.Qid'], - logs: Dict[str, Any], + classical_data: 'cirq.ClassicalDataStore', ) -> 'cirq.ActOnStabilizerCHFormArgs': """Creates the ActOnStabilizerChFormArgs for a circuit. @@ -88,6 +88,8 @@ def _create_partial_act_on_args( is often used in specifying the initial state, i.e. the ordering of the computational basis states. logs: A log of the results of measurement that is added to. + classical_data: The shared classical data container for this + simulation. Returns: ActOnStabilizerChFormArgs for the circuit. @@ -97,7 +99,7 @@ def _create_partial_act_on_args( return clifford.ActOnStabilizerCHFormArgs( prng=self._prng, - log_of_measurement_results=logs, + classical_data=classical_data, qubits=qubits, initial_state=initial_state, ) @@ -254,7 +256,6 @@ def state_vector(self): def apply_unitary(self, op: 'cirq.Operation'): ch_form_args = clifford.ActOnStabilizerCHFormArgs( prng=np.random.RandomState(), - log_of_measurement_results={}, qubits=self.qubit_map.keys(), initial_state=self.ch_form, ) @@ -284,10 +285,12 @@ def apply_measurement( else: state = self.copy() + classical_data = value.ClassicalDataDictionaryStore() ch_form_args = clifford.ActOnStabilizerCHFormArgs( prng=prng, - log_of_measurement_results=measurements, + classical_data=classical_data, qubits=self.qubit_map.keys(), initial_state=state.ch_form, ) act_on(op, ch_form_args) + measurements.update({str(k): list(v) for k, v in classical_data.measurements.items()}) diff --git a/cirq-core/cirq/sim/clifford/stabilizer_state_ch_form_test.py b/cirq-core/cirq/sim/clifford/stabilizer_state_ch_form_test.py index d5c2d72fc0d..6f1dc38b4bd 100644 --- a/cirq-core/cirq/sim/clifford/stabilizer_state_ch_form_test.py +++ b/cirq-core/cirq/sim/clifford/stabilizer_state_ch_form_test.py @@ -64,14 +64,15 @@ def test_run(): ) for _ in range(10): state = cirq.StabilizerStateChForm(num_qubits=3) - measurements = {} + classical_data = cirq.ClassicalDataDictionaryStore() for op in circuit.all_operations(): args = cirq.ActOnStabilizerCHFormArgs( qubits=list(circuit.all_qubits()), prng=np.random.RandomState(), - log_of_measurement_results=measurements, + classical_data=classical_data, initial_state=state, ) cirq.act_on(op, args) + measurements = {str(k): list(v) for k, v in classical_data.measurements.items()} assert measurements['1'] == [1] assert measurements['0'] != measurements['2'] diff --git a/cirq-core/cirq/sim/density_matrix_simulator.py b/cirq-core/cirq/sim/density_matrix_simulator.py index b10875f2529..9408016e9f4 100644 --- a/cirq-core/cirq/sim/density_matrix_simulator.py +++ b/cirq-core/cirq/sim/density_matrix_simulator.py @@ -176,7 +176,7 @@ def _create_partial_act_on_args( self, initial_state: Union[np.ndarray, 'cirq.STATE_VECTOR_LIKE', 'cirq.ActOnDensityMatrixArgs'], qubits: Sequence['cirq.Qid'], - logs: Dict[str, Any], + classical_data: 'cirq.ClassicalDataStore', ) -> 'cirq.ActOnDensityMatrixArgs': """Creates the ActOnDensityMatrixArgs for a circuit. @@ -186,7 +186,8 @@ def _create_partial_act_on_args( qubits: Determines the canonical ordering of the qubits. This is often used in specifying the initial state, i.e. the ordering of the computational basis states. - logs: The log of measurement results that is added into. + classical_data: The shared classical data container for this + simulation. Returns: ActOnDensityMatrixArgs for the circuit. @@ -197,7 +198,7 @@ def _create_partial_act_on_args( return act_on_density_matrix_args.ActOnDensityMatrixArgs( qubits=qubits, prng=self._prng, - log_of_measurement_results=logs, + classical_data=classical_data, ignore_measurement_results=self._ignore_measurement_results, initial_state=initial_state, dtype=self._dtype, diff --git a/cirq-core/cirq/sim/operation_target.py b/cirq-core/cirq/sim/operation_target.py index 3b208c3e33a..e54916303cd 100644 --- a/cirq-core/cirq/sim/operation_target.py +++ b/cirq-core/cirq/sim/operation_target.py @@ -14,16 +14,16 @@ """An interface for quantum states as targets for operations.""" import abc from typing import ( - TypeVar, - TYPE_CHECKING, - Generic, - Dict, Any, - Tuple, - Optional, + Dict, + Generic, Iterator, List, + Optional, Sequence, + Tuple, + TYPE_CHECKING, + TypeVar, Union, ) @@ -86,9 +86,14 @@ def qubits(self) -> Tuple['cirq.Qid', ...]: """Gets the qubit order maintained by this target.""" @property - @abc.abstractmethod def log_of_measurement_results(self) -> Dict[str, Any]: """Gets the log of measurement results.""" + return {str(k): list(self.classical_data.get_digits(k)) for k in self.classical_data.keys()} + + @property + @abc.abstractmethod + def classical_data(self) -> 'cirq.ClassicalDataStoreReader': + """The shared classical data container for this simulation..""" @abc.abstractmethod def sample( diff --git a/cirq-core/cirq/sim/simulator_base.py b/cirq-core/cirq/sim/simulator_base.py index 29a24d2dfb3..610e369611c 100644 --- a/cirq-core/cirq/sim/simulator_base.py +++ b/cirq-core/cirq/sim/simulator_base.py @@ -17,6 +17,7 @@ import abc import collections import inspect +import warnings from typing import ( Any, Dict, @@ -31,7 +32,6 @@ Optional, TypeVar, ) -import warnings import numpy as np @@ -126,7 +126,7 @@ def _create_partial_act_on_args( self, initial_state: Any, qubits: Sequence['cirq.Qid'], - logs: Dict[str, Any], + classical_data: 'cirq.ClassicalDataStore', ) -> TActOnArgs: """Creates an instance of the TActOnArgs class for the simulator. @@ -137,8 +137,8 @@ def _create_partial_act_on_args( understood to be a pure state. Other state representations are simulator-dependent. qubits: The sequence of qubits to represent. - logs: The structure to hold measurement logs. A single instance - should be shared among all ActOnArgs within the simulation. + classical_data: The shared classical data container for this + simulation. """ @abc.abstractmethod @@ -352,7 +352,7 @@ def _create_act_on_args( if isinstance(initial_state, OperationTarget): return initial_state - log: Dict[str, Any] = {} + classical_data = value.ClassicalDataDictionaryStore() if self._split_untangled_states: args_map: Dict[Optional['cirq.Qid'], TActOnArgs] = {} if isinstance(initial_state, int): @@ -360,24 +360,26 @@ def _create_act_on_args( args_map[q] = self._create_partial_act_on_args( initial_state=initial_state % q.dimension, qubits=[q], - logs=log, + classical_data=classical_data, ) initial_state = int(initial_state / q.dimension) else: args = self._create_partial_act_on_args( initial_state=initial_state, qubits=qubits, - logs=log, + classical_data=classical_data, ) for q in qubits: args_map[q] = args - args_map[None] = self._create_partial_act_on_args(0, (), log) - return ActOnArgsContainer(args_map, qubits, self._split_untangled_states, log) + args_map[None] = self._create_partial_act_on_args(0, (), classical_data) + return ActOnArgsContainer( + args_map, qubits, self._split_untangled_states, classical_data=classical_data + ) else: return self._create_partial_act_on_args( initial_state=initial_state, qubits=qubits, - logs=log, + classical_data=classical_data, ) diff --git a/cirq-core/cirq/sim/simulator_base_test.py b/cirq-core/cirq/sim/simulator_base_test.py index d848e2012d3..a99527f3722 100644 --- a/cirq-core/cirq/sim/simulator_base_test.py +++ b/cirq-core/cirq/sim/simulator_base_test.py @@ -25,10 +25,10 @@ class CountingActOnArgs(cirq.ActOnArgs): gate_count = 0 measurement_count = 0 - def __init__(self, state, qubits, logs): + def __init__(self, state, qubits, classical_data): super().__init__( qubits=qubits, - log_of_measurement_results=logs, + classical_data=classical_data, ) self.state = state @@ -39,7 +39,7 @@ def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]: def copy(self, deep_copy_buffers: bool = True) -> 'CountingActOnArgs': args = CountingActOnArgs( qubits=self.qubits, - logs=self.log_of_measurement_results.copy(), + classical_data=self.classical_data.copy(), state=self.state, ) args.gate_count = self.gate_count @@ -115,9 +115,9 @@ def _create_partial_act_on_args( self, initial_state: Any, qubits: Sequence['cirq.Qid'], - logs: Dict[str, Any], + classical_data: cirq.ClassicalDataStore, ) -> CountingActOnArgs: - return CountingActOnArgs(qubits=qubits, state=initial_state, logs=logs) + return CountingActOnArgs(qubits=qubits, state=initial_state, classical_data=classical_data) def _create_simulator_trial_result( self, @@ -145,9 +145,11 @@ def _create_partial_act_on_args( self, initial_state: Any, qubits: Sequence['cirq.Qid'], - logs: Dict[str, Any], + classical_data: cirq.ClassicalDataStore, ) -> CountingActOnArgs: - return SplittableCountingActOnArgs(qubits=qubits, state=initial_state, logs=logs) + return SplittableCountingActOnArgs( + qubits=qubits, state=initial_state, classical_data=classical_data + ) q0, q1 = cirq.LineQubit.range(2) @@ -270,9 +272,11 @@ def _create_partial_act_on_args( self, initial_state: Any, qubits: Sequence['cirq.Qid'], - logs: Dict[str, Any], + classical_data: cirq.ClassicalDataStore, ) -> MockCountingActOnArgs: - return MockCountingActOnArgs(qubits=qubits, state=initial_state, logs=logs) + return MockCountingActOnArgs( + qubits=qubits, state=initial_state, classical_data=classical_data + ) def _create_simulator_trial_result( self, diff --git a/cirq-core/cirq/sim/sparse_simulator.py b/cirq-core/cirq/sim/sparse_simulator.py index d722970509d..7da4d31bab5 100644 --- a/cirq-core/cirq/sim/sparse_simulator.py +++ b/cirq-core/cirq/sim/sparse_simulator.py @@ -16,7 +16,6 @@ from typing import ( Any, - Dict, Iterator, List, Type, @@ -175,7 +174,7 @@ def _create_partial_act_on_args( self, initial_state: Union['cirq.STATE_VECTOR_LIKE', 'cirq.ActOnStateVectorArgs'], qubits: Sequence['cirq.Qid'], - logs: Dict[str, Any], + classical_data: 'cirq.ClassicalDataStore', ): """Creates the ActOnStateVectorArgs for a circuit. @@ -185,7 +184,8 @@ def _create_partial_act_on_args( qubits: Determines the canonical ordering of the qubits. This is often used in specifying the initial state, i.e. the ordering of the computational basis states. - logs: Log of the measurement results. + classical_data: The shared classical data container for this + simulation. Returns: ActOnStateVectorArgs for the circuit. @@ -196,7 +196,7 @@ def _create_partial_act_on_args( return act_on_state_vector_args.ActOnStateVectorArgs( qubits=qubits, prng=self._prng, - log_of_measurement_results=logs, + classical_data=classical_data, initial_state=initial_state, dtype=self._dtype, ) diff --git a/cirq-core/cirq/value/__init__.py b/cirq-core/cirq/value/__init__.py index bbf81d71817..bd34876530c 100644 --- a/cirq-core/cirq/value/__init__.py +++ b/cirq-core/cirq/value/__init__.py @@ -25,6 +25,13 @@ chosen_angle_to_half_turns, ) +from cirq.value.classical_data import ( + ClassicalDataDictionaryStore, + ClassicalDataStore, + ClassicalDataStoreReader, + MeasurementType, +) + from cirq.value.condition import ( Condition, KeyCondition, diff --git a/cirq-core/cirq/value/classical_data.py b/cirq-core/cirq/value/classical_data.py new file mode 100644 index 00000000000..5596be02efd --- /dev/null +++ b/cirq-core/cirq/value/classical_data.py @@ -0,0 +1,245 @@ +# Copyright 2022 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import enum +from typing import Dict, Mapping, Sequence, Tuple, TYPE_CHECKING, TypeVar + +from cirq.value import digits, value_equality_attr + +if TYPE_CHECKING: + import cirq + + +class MeasurementType(enum.IntEnum): + MEASUREMENT = 1 + CHANNEL = 2 + + def __repr__(self): + return f'cirq.{str(self)}' + + +TSelf = TypeVar('TSelf', bound='ClassicalDataStoreReader') + + +class ClassicalDataStoreReader(abc.ABC): + @abc.abstractmethod + def keys(self) -> Tuple['cirq.MeasurementKey', ...]: + """Gets the measurement keys in the order they were stored.""" + + @abc.abstractmethod + def get_int(self, key: 'cirq.MeasurementKey') -> int: + """Gets the integer corresponding to the measurement. + + The integer is determined by summing the qubit-dimensional basis value + of each measured qubit. For example, if the measurement of qubits + [q1, q0] produces [1, 0], then the corresponding integer is 2, the big- + endian equivalent. If they are qutrits and the measurement is [2, 1], + then the integer is 2 * 3 + 1 = 7. + + Args: + key: The measurement key. + + Raises: + KeyError: If the key has not been used. + """ + + @abc.abstractmethod + def get_digits(self, key: 'cirq.MeasurementKey') -> Tuple[int, ...]: + """Gets the values of the qubits that were measured into this key. + + For example, if the measurement of qubits [q0, q1] produces [0, 1], + this function will return (0, 1). + + Args: + key: The measurement key. + + Raises: + KeyError: If the key has not been used. + """ + + @abc.abstractmethod + def copy(self: TSelf) -> TSelf: + """Creates a copy of the object.""" + + +class ClassicalDataStore(ClassicalDataStoreReader, abc.ABC): + @abc.abstractmethod + def record_measurement( + self, key: 'cirq.MeasurementKey', measurement: Sequence[int], qubits: Sequence['cirq.Qid'] + ): + """Records a measurement. + + Args: + key: The measurement key to hold the measurement. + measurement: The measurement result. + qubits: The qubits that were measured. + + Raises: + ValueError: If the measurement shape does not match the qubits + measured or if the measurement key was already used. + """ + + @abc.abstractmethod + def record_channel_measurement(self, key: 'cirq.MeasurementKey', measurement: int): + """Records a channel measurement. + + Args: + key: The measurement key to hold the measurement. + measurement: The measurement result. + + Raises: + ValueError: If the measurement key was already used. + """ + + +@value_equality_attr.value_equality(unhashable=True) +class ClassicalDataDictionaryStore(ClassicalDataStore): + """Classical data representing measurements and metadata.""" + + def __init__( + self, + *, + _measurements: Dict['cirq.MeasurementKey', Tuple[int, ...]] = None, + _measured_qubits: Dict['cirq.MeasurementKey', Tuple['cirq.Qid', ...]] = None, + _channel_measurements: Dict['cirq.MeasurementKey', int] = None, + _measurement_types: Dict['cirq.MeasurementKey', 'cirq.MeasurementType'] = None, + ): + """Initializes a `ClassicalDataDictionaryStore` object.""" + if not _measurement_types: + _measurement_types = {} + if _measurements: + _measurement_types.update( + {k: MeasurementType.MEASUREMENT for k, v in _measurements.items()} + ) + if _channel_measurements: + _measurement_types.update( + {k: MeasurementType.CHANNEL for k, v in _channel_measurements.items()} + ) + if _measurements is None: + _measurements = {} + if _measured_qubits is None: + _measured_qubits = {} + if _channel_measurements is None: + _channel_measurements = {} + self._measurements: Dict['cirq.MeasurementKey', Tuple[int, ...]] = _measurements + self._measured_qubits: Dict[ + 'cirq.MeasurementKey', Tuple['cirq.Qid', ...] + ] = _measured_qubits + self._channel_measurements: Dict['cirq.MeasurementKey', int] = _channel_measurements + self._measurement_types: Dict[ + 'cirq.MeasurementKey', 'cirq.MeasurementType' + ] = _measurement_types + + @property + def measurements(self) -> Mapping['cirq.MeasurementKey', Tuple[int, ...]]: + """Gets the a mapping from measurement key to measurement.""" + return self._measurements + + @property + def channel_measurements(self) -> Mapping['cirq.MeasurementKey', int]: + """Gets the a mapping from measurement key to channel measurement.""" + return self._channel_measurements + + @property + def measured_qubits(self) -> Mapping['cirq.MeasurementKey', Tuple['cirq.Qid', ...]]: + """Gets the a mapping from measurement key to the qubits measured.""" + return self._measured_qubits + + @property + def measurement_types(self) -> Mapping['cirq.MeasurementKey', 'cirq.MeasurementType']: + """Gets the a mapping from measurement key to the measurement type.""" + return self._measurement_types + + def keys(self) -> Tuple['cirq.MeasurementKey', ...]: + return tuple(self._measurement_types.keys()) + + def record_measurement( + self, key: 'cirq.MeasurementKey', measurement: Sequence[int], qubits: Sequence['cirq.Qid'] + ): + if len(measurement) != len(qubits): + raise ValueError(f'{len(measurement)} measurements but {len(qubits)} qubits.') + if key in self._measurement_types: + raise ValueError(f"Measurement already logged to key {key}") + self._measurement_types[key] = MeasurementType.MEASUREMENT + self._measurements[key] = tuple(measurement) + self._measured_qubits[key] = tuple(qubits) + + def record_channel_measurement(self, key: 'cirq.MeasurementKey', measurement: int): + if key in self._measurement_types: + raise ValueError(f"Measurement already logged to key {key}") + self._measurement_types[key] = MeasurementType.CHANNEL + self._channel_measurements[key] = measurement + + def get_digits(self, key: 'cirq.MeasurementKey') -> Tuple[int, ...]: + return ( + self._measurements[key] + if self._measurement_types[key] == MeasurementType.MEASUREMENT + else (self._channel_measurements[key],) + ) + + def get_int(self, key: 'cirq.MeasurementKey') -> int: + if key not in self._measurement_types: + raise KeyError(f'The measurement key {key} is not in {self._measurements}') + measurement_type = self._measurement_types[key] + if measurement_type == MeasurementType.CHANNEL: + return self._channel_measurements[key] + if key not in self._measured_qubits: + return digits.big_endian_bits_to_int(self._measurements[key]) + return digits.big_endian_digits_to_int( + self._measurements[key], base=[q.dimension for q in self._measured_qubits[key]] + ) + + def copy(self): + return ClassicalDataDictionaryStore( + _measurements=self._measurements.copy(), + _measured_qubits=self._measured_qubits.copy(), + _channel_measurements=self._channel_measurements.copy(), + _measurement_types=self._measurement_types.copy(), + ) + + def _json_dict_(self): + return { + 'measurements': list(self.measurements.items()), + 'measured_qubits': list(self.measured_qubits.items()), + 'channel_measurements': list(self.channel_measurements.items()), + 'measurement_types': list(self.measurement_types.items()), + } + + @classmethod + def _from_json_dict_( + cls, measurements, measured_qubits, channel_measurements, measurement_types, **kwargs + ): + return cls( + _measurements=dict(measurements), + _measured_qubits=dict(measured_qubits), + _channel_measurements=dict(channel_measurements), + _measurement_types=dict(measurement_types), + ) + + def __repr__(self): + return ( + f'cirq.ClassicalDataDictionaryStore(_measurements={self.measurements!r},' + f' _measured_qubits={self.measured_qubits!r},' + f' _channel_measurements={self.channel_measurements!r},' + f' _measurement_types={self.measurement_types!r})' + ) + + def _value_equality_values_(self): + return ( + self._measurements, + self._channel_measurements, + self._measurement_types, + self._measured_qubits, + ) diff --git a/cirq-core/cirq/value/classical_data_test.py b/cirq-core/cirq/value/classical_data_test.py new file mode 100644 index 00000000000..00cfe475d0e --- /dev/null +++ b/cirq-core/cirq/value/classical_data_test.py @@ -0,0 +1,136 @@ +# Copyright 2022 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import cirq + +mkey_m = cirq.MeasurementKey('m') +mkey_c = cirq.MeasurementKey('c') +two_qubits = tuple(cirq.LineQubit.range(2)) + + +def test_init(): + cd = cirq.ClassicalDataDictionaryStore() + assert cd.measurements == {} + assert cd.keys() == () + assert cd.measured_qubits == {} + assert cd.channel_measurements == {} + assert cd.measurement_types == {} + cd = cirq.ClassicalDataDictionaryStore( + _measurements={mkey_m: (0, 1)}, + _measured_qubits={mkey_m: two_qubits}, + _channel_measurements={mkey_c: 3}, + ) + assert cd.measurements == {mkey_m: (0, 1)} + assert cd.keys() == (mkey_m, mkey_c) + assert cd.measured_qubits == {mkey_m: two_qubits} + assert cd.channel_measurements == {mkey_c: 3} + assert cd.measurement_types == { + mkey_m: cirq.MeasurementType.MEASUREMENT, + mkey_c: cirq.MeasurementType.CHANNEL, + } + + +def test_record_measurement(): + cd = cirq.ClassicalDataDictionaryStore() + cd.record_measurement(mkey_m, (0, 1), two_qubits) + assert cd.measurements == {mkey_m: (0, 1)} + assert cd.keys() == (mkey_m,) + assert cd.measured_qubits == {mkey_m: two_qubits} + + +def test_record_measurement_errors(): + cd = cirq.ClassicalDataDictionaryStore() + with pytest.raises(ValueError, match='3 measurements but 2 qubits'): + cd.record_measurement(mkey_m, (0, 1, 2), two_qubits) + cd.record_measurement(mkey_m, (0, 1), two_qubits) + with pytest.raises(ValueError, match='Measurement already logged to key m'): + cd.record_measurement(mkey_m, (0, 1), two_qubits) + + +def test_record_channel_measurement(): + cd = cirq.ClassicalDataDictionaryStore() + cd.record_channel_measurement(mkey_m, 1) + assert cd.channel_measurements == {mkey_m: 1} + assert cd.keys() == (mkey_m,) + + +def test_record_channel_measurement_errors(): + cd = cirq.ClassicalDataDictionaryStore() + cd.record_channel_measurement(mkey_m, 1) + with pytest.raises(ValueError, match='Measurement already logged to key m'): + cd.record_channel_measurement(mkey_m, 1) + with pytest.raises(ValueError, match='Measurement already logged to key m'): + cd.record_measurement(mkey_m, (0, 1), two_qubits) + cd = cirq.ClassicalDataDictionaryStore() + cd.record_measurement(mkey_m, (0, 1), two_qubits) + with pytest.raises(ValueError, match='Measurement already logged to key m'): + cd.record_channel_measurement(mkey_m, 1) + with pytest.raises(ValueError, match='Measurement already logged to key m'): + cd.record_measurement(mkey_m, (0, 1), two_qubits) + + +def test_get_int(): + cd = cirq.ClassicalDataDictionaryStore() + cd.record_measurement(mkey_m, (0, 1), two_qubits) + assert cd.get_int(mkey_m) == 1 + cd = cirq.ClassicalDataDictionaryStore() + cd.record_measurement(mkey_m, (1, 1), two_qubits) + assert cd.get_int(mkey_m) == 3 + cd = cirq.ClassicalDataDictionaryStore() + cd.record_channel_measurement(mkey_m, 1) + assert cd.get_int(mkey_m) == 1 + cd = cirq.ClassicalDataDictionaryStore() + cd.record_measurement(mkey_m, (1, 1), (cirq.LineQid.range(2, dimension=3))) + assert cd.get_int(mkey_m) == 4 + cd = cirq.ClassicalDataDictionaryStore() + with pytest.raises(KeyError, match='The measurement key m is not in {}'): + cd.get_int(mkey_m) + + +def test_copy(): + cd = cirq.ClassicalDataDictionaryStore( + _measurements={mkey_m: (0, 1)}, + _measured_qubits={mkey_m: two_qubits}, + _channel_measurements={mkey_c: 3}, + _measurement_types={ + mkey_m: cirq.MeasurementType.MEASUREMENT, + mkey_c: cirq.MeasurementType.CHANNEL, + }, + ) + cd1 = cd.copy() + assert cd1 is not cd + assert cd1 == cd + assert cd1.measurements is not cd.measurements + assert cd1.measurements == cd.measurements + assert cd1.measured_qubits is not cd.measured_qubits + assert cd1.measured_qubits == cd.measured_qubits + assert cd1.channel_measurements is not cd.channel_measurements + assert cd1.channel_measurements == cd.channel_measurements + assert cd1.measurement_types is not cd.measurement_types + assert cd1.measurement_types == cd.measurement_types + + +def test_repr(): + cd = cirq.ClassicalDataDictionaryStore( + _measurements={mkey_m: (0, 1)}, + _measured_qubits={mkey_m: two_qubits}, + _channel_measurements={mkey_c: 3}, + _measurement_types={ + mkey_m: cirq.MeasurementType.MEASUREMENT, + mkey_c: cirq.MeasurementType.CHANNEL, + }, + ) + cirq.testing.assert_equivalent_repr(cd) diff --git a/cirq-core/cirq/value/condition.py b/cirq-core/cirq/value/condition.py index ef432b7506f..7c594eb2d95 100644 --- a/cirq-core/cirq/value/condition.py +++ b/cirq-core/cirq/value/condition.py @@ -14,13 +14,13 @@ import abc import dataclasses -from typing import Dict, Mapping, Sequence, Tuple, TYPE_CHECKING, FrozenSet +from typing import Dict, Tuple, TYPE_CHECKING, FrozenSet import sympy from cirq._compat import proper_repr from cirq.protocols import json_serialization, measurement_key_protocol as mkp -from cirq.value import digits, measurement_key +from cirq.value import measurement_key if TYPE_CHECKING: import cirq @@ -39,7 +39,10 @@ def replace_key(self, current: 'cirq.MeasurementKey', replacement: 'cirq.Measure """Replaces the control keys.""" @abc.abstractmethod - def resolve(self, measurements: Mapping[str, Sequence[int]]) -> bool: + def resolve( + self, + classical_data: 'cirq.ClassicalDataStoreReader', + ) -> bool: """Resolves the condition based on the measurements.""" @property @@ -98,11 +101,13 @@ def __str__(self): def __repr__(self): return f'cirq.KeyCondition({self.key!r})' - def resolve(self, measurements: Mapping[str, Sequence[int]]) -> bool: - key = str(self.key) - if key not in measurements: - raise ValueError(f'Measurement key {key} missing when testing classical control') - return any(measurements[key]) + def resolve( + self, + classical_data: 'cirq.ClassicalDataStoreReader', + ) -> bool: + if self.key not in classical_data.keys(): + raise ValueError(f'Measurement key {self.key} missing when testing classical control') + return classical_data.get_int(self.key) != 0 def _json_dict_(self): return json_serialization.dataclass_json_dict(self) @@ -143,15 +148,15 @@ def __str__(self): def __repr__(self): return f'cirq.SympyCondition({proper_repr(self.expr)})' - def resolve(self, measurements: Mapping[str, Sequence[int]]) -> bool: - missing = [str(k) for k in self.keys if str(k) not in measurements] + def resolve( + self, + classical_data: 'cirq.ClassicalDataStoreReader', + ) -> bool: + missing = [str(k) for k in self.keys if k not in classical_data.keys()] if missing: raise ValueError(f'Measurement keys {missing} missing when testing classical control') - def value(k): - return digits.big_endian_bits_to_int(measurements[str(k)]) - - replacements = {str(k): value(k) for k in self.keys} + replacements = {str(k): classical_data.get_int(k) for k in self.keys} return bool(self.expr.subs(replacements)) def _json_dict_(self): diff --git a/cirq-core/cirq/value/condition_test.py b/cirq-core/cirq/value/condition_test.py index fd80033a29a..e92029b1bfb 100644 --- a/cirq-core/cirq/value/condition_test.py +++ b/cirq-core/cirq/value/condition_test.py @@ -42,22 +42,26 @@ def test_key_condition_repr(): def test_key_condition_resolve(): - assert init_key_condition.resolve({'0:a': [1]}) - assert init_key_condition.resolve({'0:a': [2]}) - assert init_key_condition.resolve({'0:a': [0, 1]}) - assert init_key_condition.resolve({'0:a': [1, 0]}) - assert not init_key_condition.resolve({'0:a': [0]}) - assert not init_key_condition.resolve({'0:a': [0, 0]}) - assert not init_key_condition.resolve({'0:a': []}) - assert not init_key_condition.resolve({'0:a': [0], 'b': [1]}) + def resolve(measurements): + classical_data = cirq.ClassicalDataDictionaryStore(_measurements=measurements) + return init_key_condition.resolve(classical_data) + + assert resolve({'0:a': [1]}) + assert resolve({'0:a': [2]}) + assert resolve({'0:a': [0, 1]}) + assert resolve({'0:a': [1, 0]}) + assert not resolve({'0:a': [0]}) + assert not resolve({'0:a': [0, 0]}) + assert not resolve({'0:a': []}) + assert not resolve({'0:a': [0], 'b': [1]}) with pytest.raises( ValueError, match='Measurement key 0:a missing when testing classical control' ): - _ = init_key_condition.resolve({}) + _ = resolve({}) with pytest.raises( ValueError, match='Measurement key 0:a missing when testing classical control' ): - _ = init_key_condition.resolve({'0:b': [1]}) + _ = resolve({'0:b': [1]}) def test_key_condition_qasm(): @@ -80,24 +84,28 @@ def test_sympy_condition_repr(): def test_sympy_condition_resolve(): - assert init_sympy_condition.resolve({'0:a': [1]}) - assert init_sympy_condition.resolve({'0:a': [2]}) - assert init_sympy_condition.resolve({'0:a': [0, 1]}) - assert init_sympy_condition.resolve({'0:a': [1, 0]}) - assert not init_sympy_condition.resolve({'0:a': [0]}) - assert not init_sympy_condition.resolve({'0:a': [0, 0]}) - assert not init_sympy_condition.resolve({'0:a': []}) - assert not init_sympy_condition.resolve({'0:a': [0], 'b': [1]}) + def resolve(measurements): + classical_data = cirq.ClassicalDataDictionaryStore(_measurements=measurements) + return init_sympy_condition.resolve(classical_data) + + assert resolve({'0:a': [1]}) + assert resolve({'0:a': [2]}) + assert resolve({'0:a': [0, 1]}) + assert resolve({'0:a': [1, 0]}) + assert not resolve({'0:a': [0]}) + assert not resolve({'0:a': [0, 0]}) + assert not resolve({'0:a': []}) + assert not resolve({'0:a': [0], 'b': [1]}) with pytest.raises( ValueError, match=re.escape("Measurement keys ['0:a'] missing when testing classical control"), ): - _ = init_sympy_condition.resolve({}) + _ = resolve({}) with pytest.raises( ValueError, match=re.escape("Measurement keys ['0:a'] missing when testing classical control"), ): - _ = init_sympy_condition.resolve({'0:b': [1]}) + _ = resolve({'0:b': [1]}) def test_sympy_condition_qasm(): diff --git a/cirq-core/cirq/value/measurement_key.py b/cirq-core/cirq/value/measurement_key.py index ee4c12bb051..e53eac47fde 100644 --- a/cirq-core/cirq/value/measurement_key.py +++ b/cirq-core/cirq/value/measurement_key.py @@ -77,6 +77,16 @@ def __hash__(self): object.__setattr__(self, '_hash', hash(str(self))) return self._hash + def __lt__(self, other): + if isinstance(other, MeasurementKey): + if self.path != other.path: + return self.path < other.path + return self.name < other.name + return NotImplemented + + def __le__(self, other): + return self == other or self < other + def _json_dict_(self): return { 'name': self.name, diff --git a/cirq-core/cirq/value/measurement_key_test.py b/cirq-core/cirq/value/measurement_key_test.py index c7f01de7d9a..e04a8be9c62 100644 --- a/cirq-core/cirq/value/measurement_key_test.py +++ b/cirq-core/cirq/value/measurement_key_test.py @@ -98,3 +98,22 @@ def test_with_measurement_key_mapping(): mkey3 = cirq.with_measurement_key_mapping(mkey3, {'new_key': 'newer_key'}) assert mkey3.name == 'newer_key' assert mkey3.path == ('a',) + + +def test_compare(): + assert cirq.MeasurementKey('a') < cirq.MeasurementKey('b') + assert cirq.MeasurementKey('a') <= cirq.MeasurementKey('b') + assert cirq.MeasurementKey('a') <= cirq.MeasurementKey('a') + assert cirq.MeasurementKey('b') > cirq.MeasurementKey('a') + assert cirq.MeasurementKey('b') >= cirq.MeasurementKey('a') + assert cirq.MeasurementKey('a') >= cirq.MeasurementKey('a') + assert not cirq.MeasurementKey('a') > cirq.MeasurementKey('b') + assert not cirq.MeasurementKey('a') >= cirq.MeasurementKey('b') + assert not cirq.MeasurementKey('b') < cirq.MeasurementKey('a') + assert not cirq.MeasurementKey('b') <= cirq.MeasurementKey('a') + assert cirq.MeasurementKey(path=(), name='b') < cirq.MeasurementKey(path=('0',), name='a') + assert cirq.MeasurementKey(path=('0',), name='n') < cirq.MeasurementKey(path=('1',), name='a') + with pytest.raises(TypeError): + _ = cirq.MeasurementKey('a') < 'b' + with pytest.raises(TypeError): + _ = cirq.MeasurementKey('a') <= 'b' diff --git a/cirq-google/cirq_google/calibration/engine_simulator.py b/cirq-google/cirq_google/calibration/engine_simulator.py index d2c383f059b..29bd90f0913 100644 --- a/cirq-google/cirq_google/calibration/engine_simulator.py +++ b/cirq-google/cirq_google/calibration/engine_simulator.py @@ -474,7 +474,7 @@ def _create_partial_act_on_args( self, initial_state: Union[int, cirq.ActOnStateVectorArgs], qubits: Sequence[cirq.Qid], - logs: Dict[str, Any], + classical_data: cirq.ClassicalDataStore, ) -> cirq.ActOnStateVectorArgs: # Needs an implementation since it's abstract but will never actually be called. raise NotImplementedError()