From e0a64dd236e075d3280287959595adda3fe758d5 Mon Sep 17 00:00:00 2001 From: Dax Fohl Date: Mon, 21 Mar 2022 09:16:04 -0700 Subject: [PATCH] Extract BufferedDM/SV/MPS from ActOnDM/SV/MPSArgs (#4979) * Extract BufferedDensityMatrix from ActOnDensityMatrixArgs * state vector * clean up code * clean up code * clean up code * format * docs * test * coverage * improve state vector * improve state vector * replace deleted functions * replace deleted functions * replace deleted functions * replace deleted functions * lint * mps quantum state * mps quantum state * mps quantum state * mps quantum state * mps quantum state * mps quantum state * mps quantum state * mps quantum state * coverage * fix merge errors * Code review comments * Remove todo Co-authored-by: Orion Martin <40585662+95-martin-orion@users.noreply.github.com> --- cirq-core/cirq/contrib/quimb/mps_simulator.py | 422 +++++++++++----- .../cirq/contrib/quimb/mps_simulator_test.py | 14 + .../cirq/protocols/json_test_data/spec.py | 1 + .../cirq/sim/act_on_density_matrix_args.py | 322 +++++++++---- .../sim/act_on_density_matrix_args_test.py | 5 + .../cirq/sim/act_on_state_vector_args.py | 453 +++++++++++++----- .../cirq/sim/act_on_state_vector_args_test.py | 18 +- cirq-core/cirq/sim/state_vector.py | 2 +- 8 files changed, 890 insertions(+), 347 deletions(-) diff --git a/cirq-core/cirq/contrib/quimb/mps_simulator.py b/cirq-core/cirq/contrib/quimb/mps_simulator.py index 34daf9e10e3..cfa48ab651a 100644 --- a/cirq-core/cirq/contrib/quimb/mps_simulator.py +++ b/cirq-core/cirq/contrib/quimb/mps_simulator.py @@ -19,12 +19,13 @@ import dataclasses import math -from typing import Any, Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Union +from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, TYPE_CHECKING, Union import numpy as np import quimb.tensor as qtn -from cirq import devices, ops, protocols, value +from cirq import devices, protocols, value +from cirq._compat import deprecated from cirq.sim import simulator_base from cirq.sim.act_on_args import ActOnArgs @@ -219,59 +220,67 @@ def _simulator_state(self): @value.value_equality -class MPSState(ActOnArgs): - """A state of the MPS simulation.""" +class _MPSHandler: + """Quantum state of the MPS simulation.""" def __init__( self, - qubits: Sequence['cirq.Qid'], - prng: np.random.RandomState, + qid_shape: Tuple[int, ...], + grouping: Dict[int, int], + M: List[qtn.Tensor], + format_i: str, + estimated_gate_error_list: List[float], simulation_options: MPSOptions = MPSOptions(), - grouping: Optional[Dict['cirq.Qid', int]] = None, - initial_state: int = 0, - log_of_measurement_results: Dict[str, Any] = None, - classical_data: 'cirq.ClassicalDataStore' = None, ): - """Creates and MPSState + """Creates an MPSQuantumState Args: - qubits: Determines the canonical ordering of the qubits. This - is often used in specifying the initial state, i.e. the - ordering of the computational basis states. - prng: A random number generator, used to simulate measurements. + qid_shape: Dimensions of the qubits represented. + grouping: How to group qubits together, if None all are individual. + M: The tensor list for maintaining the MPS state. + format_i: A string for formatting the group labels. + estimated_gate_error_list: The error estimations. simulation_options: Numerical options for the simulation. + """ + self._qid_shape = qid_shape + self._grouping = grouping + self._M = M + self._format_i = format_i + self._format_mu = 'mu_{}_{}' + self._simulation_options = simulation_options + self._estimated_gate_error_list = estimated_gate_error_list + + @classmethod + def create( + cls, + *, + qid_shape: Tuple[int, ...], + grouping: Dict[int, int], + initial_state: int = 0, + simulation_options: MPSOptions = MPSOptions(), + ): + """Creates an MPSQuantumState + + Args: + qid_shape: Dimensions of the qubits represented. grouping: How to group qubits together, if None all are individual. - initial_state: An integer representing the initial state. - log_of_measurement_results: A mutable object that measurements are - being recorded into. - classical_data: The shared classical data container for this - simulation. + initial_state: The initial computational basis state. + simulation_options: Numerical options for the simulation. Raises: ValueError: If the grouping does not cover the qubits. """ - super().__init__( - prng=prng, - qubits=qubits, - log_of_measurement_results=log_of_measurement_results, - classical_data=classical_data, - ) - qubit_map = self.qubit_map - self.grouping = qubit_map if grouping is None else grouping - if self.grouping.keys() != self.qubit_map.keys(): - raise ValueError('Grouping must cover exactly the qubits.') - self.M = [] - for _ in range(max(self.grouping.values()) + 1): - self.M.append(qtn.Tensor()) + M = [] + for _ in range(max(grouping.values()) + 1): + M.append(qtn.Tensor()) # The order of the qubits matters, because the state |01> is different from |10>. Since # Quimb uses strings to name tensor indices, we want to be able to sort them too. If we are # working with, say, 123 qubits then we want qubit 3 to come before qubit 100, but then # we want write the string '003' which comes before '100' in lexicographic order. The code # below is just simple string formatting. - max_num_digits = len(f'{max(qubit_map.values())}') - self.format_i = f'i_{{:0{max_num_digits}}}' - self.format_mu = 'mu_{}_{}' + max_num_digits = len(f'{max(grouping.values())}') + format_i = f'i_{{:0{max_num_digits}}}' # TODO(tonybruguier): Instead of relying on sortable indices could you keep a parallel # mapping of e.g. qubit to string-index and do all "logic" on the qubits themselves and @@ -280,21 +289,26 @@ def __init__( # TODO(tonybruguier): Refactor out so that the code below can also be used by # circuit_to_tensors in cirq.contrib.quimb.state_vector. - for qubit in reversed(list(qubit_map.keys())): - d = qubit.dimension + for axis in reversed(range(len(qid_shape))): + d = qid_shape[axis] x = np.zeros(d) x[initial_state % d] = 1.0 - i = qubit_map[qubit] - n = self.grouping[qubit] - self.M[n] @= qtn.Tensor(x, inds=(self.i_str(i),)) + n = grouping[axis] + M[n] @= qtn.Tensor(x, inds=(format_i.format(axis),)) initial_state = initial_state // d - self.simulation_options = simulation_options - self.estimated_gate_error_list: List[float] = [] + return _MPSHandler( + qid_shape=qid_shape, + grouping=grouping, + M=M, + format_i=format_i, + estimated_gate_error_list=[], + simulation_options=simulation_options, + ) def i_str(self, i: int) -> str: # Returns the index name for the i'th qid. - return self.format_i.format(i) + return self._format_i.format(i) def mu_str(self, i: int, j: int) -> str: # Returns the index name for the pair of the i'th and j'th qids. Note @@ -302,19 +316,30 @@ def mu_str(self, i: int, j: int) -> str: # string. smallest = min(i, j) largest = max(i, j) - return self.format_mu.format(smallest, largest) + return self._format_mu.format(smallest, largest) def __str__(self) -> str: - return str(qtn.TensorNetwork(self.M)) + return str(qtn.TensorNetwork(self._M)) def _value_equality_values_(self) -> Any: - return self.qubit_map, self.M, self.simulation_options, self.grouping + return self._qid_shape, self._M, self._simulation_options, self._grouping - def _on_copy(self, target: 'MPSState', deep_copy_buffers: bool = True): - target.simulation_options = self.simulation_options - target.grouping = self.grouping - target.M = [x.copy() for x in self.M] - target.estimated_gate_error_list = self.estimated_gate_error_list + def copy(self, deep_copy_buffers: bool = True) -> '_MPSHandler': + """Copies the object. + + Args: + deep_copy_buffers: True by default, False to reuse the existing buffers. + Returns: + A copy of the object. + """ + return _MPSHandler( + simulation_options=self._simulation_options, + grouping=self._grouping, + qid_shape=self._qid_shape, + M=[x.copy() for x in self._M], + estimated_gate_error_list=self._estimated_gate_error_list.copy(), + format_i=self._format_i, + ) def state_vector(self) -> np.ndarray: """Returns the full state vector. @@ -322,7 +347,7 @@ def state_vector(self) -> np.ndarray: Returns: A vector that contains the full state. """ - tensor_network = qtn.TensorNetwork(self.M) + tensor_network = qtn.TensorNetwork(self._M) state_vector = tensor_network.contract(inplace=False) # Here, we rely on the formatting of the indices, and the fact that we have enough @@ -330,11 +355,11 @@ def state_vector(self) -> np.ndarray: sorted_ind = tuple(sorted(state_vector.inds)) return state_vector.fuse({'i': sorted_ind}).data - def partial_trace(self, keep_qubits: Set['cirq.Qid']) -> np.ndarray: - """Traces out all qubits except keep_qubits. + def partial_trace(self, keep_axes: Set[int]) -> np.ndarray: + """Traces out all qubits except keep_axes. Args: - keep_qubits: The set of qubits that are left after computing the + keep_axes: The set of axes that are left after computing the partial trace. For example, if we have a circuit for 3 qubits and this parameter only has one qubit, the entire density matrix would be 8x8, but this function returns a 2x2 matrix. @@ -343,13 +368,11 @@ def partial_trace(self, keep_qubits: Set['cirq.Qid']) -> np.ndarray: An array that contains the partial trace. """ - contracted_inds = set( - [self.i_str(i) for qubit, i in self.qubit_map.items() if qubit not in keep_qubits] - ) + contracted_inds = set(map(self.i_str, set(range(len(self._qid_shape))) - keep_axes)) conj_pfx = "conj_" - tensor_network = qtn.TensorNetwork(self.M) + tensor_network = qtn.TensorNetwork(self._M) # Rename the internal indices to avoid collisions. Also rename the qubit # indices that are kept. We do not rename the qubit indices that are @@ -363,7 +386,7 @@ def partial_trace(self, keep_qubits: Set['cirq.Qid']) -> np.ndarray: conj_tensor_network.reindex(reindex_mapping, inplace=True) partial_trace = conj_tensor_network @ tensor_network - forward_inds = [self.i_str(self.qubit_map[keep_qubit]) for keep_qubit in keep_qubits] + forward_inds = list(map(self.i_str, keep_axes)) backward_inds = [conj_pfx + forward_ind for forward_ind in forward_inds] return partial_trace.to_dense(forward_inds, backward_inds) @@ -371,7 +394,7 @@ def to_numpy(self) -> np.ndarray: """An alias for the state vector.""" return self.state_vector() - def apply_op(self, op: 'cirq.Operation', prng: np.random.RandomState): + def apply_op(self, op: Any, axes: Sequence[int], prng: np.random.RandomState): """Applies a unitary operation, mutating the object to represent the new state. op: @@ -379,7 +402,7 @@ def apply_op(self, op: 'cirq.Operation', prng: np.random.RandomState): and 2- qubit operations are currently supported. """ - old_inds = tuple([self.i_str(self.qubit_map[qubit]) for qubit in op.qubits]) + old_inds = tuple(map(self.i_str, axes)) new_inds = tuple(['new_' + old_ind for old_ind in old_inds]) if protocols.has_unitary(op): @@ -389,40 +412,40 @@ def apply_op(self, op: 'cirq.Operation', prng: np.random.RandomState): mixture_idx = int(prng.choice(len(mixtures), p=[mixture[0] for mixture in mixtures])) U = mixtures[mixture_idx][1] U = qtn.Tensor( - U.reshape([qubit.dimension for qubit in op.qubits] * 2), inds=(new_inds + old_inds) + U.reshape([self._qid_shape[axis] for axis in axes] * 2), inds=(new_inds + old_inds) ) # TODO(tonybruguier): Explore using the Quimb's tensor network natively. - if len(op.qubits) == 1: - n = self.grouping[op.qubits[0]] + if len(axes) == 1: + n = self._grouping[axes[0]] - self.M[n] = (U @ self.M[n]).reindex({new_inds[0]: old_inds[0]}) - elif len(op.qubits) == 2: - n, p = [self.grouping[qubit] for qubit in op.qubits] + self._M[n] = (U @ self._M[n]).reindex({new_inds[0]: old_inds[0]}) + elif len(axes) == 2: + n, p = [self._grouping[axis] for axis in axes] if n == p: - self.M[n] = (U @ self.M[n]).reindex( + self._M[n] = (U @ self._M[n]).reindex( {new_inds[0]: old_inds[0], new_inds[1]: old_inds[1]} ) else: # This is the index on which we do the contraction. We need to add it iff it's # the first time that we do the joining for that specific pair. mu_ind = self.mu_str(n, p) - if mu_ind not in self.M[n].inds: - self.M[n].new_ind(mu_ind) - if mu_ind not in self.M[p].inds: - self.M[p].new_ind(mu_ind) + if mu_ind not in self._M[n].inds: + self._M[n].new_ind(mu_ind) + if mu_ind not in self._M[p].inds: + self._M[p].new_ind(mu_ind) - T = U @ self.M[n] @ self.M[p] + T = U @ self._M[n] @ self._M[p] - left_inds = tuple(set(T.inds) & set(self.M[n].inds)) + (new_inds[0],) + left_inds = tuple(set(T.inds) & set(self._M[n].inds)) + (new_inds[0],) X, Y = T.split( left_inds, - method=self.simulation_options.method, - max_bond=self.simulation_options.max_bond, - cutoff=self.simulation_options.cutoff, - cutoff_mode=self.simulation_options.cutoff_mode, + method=self._simulation_options.method, + max_bond=self._simulation_options.max_bond, + cutoff=self._simulation_options.cutoff, + cutoff_mode=self._simulation_options.cutoff_mode, get='tensors', absorb='both', bond_ind=mu_ind, @@ -437,11 +460,11 @@ def apply_op(self, op: 'cirq.Operation', prng: np.random.RandomState): # The renormalization would then have to be done manually. # # However, for now, e_n are just the estimated value. - e_n = self.simulation_options.cutoff - self.estimated_gate_error_list.append(e_n) + e_n = self._simulation_options.cutoff + self._estimated_gate_error_list.append(e_n) - self.M[n] = X.reindex({new_inds[0]: old_inds[0]}) - self.M[p] = Y.reindex({new_inds[1]: old_inds[1]}) + self._M[n] = X.reindex({new_inds[0]: old_inds[0]}) + self._M[p] = Y.reindex({new_inds[1]: old_inds[1]}) else: # NOTE(tonybruguier): There could be a way to handle higher orders. I think this could # involve HOSVDs: @@ -452,28 +475,17 @@ def apply_op(self, op: 'cirq.Operation', prng: np.random.RandomState): raise ValueError('Can only handle 1 and 2 qubit operations') return True - def _act_on_fallback_( - self, - action: Union['cirq.Operation', 'cirq.Gate'], - qubits: Sequence['cirq.Qid'], - allow_decompose: bool = True, - ) -> bool: - """Delegates the action to self.apply_op""" - if isinstance(action, ops.Gate): - action = ops.GateOperation(action, qubits) - return self.apply_op(action, self.prng) - def estimation_stats(self): """Returns some statistics about the memory usage and quality of the approximation.""" - num_coefs_used = sum([Mi.data.size for Mi in self.M]) - memory_bytes = sum([Mi.data.nbytes for Mi in self.M]) + num_coefs_used = sum([Mi.data.size for Mi in self._M]) + memory_bytes = sum([Mi.data.nbytes for Mi in self._M]) # The computation below is done for numerical stability, instead of directly using the # formula: # estimated_fidelity = \prod_i (1 - estimated_gate_error_list_i) estimated_fidelity = 1.0 + np.expm1( - sum(np.log1p(-x) for x in self.estimated_gate_error_list) + sum(np.log1p(-x) for x in self._estimated_gate_error_list) ) estimated_fidelity = round(estimated_fidelity, ndigits=3) @@ -483,21 +495,9 @@ def estimation_stats(self): "estimated_fidelity": estimated_fidelity, } - def perform_measurement( - self, qubits: Sequence['cirq.Qid'], prng: np.random.RandomState, collapse_state_vector=True + def _measure( + self, axes: Sequence[int], prng: np.random.RandomState, collapse_state_vector=True ) -> List[int]: - """Performs a measurement over one or more qubits. - - Args: - qubits: The sequence of qids to measure, in that order. - prng: A random number generator, used to simulate measurements. - collapse_state_vector: A Boolean specifying whether we should mutate - the state after the measurement. - - Raises: - ValueError: If the probabilities for the measurements differ too much from one for the - tolerance specified in simulation options. - """ results: List[int] = [] if collapse_state_vector: @@ -505,55 +505,219 @@ def perform_measurement( else: state = self.copy() - for qubit in qubits: - n = state.qubit_map[qubit] - + for axis in axes: # Trace out other qubits - M = state.partial_trace(keep_qubits={qubit}) + M = state.partial_trace(keep_axes={axis}) probs = np.diag(M).real sum_probs = sum(probs) # Because the computation is approximate, the probabilities do not # necessarily add up to 1.0, and thus we re-normalize them. - if abs(sum_probs - 1.0) > self.simulation_options.sum_prob_atol: + if abs(sum_probs - 1.0) > self._simulation_options.sum_prob_atol: raise ValueError(f'Sum of probabilities exceeds tolerance: {sum_probs}') norm_probs = [x / sum_probs for x in probs] - d = qubit.dimension + d = self._qid_shape[axis] result: int = int(prng.choice(d, p=norm_probs)) collapser = np.zeros((d, d)) collapser[result][result] = 1.0 / math.sqrt(probs[result]) - old_n = state.i_str(n) + old_n = state.i_str(axis) new_n = 'new_' + old_n collapser = qtn.Tensor(collapser, inds=(new_n, old_n)) - state.M[n] = (collapser @ state.M[n]).reindex({new_n: old_n}) + state._M[axis] = (collapser @ state._M[axis]).reindex({new_n: old_n}) results.append(result) return results - def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]: - """Measures the axes specified by the simulator.""" - return self.perform_measurement(qubits, self.prng) + def measure( + self, axes: Sequence[int], seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None + ) -> List[int]: + """Measures the MPS. + + Args: + axes: The axes to measure. + seed: The random number seed to use. + Returns: + The measurements in axis order. + """ + return self._measure(axes, value.parse_random_state(seed)) def sample( self, - qubits: Sequence['cirq.Qid'], + axes: Sequence[int], repetitions: int = 1, seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, ) -> np.ndarray: + """Samples the MPS. + + Args: + axes: The axes to sample. + repetitions: The number of samples to make. + seed: The random number seed to use. + Returns: + The samples in order. + """ measurements: List[List[int]] = [] + prng = value.parse_random_state(seed) for _ in range(repetitions): - measurements.append( - self.perform_measurement( - qubits, value.parse_random_state(seed), collapse_state_vector=False - ) - ) + measurements.append(self._measure(axes, prng, collapse_state_vector=False)) return np.array(measurements, dtype=int) + + +@value.value_equality +class MPSState(ActOnArgs): + """A state of the MPS simulation.""" + + def __init__( + self, + qubits: Sequence['cirq.Qid'], + prng: np.random.RandomState, + simulation_options: MPSOptions = MPSOptions(), + grouping: Optional[Dict['cirq.Qid', int]] = None, + initial_state: int = 0, + log_of_measurement_results: Dict[str, Any] = None, + classical_data: 'cirq.ClassicalDataStore' = None, + ): + """Creates and MPSState + + Args: + qubits: Determines the canonical ordering of the qubits. This + is often used in specifying the initial state, i.e. the + ordering of the computational basis states. + prng: A random number generator, used to simulate measurements. + simulation_options: Numerical options for the simulation. + grouping: How to group qubits together, if None all are individual. + initial_state: An integer representing the initial state. + log_of_measurement_results: A mutable object that measurements are + being recorded into. + classical_data: The shared classical data container for this + simulation. + + Raises: + ValueError: If the grouping does not cover the qubits. + """ + super().__init__( + prng=prng, + qubits=qubits, + log_of_measurement_results=log_of_measurement_results, + classical_data=classical_data, + ) + final_grouping = self.qubit_map if grouping is None else grouping + if final_grouping.keys() != self.qubit_map.keys(): + raise ValueError('Grouping must cover exactly the qubits.') + self._state = _MPSHandler.create( + initial_state=initial_state, + qid_shape=tuple(q.dimension for q in qubits), + simulation_options=simulation_options, + grouping={self.qubit_map[k]: v for k, v in final_grouping.items()}, + ) + + def i_str(self, i: int) -> str: + # Returns the index name for the i'th qid. + return self._state.i_str(i) + + def mu_str(self, i: int, j: int) -> str: + # Returns the index name for the pair of the i'th and j'th qids. Note + # that by convention, the lower index is always the first in the output + # string. + return self._state.mu_str(i, j) + + def __str__(self) -> str: + return str(self._state) + + def _value_equality_values_(self) -> Any: + return self.qubits, self._state + + def _on_copy(self, target: 'MPSState', deep_copy_buffers: bool = True): + target._state = self._state.copy(deep_copy_buffers) + + def state_vector(self) -> np.ndarray: + """Returns the full state vector. + + Returns: + A vector that contains the full state. + """ + return self._state.state_vector() + + def partial_trace(self, keep_qubits: Set['cirq.Qid']) -> np.ndarray: + """Traces out all qubits except keep_qubits. + + Args: + keep_qubits: The set of qubits that are left after computing the + partial trace. For example, if we have a circuit for 3 qubits + and this parameter only has one qubit, the entire density matrix + would be 8x8, but this function returns a 2x2 matrix. + + Returns: + An array that contains the partial trace. + """ + return self._state.partial_trace(set(self.get_axes(list(keep_qubits)))) + + def to_numpy(self) -> np.ndarray: + """An alias for the state vector.""" + return self._state.to_numpy() + + @deprecated(deadline="v0.15", fix="Use cirq.act_on(op, mps_state)") + def apply_op(self, op: 'cirq.Operation', prng: np.random.RandomState): + """Applies a unitary operation, mutating the object to represent the new state. + + op: + The operation that mutates the object. Note that currently, only 1- + and 2- qubit operations are currently supported. + """ + return self._state.apply_op(op, self.get_axes(op.qubits), prng) + + def _act_on_fallback_( + self, + action: Union['cirq.Operation', 'cirq.Gate'], + qubits: Sequence['cirq.Qid'], + allow_decompose: bool = True, + ) -> bool: + """Delegates the action to self.apply_op""" + return self._state.apply_op(action, self.get_axes(qubits), self.prng) + + def estimation_stats(self): + """Returns some statistics about the memory usage and quality of the approximation.""" + return self._state.estimation_stats() + + @property + def M(self): + return self._state._M + + @deprecated(deadline="v0.15", fix="Use cirq.act_on(measurement, mps_state)") + def perform_measurement( + self, qubits: Sequence['cirq.Qid'], prng: np.random.RandomState, collapse_state_vector=True + ) -> List[int]: + """Performs a measurement over one or more qubits. + + Args: + qubits: The sequence of qids to measure, in that order. + prng: A random number generator, used to simulate measurements. + collapse_state_vector: A Boolean specifying whether we should mutate + the state after the measurement. + + Raises: + ValueError: If the probabilities for the measurements differ too much from one for the + tolerance specified in simulation options. + """ + return self._state._measure(self.get_axes(qubits), prng, collapse_state_vector) + + def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]: + """Measures the axes specified by the simulator.""" + return self._state.measure(self.get_axes(qubits), self.prng) + + def sample( + self, + qubits: Sequence['cirq.Qid'], + repetitions: int = 1, + seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, + ) -> np.ndarray: + return self._state.sample(self.get_axes(qubits), repetitions, seed) diff --git a/cirq-core/cirq/contrib/quimb/mps_simulator_test.py b/cirq-core/cirq/contrib/quimb/mps_simulator_test.py index aef78565ad2..e6386d216fb 100644 --- a/cirq-core/cirq/contrib/quimb/mps_simulator_test.py +++ b/cirq-core/cirq/contrib/quimb/mps_simulator_test.py @@ -569,3 +569,17 @@ def test_act_on_gate(): args.state_vector().reshape((2, 2, 2)), cirq.one_hot(index=(0, 1, 0), shape=(2, 2, 2), dtype=np.complex64), ) + + +def test_deprectated(): + q0 = cirq.LineQubit(0) + prng = np.random.RandomState(0) + args = ccq.mps_simulator.MPSState( + qubits=cirq.LineQubit.range(3), + prng=prng, + log_of_measurement_results={}, + ) + with cirq.testing.assert_deprecated(deadline='0.15'): + args.perform_measurement([q0], prng) + with cirq.testing.assert_deprecated(deadline='0.15'): + args.apply_op(cirq.X(q0), prng) diff --git a/cirq-core/cirq/protocols/json_test_data/spec.py b/cirq-core/cirq/protocols/json_test_data/spec.py index 57374985c77..2e351d7b084 100644 --- a/cirq-core/cirq/protocols/json_test_data/spec.py +++ b/cirq-core/cirq/protocols/json_test_data/spec.py @@ -85,6 +85,7 @@ 'ActOnArgsContainer', 'ActOnCliffordTableauArgs', 'ActOnDensityMatrixArgs', + 'ActOnStabilizerArgs', 'ActOnStabilizerCHFormArgs', 'ActOnStateVectorArgs', 'ApplyChannelArgs', diff --git a/cirq-core/cirq/sim/act_on_density_matrix_args.py b/cirq-core/cirq/sim/act_on_density_matrix_args.py index b92a3695a17..015e373f8c5 100644 --- a/cirq-core/cirq/sim/act_on_density_matrix_args.py +++ b/cirq-core/cirq/sim/act_on_density_matrix_args.py @@ -13,7 +13,7 @@ # limitations under the License. """Objects and methods for acting efficiently on a density matrix.""" -from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Sequence, Type, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Type, Union import numpy as np @@ -24,6 +24,204 @@ if TYPE_CHECKING: import cirq + from numpy.typing import DTypeLike + + +class _BufferedDensityMatrix: + """Contains the density matrix and buffers for efficient state evolution.""" + + def __init__(self, density_matrix: np.ndarray, buffer: Optional[List[np.ndarray]] = None): + """Initializes the object with the inputs. + + This initializer creates the buffer if necessary. + + Args: + density_matrix: The density matrix, must be correctly formatted. The data is not + checked for validity here due to performance concerns. + buffer: Optional, must be length 3 and same shape as the density matrix. If not + provided, a buffer will be created automatically. + Raises: + ValueError: If the array is not the shape of a density matrix. + """ + self._density_matrix = density_matrix + if buffer is None: + buffer = [np.empty_like(density_matrix) for _ in range(3)] + self._buffer = buffer + if len(density_matrix.shape) % 2 != 0: + raise ValueError('The dimension of target_tensor is not divisible by 2.') + self._qid_shape = density_matrix.shape[: len(density_matrix.shape) // 2] + + @classmethod + def create( + cls, + *, + initial_state: Union[np.ndarray, 'cirq.STATE_VECTOR_LIKE'] = 0, + qid_shape: Optional[Tuple[int, ...]] = None, + dtype: Optional['DTypeLike'] = None, + buffer: Optional[List[np.ndarray]] = None, + ): + """Creates a buffered density matrix with the requested state. + + Args: + initial_state: The initial state for the simulation in the computational basis. + qid_shape: The shape of the density matrix, if the initial state is provided as an int. + dtype: The desired dtype of the density matrix. + buffer: Optional, must be length 3 and same shape as the density matrix. If not + provided, a buffer will be created automatically. + Raises: + ValueError: If initial state is provided as integer, but qid_shape is not provided. + """ + if not isinstance(initial_state, np.ndarray): + if qid_shape is None: + raise ValueError('qid_shape must be provided if initial_state is not ndarray') + density_matrix = qis.to_valid_density_matrix( + initial_state, len(qid_shape), qid_shape=qid_shape, dtype=dtype + ).reshape(qid_shape * 2) + else: + if qid_shape is not None: + density_matrix = initial_state.reshape(qid_shape * 2) + else: + density_matrix = initial_state + if np.may_share_memory(density_matrix, initial_state): + density_matrix = density_matrix.copy() + density_matrix = density_matrix.astype(dtype, copy=False) + return cls(density_matrix, buffer) + + def copy(self, deep_copy_buffers: bool = True) -> '_BufferedDensityMatrix': + """Copies the object. + + Args: + deep_copy_buffers: True by default, False to reuse the existing buffers. + Returns: + A copy of the object. + """ + return _BufferedDensityMatrix( + density_matrix=self._density_matrix.copy(), + buffer=[b.copy() for b in self._buffer] if deep_copy_buffers else self._buffer, + ) + + def kron(self, other: '_BufferedDensityMatrix') -> '_BufferedDensityMatrix': + """Creates the Kronecker product with the other density matrix. + + Args: + other: The density matrix with which to kron. + Returns: + The Kronecker product of the two density matrices. + """ + density_matrix = transformations.density_matrix_kronecker_product( + self._density_matrix, other._density_matrix + ) + return _BufferedDensityMatrix(density_matrix=density_matrix) + + def factor( + self, axes: Sequence[int], *, validate=True, atol=1e-07 + ) -> Tuple['_BufferedDensityMatrix', '_BufferedDensityMatrix']: + """Factors out the desired axes. + + Args: + axes: The axes to factor out. Only the left axes should be provided. For example, to + extract [C,A] from density matrix of shape [A,B,C,D,A,B,C,D], `axes` should be + [2,0], and the return value will be two density matrices ([C,A,C,A], [B,D,B,D]). + validate: Perform a validation that the density matrix factors cleanly. + atol: The absolute tolerance for the validation. + Returns: + A tuple with the `(extracted, remainder)` density matrices, where `extracted` means + the sub-matrix which corresponds to the axes requested, and with the axes in the + requested order, and where `remainder` means the sub-matrix on the remaining axes, + in the same order as the original density matrix. + """ + extracted_tensor, remainder_tensor = transformations.factor_density_matrix( + self._density_matrix, axes, validate=validate, atol=atol + ) + extracted = _BufferedDensityMatrix(density_matrix=extracted_tensor) + remainder = _BufferedDensityMatrix(density_matrix=remainder_tensor) + return extracted, remainder + + def reindex(self, axes: Sequence[int]) -> '_BufferedDensityMatrix': + """Transposes the axes of a density matrix to a specified order. + + Args: + axes: The desired axis order. Only the left axes should be provided. For example, to + transpose [A,B,C,A,B,C] to [C,B,A,C,B,A], `axes` should be [2,1,0]. + Returns: + The transposed density matrix. + """ + new_tensor = transformations.transpose_density_matrix_to_axis_order( + self._density_matrix, axes + ) + return _BufferedDensityMatrix(density_matrix=new_tensor) + + def apply_channel(self, action: Any, axes: Sequence[int]) -> bool: + """Apply channel to state. + + Args: + action: The value with a channel to apply. + axes: The axes on which to apply the channel. + Returns: + True if the action succeeded. + """ + result = protocols.apply_channel( + action, + args=protocols.ApplyChannelArgs( + target_tensor=self._density_matrix, + out_buffer=self._buffer[0], + auxiliary_buffer0=self._buffer[1], + auxiliary_buffer1=self._buffer[2], + left_axes=axes, + right_axes=[e + len(self._qid_shape) for e in axes], + ), + default=None, + ) + if result is None: + return False + for i in range(len(self._buffer)): + if result is self._buffer[i]: + self._buffer[i] = self._density_matrix + self._density_matrix = result + return True + + def measure( + self, axes: Sequence[int], seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None + ) -> List[int]: + """Measures the density matrix. + + Args: + axes: The axes to measure. + seed: The random number seed to use. + Returns: + The measurements in order. + """ + bits, _ = sim.measure_density_matrix( + self._density_matrix, + axes, + out=self._density_matrix, + qid_shape=self._qid_shape, + seed=seed, + ) + return bits + + def sample( + self, + axes: Sequence[int], + repetitions: int = 1, + seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, + ) -> np.ndarray: + """Samples the density matrix. + + Args: + axes: The axes to sample. + repetitions: The number of samples to make. + seed: The random number seed to use. + Returns: + The samples in order. + """ + return sim.sample_density_matrix( + self._density_matrix, + axes, + qid_shape=self._qid_shape, + repetitions=repetitions, + seed=seed, + ) class ActOnDensityMatrixArgs(ActOnArgs): @@ -113,29 +311,12 @@ def __init__( log_of_measurement_results=log_of_measurement_results, classical_data=classical_data, ) - if target_tensor is None: - qubits_qid_shape = protocols.qid_shape(self.qubits) - initial_matrix = qis.to_valid_density_matrix( - initial_state, len(qubits_qid_shape), qid_shape=qubits_qid_shape, dtype=dtype - ) - if np.may_share_memory(initial_matrix, initial_state): - initial_matrix = initial_matrix.copy() - target_tensor = initial_matrix.reshape(qubits_qid_shape * 2) - self.target_tensor = target_tensor - - if available_buffer is None: - available_buffer = [np.empty_like(target_tensor) for _ in range(3)] - self.available_buffer = available_buffer - - if qid_shape is None: - target_shape = target_tensor.shape - if len(target_shape) % 2 != 0: - raise ValueError( - 'The dimension of target_tensor is not divisible by 2.' - ' Require explicit qid_shape.' - ) - qid_shape = target_shape[: len(target_shape) // 2] - self.qid_shape = qid_shape + self._state = _BufferedDensityMatrix.create( + initial_state=target_tensor if target_tensor is not None else initial_state, + qid_shape=tuple(q.dimension for q in qubits) if qubits is not None else None, + dtype=dtype, + buffer=available_buffer, + ) def _act_on_fallback_( self, @@ -143,11 +324,11 @@ def _act_on_fallback_( qubits: Sequence['cirq.Qid'], allow_decompose: bool = True, ) -> bool: - strats = [ + strats: List[Callable[[Any, Any, Sequence['cirq.Qid']], bool]] = [ _strat_apply_channel_to_state, ] if allow_decompose: - strats.append(strat_act_on_from_apply_decompose) # type: ignore + strats.append(strat_act_on_from_apply_decompose) # Try each strategy, stopping if one works. for strat in strats: @@ -165,33 +346,15 @@ def _act_on_fallback_( def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]: """Delegates the call to measure the density matrix.""" - bits, _ = sim.measure_density_matrix( - self.target_tensor, - self.get_axes(qubits), - out=self.target_tensor, - qid_shape=self.qid_shape, - seed=self.prng, - ) - return bits + return self._state.measure(self.get_axes(qubits), self.prng) def _on_copy(self, target: 'cirq.ActOnDensityMatrixArgs', deep_copy_buffers: bool = True): - target.target_tensor = self.target_tensor.copy() - if deep_copy_buffers: - target.available_buffer = [b.copy() for b in self.available_buffer] - else: - target.available_buffer = self.available_buffer + target._state = self._state.copy(deep_copy_buffers) def _on_kronecker_product( self, other: 'cirq.ActOnDensityMatrixArgs', target: 'cirq.ActOnDensityMatrixArgs' ): - target_tensor = transformations.density_matrix_kronecker_product( - self.target_tensor, other.target_tensor - ) - target.target_tensor = target_tensor - target.available_buffer = [ - np.empty_like(target_tensor) for _ in range(len(self.available_buffer)) - ] - target.qid_shape = target_tensor.shape[: int(target_tensor.ndim / 2)] + target._state = self._state.kron(other._state) def _on_factor( self, @@ -202,19 +365,7 @@ def _on_factor( atol=1e-07, ): axes = self.get_axes(qubits) - extracted_tensor, remainder_tensor = transformations.factor_density_matrix( - self.target_tensor, axes, validate=validate, atol=atol - ) - extracted.target_tensor = extracted_tensor - extracted.available_buffer = [ - np.empty_like(extracted_tensor) for _ in self.available_buffer - ] - extracted.qid_shape = extracted_tensor.shape[: int(extracted_tensor.ndim / 2)] - remainder.target_tensor = remainder_tensor - remainder.available_buffer = [ - np.empty_like(remainder_tensor) for _ in self.available_buffer - ] - remainder.qid_shape = remainder_tensor.shape[: int(remainder_tensor.ndim / 2)] + extracted._state, remainder._state = self._state.factor(axes, validate=validate, atol=atol) @property def allows_factoring(self): @@ -223,14 +374,7 @@ def allows_factoring(self): def _on_transpose_to_qubit_order( self, qubits: Sequence['cirq.Qid'], target: 'cirq.ActOnDensityMatrixArgs' ): - axes = self.get_axes(qubits) - new_tensor = transformations.transpose_density_matrix_to_axis_order( - self.target_tensor, axes - ) - buffer = [np.empty_like(new_tensor) for _ in self.available_buffer] - target.target_tensor = new_tensor - target.available_buffer = buffer - target.qid_shape = new_tensor.shape[: int(new_tensor.ndim / 2)] + target._state = self._state.reindex(self.get_axes(qubits)) def sample( self, @@ -238,14 +382,7 @@ def sample( repetitions: int = 1, seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, ) -> np.ndarray: - indices = [self.qubit_map[q] for q in qubits] - return sim.sample_density_matrix( - self.target_tensor, - indices, - qid_shape=tuple(q.dimension for q in self.qubits), - repetitions=repetitions, - seed=seed, - ) + return self._state.sample(self.get_axes(qubits), repetitions, seed) @property def can_represent_mixed_states(self) -> bool: @@ -261,28 +398,21 @@ def __repr__(self) -> str: f' log_of_measurement_results={proper_repr(self.log_of_measurement_results)})' ) + @property + def target_tensor(self): + return self._state._density_matrix + + @property + def available_buffer(self): + return self._state._buffer + + @property + def qid_shape(self): + return self._state._qid_shape + def _strat_apply_channel_to_state( action: Any, args: 'cirq.ActOnDensityMatrixArgs', qubits: Sequence['cirq.Qid'] ) -> bool: """Apply channel to state.""" - axes = args.get_axes(qubits) - result = protocols.apply_channel( - action, - args=protocols.ApplyChannelArgs( - target_tensor=args.target_tensor, - out_buffer=args.available_buffer[0], - auxiliary_buffer0=args.available_buffer[1], - auxiliary_buffer1=args.available_buffer[2], - left_axes=axes, - right_axes=[e + len(args.qubits) for e in axes], - ), - default=None, - ) - if result is None: - return NotImplemented - for i in range(len(args.available_buffer)): - if result is args.available_buffer[i]: - args.available_buffer[i] = args.target_tensor - args.target_tensor = result - return True + return True if args._state.apply_channel(action, args.get_axes(qubits)) else NotImplemented diff --git a/cirq-core/cirq/sim/act_on_density_matrix_args_test.py b/cirq-core/cirq/sim/act_on_density_matrix_args_test.py index a9604adbfa8..82fd6dba366 100644 --- a/cirq-core/cirq/sim/act_on_density_matrix_args_test.py +++ b/cirq-core/cirq/sim/act_on_density_matrix_args_test.py @@ -111,3 +111,8 @@ def test_with_qubits(): np.array([[1, 0], [0, 0]], dtype=np.complex64), ), ) + + +def test_qid_shape_error(): + with pytest.raises(ValueError, match="qid_shape must be provided"): + cirq.sim.act_on_density_matrix_args._BufferedDensityMatrix.create(initial_state=0) diff --git a/cirq-core/cirq/sim/act_on_state_vector_args.py b/cirq-core/cirq/sim/act_on_state_vector_args.py index f3f803945ad..7d9aeeb14c4 100644 --- a/cirq-core/cirq/sim/act_on_state_vector_args.py +++ b/cirq-core/cirq/sim/act_on_state_vector_args.py @@ -13,7 +13,7 @@ # limitations under the License. """Objects and methods for acting efficiently on a state vector.""" -from typing import Any, Optional, Tuple, TYPE_CHECKING, Type, Union, Dict, List, Sequence +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Type, Union import numpy as np @@ -24,6 +24,302 @@ if TYPE_CHECKING: import cirq + from numpy.typing import DTypeLike + + +class _BufferedStateVector: + """Contains the state vector and buffer for efficient state evolution.""" + + def __init__(self, state_vector: np.ndarray, buffer: Optional[np.ndarray] = None): + """Initializes the object with the inputs. + + This initializer creates the buffer if necessary. + + Args: + state_vector: The state vector, must be correctly formatted. The data is not checked + for validity here due to performance concerns. + buffer: Optional, must be same shape as the state vector. If not provided, a buffer + will be created automatically. + """ + self._state_vector = state_vector + if buffer is None: + buffer = np.empty_like(state_vector) + self._buffer = buffer + self._qid_shape = state_vector.shape + + @classmethod + def create( + cls, + *, + initial_state: Union[np.ndarray, 'cirq.STATE_VECTOR_LIKE'] = 0, + qid_shape: Optional[Tuple[int, ...]] = None, + dtype: Optional['DTypeLike'] = None, + buffer: Optional[List[np.ndarray]] = None, + ): + """Initializes the object with the inputs. + + This initializer creates the buffer if necessary. + + Args: + initial_state: The density matrix, must be correctly formatted. The data is not + checked for validity here due to performance concerns. + qid_shape: The shape of the density matrix, if the initial state is provided as an int. + dtype: The dtype of the density matrix, if the initial state is provided as an int. + buffer: Optional, must be length 3 and same shape as the density matrix. If not + provided, a buffer will be created automatically. + Raises: + ValueError: If initial state is provided as integer, but qid_shape is not provided. + """ + if not isinstance(initial_state, np.ndarray): + if qid_shape is None: + raise ValueError('qid_shape must be provided if initial_state is not ndarray') + state_vector = qis.to_valid_state_vector( + initial_state, len(qid_shape), qid_shape=qid_shape, dtype=dtype + ).reshape(qid_shape) + else: + if qid_shape is not None: + state_vector = initial_state.reshape(qid_shape) + else: + state_vector = initial_state + if np.may_share_memory(state_vector, initial_state): + state_vector = state_vector.copy() + state_vector = state_vector.astype(dtype, copy=False) + return cls(state_vector, buffer) + + def copy(self, deep_copy_buffers: bool = True) -> '_BufferedStateVector': + """Copies the object. + + Args: + deep_copy_buffers: True by default, False to reuse the existing buffers. + Returns: + A copy of the object. + """ + return _BufferedStateVector( + state_vector=self._state_vector.copy(), + buffer=self._buffer.copy() if deep_copy_buffers else self._buffer, + ) + + def kron(self, other: '_BufferedStateVector') -> '_BufferedStateVector': + """Creates the Kronecker product with the other state vector. + + Args: + other: The state vector with which to kron. + Returns: + The Kronecker product of the two state vectors. + """ + target_tensor = transformations.state_vector_kronecker_product( + self._state_vector, other._state_vector + ) + return _BufferedStateVector( + state_vector=target_tensor, + buffer=np.empty_like(target_tensor), + ) + + def factor( + self, axes: Sequence[int], *, validate=True, atol=1e-07 + ) -> Tuple['_BufferedStateVector', '_BufferedStateVector']: + """Factors a state vector into two independent state vectors. + + This function should only be called on state vectors that are known to be separable, such + as immediately after a measurement or reset operation. It does not verify that the provided + state vector is indeed separable, and will return nonsense results for vectors + representing entangled states. + + Args: + axes: The axes to factor out. + validate: Perform a validation that the state vector factors cleanly. + atol: The absolute tolerance for the validation. + + Returns: + A tuple with the `(extracted, remainder)` state vectors, where `extracted` means the + sub-state vector which corresponds to the axes requested, and with the axes in the + requested order, and where `remainder` means the sub-state vector on the remaining + axes, in the same order as the original state vector. + """ + extracted_tensor, remainder_tensor = transformations.factor_state_vector( + self._state_vector, axes, validate=validate, atol=atol + ) + extracted = _BufferedStateVector( + state_vector=extracted_tensor, + buffer=np.empty_like(extracted_tensor), + ) + remainder = _BufferedStateVector( + state_vector=remainder_tensor, + buffer=np.empty_like(remainder_tensor), + ) + return extracted, remainder + + def reindex(self, axes: Sequence[int]) -> '_BufferedStateVector': + """Transposes the axes of a state vector to a specified order. + + Args: + axes: The desired axis order. + Returns: + The transposed state vector. + """ + new_tensor = transformations.transpose_state_vector_to_axis_order(self._state_vector, axes) + return _BufferedStateVector( + state_vector=new_tensor, + buffer=np.empty_like(new_tensor), + ) + + def apply_unitary(self, action: Any, axes: Sequence[int]) -> bool: + """Apply unitary to state. + + Args: + action: The value with a unitary to apply. + axes: The axes on which to apply the unitary. + Returns: + True if the operation succeeded. + """ + new_target_tensor = protocols.apply_unitary( + action, + protocols.ApplyUnitaryArgs( + target_tensor=self._state_vector, + available_buffer=self._buffer, + axes=axes, + ), + allow_decompose=False, + default=NotImplemented, + ) + if new_target_tensor is NotImplemented: + return False + self._swap_target_tensor_for(new_target_tensor) + return True + + def apply_mixture(self, action: Any, axes: Sequence[int], prng) -> Optional[int]: + """Apply mixture to state. + + Args: + action: The value with a mixture to apply. + axes: The axes on which to apply the mixture. + prng: The pseudo random number generator to use. + Returns: + The mixture index if the operation succeeded, otherwise None. + """ + mixture = protocols.mixture(action, default=None) + if mixture is None: + return None + probabilities, unitaries = zip(*mixture) + + index = prng.choice(range(len(unitaries)), p=probabilities) + shape = protocols.qid_shape(action) * 2 + unitary = unitaries[index].astype(self._state_vector.dtype).reshape(shape) + linalg.targeted_left_multiply(unitary, self._state_vector, axes, out=self._buffer) + self._swap_target_tensor_for(self._buffer) + return index + + def apply_channel(self, action: Any, axes: Sequence[int], prng) -> Optional[int]: + """Apply channel to state. + + Args: + action: The value with a channel to apply. + axes: The axes on which to apply the channel. + prng: The pseudo random number generator to use. + Returns: + The kraus index if the operation succeeded, otherwise None. + """ + kraus_operators = protocols.kraus(action, default=None) + if kraus_operators is None: + return None + + def prepare_into_buffer(k: int): + linalg.targeted_left_multiply( + left_matrix=kraus_tensors[k], + right_target=self._state_vector, + target_axes=axes, + out=self._buffer, + ) + + shape = protocols.qid_shape(action) + kraus_tensors = [ + e.reshape(shape * 2).astype(self._state_vector.dtype) for e in kraus_operators + ] + p = prng.random() + weight = None + fallback_weight = 0 + fallback_weight_index = 0 + index = None + for index in range(len(kraus_tensors)): + prepare_into_buffer(index) + weight = np.linalg.norm(self._buffer) ** 2 + + if weight > fallback_weight: + fallback_weight_index = index + fallback_weight = weight + + p -= weight + if p < 0: + break + + assert weight is not None, "No Kraus operators" + if p >= 0 or weight == 0: + # Floating point error resulted in a malformed sample. + # Fall back to the most likely case. + prepare_into_buffer(fallback_weight_index) + weight = fallback_weight + index = fallback_weight_index + + self._buffer /= np.sqrt(weight) + self._swap_target_tensor_for(self._buffer) + return index + + def measure( + self, axes: Sequence[int], seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None + ) -> List[int]: + """Measures the state vector. + + Args: + axes: The axes to measure. + seed: The random number seed to use. + Returns: + The measurements in order. + """ + bits, _ = sim.measure_state_vector( + self._state_vector, + axes, + out=self._state_vector, + qid_shape=self._qid_shape, + seed=seed, + ) + return bits + + def sample( + self, + axes: Sequence[int], + repetitions: int = 1, + seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, + ) -> np.ndarray: + """Samples the state vector. + + Args: + axes: The axes to sample. + repetitions: The number of samples to make. + seed: The random number seed to use. + Returns: + The samples in order. + """ + return sim.sample_state_vector( + self._state_vector, + axes, + qid_shape=self._qid_shape, + repetitions=repetitions, + seed=seed, + ) + + def _swap_target_tensor_for(self, new_target_tensor: np.ndarray): + """Gives a new state vector for the system. + + Typically, the new state vector should be `args.available_buffer` where + `args` is this `cirq.ActOnStateVectorArgs` instance. + + Args: + new_target_tensor: The new system state. Must have the same shape + and dtype as the old system state. + """ + if new_target_tensor is self._buffer: + self._buffer = self._state_vector + self._state_vector = new_target_tensor class ActOnStateVectorArgs(ActOnArgs): @@ -92,18 +388,17 @@ def __init__( log_of_measurement_results=log_of_measurement_results, classical_data=classical_data, ) - if target_tensor is None: - qid_shape = protocols.qid_shape(self.qubits) - state = qis.to_valid_state_vector( - initial_state, len(self.qubits), qid_shape=qid_shape, dtype=dtype - ) - target_tensor = np.reshape(state, qid_shape) - self.target_tensor = target_tensor - - if available_buffer is None: - available_buffer = np.empty_like(target_tensor) - self.available_buffer = available_buffer + self._state = _BufferedStateVector.create( + initial_state=target_tensor if target_tensor is not None else initial_state, + qid_shape=tuple(q.dimension for q in qubits) if qubits is not None else None, + dtype=dtype, + buffer=available_buffer, + ) + @_compat.deprecated( + deadline='v0.16', + fix='None, this function was unintentionally made public.', + ) def swap_target_tensor_for(self, new_target_tensor: np.ndarray): """Gives a new state vector for the system. @@ -114,10 +409,12 @@ def swap_target_tensor_for(self, new_target_tensor: np.ndarray): new_target_tensor: The new system state. Must have the same shape and dtype as the old system state. """ - if new_target_tensor is self.available_buffer: - self.available_buffer = self.target_tensor - self.target_tensor = new_target_tensor + self._state._swap_target_tensor_for(new_target_tensor) + @_compat.deprecated( + deadline='v0.16', + fix='None, this function was unintentionally made public.', + ) def subspace_index( self, axes: Sequence[int], little_endian_bits_int: int = 0, *, big_endian_bits_int: int = 0 ) -> Tuple[Union[slice, int, 'ellipsis'], ...]: @@ -177,7 +474,7 @@ def _act_on_fallback_( qubits: Sequence['cirq.Qid'], allow_decompose: bool = True, ) -> bool: - strats = [ + strats: List[Callable[[Any, Any, Sequence['cirq.Qid']], bool]] = [ _strat_act_on_state_vector_from_apply_unitary, _strat_act_on_state_vector_from_mixture, _strat_act_on_state_vector_from_channel, @@ -201,30 +498,15 @@ def _act_on_fallback_( def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]: """Delegates the call to measure the state vector.""" - bits, _ = sim.measure_state_vector( - self.target_tensor, - self.get_axes(qubits), - out=self.target_tensor, - qid_shape=self.target_tensor.shape, - seed=self.prng, - ) - return bits + return self._state.measure(self.get_axes(qubits), self.prng) def _on_copy(self, target: 'cirq.ActOnStateVectorArgs', deep_copy_buffers: bool = True): - target.target_tensor = self.target_tensor.copy() - if deep_copy_buffers: - target.available_buffer = self.available_buffer.copy() - else: - target.available_buffer = self.available_buffer + target._state = self._state.copy(deep_copy_buffers) def _on_kronecker_product( self, other: 'cirq.ActOnStateVectorArgs', target: 'cirq.ActOnStateVectorArgs' ): - target_tensor = transformations.state_vector_kronecker_product( - self.target_tensor, other.target_tensor - ) - target.target_tensor = target_tensor - target.available_buffer = np.empty_like(target_tensor) + target._state = self._state.kron(other._state) def _on_factor( self, @@ -235,13 +517,7 @@ def _on_factor( atol=1e-07, ): axes = self.get_axes(qubits) - extracted_tensor, remainder_tensor = transformations.factor_state_vector( - self.target_tensor, axes, validate=validate, atol=atol - ) - extracted.target_tensor = extracted_tensor - extracted.available_buffer = np.empty_like(extracted_tensor) - remainder.target_tensor = remainder_tensor - remainder.available_buffer = np.empty_like(remainder_tensor) + extracted._state, remainder._state = self._state.factor(axes, validate=validate, atol=atol) @property def allows_factoring(self): @@ -250,10 +526,7 @@ def allows_factoring(self): def _on_transpose_to_qubit_order( self, qubits: Sequence['cirq.Qid'], target: 'cirq.ActOnStateVectorArgs' ): - axes = self.get_axes(qubits) - new_tensor = transformations.transpose_state_vector_to_axis_order(self.target_tensor, axes) - target.target_tensor = new_tensor - target.available_buffer = np.empty_like(new_tensor) + target._state = self._state.reindex(self.get_axes(qubits)) def sample( self, @@ -261,14 +534,7 @@ def sample( repetitions: int = 1, seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, ) -> np.ndarray: - indices = [self.qubit_map[q] for q in qubits] - return sim.sample_state_vector( - self.target_tensor, - indices, - qid_shape=tuple(q.dimension for q in self.qubits), - repetitions=repetitions, - seed=seed, - ) + return self._state.sample(self.get_axes(qubits), repetitions, seed) def __repr__(self) -> str: return ( @@ -279,43 +545,27 @@ def __repr__(self) -> str: f' log_of_measurement_results={proper_repr(self.log_of_measurement_results)})' ) + @property + def target_tensor(self): + return self._state._state_vector + + @property + def available_buffer(self): + return self._state._buffer + def _strat_act_on_state_vector_from_apply_unitary( - unitary_value: Any, - args: 'cirq.ActOnStateVectorArgs', - qubits: Sequence['cirq.Qid'], + action: Any, args: 'cirq.ActOnStateVectorArgs', qubits: Sequence['cirq.Qid'] ) -> bool: - new_target_tensor = protocols.apply_unitary( - unitary_value, - protocols.ApplyUnitaryArgs( - target_tensor=args.target_tensor, - available_buffer=args.available_buffer, - axes=args.get_axes(qubits), - ), - allow_decompose=False, - default=NotImplemented, - ) - if new_target_tensor is NotImplemented: - return NotImplemented - args.swap_target_tensor_for(new_target_tensor) - return True + return True if args._state.apply_unitary(action, args.get_axes(qubits)) else NotImplemented def _strat_act_on_state_vector_from_mixture( action: Any, args: 'cirq.ActOnStateVectorArgs', qubits: Sequence['cirq.Qid'] ) -> bool: - mixture = protocols.mixture(action, default=None) - if mixture is None: + index = args._state.apply_mixture(action, args.get_axes(qubits), args.prng) + if index is None: return NotImplemented - probabilities, unitaries = zip(*mixture) - - index = args.prng.choice(range(len(unitaries)), p=probabilities) - shape = protocols.qid_shape(action) * 2 - unitary = unitaries[index].astype(args.target_tensor.dtype).reshape(shape) - linalg.targeted_left_multiply( - unitary, args.target_tensor, args.get_axes(qubits), out=args.available_buffer - ) - args.swap_target_tensor_for(args.available_buffer) if protocols.is_measurement(action): key = protocols.measurement_key_name(action) args._classical_data.record_channel_measurement(key, index) @@ -325,46 +575,9 @@ def _strat_act_on_state_vector_from_mixture( def _strat_act_on_state_vector_from_channel( action: Any, args: 'cirq.ActOnStateVectorArgs', qubits: Sequence['cirq.Qid'] ) -> bool: - kraus_operators = protocols.kraus(action, default=None) - if kraus_operators is None: + index = args._state.apply_channel(action, args.get_axes(qubits), args.prng) + if index is None: return NotImplemented - - def prepare_into_buffer(k: int): - linalg.targeted_left_multiply( - left_matrix=kraus_tensors[k], - right_target=args.target_tensor, - target_axes=args.get_axes(qubits), - out=args.available_buffer, - ) - - shape = protocols.qid_shape(action) - kraus_tensors = [e.reshape(shape * 2).astype(args.target_tensor.dtype) for e in kraus_operators] - p = args.prng.random() - weight = None - fallback_weight = 0 - fallback_weight_index = 0 - for index in range(len(kraus_tensors)): - prepare_into_buffer(index) - weight = np.linalg.norm(args.available_buffer) ** 2 - - if weight > fallback_weight: - fallback_weight_index = index - fallback_weight = weight - - p -= weight - if p < 0: - break - - assert weight is not None, "No Kraus operators" - if p >= 0 or weight == 0: - # Floating point error resulted in a malformed sample. - # Fall back to the most likely case. - prepare_into_buffer(fallback_weight_index) - weight = fallback_weight - index = fallback_weight_index - - args.available_buffer /= np.sqrt(weight) - args.swap_target_tensor_for(args.available_buffer) if protocols.is_measurement(action): key = protocols.measurement_key_name(action) args._classical_data.record_channel_measurement(key, index) diff --git a/cirq-core/cirq/sim/act_on_state_vector_args_test.py b/cirq-core/cirq/sim/act_on_state_vector_args_test.py index 9efa9181ec7..46078db300e 100644 --- a/cirq-core/cirq/sim/act_on_state_vector_args_test.py +++ b/cirq-core/cirq/sim/act_on_state_vector_args_test.py @@ -75,7 +75,10 @@ def test_infer_target_tensor(): def test_shallow_copy_buffers(): - args = cirq.ActOnStateVectorArgs() + args = cirq.ActOnStateVectorArgs( + qubits=cirq.LineQubit.range(1), + initial_state=0, + ) copy = args.copy(deep_copy_buffers=False) assert copy.available_buffer is args.available_buffer @@ -328,3 +331,16 @@ def test_with_qubits(): np.array([[1.0 + 0.0j, 0.0 + 0.0j], [0.0 + 0.0j, 0.0 + 0.0j]], dtype=np.complex64), ), ) + + +def test_qid_shape_error(): + with pytest.raises(ValueError, match="qid_shape must be provided"): + cirq.sim.act_on_state_vector_args._BufferedStateVector.create(initial_state=0) + + +def test_deprecated_methods(): + args = cirq.ActOnStateVectorArgs(qubits=[cirq.LineQubit(0)]) + with cirq.testing.assert_deprecated('unintentionally made public', deadline='v0.16'): + args.subspace_index([0], 0) + with cirq.testing.assert_deprecated('unintentionally made public', deadline='v0.16'): + args.swap_target_tensor_for(np.array([])) diff --git a/cirq-core/cirq/sim/state_vector.py b/cirq-core/cirq/sim/state_vector.py index d448caca883..7364cbcd352 100644 --- a/cirq-core/cirq/sim/state_vector.py +++ b/cirq-core/cirq/sim/state_vector.py @@ -163,7 +163,7 @@ def bloch_vector_of(self, qubit: 'cirq.Qid') -> np.ndarray: def sample_state_vector( state_vector: np.ndarray, - indices: List[int], + indices: Sequence[int], *, # Force keyword args qid_shape: Optional[Tuple[int, ...]] = None, repetitions: int = 1,