Skip to content

Commit

Permalink
Bugfix in handling of deep=True flag in `cirq.merge_k_qubit_unitaries…
Browse files Browse the repository at this point in the history
…` transformer (#5125)

- Fixes a bug in `cirq.merge_k_qubit_unitaries` due to which the transformer was applied recursively only on circuit operations satisfying `cirq.num_qubits(op) <= k and cirq.has_unitary(op)`.  Fixed the bug and added more tests. 
- Part of #5039
  • Loading branch information
tanujkhattar authored Mar 22, 2022
1 parent 89e7210 commit 091582e
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 2 deletions.
4 changes: 2 additions & 2 deletions cirq-core/cirq/transformers/merge_k_qubit_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ def _rewrite_merged_k_qubit_unitaries(
deep = context.deep if context else False

def map_func(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
if not (protocols.num_qubits(op) <= k and protocols.has_unitary(op)):
return op
op_untagged = op.untagged
if (
deep
Expand All @@ -51,6 +49,8 @@ def map_func(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
merged_circuit_op_tag=merged_circuit_op_tag,
).freeze()
).with_tags(*op.tags)
if not (protocols.num_qubits(op) <= k and protocols.has_unitary(op)):
return op
if rewriter:
return rewriter(
cast(circuits.CircuitOperation, op_untagged)
Expand Down
24 changes: 24 additions & 0 deletions cirq-core/cirq/transformers/merge_k_qubit_gates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,27 @@ def _wrap_in_matrix_gate(ops: cirq.OP_TREE):
)
c_new_matrix = cirq.merge_k_qubit_unitaries(c_orig, k=2, context=context)
cirq.testing.assert_same_circuits(c_new_matrix, c_expected_matrix)


def test_merge_k_qubit_unitaries_deep_recurses_on_large_circuit_op():
q = cirq.LineQubit.range(2)
c_orig = cirq.Circuit(
cirq.CircuitOperation(cirq.FrozenCircuit(cirq.X(q[0]), cirq.H(q[0]), cirq.CNOT(*q)))
)
c_expected = cirq.Circuit(
cirq.CircuitOperation(
cirq.FrozenCircuit(
cirq.CircuitOperation(cirq.FrozenCircuit(cirq.X(q[0]), cirq.H(q[0]))).with_tags(
"merged"
),
cirq.CNOT(*q),
)
)
)
c_new = cirq.merge_k_qubit_unitaries(
c_orig,
context=cirq.TransformerContext(deep=True),
k=1,
rewriter=lambda op: op.with_tags("merged"),
)
cirq.testing.assert_same_circuits(c_new, c_expected)

0 comments on commit 091582e

Please sign in to comment.