Skip to content

Commit

Permalink
Allow repeated measurements in deferred transformer (#5857)
Browse files Browse the repository at this point in the history
* Add handling for sympy conditions in deferred measurement transformer

* docstring

* mypy

* mypy

* cover

* Make this more generic, covers all kinds of conditions.

* Better docs

* Sympy can also be CX

* docs

* docs

* Allow repeated measurements in deferred transformer

* Coverage

* Add mixed tests, simplify loop, add simplification in ControlledGate

* Fix error message

* Simplify error message

* Inline variable

* fix merge

* qudit sympy test

* fix build

* Fix test

* Fix test

* nits

* mypy

* mypy

* mypy

* Add some code comments

* Add test for repeated measurement diagram

* change test back

Co-authored-by: Tanuj Khattar <tanujkhattar@google.com>
  • Loading branch information
daxfohl and tanujkhattar authored Dec 19, 2022
1 parent 7019adc commit af1267d
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 52 deletions.
89 changes: 54 additions & 35 deletions cirq-core/cirq/transformers/measurement_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,8 @@
# limitations under the License.

import itertools
from typing import (
Any,
Dict,
Iterable,
List,
Mapping,
Optional,
Sequence,
Tuple,
TYPE_CHECKING,
Union,
)
from collections import defaultdict
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union

import numpy as np

Expand All @@ -43,30 +33,32 @@ class _MeasurementQid(ops.Qid):
Exactly one qubit will be created per qubit in the measurement gate.
"""

def __init__(self, key: Union[str, 'cirq.MeasurementKey'], qid: 'cirq.Qid'):
def __init__(self, key: Union[str, 'cirq.MeasurementKey'], qid: 'cirq.Qid', index: int = 0):
"""Initializes the qubit.
Args:
key: The key of the measurement gate being deferred.
qid: One qubit that is being measured. Each deferred measurement
should create one new _MeasurementQid per qubit being measured
by that gate.
index: For repeated measurement keys, this represents the index of that measurement.
"""
self._key = value.MeasurementKey.parse_serialized(key) if isinstance(key, str) else key
self._qid = qid
self._index = index

@property
def dimension(self) -> int:
return self._qid.dimension

def _comparison_key(self) -> Any:
return str(self._key), self._qid._comparison_key()
return str(self._key), self._index, self._qid._comparison_key()

def __str__(self) -> str:
return f"M('{self._key}', q={self._qid})"
return f"M('{self._key}[{self._index}]', q={self._qid})"

def __repr__(self) -> str:
return f'_MeasurementQid({self._key!r}, {self._qid!r})'
return f'_MeasurementQid({self._key!r}, {self._qid!r}, {self._index})'


@transformer_api.transformer
Expand Down Expand Up @@ -102,16 +94,18 @@ def defer_measurements(

circuit = transformer_primitives.unroll_circuit_op(circuit, deep=True, tags_to_check=None)
terminal_measurements = {op for _, op in find_terminal_measurements(circuit)}
measurement_qubits: Dict['cirq.MeasurementKey', List['_MeasurementQid']] = {}
measurement_qubits: Dict['cirq.MeasurementKey', List[Tuple['cirq.Qid', ...]]] = defaultdict(
list
)

def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
if op in terminal_measurements:
return op
gate = op.gate
if isinstance(gate, ops.MeasurementGate):
key = value.MeasurementKey.parse_serialized(gate.key)
targets = [_MeasurementQid(key, q) for q in op.qubits]
measurement_qubits[key] = targets
targets = [_MeasurementQid(key, q, len(measurement_qubits[key])) for q in op.qubits]
measurement_qubits[key].append(tuple(targets))
cxs = [_mod_add(q, target) for q, target in zip(op.qubits, targets)]
confusions = [
_ConfusionChannel(m, [op.qubits[i].dimension for i in indexes]).on(
Expand All @@ -125,10 +119,24 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
return [defer(op, None) for op in protocols.decompose_once(op)]
elif op.classical_controls:
# Convert to a quantum control
keys = sorted(set(key for c in op.classical_controls for key in c.keys))
for key in keys:

# First create a sorted set of the indexed keys for this control.
keys = sorted(
set(
indexed_key
for condition in op.classical_controls
for indexed_key in (
[(condition.key, condition.index)]
if isinstance(condition, value.KeyCondition)
else [(k, -1) for k in condition.keys]
)
)
)
for key, index in keys:
if key not in measurement_qubits:
raise ValueError(f'Deferred measurement for key={key} not found.')
if index >= len(measurement_qubits[key]) or index < -len(measurement_qubits[key]):
raise ValueError(f'Invalid index for {key}')

# Try every possible datastore state (exponential in the number of keys) against the
# condition, and the ones that work are the control values for the new op.
Expand All @@ -140,12 +148,11 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':

# Rearrange these into the format expected by SumOfProducts
products = [
[i for key in keys for i in store.records[key][0]]
[val for k, i in keys for val in store.records[k][i]]
for store in compatible_datastores
]

control_values = ops.SumOfProducts(products)
qs = [q for key in keys for q in measurement_qubits[key]]
qs = [q for k, i in keys for q in measurement_qubits[k][i]]
return op.without_classical_controls().controlled_by(*qs, control_values=control_values)
return op

Expand All @@ -155,14 +162,15 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
tags_to_ignore=context.tags_to_ignore if context else (),
raise_if_add_qubits=False,
).unfreeze()
for k, qubits in measurement_qubits.items():
circuit.append(ops.measure(*qubits, key=k))
for k, qubits_list in measurement_qubits.items():
for qubits in qubits_list:
circuit.append(ops.measure(*qubits, key=k))
return circuit


def _all_possible_datastore_states(
keys: Iterable['cirq.MeasurementKey'],
measurement_qubits: Mapping['cirq.MeasurementKey', Iterable['cirq.Qid']],
keys: Iterable[Tuple['cirq.MeasurementKey', int]],
measurement_qubits: Dict['cirq.MeasurementKey', List[Tuple['cirq.Qid', ...]]],
) -> Iterable['cirq.ClassicalDataStoreReader']:
"""The cartesian product of all possible DataStore states for the given keys."""
# First we get the list of all possible values. So if we have a key mapped to qubits of shape
Expand All @@ -179,17 +187,28 @@ def _all_possible_datastore_states(
# ((1, 1), (0,)),
# ((1, 1), (1,)),
# ((1, 1), (2,))]
all_values = itertools.product(
all_possible_measurements = itertools.product(
*[
tuple(itertools.product(*[range(q.dimension) for q in measurement_qubits[k]]))
for k in keys
tuple(itertools.product(*[range(q.dimension) for q in measurement_qubits[k][i]]))
for k, i in keys
]
)
# Then we create the ClassicalDataDictionaryStore for each of the above.
for sequences in all_values:
lookup = {k: [sequence] for k, sequence in zip(keys, sequences)}
# Then we create the ClassicalDataDictionaryStore for each of the above. A `measurement_list`
# is a single row of the above example, and can be zipped with `keys`.
for measurement_list in all_possible_measurements:
# Initialize a set of measurement records for this iteration. This will have the same shape
# as `measurement_qubits` but zeros for all measurements.
records = {
key: [(0,) * len(qubits) for qubits in qubits_list]
for key, qubits_list in measurement_qubits.items()
}
# Set the measurement values from the current row of the above, for each key/index we care
# about.
for (k, i), measurement in zip(keys, measurement_list):
records[k][i] = measurement
# Finally yield this sample to the consumer.
yield value.ClassicalDataDictionaryStore(
_records=lookup, _measured_qubits={k: [tuple(measurement_qubits[k])] for k in keys}
_records=records, _measured_qubits=measurement_qubits
)


Expand Down
68 changes: 51 additions & 17 deletions cirq-core/cirq/transformers/measurement_transformers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,40 @@ def test_multi_qubit_control():
)


@pytest.mark.parametrize('index', [-3, -2, -1, 0, 1, 2])
def test_repeated(index: int):
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(
cirq.measure(q0, key='a'), # The control measurement when `index` is 0 or -2
cirq.X(q0),
cirq.measure(q0, key='a'), # The control measurement when `index` is 1 or -1
cirq.X(q1).with_classical_controls(cirq.KeyCondition(cirq.MeasurementKey('a'), index)),
cirq.measure(q1, key='b'),
)
if index in [-3, 2]:
with pytest.raises(ValueError, match='Invalid index'):
_ = cirq.defer_measurements(circuit)
return
assert_equivalent_to_deferred(circuit)
deferred = cirq.defer_measurements(circuit)
q_ma = _MeasurementQid('a', q0) # The ancilla qubit created for the first `a` measurement
q_ma1 = _MeasurementQid('a', q0, 1) # The ancilla qubit created for the second `a` measurement
# The ancilla used for control should match the measurement used for control above.
q_expected_control = q_ma if index in [0, -2] else q_ma1
cirq.testing.assert_same_circuits(
deferred,
cirq.Circuit(
cirq.CX(q0, q_ma),
cirq.X(q0),
cirq.CX(q0, q_ma1),
cirq.Moment(cirq.CX(q_expected_control, q1)),
cirq.measure(q_ma, key='a'),
cirq.measure(q_ma1, key='a'),
cirq.measure(q1, key='b'),
),
)


def test_diagram():
q0, q1, q2, q3 = cirq.LineQubit.range(4)
circuit = cirq.Circuit(
Expand All @@ -457,23 +491,23 @@ def test_diagram():
cirq.testing.assert_has_diagram(
deferred,
"""
┌────┐
0: ─────────────────@───────X────────M('c')───
│ │
1: ─────────────────┼─@──────────────M────────
│ │ │
2: ─────────────────┼@┼──────────────M────────
│││ │
3: ─────────────────┼┼┼@─────────────M────────
││││
M('a', q=q(0)): ────X┼┼┼────M('a')────────────
│││ │
M('a', q=q(2)): ─────X┼┼────M─────────────────
││
M('b', q=q(1)): ──────X┼────M('b')────────────
│ │
M('b', q=q(3)): ───────X────M─────────────────
└────┘
┌────┐
0: ────────────────────@───────X────────M('c')───
│ │
1: ────────────────────┼─@──────────────M────────
│ │ │
2: ────────────────────┼@┼──────────────M────────
│││ │
3: ────────────────────┼┼┼@─────────────M────────
││││
M('a[0]', q=q(0)): ────X┼┼┼────M('a')────────────
│││ │
M('a[0]', q=q(2)): ─────X┼┼────M─────────────────
││
M('b[0]', q=q(1)): ──────X┼────M('b')────────────
│ │
M('b[0]', q=q(3)): ───────X────M─────────────────
└────┘
""",
use_unicode_characters=True,
)
Expand Down

0 comments on commit af1267d

Please sign in to comment.