Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve support for recursively applying transformer primitives on circuit operations using deep=True #5103

Merged
merged 2 commits into from
Mar 21, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Improve support of deep=True flag to work with circuit operations in …
…transformer primitives
  • Loading branch information
tanujkhattar committed Mar 19, 2022
commit 0ed09d4e53af48916a4456b0a78d9b0fd9d2a897
110 changes: 68 additions & 42 deletions cirq-core/cirq/transformers/transformer_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def map_moments(
):
op_untagged = cast(circuits.CircuitOperation, op.untagged)
mapped_op = op_untagged.replace(
circuit=map_moments(op_untagged.mapped_circuit(), map_func, deep=deep).freeze()
)
circuit=map_moments(op_untagged.circuit, map_func, deep=deep)
).with_tags(*op.tags)
batch_replace.append((i, op, mapped_op))
mutable_circuit = circuit.unfreeze(copy=True)
mutable_circuit.batch_replace(batch_replace)
Expand Down Expand Up @@ -180,7 +180,8 @@ def map_operations_and_unroll(
deep=deep,
raise_if_add_qubits=raise_if_add_qubits,
tags_to_ignore=tags_to_ignore,
)
),
deep=deep,
)


Expand Down Expand Up @@ -399,12 +400,6 @@ def merge_moments(
return _create_target_circuit_type(merged_moments, circuit)


def _check_circuit_op(op, tags_to_check: Optional[Sequence[Hashable]]) -> bool:
return isinstance(op.untagged, circuits.CircuitOperation) and (
tags_to_check is None or any(tag in op.tags for tag in tags_to_check)
)


def unroll_circuit_op(
circuit: CIRCUIT_TYPE,
*,
Expand All @@ -418,8 +413,8 @@ def unroll_circuit_op(

Args:
circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
deep: If True, `unroll_circuit_op` is recursively called on all circuit operations matching
`tags_to_check`.
deep: If true, the transformer primitive will be recursively applied to all circuits
wrapped inside circuit operations.
tags_to_check: If specified, only circuit operations tagged with one of the `tags_to_check`
are unrolled.

Expand All @@ -430,12 +425,18 @@ def unroll_circuit_op(
def map_func(m: circuits.Moment, _: int):
to_zip: List['cirq.AbstractCircuit'] = []
for op in m:
if _check_circuit_op(op, tags_to_check):
sub_circuit = cast(circuits.CircuitOperation, op.untagged).mapped_circuit()
op_untagged = op.untagged
if isinstance(op_untagged, circuits.CircuitOperation):
if deep:
op_untagged = op_untagged.replace(
circuit=unroll_circuit_op(
op_untagged.circuit, deep=deep, tags_to_check=tags_to_check
)
)
to_zip.append(
unroll_circuit_op(sub_circuit, deep=deep, tags_to_check=tags_to_check)
if deep
else sub_circuit
op_untagged.mapped_circuit()
if (tags_to_check is None or set(tags_to_check).intersection(op.tags))
else circuits.Circuit(op_untagged.with_tags(*op.tags))
)
else:
to_zip.append(circuits.Circuit(op))
Expand All @@ -458,27 +459,36 @@ def unroll_circuit_op_greedy_earliest(

Args:
circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
deep: If True, `unroll_circuit_op_greedy_earliest` is recursively called on all circuit
operations matching `tags_to_check`.
deep: If true, the transformer primitive will be recursively applied to all circuits
wrapped inside circuit operations.
tags_to_check: If specified, only circuit operations tagged with one of the `tags_to_check`
are unrolled.

Returns:
Copy of input circuit with (Tagged) CircuitOperation's expanded using EARLIEST strategy.
"""
batch_removals = [*circuit.findall_operations(lambda op: _check_circuit_op(op, tags_to_check))]
batch_inserts = []
for i, op in batch_removals:
sub_circuit = cast(circuits.CircuitOperation, op.untagged).mapped_circuit()
sub_circuit = (
unroll_circuit_op_greedy_earliest(sub_circuit, deep=deep, tags_to_check=tags_to_check)
if deep
else sub_circuit
)
batch_inserts += [(i, sub_circuit.all_operations())]
batch_replace = []
batch_remove = []
batch_insert = []
for i, op in circuit.findall_operations(
lambda o: isinstance(o.untagged, circuits.CircuitOperation)
):
op_untagged = cast(circuits.CircuitOperation, op.untagged)
if deep:
op_untagged = op_untagged.replace(
circuit=unroll_circuit_op_greedy_earliest(
op_untagged.circuit, deep=deep, tags_to_check=tags_to_check
)
)
if tags_to_check is None or set(tags_to_check).intersection(op.tags):
batch_remove.append((i, op))
batch_insert.append((i, op_untagged.mapped_circuit().all_operations()))
elif deep:
batch_replace.append((i, op, op_untagged.with_tags(*op.tags)))
unrolled_circuit = circuit.unfreeze(copy=True)
unrolled_circuit.batch_remove(batch_removals)
unrolled_circuit.batch_insert(batch_inserts)
unrolled_circuit.batch_replace(batch_replace)
unrolled_circuit.batch_remove(batch_remove)
unrolled_circuit.batch_insert(batch_insert)
return _to_target_circuit_type(unrolled_circuit, circuit)


Expand All @@ -496,8 +506,8 @@ def unroll_circuit_op_greedy_frontier(

Args:
circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
deep: If True, `unroll_circuit_op_greedy_frontier` is recursively called on all circuit
operations matching `tags_to_check`.
deep: If true, the transformer primitive will be recursively applied to all circuits
wrapped inside circuit operations.
tags_to_check: If specified, only circuit operations tagged with one of the `tags_to_check`
are unrolled.

Expand All @@ -506,16 +516,32 @@ def unroll_circuit_op_greedy_frontier(
"""
unrolled_circuit = circuit.unfreeze(copy=True)
frontier: Dict['cirq.Qid', int] = defaultdict(lambda: 0)
for idx, op in circuit.findall_operations(lambda op: _check_circuit_op(op, tags_to_check)):
idx = max(idx, max(frontier[q] for q in op.qubits))
unrolled_circuit.clear_operations_touching(op.qubits, [idx])
sub_circuit = cast(circuits.CircuitOperation, op.untagged).mapped_circuit()
sub_circuit = (
unroll_circuit_op_greedy_earliest(sub_circuit, deep=deep, tags_to_check=tags_to_check)
if deep
else sub_circuit
)
frontier = unrolled_circuit.insert_at_frontier(sub_circuit.all_operations(), idx, frontier)
idx = 0
while idx < len(unrolled_circuit):
for op in unrolled_circuit[idx].operations:
# for idx, op in circuit.findall_operations(
# lambda o: isinstance(o.untagged, circuits.CircuitOperation)
# ):
# Don't touch stuff inserted by unrolling previous circuit ops.
if not isinstance(op.untagged, circuits.CircuitOperation):
continue
if any(frontier[q] > idx for q in op.qubits):
continue
op_untagged = cast(circuits.CircuitOperation, op.untagged)
if deep:
op_untagged = op_untagged.replace(
circuit=unroll_circuit_op_greedy_frontier(
op_untagged.circuit, deep=deep, tags_to_check=tags_to_check
)
)
if tags_to_check is None or set(tags_to_check).intersection(op.tags):
unrolled_circuit.clear_operations_touching(op.qubits, [idx])
frontier = unrolled_circuit.insert_at_frontier(
op_untagged.mapped_circuit().all_operations(), idx, frontier
)
elif deep:
unrolled_circuit.batch_replace([(idx, op, op_untagged.with_tags(*op.tags))])
idx += 1
return _to_target_circuit_type(unrolled_circuit, circuit)


Expand Down
Loading