Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for > 32 qudits to cirq.sample_state_vector. Fix for #6031 #6090

Merged
merged 11 commits into from
Jun 28, 2023
2 changes: 2 additions & 0 deletions cirq-core/cirq/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,6 @@
targeted_conjugate_about,
targeted_left_multiply,
to_special,
transpose_flattened_array,
can_numpy_support_shape,
)
64 changes: 64 additions & 0 deletions cirq-core/cirq/linalg/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
# user provides a different np.array([]) value.
RaiseValueErrorIfNotProvided: np.ndarray = np.array([])

_NPY_MAXDIMS = 32 # Should be changed once numpy/numpy#5744 is resolved.


def reflection_matrix_pow(reflection_matrix: np.ndarray, exponent: float):
"""Raises a matrix with two opposing eigenvalues to a power.
Expand Down Expand Up @@ -746,3 +748,65 @@ def transpose_density_matrix_to_axis_order(t: np.ndarray, axes: Sequence[int]):
"""
axes = list(axes) + [i + len(axes) for i in axes]
return transpose_state_vector_to_axis_order(t, axes)


def _volumes(shape: Sequence[int]) -> List[int]:
r"""Returns a list of the volume spanned by each dimension.

Given a shape=[d_0, d_1, .., d_n] the volume spanned by each dimension is
volume[i] = `\prod_{j=i+1}^n d_j`

Args:
shape: Sequence of the size of each dimension.

Returns:
Sequence of the volume spanned of each dimension.
"""
volume = [0] * len(shape)
v = 1
for i in reversed(range(len(shape))):
volume[i] = v
v *= shape[i]
return volume


