Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve support for recursively applying transformer primitives on circuit operations using deep=True #5103

Merged
merged 2 commits into from
Mar 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 65 additions & 42 deletions cirq-core/cirq/transformers/transformer_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def map_moments(
):
op_untagged = cast(circuits.CircuitOperation, op.untagged)
mapped_op = op_untagged.replace(
circuit=map_moments(op_untagged.mapped_circuit(), map_func, deep=deep).freeze()
)
circuit=map_moments(op_untagged.circuit, map_func, deep=deep)
).with_tags(*op.tags)
batch_replace.append((i, op, mapped_op))
mutable_circuit = circuit.unfreeze(copy=True)
mutable_circuit.batch_replace(batch_replace)
Expand Down Expand Up @@ -180,7 +180,8 @@ def map_operations_and_unroll(
deep=deep,
raise_if_add_qubits=raise_if_add_qubits,
tags_to_ignore=tags_to_ignore,
)
),
deep=deep,
)


Expand Down Expand Up @@ -399,12 +400,6 @@ def merge_moments(
return _create_target_circuit_type(merged_moments, circuit)


def _check_circuit_op(op, tags_to_check: Optional[Sequence[Hashable]]) -> bool:
return isinstance(op.untagged, circuits.CircuitOperation) and (
tags_to_check is None or any(tag in op.tags for tag in tags_to_check)
)


def unroll_circuit_op(
circuit: CIRCUIT_TYPE,
*,
Expand All @@ -418,8 +413,8 @@ def unroll_circuit_op(

Args:
circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
deep: If True, `unroll_circuit_op` is recursively called on all circuit operations matching
`tags_to_check`.
deep: If true, the transformer primitive will be recursively applied to all circuits
wrapped inside circuit operations.
tags_to_check: If specified, only circuit operations tagged with one of the `tags_to_check`
are unrolled.

Expand All @@ -430,12 +425,18 @@ def unroll_circuit_op(
def map_func(m: circuits.Moment, _: int):
to_zip: List['cirq.AbstractCircuit'] = []
for op in m:
if _check_circuit_op(op, tags_to_check):
sub_circuit = cast(circuits.CircuitOperation, op.untagged).mapped_circuit()
op_untagged = op.untagged
if isinstance(op_untagged, circuits.CircuitOperation):
if deep:
op_untagged = op_untagged.replace(
circuit=unroll_circuit_op(
op_untagged.circuit, deep=deep, tags_to_check=tags_to_check
)
)
to_zip.append(
unroll_circuit_op(sub_circuit, deep=deep, tags_to_check=tags_to_check)
if deep
else sub_circuit
op_untagged.mapped_circuit()
if (tags_to_check is None or set(tags_to_check).intersection(op.tags))
else circuits.Circuit(op_untagged.with_tags(*op.tags))
)
else:
to_zip.append(circuits.Circuit(op))
Expand All @@ -458,27 +459,36 @@ def unroll_circuit_op_greedy_earliest(

Args:
circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
deep: If True, `unroll_circuit_op_greedy_earliest` is recursively called on all circuit
operations matching `tags_to_check`.
deep: If true, the transformer primitive will be recursively applied to all circuits
wrapped inside circuit operations.
tags_to_check: If specified, only circuit operations tagged with one of the `tags_to_check`
are unrolled.

Returns:
Copy of input circuit with (Tagged) CircuitOperation's expanded using EARLIEST strategy.
"""
batch_removals = [*circuit.findall_operations(lambda op: _check_circuit_op(op, tags_to_check))]
batch_inserts = []
for i, op in batch_removals:
sub_circuit = cast(circuits.CircuitOperation, op.untagged).mapped_circuit()
sub_circuit = (
unroll_circuit_op_greedy_earliest(sub_circuit, deep=deep, tags_to_check=tags_to_check)
if deep
else sub_circuit
)
batch_inserts += [(i, sub_circuit.all_operations())]
batch_replace = []
batch_remove = []
batch_insert = []
for i, op in circuit.findall_operations(
lambda o: isinstance(o.untagged, circuits.CircuitOperation)
):
op_untagged = cast(circuits.CircuitOperation, op.untagged)
if deep:
op_untagged = op_untagged.replace(
circuit=unroll_circuit_op_greedy_earliest(
op_untagged.circuit, deep=deep, tags_to_check=tags_to_check
)
)
if tags_to_check is None or set(tags_to_check).intersection(op.tags):
batch_remove.append((i, op))
batch_insert.append((i, op_untagged.mapped_circuit().all_operations()))
elif deep:
batch_replace.append((i, op, op_untagged.with_tags(*op.tags)))
unrolled_circuit = circuit.unfreeze(copy=True)
unrolled_circuit.batch_remove(batch_removals)
unrolled_circuit.batch_insert(batch_inserts)
unrolled_circuit.batch_replace(batch_replace)
unrolled_circuit.batch_remove(batch_remove)
unrolled_circuit.batch_insert(batch_insert)
return _to_target_circuit_type(unrolled_circuit, circuit)


Expand All @@ -496,8 +506,8 @@ def unroll_circuit_op_greedy_frontier(

Args:
circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
deep: If True, `unroll_circuit_op_greedy_frontier` is recursively called on all circuit
operations matching `tags_to_check`.
deep: If true, the transformer primitive will be recursively applied to all circuits
wrapped inside circuit operations.
tags_to_check: If specified, only circuit operations tagged with one of the `tags_to_check`
are unrolled.

Expand All @@ -506,16 +516,29 @@ def unroll_circuit_op_greedy_frontier(
"""
unrolled_circuit = circuit.unfreeze(copy=True)
frontier: Dict['cirq.Qid', int] = defaultdict(lambda: 0)
for idx, op in circuit.findall_operations(lambda op: _check_circuit_op(op, tags_to_check)):
idx = max(idx, max(frontier[q] for q in op.qubits))
unrolled_circuit.clear_operations_touching(op.qubits, [idx])
sub_circuit = cast(circuits.CircuitOperation, op.untagged).mapped_circuit()
sub_circuit = (
unroll_circuit_op_greedy_earliest(sub_circuit, deep=deep, tags_to_check=tags_to_check)
if deep
else sub_circuit
)
frontier = unrolled_circuit.insert_at_frontier(sub_circuit.all_operations(), idx, frontier)
idx = 0
while idx < len(unrolled_circuit):
for op in unrolled_circuit[idx].operations:
# Don't touch stuff inserted by unrolling previous circuit ops.
if not isinstance(op.untagged, circuits.CircuitOperation):
continue
if any(frontier[q] > idx for q in op.qubits):
continue
op_untagged = cast(circuits.CircuitOperation, op.untagged)
if deep:
op_untagged = op_untagged.replace(
circuit=unroll_circuit_op_greedy_frontier(
op_untagged.circuit, deep=deep, tags_to_check=tags_to_check
)
)
if tags_to_check is None or set(tags_to_check).intersection(op.tags):
unrolled_circuit.clear_operations_touching(op.qubits, [idx])
frontier = unrolled_circuit.insert_at_frontier(
op_untagged.mapped_circuit().all_operations(), idx, frontier
)
elif deep:
unrolled_circuit.batch_replace([(idx, op, op_untagged.with_tags(*op.tags))])
idx += 1
return _to_target_circuit_type(unrolled_circuit, circuit)


Expand Down
110 changes: 96 additions & 14 deletions cirq-core/cirq/transformers/transformer_primitives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def map_func(op: cirq.Operation, _: int) -> cirq.OP_TREE:
)


# pylint: disable=line-too-long
def test_map_operations_deep_subcircuits():
q = cirq.LineQubit.range(5)
c_orig = cirq.Circuit(
Expand All @@ -127,9 +128,14 @@ def test_map_operations_deep_subcircuits():
c_orig_with_circuit_ops = cirq.Circuit(
cirq.CircuitOperation(
cirq.FrozenCircuit(
[cirq.CircuitOperation(cirq.FrozenCircuit(op)) for op in c_orig.all_operations()]
[
cirq.CircuitOperation(cirq.FrozenCircuit(op)).repeat(2).with_tags("internal")
for op in c_orig.all_operations()
]
)
)
.repeat(6)
.with_tags("external")
)

def map_func(op: cirq.Operation, _: int) -> cirq.OP_TREE:
Expand All @@ -139,23 +145,73 @@ def map_func(op: cirq.Operation, _: int) -> cirq.OP_TREE:
cirq.Z.on_each(*op.qubits),
] if op.gate == cirq.CX else op

c_mapped = cirq.map_operations(c_orig_with_circuit_ops, map_func, deep=True)
c_mapped = cirq.unroll_circuit_op(c_mapped, deep=True, tags_to_check=None)
cirq.testing.assert_has_diagram(
c_mapped,
c_orig_with_circuit_ops,
'''
0: ───Z───@───Z───────────────
1: ───Z───X───Z───────────────

2: ───Z───X───Z───────────────
3: ───Z───@───Z───Z───@───Z───
4: ───────────────Z───X───Z───
[ [ 0: ───@─── ] ]
[ 0: ───[ │ ]────────────────────────────────────────────────────────────── ]
[ [ 1: ───X─── ](loops=2)['internal'] ]
[ │ ]
[ 1: ───#2────────────────────────────────────────────────────────────────────────── ]
[ ]
[ [ 2: ───X─── ] ]
0: ───[ 2: ───[ │ ]────────────────────────────────────────────────────────────── ]────────────────────────
[ [ 3: ───@─── ](loops=2)['internal'] ]
[ │ ]
[ │ [ 3: ───@─── ] ]
[ 3: ───#2────────────────────────────────────[ │ ]──────────────────────── ]
[ [ 4: ───X─── ](loops=2)['internal'] ]
[ │ ]
[ 4: ─────────────────────────────────────────#2──────────────────────────────────── ](loops=6)['external']
1: ───#2────────────────────────────────────────────────────────────────────────────────────────────────────────────
2: ───#3────────────────────────────────────────────────────────────────────────────────────────────────────────────
3: ───#4────────────────────────────────────────────────────────────────────────────────────────────────────────────
4: ───#5────────────────────────────────────────────────────────────────────────────────────────────────────────────
''',
)

c_mapped = cirq.map_operations(c_orig_with_circuit_ops, map_func, deep=True)
for unroller in [
cirq.unroll_circuit_op,
cirq.unroll_circuit_op_greedy_earliest,
cirq.unroll_circuit_op_greedy_frontier,
]:
cirq.testing.assert_has_diagram(
unroller(c_mapped, deep=True),
'''
[ [ 0: ───Z───@───Z─── ] ]
[ 0: ───[ │ ]────────────────────────────────────────────────────────────────────── ]
[ [ 1: ───Z───X───Z─── ](loops=2)['internal'] ]
[ │ ]
[ 1: ───#2────────────────────────────────────────────────────────────────────────────────────────── ]
[ ]
[ [ 2: ───Z───X───Z─── ] ]
0: ───[ 2: ───[ │ ]────────────────────────────────────────────────────────────────────── ]────────────────────────
[ [ 3: ───Z───@───Z─── ](loops=2)['internal'] ]
[ │ ]
[ │ [ 3: ───Z───@───Z─── ] ]
[ 3: ───#2────────────────────────────────────────────[ │ ]──────────────────────── ]
[ [ 4: ───Z───X───Z─── ](loops=2)['internal'] ]
[ │ ]
[ 4: ─────────────────────────────────────────────────#2──────────────────────────────────────────── ](loops=6)['external']
1: ───#2────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
2: ───#3────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
3: ───#4────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
4: ───#5────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
''',
)


# pylint: enable=line-too-long


def test_map_operations_respects_tags_to_ignore():
q = cirq.LineQubit.range(2)
Expand Down Expand Up @@ -204,13 +260,29 @@ def test_unroll_circuit_op_and_variants():
[cirq.Moment(cirq.CircuitOperation(cirq.FrozenCircuit(m))) for m in mapped_circuit[:-1]],
mapped_circuit[-1],
)
cirq.testing.assert_has_diagram(
mapped_circuit_deep,
'''
0: ───[ 0: ───X─── ]────────────────────────────────────────────────────────────X───

1: ────────────────────[ 1: ───[ 1: ───Z───Z─── ]['<mapped_circuit_op>']─── ]───────
''',
)
for unroller in [
cirq.unroll_circuit_op_greedy_earliest,
cirq.unroll_circuit_op_greedy_frontier,
cirq.unroll_circuit_op,
]:
cirq.testing.assert_same_circuits(
unroller(mapped_circuit), unroller(mapped_circuit_deep, tags_to_check=None, deep=True)
unroller(mapped_circuit), unroller(mapped_circuit_deep, deep=True, tags_to_check=None)
)
cirq.testing.assert_has_diagram(
unroller(mapped_circuit_deep, deep=True),
'''
0: ───[ 0: ───X─── ]────────────────────────X───

1: ────────────────────[ 1: ───Z───Z─── ]───────
''',
)

cirq.testing.assert_has_diagram(
Expand Down Expand Up @@ -239,6 +311,16 @@ def test_unroll_circuit_op_and_variants():
)


def test_unroll_circuit_op_greedy_frontier_doesnt_touch_same_op_twice():
q = cirq.NamedQubit("q")
nested_ops = [cirq.CircuitOperation(cirq.FrozenCircuit(cirq.X(q)))] * 5
nested_circuit_op = cirq.CircuitOperation(cirq.FrozenCircuit(nested_ops))
c = cirq.Circuit(nested_circuit_op, nested_circuit_op, nested_circuit_op)
c_expected = cirq.Circuit(nested_ops, nested_ops, nested_ops)
c_unrolled = cirq.unroll_circuit_op_greedy_frontier(c, tags_to_check=None)
cirq.testing.assert_same_circuits(c_unrolled, c_expected)


def test_unroll_circuit_op_deep():
q0, q1, q2 = cirq.LineQubit.range(3)
c = cirq.Circuit(
Expand Down