diff --git a/cirq-core/cirq/transformers/eject_phased_paulis.py b/cirq-core/cirq/transformers/eject_phased_paulis.py index a354fe9f6f1..c58fb6b14ed 100644 --- a/cirq-core/cirq/transformers/eject_phased_paulis.py +++ b/cirq-core/cirq/transformers/eject_phased_paulis.py @@ -26,7 +26,7 @@ import cirq -@transformer_api.transformer +@transformer_api.transformer(add_deep_support=True) def eject_phased_paulis( circuit: 'cirq.AbstractCircuit', *, diff --git a/cirq-core/cirq/transformers/eject_phased_paulis_test.py b/cirq-core/cirq/transformers/eject_phased_paulis_test.py index 5f4663da5f4..b0e545250dc 100644 --- a/cirq-core/cirq/transformers/eject_phased_paulis_test.py +++ b/cirq-core/cirq/transformers/eject_phased_paulis_test.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Iterable, cast +import dataclasses import numpy as np import pytest import sympy @@ -53,29 +54,42 @@ def assert_optimizes( ) # And match the expected circuit. - assert circuit == expected, ( - "Circuit wasn't optimized as expected.\n" - "INPUT:\n" - "{}\n" - "\n" - "EXPECTED OUTPUT:\n" - "{}\n" - "\n" - "ACTUAL OUTPUT:\n" - "{}\n" - "\n" - "EXPECTED OUTPUT (detailed):\n" - "{!r}\n" - "\n" - "ACTUAL OUTPUT (detailed):\n" - "{!r}" - ).format(before, expected, circuit, expected, circuit) + cirq.testing.assert_same_circuits(circuit, expected) # And it should be idempotent. circuit = cirq.eject_phased_paulis( circuit, eject_parameterized=eject_parameterized, context=context ) - assert circuit == expected + cirq.testing.assert_same_circuits(circuit, expected) + + # Nested sub-circuits should also get optimized. + q = before.all_qubits() + c_nested = cirq.Circuit( + [cirq.PhasedXPowGate(phase_exponent=0.5).on_each(*q), (cirq.Z ** 0.5).on_each(*q)], + cirq.CircuitOperation(before.freeze()).repeat(2).with_tags("ignore"), + [cirq.Y.on_each(*q), cirq.X.on_each(*q)], + cirq.CircuitOperation(before.freeze()).repeat(3).with_tags("preserve_tag"), + ) + c_expected = cirq.Circuit( + cirq.PhasedXPowGate(phase_exponent=0.75).on_each(*q), + cirq.Moment(cirq.CircuitOperation(before.freeze()).repeat(2).with_tags("ignore")), + cirq.Z.on_each(*q), + cirq.Moment(cirq.CircuitOperation(expected.freeze()).repeat(3).with_tags("preserve_tag")), + ) + if context is None: + context = cirq.TransformerContext(tags_to_ignore=("ignore",), deep=True) + else: + context = dataclasses.replace( + context, tags_to_ignore=context.tags_to_ignore + ("ignore",), deep=True + ) + c_nested = cirq.eject_phased_paulis( + c_nested, context=context, eject_parameterized=eject_parameterized + ) + cirq.testing.assert_same_circuits(c_nested, c_expected) + c_nested = cirq.eject_phased_paulis( + c_nested, context=context, eject_parameterized=eject_parameterized + ) + cirq.testing.assert_same_circuits(c_nested, c_expected) def quick_circuit(*moments: Iterable[cirq.OP_TREE]) -> cirq.Circuit: