Skip to content

Commit

Permalink
Improve support for recursively applying transformer primitives on ci…
Browse files Browse the repository at this point in the history
…rcuit operations using `deep=True` (#5103)

- Part of fixing #5039
- Fixes multiple bugs and improves support for `deep=True` flag in transformer primitives.
  • Loading branch information
tanujkhattar authored Mar 21, 2022
1 parent caadb0c commit 2693951
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 56 deletions.
107 changes: 65 additions & 42 deletions cirq-core/cirq/transformers/transformer_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def map_moments(
):
op_untagged = cast(circuits.CircuitOperation, op.untagged)
mapped_op = op_untagged.replace(
circuit=map_moments(op_untagged.mapped_circuit(), map_func, deep=deep).freeze()
)
circuit=map_moments(op_untagged.circuit, map_func, deep=deep)
).with_tags(*op.tags)
batch_replace.append((i, op, mapped_op))
mutable_circuit = circuit.unfreeze(copy=True)
mutable_circuit.batch_replace(batch_replace)
Expand Down Expand Up @@ -180,7 +180,8 @@ def map_operations_and_unroll(
deep=deep,
raise_if_add_qubits=raise_if_add_qubits,
tags_to_ignore=tags_to_ignore,
)
),
deep=deep,
)


Expand Down Expand Up @@ -399,12 +400,6 @@ def merge_moments(
return _create_target_circuit_type(merged_moments, circuit)


def _check_circuit_op(op, tags_to_check: Optional[Sequence[Hashable]]) -> bool:
return isinstance(op.untagged, circuits.CircuitOperation) and (
tags_to_check is None or any(tag in op.tags for tag in tags_to_check)
)


def unroll_circuit_op(
circuit: CIRCUIT_TYPE,
*,
Expand All @@ -418,8 +413,8 @@ def unroll_circuit_op(
Args:
circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
deep: If True, `unroll_circuit_op` is recursively called on all circuit operations matching
`tags_to_check`.
deep: If true, the transformer primitive will be recursively applied to all circuits
wrapped inside circuit operations.
tags_to_check: If specified, only circuit operations tagged with one of the `tags_to_check`
are unrolled.
Expand All @@ -430,12 +425,18 @@ def unroll_circuit_op(
def map_func(m: circuits.Moment, _: int):
to_zip: List['cirq.AbstractCircuit'] = []
for op in m:
if _check_circuit_op(op, tags_to_check):
sub_circuit = cast(circuits.CircuitOperation, op.untagged).mapped_circuit()
op_untagged = op.untagged
if isinstance(op_untagged, circuits.CircuitOperation):
if deep:
op_untagged = op_untagged.replace(
circuit=unroll_circuit_op(
op_untagged.circuit, deep=deep, tags_to_check=tags_to_check
)
)
to_zip.append(
unroll_circuit_op(sub_circuit, deep=deep, tags_to_check=tags_to_check)
if deep
else sub_circuit
op_untagged.mapped_circuit()
if (tags_to_check is None or set(tags_to_check).intersection(op.tags))
else circuits.Circuit(op_untagged.with_tags(*op.tags))
)
else:
to_zip.append(circuits.Circuit(op))
Expand All @@ -458,27 +459,36 @@ def unroll_circuit_op_greedy_earliest(
Args:
circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
deep: If True, `unroll_circuit_op_greedy_earliest` is recursively called on all circuit
operations matching `tags_to_check`.
deep: If true, the transformer primitive will be recursively applied to all circuits
wrapped inside circuit operations.
tags_to_check: If specified, only circuit operations tagged with one of the `tags_to_check`
are unrolled.
Returns:
Copy of input circuit with (Tagged) CircuitOperation's expanded using EARLIEST strategy.
"""
batch_removals = [*circuit.findall_operations(lambda op: _check_circuit_op(op, tags_to_check))]
batch_inserts = []
for i, op in batch_removals:
sub_circuit = cast(circuits.CircuitOperation, op.untagged).mapped_circuit()
sub_circuit = (
unroll_circuit_op_greedy_earliest(sub_circuit, deep=deep, tags_to_check=tags_to_check)
if deep
else sub_circuit
)
batch_inserts += [(i, sub_circuit.all_operations())]
batch_replace = []
batch_remove = []
batch_insert = []
for i, op in circuit.findall_operations(
lambda o: isinstance(o.untagged, circuits.CircuitOperation)
):
op_untagged = cast(circuits.CircuitOperation, op.untagged)
if deep:
op_untagged = op_untagged.replace(
circuit=unroll_circuit_op_greedy_earliest(
op_untagged.circuit, deep=deep, tags_to_check=tags_to_check
)
)
if tags_to_check is None or set(tags_to_check).intersection(op.tags):
batch_remove.append((i, op))
batch_insert.append((i, op_untagged.mapped_circuit().all_operations()))
elif deep:
batch_replace.append((i, op, op_untagged.with_tags(*op.tags)))
unrolled_circuit = circuit.unfreeze(copy=True)
unrolled_circuit.batch_remove(batch_removals)
unrolled_circuit.batch_insert(batch_inserts)
unrolled_circuit.batch_replace(batch_replace)
unrolled_circuit.batch_remove(batch_remove)
unrolled_circuit.batch_insert(batch_insert)
return _to_target_circuit_type(unrolled_circuit, circuit)


Expand All @@ -496,8 +506,8 @@ def unroll_circuit_op_greedy_frontier(
Args:
circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
deep: If True, `unroll_circuit_op_greedy_frontier` is recursively called on all circuit
operations matching `tags_to_check`.
deep: If true, the transformer primitive will be recursively applied to all circuits
wrapped inside circuit operations.
tags_to_check: If specified, only circuit operations tagged with one of the `tags_to_check`
are unrolled.
Expand All @@ -506,16 +516,29 @@ def unroll_circuit_op_greedy_frontier(
"""
unrolled_circuit = circuit.unfreeze(copy=True)
frontier: Dict['cirq.Qid', int] = defaultdict(lambda: 0)
for idx, op in circuit.findall_operations(lambda op: _check_circuit_op(op, tags_to_check)):
idx = max(idx, max(frontier[q] for q in op.qubits))
unrolled_circuit.clear_operations_touching(op.qubits, [idx])
sub_circuit = cast(circuits.CircuitOperation, op.untagged).mapped_circuit()
sub_circuit = (
unroll_circuit_op_greedy_earliest(sub_circuit, deep=deep, tags_to_check=tags_to_check)
if deep
else sub_circuit
)
frontier = unrolled_circuit.insert_at_frontier(sub_circuit.all_operations(), idx, frontier)
idx = 0
while idx < len(unrolled_circuit):
for op in unrolled_circuit[idx].operations:
# Don't touch stuff inserted by unrolling previous circuit ops.
if not isinstance(op.untagged, circuits.CircuitOperation):
continue
if any(frontier[q] > idx for q in op.qubits):
continue
op_untagged = cast(circuits.CircuitOperation, op.untagged)
if deep:
op_untagged = op_untagged.replace(
circuit=unroll_circuit_op_greedy_frontier(
op_untagged.circuit, deep=deep, tags_to_check=tags_to_check
)
)
if tags_to_check is None or set(tags_to_check).intersection(op.tags):
unrolled_circuit.clear_operations_touching(op.qubits, [idx])
frontier = unrolled_circuit.insert_at_frontier(
op_untagged.mapped_circuit().all_operations(), idx, frontier
)
elif deep:
unrolled_circuit.batch_replace([(idx, op, op_untagged.with_tags(*op.tags))])
idx += 1
return _to_target_circuit_type(unrolled_circuit, circuit)


Expand Down
110 changes: 96 additions & 14 deletions cirq-core/cirq/transformers/transformer_primitives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def map_func(op: cirq.Operation, _: int) -> cirq.OP_TREE:
)


# pylint: disable=line-too-long
def test_map_operations_deep_subcircuits():
q = cirq.LineQubit.range(5)
c_orig = cirq.Circuit(
Expand All @@ -127,9 +128,14 @@ def test_map_operations_deep_subcircuits():
c_orig_with_circuit_ops = cirq.Circuit(
cirq.CircuitOperation(
cirq.FrozenCircuit(
[cirq.CircuitOperation(cirq.FrozenCircuit(op)) for op in c_orig.all_operations()]
[
cirq.CircuitOperation(cirq.FrozenCircuit(op)).repeat(2).with_tags("internal")
for op in c_orig.all_operations()
]
)
)
.repeat(6)
.with_tags("external")
)

def map_func(op: cirq.Operation, _: int) -> cirq.OP_TREE:
Expand All @@ -139,23 +145,73 @@ def map_func(op: cirq.Operation, _: int) -> cirq.OP_TREE:
cirq.Z.on_each(*op.qubits),
] if op.gate == cirq.CX else op

c_mapped = cirq.map_operations(c_orig_with_circuit_ops, map_func, deep=True)
c_mapped = cirq.unroll_circuit_op(c_mapped, deep=True, tags_to_check=None)
cirq.testing.assert_has_diagram(
c_mapped,
c_orig_with_circuit_ops,
'''
0: ───Z───@───Z───────────────
1: ───Z───X───Z───────────────
2: ───Z───X───Z───────────────
3: ───Z───@───Z───Z───@───Z───
4: ───────────────Z───X───Z───
[ [ 0: ───@─── ] ]
[ 0: ───[ │ ]────────────────────────────────────────────────────────────── ]
[ [ 1: ───X─── ](loops=2)['internal'] ]
[ │ ]
[ 1: ───#2────────────────────────────────────────────────────────────────────────── ]
[ ]
[ [ 2: ───X─── ] ]
0: ───[ 2: ───[ │ ]────────────────────────────────────────────────────────────── ]────────────────────────
[ [ 3: ───@─── ](loops=2)['internal'] ]
[ │ ]
[ │ [ 3: ───@─── ] ]
[ 3: ───#2────────────────────────────────────[ │ ]──────────────────────── ]
[ [ 4: ───X─── ](loops=2)['internal'] ]
[ │ ]
[ 4: ─────────────────────────────────────────#2──────────────────────────────────── ](loops=6)['external']
1: ───#2────────────────────────────────────────────────────────────────────────────────────────────────────────────
2: ───#3────────────────────────────────────────────────────────────────────────────────────────────────────────────
3: ───#4────────────────────────────────────────────────────────────────────────────────────────────────────────────
4: ───#5────────────────────────────────────────────────────────────────────────────────────────────────────────────
''',
)

c_mapped = cirq.map_operations(c_orig_with_circuit_ops, map_func, deep=True)
for unroller in [
cirq.unroll_circuit_op,
cirq.unroll_circuit_op_greedy_earliest,
cirq.unroll_circuit_op_greedy_frontier,
]:
cirq.testing.assert_has_diagram(
unroller(c_mapped, deep=True),
'''
[ [ 0: ───Z───@───Z─── ] ]
[ 0: ───[ │ ]────────────────────────────────────────────────────────────────────── ]
[ [ 1: ───Z───X───Z─── ](loops=2)['internal'] ]
[ │ ]
[ 1: ───#2────────────────────────────────────────────────────────────────────────────────────────── ]
[ ]
[ [ 2: ───Z───X───Z─── ] ]
0: ───[ 2: ───[ │ ]────────────────────────────────────────────────────────────────────── ]────────────────────────
[ [ 3: ───Z───@───Z─── ](loops=2)['internal'] ]
[ │ ]
[ │ [ 3: ───Z───@───Z─── ] ]
[ 3: ───#2────────────────────────────────────────────[ │ ]──────────────────────── ]
[ [ 4: ───Z───X───Z─── ](loops=2)['internal'] ]
[ │ ]
[ 4: ─────────────────────────────────────────────────#2──────────────────────────────────────────── ](loops=6)['external']
1: ───#2────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
2: ───#3────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
3: ───#4────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
4: ───#5────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
''',
)


# pylint: enable=line-too-long


def test_map_operations_respects_tags_to_ignore():
q = cirq.LineQubit.range(2)
Expand Down Expand Up @@ -204,13 +260,29 @@ def test_unroll_circuit_op_and_variants():
[cirq.Moment(cirq.CircuitOperation(cirq.FrozenCircuit(m))) for m in mapped_circuit[:-1]],
mapped_circuit[-1],
)
cirq.testing.assert_has_diagram(
mapped_circuit_deep,
'''
0: ───[ 0: ───X─── ]────────────────────────────────────────────────────────────X───
1: ────────────────────[ 1: ───[ 1: ───Z───Z─── ]['<mapped_circuit_op>']─── ]───────
''',
)
for unroller in [
cirq.unroll_circuit_op_greedy_earliest,
cirq.unroll_circuit_op_greedy_frontier,
cirq.unroll_circuit_op,
]:
cirq.testing.assert_same_circuits(
unroller(mapped_circuit), unroller(mapped_circuit_deep, tags_to_check=None, deep=True)
unroller(mapped_circuit), unroller(mapped_circuit_deep, deep=True, tags_to_check=None)
)
cirq.testing.assert_has_diagram(
unroller(mapped_circuit_deep, deep=True),
'''
0: ───[ 0: ───X─── ]────────────────────────X───
1: ────────────────────[ 1: ───Z───Z─── ]───────
''',
)

cirq.testing.assert_has_diagram(
Expand Down Expand Up @@ -239,6 +311,16 @@ def test_unroll_circuit_op_and_variants():
)


def test_unroll_circuit_op_greedy_frontier_doesnt_touch_same_op_twice():
q = cirq.NamedQubit("q")
nested_ops = [cirq.CircuitOperation(cirq.FrozenCircuit(cirq.X(q)))] * 5
nested_circuit_op = cirq.CircuitOperation(cirq.FrozenCircuit(nested_ops))
c = cirq.Circuit(nested_circuit_op, nested_circuit_op, nested_circuit_op)
c_expected = cirq.Circuit(nested_ops, nested_ops, nested_ops)
c_unrolled = cirq.unroll_circuit_op_greedy_frontier(c, tags_to_check=None)
cirq.testing.assert_same_circuits(c_unrolled, c_expected)


def test_unroll_circuit_op_deep():
q0, q1, q2 = cirq.LineQubit.range(3)
c = cirq.Circuit(
Expand Down

0 comments on commit 2693951

Please sign in to comment.