Skip to content

Commit

Permalink
Add support for deep=True to merge_single_qubit_gates* transformers (#…
Browse files Browse the repository at this point in the history
…5123)

- Adds support to recursively run `cirq.merge_single_qubit_moments_to_phxz` transformer on circuits wrapped inside a circuit operation by setting deep=True in transformer context.
- Also adds tests for `cirq.merge_single_qubit_gates_to_phxz` and `cirq.merge_single_qubit_gates_to_phased_x_and_z`, both of which automatically support deep=True flag after #5122
- Part of #5039
  • Loading branch information
tanujkhattar authored Mar 22, 2022
1 parent 92d19f6 commit 518d828
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 11 deletions.
7 changes: 6 additions & 1 deletion cirq-core/cirq/transformers/merge_single_qubit_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,9 @@ def merge_func(m1: 'cirq.Moment', m2: 'cirq.Moment') -> Optional['cirq.Moment']:
ret_ops.append(gate(q))
return circuits.Moment(ret_ops)

return transformer_primitives.merge_moments(circuit, merge_func).unfreeze(copy=False)
return transformer_primitives.merge_moments(
circuit,
merge_func,
deep=context.deep if context else False,
tags_to_ignore=tuple(tags_to_ignore),
).unfreeze(copy=False)
110 changes: 100 additions & 10 deletions cirq-core/cirq/transformers/merge_single_qubit_gates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,38 @@ def test_merge_single_qubit_gates_into_phased_x_z():
)


def test_merge_single_qubit_gates_into_phxz():
def phxz(a, x, z):
return cirq.PhasedXZGate(
axis_phase_exponent=a,
x_exponent=x,
z_exponent=z,
)
def test_merge_single_qubit_gates_into_phased_x_z_deep():
a = cirq.NamedQubit("a")
c_nested = cirq.FrozenCircuit(cirq.H(a), cirq.Z(a), cirq.H(a).with_tags("ignore"))
c_nested_merged = cirq.FrozenCircuit(
cirq.PhasedXPowGate(phase_exponent=-0.5, exponent=0.5).on(a), cirq.H(a).with_tags("ignore")
)
c_orig = cirq.Circuit(
c_nested,
cirq.CircuitOperation(c_nested).repeat(4).with_tags("ignore"),
c_nested,
cirq.CircuitOperation(c_nested).repeat(5).with_tags("preserve_tags"),
c_nested,
cirq.CircuitOperation(c_nested).repeat(6),
)
c_expected = cirq.Circuit(
c_nested_merged,
cirq.CircuitOperation(c_nested).repeat(4).with_tags("ignore"),
c_nested_merged,
cirq.CircuitOperation(c_nested_merged).repeat(5).with_tags("preserve_tags"),
c_nested_merged,
cirq.CircuitOperation(c_nested_merged).repeat(6),
)
context = cirq.TransformerContext(tags_to_ignore=["ignore"], deep=True)
c_new = cirq.merge_single_qubit_gates_to_phased_x_and_z(c_orig, context=context)
cirq.testing.assert_same_circuits(c_new, c_expected)


def _phxz(a: float, x: float, z: float):
return cirq.PhasedXZGate(axis_phase_exponent=a, x_exponent=x, z_exponent=z)