def _coordinates_from_index(idx: int, volume: Sequence[int]) -> Sequence[int]:
ret = []
for v in volume:
ret.append(idx // v)
idx %= v
return tuple(ret)


def _index_from_coordinates(s: Sequence[int], volume: Sequence[int]) -> int:
return np.dot(s, volume)


def transpose_flattened_array(t: np.ndarray, shape: Sequence[int], axes: Sequence[int]):
"""Transposes a flattened array.

Equivalent to np.transpose(t.reshape(shape), axes).reshape((-1,)).

Args:
t: flat array.
shape: the shape of `t` before flattening.
axes: permutation of range(len(shape)).

Returns:
Flattened transpose of `t`.
"""
if len(t.shape) != 1:
t = t.reshape((-1,))
cur_volume = _volumes(shape)
new_volume = _volumes([shape[i] for i in axes])
ret = np.zeros_like(t)
for idx in range(t.shape[0]):
cell = _coordinates_from_index(idx, cur_volume)
new_cell = [cell[i] for i in axes]
ret[_index_from_coordinates(new_cell, new_volume)] = t[idx]
return ret


def can_numpy_support_shape(shape: Sequence[int]) -> bool:
"""Returns whether numpy supports the given shape or not numpy/numpy#5744."""
return len(shape) <= _NPY_MAXDIMS
16 changes: 16 additions & 0 deletions cirq-core/cirq/linalg/transformations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import cirq
import cirq.testing
from cirq import linalg


def test_reflection_matrix_pow_consistent_results():
Expand Down Expand Up @@ -632,3 +633,18 @@ def test_factor_state_vector(state_1: int, state_2: int):
# All phase goes into a1, and b1 is just the dephased state vector
assert np.allclose(a1, a * phase)
assert np.allclose(b1, b)


@pytest.mark.parametrize('num_dimensions', [*range(1, 7)])
def test_transpose_flattened_array(num_dimensions):
np.random.seed(0)
for _ in range(10):
shape = np.random.randint(1, 5, (num_dimensions,)).tolist()
axes = np.random.permutation(num_dimensions).tolist()
volume = np.prod(shape)
A = np.random.permutation(volume)
want = np.transpose(A.reshape(shape), axes)
got = linalg.transpose_flattened_array(A, shape, axes).reshape(want.shape)
assert np.array_equal(want, got)
got = linalg.transpose_flattened_array(A.reshape(shape), shape, axes).reshape(want.shape)
assert np.array_equal(want, got)
28 changes: 2 additions & 26 deletions cirq-core/cirq/sim/density_matrix_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np

from cirq import linalg, value
from cirq.sim import simulation_utils

if TYPE_CHECKING:
import cirq
Expand Down Expand Up @@ -188,33 +189,8 @@ def _probs(
"""Returns the probabilities for a measurement on the given indices."""
# Only diagonal elements matter.
all_probs = np.diagonal(np.reshape(density_matrix, (np.prod(qid_shape, dtype=np.int64),) * 2))
# Shape into a tensor
tensor = np.reshape(all_probs, qid_shape)

# Calculate the probabilities for measuring the particular results.
if len(indices) == len(qid_shape):
# We're measuring every qudit, so no need for fancy indexing
probs = np.abs(tensor)
probs = np.transpose(probs, indices)
probs = probs.reshape(-1)
else:
# Fancy indexing required
meas_shape = tuple(qid_shape[i] for i in indices)
probs = np.abs(
[
tensor[
linalg.slice_for_qubits_equal_to(
indices, big_endian_qureg_value=b, qid_shape=qid_shape
)
]
for b in range(np.prod(meas_shape, dtype=np.int64))
]
)
probs = np.sum(probs, axis=tuple(range(1, len(probs.shape))))

# To deal with rounding issues, ensure that the probabilities sum to 1.
probs /= np.sum(probs)
return probs
return simulation_utils.state_probabilities_by_indices(all_probs.real, indices, qid_shape)


def _validate_density_matrix_qid_shape(
Expand Down
59 changes: 59 additions & 0 deletions cirq-core/cirq/sim/simulation_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2023 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Sequence, Tuple

import numpy as np

from cirq import linalg


def state_probabilities_by_indices(
state_probability: np.ndarray, indices: Sequence[int], qid_shape: Tuple[int, ...]
) -> np.ndarray:
"""Returns the probabilities for a state/measurement on the given indices.

Args:
state_probability: The multi-qubit state probability vector. This is an
array of 2 to the power of the number of real numbers, and
so state must be of size ``2**integer``. The `state_probability` can be
a vector of size ``2**integer`` or a tensor of shape
``(2, 2, ..., 2)``.
indices: Which qubits are measured. The `state_probability` is assumed to be
supplied in big endian order. That is the xth index of v, when
expressed as a bitstring, has its largest values in the 0th index.
qid_shape: The qid shape of the `state_probability`.

Returns:
State probabilities.
"""
probs = state_probability.reshape((-1,))
not_measured = [i for i in range(len(qid_shape)) if i not in indices]
if linalg.can_numpy_support_shape(qid_shape):
# Use numpy transpose if we can since it's more efficient.
probs = probs.reshape(qid_shape)
probs = np.transpose(probs, list(indices) + not_measured)
probs = probs.reshape((-1,))
else:
# If we can't use numpy due to numpy/numpy#5744, use a slower method.
probs = linalg.transpose_flattened_array(probs, qid_shape, list(indices) + not_measured)

if len(not_measured):
# Not all qudits are measured.
volume = np.prod([qid_shape[i] for i in indices])
# Reshape into a 2D array in which each of the measured states correspond to a row.
probs = probs.reshape((volume, -1))
probs = np.sum(probs, axis=-1)

# To deal with rounding issues, ensure that the probabilities sum to 1.
return probs / np.sum(probs)
32 changes: 32 additions & 0 deletions cirq-core/cirq/sim/simulation_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright 2023 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest

import numpy as np

from cirq.sim import simulation_utils
from cirq import testing


@pytest.mark.parametrize('n,m', [(n, m) for n in range(1, 4) for m in range(1, n + 1)])
def test_state_probabilities_by_indices(n: int, m: int):
np.random.seed(0)
state = testing.random_superposition(1 << n)
d = (state.conj() * state).real
desired_axes = list(np.random.choice(n, m, replace=False))
not_wanted = [i for i in range(n) if i not in desired_axes]
got = simulation_utils.state_probabilities_by_indices(d, desired_axes, (2,) * n)
want = np.transpose(d.reshape((2,) * n), desired_axes + not_wanted)
want = np.sum(want.reshape((1 << len(desired_axes), -1)), axis=-1)
np.testing.assert_allclose(want, got)
39 changes: 5 additions & 34 deletions cirq-core/cirq/sim/state_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np

from cirq import linalg, qis, value
from cirq.sim import simulator
from cirq.sim import simulator, simulation_utils

if TYPE_CHECKING:
import cirq
Expand Down Expand Up @@ -215,7 +215,8 @@ def sample_state_vector(
prng = value.parse_random_state(seed)

# Calculate the measurement probabilities.
probs = _probs(state_vector, indices, shape)
probs = (state_vector * state_vector.conj()).real
probs = simulation_utils.state_probabilities_by_indices(probs, indices, shape)

# We now have the probability vector, correctly ordered, so sample over
# it. Note that we us ints here, since numpy's choice does not allow for
Expand Down Expand Up @@ -288,7 +289,8 @@ def measure_state_vector(
initial_shape = state_vector.shape

# Calculate the measurement probabilities and then make the measurement.
probs = _probs(state_vector, indices, shape)
probs = (state_vector * state_vector.conj()).real
probs = simulation_utils.state_probabilities_by_indices(probs, indices, shape)
result = prng.choice(len(probs), p=probs)
###measurement_bits = [(1 & (result >> i)) for i in range(len(indices))]
# Convert to individual qudit measurements.
Expand Down Expand Up @@ -321,34 +323,3 @@ def measure_state_vector(
assert out is not None
# We mutate and return out, so mypy cannot identify that the out cannot be None.
return measurement_bits, out


def _probs(state: np.ndarray, indices: Sequence[int], qid_shape: Tuple[int, ...]) -> np.ndarray:
"""Returns the probabilities for a measurement on the given indices."""
tensor = np.reshape(state, qid_shape)
# Calculate the probabilities for measuring the particular results.
if len(indices) == len(qid_shape):
# We're measuring every qudit, so no need for fancy indexing
probs = np.abs(tensor) ** 2
probs = np.transpose(probs, indices)
probs = probs.reshape(-1)
else:
# Fancy indexing required
meas_shape = tuple(qid_shape[i] for i in indices)
probs = (
np.abs(
[
tensor[
linalg.slice_for_qubits_equal_to(
indices, big_endian_qureg_value=b, qid_shape=qid_shape
)
]
for b in range(np.prod(meas_shape, dtype=np.int64))
]
)
** 2
)
probs = np.sum(probs, axis=tuple(range(1, len(probs.shape))))

# To deal with rounding issues, ensure that the probabilities sum to 1.
return probs / np.sum(probs)
Loading