Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow qudits in deferred measurements #5850

Merged
merged 10 commits into from
Oct 11, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cirq-core/cirq/ops/gate_operation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,7 @@ def all_subclasses(cls):
cirq.Pauli,
# Private gates.
cirq.transformers.analytical_decompositions.two_qubit_to_fsim._BGate,
cirq.transformers.measurement_transformers._Add,
cirq.ops.raw_types._InverseCompositeGate,
cirq.circuits.qasm_output.QasmTwoQubitGate,
cirq.ops.MSGate,
Expand Down
37 changes: 34 additions & 3 deletions cirq-core/cirq/transformers/measurement_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def dimension(self) -> int:
return self._qid.dimension

def _comparison_key(self) -> Any:
return (str(self._key), self._qid._comparison_key())
return str(self._key), self._qid._comparison_key()

def __str__(self) -> str:
return f"M('{self._key}', q={self._qid})"
Expand Down Expand Up @@ -104,7 +104,7 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
key = value.MeasurementKey.parse_serialized(gate.key)
targets = [_MeasurementQid(key, q) for q in op.qubits]
measurement_qubits[key] = targets
cxs = [ops.CX(q, target) for q, target in zip(op.qubits, targets)]
cxs = [_cx(q.dimension).on(q, target) for q, target in zip(op.qubits, targets)]
daxfohl marked this conversation as resolved.
Show resolved Hide resolved
xs = [ops.X(targets[i]) for i, b in enumerate(gate.full_invert_mask()) if b]
return cxs + xs
elif protocols.is_measurement(op):
Expand All @@ -117,7 +117,7 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
raise ValueError(f'Deferred measurement for key={c.key} not found.')
qs = measurement_qubits[c.key]
if len(qs) == 1:
control_values: Any = range(1, qs[0].dimension)
control_values: Any = [range(1, qs[0].dimension)]
else:
all_values = itertools.product(*[range(q.dimension) for q in qs])
anything_but_all_zeros = tuple(itertools.islice(all_values, 1, None))
Expand Down Expand Up @@ -227,3 +227,34 @@ def flip_inversion(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
return transformer_primitives.map_operations(
circuit, flip_inversion, deep=context.deep if context else True, tags_to_ignore=ignored
).unfreeze()


@value.value_equality
class _Add(ops.ArithmeticGate):
daxfohl marked this conversation as resolved.
Show resolved Hide resolved
"""Adds two qudits of the same dimension.

Operates on two qudits by modular addition:

|a,b> -> |a,a+b mod d>"""

def __init__(self, dimension: int):
self._dimension = dimension

def registers(self):
return (self._dimension,), (self._dimension,)

def with_registers(self, *new_registers):
raise NotImplementedError()

def apply(self, input_value, target_value):
return input_value, target_value + input_value

def _value_equality_values_(self):
return self._dimension
daxfohl marked this conversation as resolved.
Show resolved Hide resolved


def _cx(dimension: int):
daxfohl marked this conversation as resolved.
Show resolved Hide resolved
daxfohl marked this conversation as resolved.
Show resolved Hide resolved
# We can use an Add gate in the qudit case, since the ancilla qudit corresponding to the
# measurement is always zero, so "adding" the measured qudit to it sets the ancilla qudit to
# the same state.
return ops.CX if dimension == 2 else _Add(dimension)
31 changes: 26 additions & 5 deletions cirq-core/cirq/transformers/measurement_transformers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,19 @@
import sympy

import cirq
from cirq.transformers.measurement_transformers import _MeasurementQid
from cirq.transformers.measurement_transformers import _cx, _MeasurementQid


def assert_equivalent_to_deferred(circuit: cirq.Circuit):
qubits = list(circuit.all_qubits())
sim = cirq.Simulator()
num_qubits = len(qubits)
for i in range(2**num_qubits):
bits = cirq.big_endian_int_to_bits(i, bit_count=num_qubits)
dimensions = [q.dimension for q in qubits]
for i in range(np.prod(dimensions)):
bits = cirq.big_endian_int_to_digits(i, base=dimensions)
modified = cirq.Circuit()
for j in range(num_qubits):
if bits[j]:
modified.append(cirq.X(qubits[j]))
modified.append(cirq.XPowGate(dimension=qubits[j].dimension)(qubits[j]) ** bits[j])
modified.append(circuit)
deferred = cirq.defer_measurements(modified)
result = sim.simulate(modified)
Expand Down Expand Up @@ -58,6 +58,27 @@ def test_basic():
)


def test_qudits():
q0, q1 = cirq.LineQid.range(2, dimension=3)
circuit = cirq.Circuit(
cirq.measure(q0, key='a'),
cirq.XPowGate(dimension=3).on(q1).with_classical_controls('a'),
cirq.measure(q1, key='b'),
)
assert_equivalent_to_deferred(circuit)
deferred = cirq.defer_measurements(circuit)
q_ma = _MeasurementQid('a', q0)
cirq.testing.assert_same_circuits(
deferred,
cirq.Circuit(
_cx(3)(q0, q_ma),
cirq.XPowGate(dimension=3).on(q1).controlled_by(q_ma, control_values=[[1, 2]]),
cirq.measure(q_ma, key='a'),
cirq.measure(q1, key='b'),
),
)


def test_nocompile_context():
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(
Expand Down