def test_merge_single_qubit_gates_into_phxz():
a, b = cirq.LineQubit.range(2)
c = cirq.Circuit(
cirq.X(a),
Expand All @@ -75,16 +99,41 @@ def phxz(a, x, z):
assert_optimizes(
optimized=cirq.merge_single_qubit_gates_to_phxz(c),
expected=cirq.Circuit(
phxz(-1, 1, 0).on(a),
phxz(0.5, 0.5, 0).on(b),
_phxz(-1, 1, 0).on(a),
_phxz(0.5, 0.5, 0).on(b),
cirq.CZ(a, b),
phxz(-0.5, 0.5, 0).on(a),
_phxz(-0.5, 0.5, 0).on(a),
cirq.measure(b, key="m"),
cirq.H(a).with_classical_controls("m"),
),
)


def test_merge_single_qubit_gates_into_phxz_deep():
a = cirq.NamedQubit("a")
c_nested = cirq.FrozenCircuit(cirq.H(a), cirq.Z(a), cirq.H(a).with_tags("ignore"))
c_nested_merged = cirq.FrozenCircuit(_phxz(-0.5, 0.5, 0).on(a), cirq.H(a).with_tags("ignore"))
c_orig = cirq.Circuit(
c_nested,
cirq.CircuitOperation(c_nested).repeat(4).with_tags("ignore"),
c_nested,
cirq.CircuitOperation(c_nested).repeat(5).with_tags("preserve_tags"),
c_nested,
cirq.CircuitOperation(c_nested).repeat(6),
)
c_expected = cirq.Circuit(
c_nested_merged,
cirq.CircuitOperation(c_nested).repeat(4).with_tags("ignore"),
c_nested_merged,
cirq.CircuitOperation(c_nested_merged).repeat(5).with_tags("preserve_tags"),
c_nested_merged,
cirq.CircuitOperation(c_nested_merged).repeat(6),
)
context = cirq.TransformerContext(tags_to_ignore=["ignore"], deep=True)
c_new = cirq.merge_single_qubit_gates_to_phxz(c_orig, context=context)
cirq.testing.assert_same_circuits(c_new, c_expected)


def test_merge_single_qubit_moments_to_phxz():
q = cirq.LineQubit.range(3)
c_orig = cirq.Circuit(
Expand Down Expand Up @@ -127,3 +176,44 @@ def test_merge_single_qubit_moments_to_phxz():
a: ══════════════════════════════════════════════════════════════════════════════════@═══^═══════
''',
)


def test_merge_single_qubit_moments_to_phxz_deep():
q = cirq.LineQubit.range(3)
x_t_y = cirq.FrozenCircuit(
cirq.Moment(cirq.X.on_each(*q[:2])),
cirq.Moment(cirq.T.on_each(*q[1:])),
cirq.Moment(cirq.Y.on_each(*q[:2])),
)
c_nested = cirq.FrozenCircuit(
x_t_y,
cirq.Moment(cirq.CZ(*q[:2]), cirq.Y(q[2])),
x_t_y,
cirq.Moment(cirq.Y(q[0]).with_tags("ignore"), cirq.Z.on_each(*q[1:])),
)

c_nested_merged = cirq.FrozenCircuit(
[_phxz(-0.25, 0.0, 0.75)(q[1]), _phxz(0.25, 0.0, 0.25)(q[2]), _phxz(-0.5, 0.0, -1.0)(q[0])],
[cirq.CZ(q[0], q[1]), cirq.Y(q[2])],
[_phxz(-0.25, 0.0, 0.75)(q[1]), _phxz(0.25, 0.0, 0.25)(q[2]), _phxz(-0.5, 0.0, -1.0)(q[0])],
cirq.Moment(cirq.Y(q[0]).with_tags("ignore"), cirq.Z.on_each(*q[1:])),
)
c_orig = cirq.Circuit(
c_nested,
cirq.CircuitOperation(c_nested).repeat(4).with_tags("ignore"),
c_nested,
cirq.CircuitOperation(c_nested).repeat(5).with_tags("preserve_tags"),
c_nested,
cirq.CircuitOperation(c_nested).repeat(6),
)
c_expected = cirq.Circuit(
c_nested_merged,
cirq.CircuitOperation(c_nested).repeat(4).with_tags("ignore"),
c_nested_merged,
cirq.CircuitOperation(c_nested_merged).repeat(5).with_tags("preserve_tags"),
c_nested_merged,
cirq.CircuitOperation(c_nested_merged).repeat(6),
)
context = cirq.TransformerContext(tags_to_ignore=["ignore"], deep=True)
c_new = cirq.merge_single_qubit_moments_to_phxz(c_orig, context=context)
cirq.testing.assert_same_circuits(c_new, c_expected)

0 comments on commit 518d828

Please sign in to comment.