From 04814e1b9c32ccb09aa437496ab34488fc701259 Mon Sep 17 00:00:00 2001 From: Noureldin Date: Tue, 9 May 2023 20:00:58 +0100 Subject: [PATCH 1/5] Add support for > 32 qudits to cirq.sample_state_vector. Fix for #6031 --- cirq-core/cirq/linalg/__init__.py | 2 + cirq-core/cirq/linalg/transformations.py | 64 +++++++++++++++++++ cirq-core/cirq/linalg/transformations_test.py | 16 +++++ cirq-core/cirq/sim/state_vector.py | 41 ++++++------ cirq-core/cirq/sim/state_vector_test.py | 61 +++++++++++++----- 5 files changed, 146 insertions(+), 38 deletions(-) diff --git a/cirq-core/cirq/linalg/__init__.py b/cirq-core/cirq/linalg/__init__.py index 62d21593551..e0181859837 100644 --- a/cirq-core/cirq/linalg/__init__.py +++ b/cirq-core/cirq/linalg/__init__.py @@ -81,4 +81,6 @@ targeted_conjugate_about, targeted_left_multiply, to_special, + transpose_flattened_array, + can_numpy_support_shape, ) diff --git a/cirq-core/cirq/linalg/transformations.py b/cirq-core/cirq/linalg/transformations.py index 6a04ce1dda0..ba02322c62f 100644 --- a/cirq-core/cirq/linalg/transformations.py +++ b/cirq-core/cirq/linalg/transformations.py @@ -29,6 +29,8 @@ # user provides a different np.array([]) value. RaiseValueErrorIfNotProvided: np.ndarray = np.array([]) +_NPY_MAXDIMS = 32 # Should be changed once numpy/numpy#5744 is resolved. + def reflection_matrix_pow(reflection_matrix: np.ndarray, exponent: float): """Raises a matrix with two opposing eigenvalues to a power. @@ -746,3 +748,65 @@ def transpose_density_matrix_to_axis_order(t: np.ndarray, axes: Sequence[int]): """ axes = list(axes) + [i + len(axes) for i in axes] return transpose_state_vector_to_axis_order(t, axes) + + +def _volumes(shape: Sequence[int]) -> List[int]: + r"""Returns a list of the volume spanned by each dimension. + + Given a shape=[d_0, d_1, .., d_n] the volume spanned by each dimension is + volume[i] = `\prod_{j=i+1}^n d_j` + + Args: + shape: Sequence of the size of each dimension. + + Returns: + Sequence of the volume spanned of each dimension. + """ + volume = [0] * len(shape) + v = 1 + for i in reversed(range(len(shape))): + volume[i] = v + v *= shape[i] + return volume + + +def _coordinates_from_index(idx: int, volume: Sequence[int]) -> Sequence[int]: + ret = [] + for v in volume: + ret.append(idx // v) + idx %= v + return tuple(ret) + + +def _index_from_coordinates(s: Sequence[int], volume: Sequence[int]) -> int: + return np.dot(s, volume) + + +def transpose_flattened_array(t: np.ndarray, shape: Sequence[int], axes: Sequence[int]): + """Transposes a flattened array. + + Equivalent to np.transpose(t.reshape(shape), axes).reshape((-1,)). + + Args: + t: flat array. + shape: the shape of `t` before flattening. + axes: permutation of range(len(shape)). + + Returns: + Flattened transpose of `t`. + """ + if len(t.shape) != 1: + t = t.reshape((-1,)) + cur_volume = _volumes(shape) + new_volume = _volumes([shape[i] for i in axes]) + ret = np.zeros_like(t) + for idx in range(t.shape[0]): + cell = _coordinates_from_index(idx, cur_volume) + new_cell = [cell[i] for i in axes] + ret[_index_from_coordinates(new_cell, new_volume)] = t[idx] + return ret + + +def can_numpy_support_shape(shape: Sequence[int]) -> bool: + """Returns whether numpy supports the given shape or not numpy/numpy#5744.""" + return len(shape) <= _NPY_MAXDIMS diff --git a/cirq-core/cirq/linalg/transformations_test.py b/cirq-core/cirq/linalg/transformations_test.py index b92ff5faa13..8950db78bc6 100644 --- a/cirq-core/cirq/linalg/transformations_test.py +++ b/cirq-core/cirq/linalg/transformations_test.py @@ -17,6 +17,7 @@ import cirq import cirq.testing +from cirq import linalg def test_reflection_matrix_pow_consistent_results(): @@ -632,3 +633,18 @@ def test_factor_state_vector(state_1: int, state_2: int): # All phase goes into a1, and b1 is just the dephased state vector assert np.allclose(a1, a * phase) assert np.allclose(b1, b) + + +@pytest.mark.parametrize('num_dimensions', [*range(1, 7)]) +def test_transpose_flattened_array(num_dimensions): + np.random.seed(0) + for _ in range(10): + shape = np.random.randint(1, 5, (num_dimensions,)).tolist() + axes = np.random.permutation(num_dimensions).tolist() + volume = np.prod(shape) + A = np.random.permutation(volume) + want = np.transpose(A.reshape(shape), axes) + got = linalg.transpose_flattened_array(A, shape, axes).reshape(want.shape) + assert np.array_equal(want, got) + got = linalg.transpose_flattened_array(A.reshape(shape), shape, axes).reshape(want.shape) + assert np.array_equal(want, got) diff --git a/cirq-core/cirq/sim/state_vector.py b/cirq-core/cirq/sim/state_vector.py index 2beea5539f5..c204192a9c0 100644 --- a/cirq-core/cirq/sim/state_vector.py +++ b/cirq-core/cirq/sim/state_vector.py @@ -20,6 +20,7 @@ from cirq import linalg, qis, value from cirq.sim import simulator +from cirq import linalg if TYPE_CHECKING: import cirq @@ -325,30 +326,24 @@ def measure_state_vector( def _probs(state: np.ndarray, indices: Sequence[int], qid_shape: Tuple[int, ...]) -> np.ndarray: """Returns the probabilities for a measurement on the given indices.""" - tensor = np.reshape(state, qid_shape) - # Calculate the probabilities for measuring the particular results. - if len(indices) == len(qid_shape): - # We're measuring every qudit, so no need for fancy indexing - probs = np.abs(tensor) ** 2 - probs = np.transpose(probs, indices) - probs = probs.reshape(-1) + state = state.reshape((-1,)) + probs = np.abs(state) ** 2 + not_measured = [i for i in range(len(qid_shape)) if i not in indices] + if linalg.can_numpy_support_shape(qid_shape): + # Use numpy transpose if we can since it's more efficient. + probs = probs.reshape(qid_shape) + probs = np.transpose(probs, list(indices) + not_measured) + probs = probs.reshape((-1,)) else: - # Fancy indexing required - meas_shape = tuple(qid_shape[i] for i in indices) - probs = ( - np.abs( - [ - tensor[ - linalg.slice_for_qubits_equal_to( - indices, big_endian_qureg_value=b, qid_shape=qid_shape - ) - ] - for b in range(np.prod(meas_shape, dtype=np.int64)) - ] - ) - ** 2 - ) - probs = np.sum(probs, axis=tuple(range(1, len(probs.shape)))) + # If we can't use numpy due to numpy/numpy#5744, use a slower method. + probs = linalg.transpose_flattened_array(probs, qid_shape, list(indices) + not_measured) + + if len(not_measured): + # Not all qudits are measured. + volume = np.prod([qid_shape[i] for i in indices]) + # Reshape into a 2D array in which each of the measured states correspond to a row. + probs = probs.reshape((volume, -1)) + probs = np.sum(probs, axis=-1) # To deal with rounding issues, ensure that the probabilities sum to 1. return probs / np.sum(probs) diff --git a/cirq-core/cirq/sim/state_vector_test.py b/cirq-core/cirq/sim/state_vector_test.py index 4f5f2ea342c..206c2220fb1 100644 --- a/cirq-core/cirq/sim/state_vector_test.py +++ b/cirq-core/cirq/sim/state_vector_test.py @@ -21,6 +21,7 @@ import cirq import cirq.testing +from cirq import linalg def test_state_mixin(): @@ -172,7 +173,9 @@ def test_sample_no_indices_repetitions(): ) -def test_measure_state_computational_basis(): +@pytest.mark.parametrize('use_np_transpose', [False, True]) +def test_measure_state_computational_basis(use_np_transpose: bool): + linalg.can_numpy_support_shape = lambda s: use_np_transpose results = [] for x in range(8): initial_state = cirq.to_valid_state_vector(x, 3) @@ -183,7 +186,9 @@ def test_measure_state_computational_basis(): assert results == expected -def test_measure_state_reshape(): +@pytest.mark.parametrize('use_np_transpose', [False, True]) +def test_measure_state_reshape(use_np_transpose: bool): + linalg.can_numpy_support_shape = lambda s: use_np_transpose results = [] for x in range(8): initial_state = np.reshape(cirq.to_valid_state_vector(x, 3), [2] * 3) @@ -194,7 +199,9 @@ def test_measure_state_reshape(): assert results == expected -def test_measure_state_partial_indices(): +@pytest.mark.parametrize('use_np_transpose', [False, True]) +def test_measure_state_partial_indices(use_np_transpose: bool): + linalg.can_numpy_support_shape = lambda s: use_np_transpose for index in range(3): for x in range(8): initial_state = cirq.to_valid_state_vector(x, 3) @@ -203,7 +210,9 @@ def test_measure_state_partial_indices(): assert bits == [bool(1 & (x >> (2 - index)))] -def test_measure_state_partial_indices_order(): +@pytest.mark.parametrize('use_np_transpose', [False, True]) +def test_measure_state_partial_indices_order(use_np_transpose: bool): + linalg.can_numpy_support_shape = lambda s: use_np_transpose for x in range(8): initial_state = cirq.to_valid_state_vector(x, 3) bits, state = cirq.measure_state_vector(initial_state, [2, 1]) @@ -211,7 +220,9 @@ def test_measure_state_partial_indices_order(): assert bits == [bool(1 & (x >> 0)), bool(1 & (x >> 1))] -def test_measure_state_partial_indices_all_orders(): +@pytest.mark.parametrize('use_np_transpose', [False, True]) +def test_measure_state_partial_indices_all_orders(use_np_transpose: bool): + linalg.can_numpy_support_shape = lambda s: use_np_transpose for perm in itertools.permutations([0, 1, 2]): for x in range(8): initial_state = cirq.to_valid_state_vector(x, 3) @@ -220,7 +231,9 @@ def test_measure_state_partial_indices_all_orders(): assert bits == [bool(1 & (x >> (2 - p))) for p in perm] -def test_measure_state_collapse(): +@pytest.mark.parametrize('use_np_transpose', [False, True]) +def test_measure_state_collapse(use_np_transpose: bool): + linalg.can_numpy_support_shape = lambda s: use_np_transpose initial_state = np.zeros(8, dtype=np.complex64) initial_state[0] = 1 / np.sqrt(2) initial_state[2] = 1 / np.sqrt(2) @@ -243,7 +256,9 @@ def test_measure_state_collapse(): assert bits == [False] -def test_measure_state_seed(): +@pytest.mark.parametrize('use_np_transpose', [False, True]) +def test_measure_state_seed(use_np_transpose: bool): + linalg.can_numpy_support_shape = lambda s: use_np_transpose n = 10 initial_state = np.ones(2**n) / 2 ** (n / 2) @@ -262,7 +277,9 @@ def test_measure_state_seed(): np.testing.assert_allclose(state1, state2) -def test_measure_state_out_is_state(): +@pytest.mark.parametrize('use_np_transpose', [False, True]) +def test_measure_state_out_is_state(use_np_transpose: bool): + linalg.can_numpy_support_shape = lambda s: use_np_transpose initial_state = np.zeros(8, dtype=np.complex64) initial_state[0] = 1 / np.sqrt(2) initial_state[2] = 1 / np.sqrt(2) @@ -273,7 +290,9 @@ def test_measure_state_out_is_state(): assert state is initial_state -def test_measure_state_out_is_not_state(): +@pytest.mark.parametrize('use_np_transpose', [False, True]) +def test_measure_state_out_is_not_state(use_np_transpose: bool): + linalg.can_numpy_support_shape = lambda s: use_np_transpose initial_state = np.zeros(8, dtype=np.complex64) initial_state[0] = 1 / np.sqrt(2) initial_state[2] = 1 / np.sqrt(2) @@ -283,14 +302,18 @@ def test_measure_state_out_is_not_state(): assert out is state -def test_measure_state_not_power_of_two(): +@pytest.mark.parametrize('use_np_transpose', [False, True]) +def test_measure_state_not_power_of_two(use_np_transpose: bool): + linalg.can_numpy_support_shape = lambda s: use_np_transpose with pytest.raises(ValueError, match='3'): _, _ = cirq.measure_state_vector(np.array([1, 0, 0]), [1]) with pytest.raises(ValueError, match='5'): cirq.measure_state_vector(np.array([0, 1, 0, 0, 0]), [1]) -def test_measure_state_index_out_of_range(): +@pytest.mark.parametrize('use_np_transpose', [False, True]) +def test_measure_state_index_out_of_range(use_np_transpose: bool): + linalg.can_numpy_support_shape = lambda s: use_np_transpose state = cirq.to_valid_state_vector(0, 3) with pytest.raises(IndexError, match='-2'): cirq.measure_state_vector(state, [-2]) @@ -298,14 +321,18 @@ def test_measure_state_index_out_of_range(): cirq.measure_state_vector(state, [3]) -def test_measure_state_no_indices(): +@pytest.mark.parametrize('use_np_transpose', [False, True]) +def test_measure_state_no_indices(use_np_transpose: bool): + linalg.can_numpy_support_shape = lambda s: use_np_transpose initial_state = cirq.to_valid_state_vector(0, 3) bits, state = cirq.measure_state_vector(initial_state, []) assert [] == bits np.testing.assert_almost_equal(state, initial_state) -def test_measure_state_no_indices_out_is_state(): +@pytest.mark.parametrize('use_np_transpose', [False, True]) +def test_measure_state_no_indices_out_is_state(use_np_transpose: bool): + linalg.can_numpy_support_shape = lambda s: use_np_transpose initial_state = cirq.to_valid_state_vector(0, 3) bits, state = cirq.measure_state_vector(initial_state, [], out=initial_state) assert [] == bits @@ -313,7 +340,9 @@ def test_measure_state_no_indices_out_is_state(): assert state is initial_state -def test_measure_state_no_indices_out_is_not_state(): +@pytest.mark.parametrize('use_np_transpose', [False, True]) +def test_measure_state_no_indices_out_is_not_state(use_np_transpose: bool): + linalg.can_numpy_support_shape = lambda s: use_np_transpose initial_state = cirq.to_valid_state_vector(0, 3) out = np.zeros_like(initial_state) bits, state = cirq.measure_state_vector(initial_state, [], out=out) @@ -323,7 +352,9 @@ def test_measure_state_no_indices_out_is_not_state(): assert out is not initial_state -def test_measure_state_empty_state(): +@pytest.mark.parametrize('use_np_transpose', [False, True]) +def test_measure_state_empty_state(use_np_transpose: bool): + linalg.can_numpy_support_shape = lambda s: use_np_transpose initial_state = np.array([1.0]) bits, state = cirq.measure_state_vector(initial_state, []) assert [] == bits From 655adee28ab5c3abc0cc7ae8702f9ddc625d39ee Mon Sep 17 00:00:00 2001 From: Noureldin Date: Fri, 12 May 2023 14:15:23 +0100 Subject: [PATCH 2/5] refactor the method into a seprate util module --- cirq-core/cirq/sim/density_matrix_utils.py | 28 +------------- cirq-core/cirq/sim/simulation_util.py | 45 ++++++++++++++++++++++ cirq-core/cirq/sim/state_vector.py | 32 ++------------- 3 files changed, 50 insertions(+), 55 deletions(-) create mode 100644 cirq-core/cirq/sim/simulation_util.py diff --git a/cirq-core/cirq/sim/density_matrix_utils.py b/cirq-core/cirq/sim/density_matrix_utils.py index a6235c70dba..f405b0fe55b 100644 --- a/cirq-core/cirq/sim/density_matrix_utils.py +++ b/cirq-core/cirq/sim/density_matrix_utils.py @@ -18,6 +18,7 @@ import numpy as np from cirq import linalg, value +from cirq.sim import simulation_util if TYPE_CHECKING: import cirq @@ -188,33 +189,8 @@ def _probs( """Returns the probabilities for a measurement on the given indices.""" # Only diagonal elements matter. all_probs = np.diagonal(np.reshape(density_matrix, (np.prod(qid_shape, dtype=np.int64),) * 2)) - # Shape into a tensor - tensor = np.reshape(all_probs, qid_shape) - - # Calculate the probabilities for measuring the particular results. - if len(indices) == len(qid_shape): - # We're measuring every qudit, so no need for fancy indexing - probs = np.abs(tensor) - probs = np.transpose(probs, indices) - probs = probs.reshape(-1) - else: - # Fancy indexing required - meas_shape = tuple(qid_shape[i] for i in indices) - probs = np.abs( - [ - tensor[ - linalg.slice_for_qubits_equal_to( - indices, big_endian_qureg_value=b, qid_shape=qid_shape - ) - ] - for b in range(np.prod(meas_shape, dtype=np.int64)) - ] - ) - probs = np.sum(probs, axis=tuple(range(1, len(probs.shape)))) - # To deal with rounding issues, ensure that the probabilities sum to 1. - probs /= np.sum(probs) - return probs + return simulation_util.state_probabilities(all_probs, indices, qid_shape) def _validate_density_matrix_qid_shape( diff --git a/cirq-core/cirq/sim/simulation_util.py b/cirq-core/cirq/sim/simulation_util.py new file mode 100644 index 00000000000..1d84e7b588d --- /dev/null +++ b/cirq-core/cirq/sim/simulation_util.py @@ -0,0 +1,45 @@ +from typing import Sequence, Tuple + +import numpy as np + +from cirq import linalg + + +def state_probabilities(state_vector: np.ndarray, indices: Sequence[int], qid_shape: Tuple[int, ...]) -> np.ndarray: + """Returns the probabilities for a state/measurement on the given indices. + + Args: + state_vector: The multi-qubit state vector to be sampled. This is an + array of 2 to the power of the number of qubit complex numbers, and + so state must be of size ``2**integer``. The `state_vector` can be + a vector of size ``2**integer`` or a tensor of shape + ``(2, 2, ..., 2)``. + indices: Which qubits are measured. The `state_vector` is assumed to be + supplied in big endian order. That is the xth index of v, when + expressed as a bitstring, has its largest values in the 0th index. + qid_shape: The qid shape of the `state_vector`. + + Returns: + State probabilities. + """ + state = state_vector.reshape((-1,)) + probs = (state * state.conj()).real + not_measured = [i for i in range(len(qid_shape)) if i not in indices] + if linalg.can_numpy_support_shape(qid_shape): + # Use numpy transpose if we can since it's more efficient. + probs = probs.reshape(qid_shape) + probs = np.transpose(probs, list(indices) + not_measured) + probs = probs.reshape((-1,)) + else: + # If we can't use numpy due to numpy/numpy#5744, use a slower method. + probs = linalg.transpose_flattened_array(probs, qid_shape, list(indices) + not_measured) + + if len(not_measured): + # Not all qudits are measured. + volume = np.prod([qid_shape[i] for i in indices]) + # Reshape into a 2D array in which each of the measured states correspond to a row. + probs = probs.reshape((volume, -1)) + probs = np.sum(probs, axis=-1) + + # To deal with rounding issues, ensure that the probabilities sum to 1. + return probs / np.sum(probs) \ No newline at end of file diff --git a/cirq-core/cirq/sim/state_vector.py b/cirq-core/cirq/sim/state_vector.py index c204192a9c0..beb3bbc88ab 100644 --- a/cirq-core/cirq/sim/state_vector.py +++ b/cirq-core/cirq/sim/state_vector.py @@ -19,8 +19,7 @@ import numpy as np from cirq import linalg, qis, value -from cirq.sim import simulator -from cirq import linalg +from cirq.sim import simulator, simulation_util if TYPE_CHECKING: import cirq @@ -216,7 +215,7 @@ def sample_state_vector( prng = value.parse_random_state(seed) # Calculate the measurement probabilities. - probs = _probs(state_vector, indices, shape) + probs = simulation_util.state_probabilities(state_vector, indices, shape) # We now have the probability vector, correctly ordered, so sample over # it. Note that we us ints here, since numpy's choice does not allow for @@ -289,7 +288,7 @@ def measure_state_vector( initial_shape = state_vector.shape # Calculate the measurement probabilities and then make the measurement. - probs = _probs(state_vector, indices, shape) + probs = simulation_util.state_probabilities(state_vector, indices, shape) result = prng.choice(len(probs), p=probs) ###measurement_bits = [(1 & (result >> i)) for i in range(len(indices))] # Convert to individual qudit measurements. @@ -322,28 +321,3 @@ def measure_state_vector( assert out is not None # We mutate and return out, so mypy cannot identify that the out cannot be None. return measurement_bits, out - - -def _probs(state: np.ndarray, indices: Sequence[int], qid_shape: Tuple[int, ...]) -> np.ndarray: - """Returns the probabilities for a measurement on the given indices.""" - state = state.reshape((-1,)) - probs = np.abs(state) ** 2 - not_measured = [i for i in range(len(qid_shape)) if i not in indices] - if linalg.can_numpy_support_shape(qid_shape): - # Use numpy transpose if we can since it's more efficient. - probs = probs.reshape(qid_shape) - probs = np.transpose(probs, list(indices) + not_measured) - probs = probs.reshape((-1,)) - else: - # If we can't use numpy due to numpy/numpy#5744, use a slower method. - probs = linalg.transpose_flattened_array(probs, qid_shape, list(indices) + not_measured) - - if len(not_measured): - # Not all qudits are measured. - volume = np.prod([qid_shape[i] for i in indices]) - # Reshape into a 2D array in which each of the measured states correspond to a row. - probs = probs.reshape((volume, -1)) - probs = np.sum(probs, axis=-1) - - # To deal with rounding issues, ensure that the probabilities sum to 1. - return probs / np.sum(probs) From 2fff50581296ef2f6349224bfc48c2ea6b156992 Mon Sep 17 00:00:00 2001 From: Noureldin Date: Fri, 12 May 2023 14:34:50 +0100 Subject: [PATCH 3/5] fix lint --- cirq-core/cirq/sim/simulation_util.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/cirq-core/cirq/sim/simulation_util.py b/cirq-core/cirq/sim/simulation_util.py index 1d84e7b588d..f7fa0dccb9c 100644 --- a/cirq-core/cirq/sim/simulation_util.py +++ b/cirq-core/cirq/sim/simulation_util.py @@ -1,3 +1,16 @@ +# Copyright 2023 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Sequence, Tuple import numpy as np @@ -5,7 +18,9 @@ from cirq import linalg -def state_probabilities(state_vector: np.ndarray, indices: Sequence[int], qid_shape: Tuple[int, ...]) -> np.ndarray: +def state_probabilities( + state_vector: np.ndarray, indices: Sequence[int], qid_shape: Tuple[int, ...] +) -> np.ndarray: """Returns the probabilities for a state/measurement on the given indices. Args: @@ -42,4 +57,4 @@ def state_probabilities(state_vector: np.ndarray, indices: Sequence[int], qid_sh probs = np.sum(probs, axis=-1) # To deal with rounding issues, ensure that the probabilities sum to 1. - return probs / np.sum(probs) \ No newline at end of file + return probs / np.sum(probs) From c431e4f798b351cee8e2d181ac001989c8445995 Mon Sep 17 00:00:00 2001 From: Noureldin Date: Fri, 12 May 2023 15:14:32 +0100 Subject: [PATCH 4/5] accept only probabilities not complex numbers --- cirq-core/cirq/sim/density_matrix_utils.py | 2 +- cirq-core/cirq/sim/simulation_util.py | 15 +++++++-------- cirq-core/cirq/sim/state_vector.py | 6 ++++-- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/cirq-core/cirq/sim/density_matrix_utils.py b/cirq-core/cirq/sim/density_matrix_utils.py index f405b0fe55b..8c56c6a5dbc 100644 --- a/cirq-core/cirq/sim/density_matrix_utils.py +++ b/cirq-core/cirq/sim/density_matrix_utils.py @@ -190,7 +190,7 @@ def _probs( # Only diagonal elements matter. all_probs = np.diagonal(np.reshape(density_matrix, (np.prod(qid_shape, dtype=np.int64),) * 2)) - return simulation_util.state_probabilities(all_probs, indices, qid_shape) + return simulation_util.state_probabilities(all_probs.real, indices, qid_shape) def _validate_density_matrix_qid_shape( diff --git a/cirq-core/cirq/sim/simulation_util.py b/cirq-core/cirq/sim/simulation_util.py index f7fa0dccb9c..681193aa66d 100644 --- a/cirq-core/cirq/sim/simulation_util.py +++ b/cirq-core/cirq/sim/simulation_util.py @@ -19,26 +19,25 @@ def state_probabilities( - state_vector: np.ndarray, indices: Sequence[int], qid_shape: Tuple[int, ...] + state_probability: np.ndarray, indices: Sequence[int], qid_shape: Tuple[int, ...] ) -> np.ndarray: """Returns the probabilities for a state/measurement on the given indices. Args: - state_vector: The multi-qubit state vector to be sampled. This is an - array of 2 to the power of the number of qubit complex numbers, and - so state must be of size ``2**integer``. The `state_vector` can be + state_probability: The multi-qubit state probability vector. This is an + array of 2 to the power of the number of real numbers, and + so state must be of size ``2**integer``. The `state_probability` can be a vector of size ``2**integer`` or a tensor of shape ``(2, 2, ..., 2)``. - indices: Which qubits are measured. The `state_vector` is assumed to be + indices: Which qubits are measured. The `state_probability` is assumed to be supplied in big endian order. That is the xth index of v, when expressed as a bitstring, has its largest values in the 0th index. - qid_shape: The qid shape of the `state_vector`. + qid_shape: The qid shape of the `state_probability`. Returns: State probabilities. """ - state = state_vector.reshape((-1,)) - probs = (state * state.conj()).real + probs = state_probability.reshape((-1,)) not_measured = [i for i in range(len(qid_shape)) if i not in indices] if linalg.can_numpy_support_shape(qid_shape): # Use numpy transpose if we can since it's more efficient. diff --git a/cirq-core/cirq/sim/state_vector.py b/cirq-core/cirq/sim/state_vector.py index beb3bbc88ab..056d9e10480 100644 --- a/cirq-core/cirq/sim/state_vector.py +++ b/cirq-core/cirq/sim/state_vector.py @@ -215,7 +215,8 @@ def sample_state_vector( prng = value.parse_random_state(seed) # Calculate the measurement probabilities. - probs = simulation_util.state_probabilities(state_vector, indices, shape) + probs = (state_vector * state_vector.conj()).real + probs = simulation_util.state_probabilities(probs, indices, shape) # We now have the probability vector, correctly ordered, so sample over # it. Note that we us ints here, since numpy's choice does not allow for @@ -288,7 +289,8 @@ def measure_state_vector( initial_shape = state_vector.shape # Calculate the measurement probabilities and then make the measurement. - probs = simulation_util.state_probabilities(state_vector, indices, shape) + probs = (state_vector * state_vector.conj()).real + probs = simulation_util.state_probabilities(probs, indices, shape) result = prng.choice(len(probs), p=probs) ###measurement_bits = [(1 & (result >> i)) for i in range(len(indices))] # Convert to individual qudit measurements. From 5905c69fee70dad64731ddb4781ece167103a9ca Mon Sep 17 00:00:00 2001 From: Noureldin Date: Thu, 25 May 2023 15:36:15 +0100 Subject: [PATCH 5/5] added tests --- cirq-core/cirq/sim/density_matrix_utils.py | 4 +-- ...simulation_util.py => simulation_utils.py} | 2 +- cirq-core/cirq/sim/simulation_utils_test.py | 32 +++++++++++++++++++ cirq-core/cirq/sim/state_vector.py | 6 ++-- 4 files changed, 38 insertions(+), 6 deletions(-) rename cirq-core/cirq/sim/{simulation_util.py => simulation_utils.py} (98%) create mode 100644 cirq-core/cirq/sim/simulation_utils_test.py diff --git a/cirq-core/cirq/sim/density_matrix_utils.py b/cirq-core/cirq/sim/density_matrix_utils.py index 8c56c6a5dbc..bb3d87195e7 100644 --- a/cirq-core/cirq/sim/density_matrix_utils.py +++ b/cirq-core/cirq/sim/density_matrix_utils.py @@ -18,7 +18,7 @@ import numpy as np from cirq import linalg, value -from cirq.sim import simulation_util +from cirq.sim import simulation_utils if TYPE_CHECKING: import cirq @@ -190,7 +190,7 @@ def _probs( # Only diagonal elements matter. all_probs = np.diagonal(np.reshape(density_matrix, (np.prod(qid_shape, dtype=np.int64),) * 2)) - return simulation_util.state_probabilities(all_probs.real, indices, qid_shape) + return simulation_utils.state_probabilities_by_indices(all_probs.real, indices, qid_shape) def _validate_density_matrix_qid_shape( diff --git a/cirq-core/cirq/sim/simulation_util.py b/cirq-core/cirq/sim/simulation_utils.py similarity index 98% rename from cirq-core/cirq/sim/simulation_util.py rename to cirq-core/cirq/sim/simulation_utils.py index 681193aa66d..00934962d2e 100644 --- a/cirq-core/cirq/sim/simulation_util.py +++ b/cirq-core/cirq/sim/simulation_utils.py @@ -18,7 +18,7 @@ from cirq import linalg -def state_probabilities( +def state_probabilities_by_indices( state_probability: np.ndarray, indices: Sequence[int], qid_shape: Tuple[int, ...] ) -> np.ndarray: """Returns the probabilities for a state/measurement on the given indices. diff --git a/cirq-core/cirq/sim/simulation_utils_test.py b/cirq-core/cirq/sim/simulation_utils_test.py new file mode 100644 index 00000000000..2e8736029ca --- /dev/null +++ b/cirq-core/cirq/sim/simulation_utils_test.py @@ -0,0 +1,32 @@ +# Copyright 2023 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +import numpy as np + +from cirq.sim import simulation_utils +from cirq import testing + + +@pytest.mark.parametrize('n,m', [(n, m) for n in range(1, 4) for m in range(1, n + 1)]) +def test_state_probabilities_by_indices(n: int, m: int): + np.random.seed(0) + state = testing.random_superposition(1 << n) + d = (state.conj() * state).real + desired_axes = list(np.random.choice(n, m, replace=False)) + not_wanted = [i for i in range(n) if i not in desired_axes] + got = simulation_utils.state_probabilities_by_indices(d, desired_axes, (2,) * n) + want = np.transpose(d.reshape((2,) * n), desired_axes + not_wanted) + want = np.sum(want.reshape((1 << len(desired_axes), -1)), axis=-1) + np.testing.assert_allclose(want, got) diff --git a/cirq-core/cirq/sim/state_vector.py b/cirq-core/cirq/sim/state_vector.py index 056d9e10480..7250e6cb1f6 100644 --- a/cirq-core/cirq/sim/state_vector.py +++ b/cirq-core/cirq/sim/state_vector.py @@ -19,7 +19,7 @@ import numpy as np from cirq import linalg, qis, value -from cirq.sim import simulator, simulation_util +from cirq.sim import simulator, simulation_utils if TYPE_CHECKING: import cirq @@ -216,7 +216,7 @@ def sample_state_vector( # Calculate the measurement probabilities. probs = (state_vector * state_vector.conj()).real - probs = simulation_util.state_probabilities(probs, indices, shape) + probs = simulation_utils.state_probabilities_by_indices(probs, indices, shape) # We now have the probability vector, correctly ordered, so sample over # it. Note that we us ints here, since numpy's choice does not allow for @@ -290,7 +290,7 @@ def measure_state_vector( # Calculate the measurement probabilities and then make the measurement. probs = (state_vector * state_vector.conj()).real - probs = simulation_util.state_probabilities(probs, indices, shape) + probs = simulation_utils.state_probabilities_by_indices(probs, indices, shape) result = prng.choice(len(probs), p=probs) ###measurement_bits = [(1 & (result >> i)) for i in range(len(indices))] # Convert to individual qudit measurements.