Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up cirq.map_operations and cirq.map_operations_and_unroll #6250

Merged
merged 8 commits into from
Aug 24, 2023
121 changes: 93 additions & 28 deletions cirq-core/cirq/transformers/transformer_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,89 @@ def map_moments(
)


def _map_operations_impl(
tanujkhattar marked this conversation as resolved.
Show resolved Hide resolved
circuit: CIRCUIT_TYPE,
map_func: Callable[[ops.Operation, int], ops.OP_TREE],
*,
deep: bool = False,
raise_if_add_qubits=True,
tags_to_ignore: Sequence[Hashable] = (),
wrap_in_circuit_op: bool = True,
) -> CIRCUIT_TYPE:
tags_to_ignore_set = set(tags_to_ignore)

def apply_map_func(op: 'cirq.Operation', idx: int) -> List['cirq.Operation']:
if tags_to_ignore_set.intersection(op.tags):
return [op]
if deep and isinstance(op.untagged, circuits.CircuitOperation):
mapped_op = op.untagged.replace(
tanujkhattar marked this conversation as resolved.
Show resolved Hide resolved
circuit=_map_operations_impl(
op.untagged.circuit,
map_func,
deep=deep,
raise_if_add_qubits=raise_if_add_qubits,
tags_to_ignore=tags_to_ignore,
wrap_in_circuit_op=wrap_in_circuit_op,
)
).with_tags(*op.tags)
op = mapped_op
mapped_ops = [*ops.flatten_to_ops(map_func(op, idx))]
op_qubits = set(op.qubits)
mapped_ops_qubits: Set['cirq.Qid'] = set()
has_overlapping_ops = False
for mapped_op in mapped_ops:
if raise_if_add_qubits and not op_qubits.issuperset(mapped_op.qubits):
raise ValueError(
f"Mapped operations {mapped_ops} should act on a subset "
f"of qubits of the original operation {op}"
)
if mapped_ops_qubits.intersection(mapped_op.qubits):
has_overlapping_ops = True
mapped_ops_qubits = mapped_ops_qubits.union(mapped_op.qubits)
if wrap_in_circuit_op and has_overlapping_ops:
tanujkhattar marked this conversation as resolved.
Show resolved Hide resolved
mapped_ops = [
circuits.CircuitOperation(circuits.FrozenCircuit(mapped_ops)).with_tags(
MAPPED_CIRCUIT_OP_TAG
)
]
return mapped_ops

new_moments: List[List['cirq.Operation']] = []

# Keep track of the latest time index for each qubit, measurement key, and control key.
qubit_time_index: Dict['cirq.Qid', int] = {}
measurement_time_index: Dict['cirq.MeasurementKey', int] = {}
control_time_index: Dict['cirq.MeasurementKey', int] = {}

# New mapped operations in the current moment should be inserted after `last_moment_time_index`.
last_moment_time_index = -1

for idx, moment in enumerate(circuit):
if wrap_in_circuit_op:
new_moments.append([])
for op in moment:
mapped_ops = apply_map_func(op, idx)

for mapped_op in mapped_ops:
# Identify the earliest moment that can accommodate this op.
placement_index = circuits.circuit.get_earliest_accommodating_moment_index(
mapped_op, qubit_time_index, measurement_time_index, control_time_index
)
placement_index = max(placement_index, last_moment_time_index + 1)
new_moments += [[] for _ in range(placement_index - len(new_moments) + 1)]
tanujkhattar marked this conversation as resolved.
Show resolved Hide resolved
new_moments[placement_index].append(mapped_op)
for qubit in mapped_op.qubits:
qubit_time_index[qubit] = placement_index
for key in protocols.measurement_key_objs(mapped_op):
measurement_time_index[key] = placement_index
for key in protocols.control_keys(mapped_op):
control_time_index[key] = placement_index

last_moment_time_index = len(new_moments) - 1

return _create_target_circuit_type([circuits.Moment(moment) for moment in new_moments], circuit)


def map_operations(
circuit: CIRCUIT_TYPE,
map_func: Callable[[ops.Operation, int], ops.OP_TREE],
Expand Down Expand Up @@ -139,29 +222,13 @@ def map_operations(
Returns:
Copy of input circuit with mapped operations (wrapped in a tagged CircuitOperation).
"""

def apply_map(op: ops.Operation, idx: int) -> ops.OP_TREE:
if not set(op.tags).isdisjoint(tags_to_ignore):
return op
c = circuits.FrozenCircuit(map_func(op, idx))
if raise_if_add_qubits and not c.all_qubits().issubset(op.qubits):
raise ValueError(
f"Mapped operations {c.all_operations()} should act on a subset "
f"of qubits of the original operation {op}"
)
if len(c) <= 1:
# Either empty circuit or all operations act in the same moment;
# So, we don't need to wrap them in a circuit_op.
return c[0].operations if c else []
circuit_op = circuits.CircuitOperation(c).with_tags(MAPPED_CIRCUIT_OP_TAG)
return circuit_op

return map_moments(
return _map_operations_impl(
circuit,
lambda m, i: circuits.Circuit(apply_map(op, i) for op in m.operations).moments
or [circuits.Moment()],
map_func,
deep=deep,
raise_if_add_qubits=raise_if_add_qubits,
tags_to_ignore=tags_to_ignore,
wrap_in_circuit_op=True,
)


Expand Down Expand Up @@ -191,15 +258,13 @@ def map_operations_and_unroll(
Returns:
Copy of input circuit with mapped operations, unrolled in a moment preserving way.
"""
return unroll_circuit_op(
map_operations(
circuit,
map_func,
deep=deep,
raise_if_add_qubits=raise_if_add_qubits,
tags_to_ignore=tags_to_ignore,
),
return _map_operations_impl(
circuit,
map_func,
deep=deep,
raise_if_add_qubits=raise_if_add_qubits,
tags_to_ignore=tags_to_ignore,
wrap_in_circuit_op=False,
)


Expand Down