diff --git a/cirq-core/cirq/circuits/circuit.py b/cirq-core/cirq/circuits/circuit.py index a23e87a2d4a..7c4c8c20afa 100644 --- a/cirq-core/cirq/circuits/circuit.py +++ b/cirq-core/cirq/circuits/circuit.py @@ -149,11 +149,26 @@ def from_moments(cls: Type[CIRCUIT_TYPE], *moments: 'cirq.OP_TREE') -> CIRCUIT_T """Create a circuit from moment op trees. Args: - *moments: Op tree for each moment. + *moments: Op tree for each moment. If an op tree is a moment, it + will be included directly in the new circuit. If an op tree is + a circuit, it will be frozen, wrapped in a CircuitOperation, and + included in its own moment in the new circuit. Otherwise, the + op tree will be passed to `cirq.Moment` to create a new moment + which is then included in the new circuit. Note that in the + latter case we have the normal restriction that operations in a + moment must be applied to disjoint sets of qubits. """ - return cls._from_moments( - moment if isinstance(moment, Moment) else Moment(moment) for moment in moments - ) + return cls._from_moments(cls._make_moments(moments)) + + @staticmethod + def _make_moments(moments: Iterable['cirq.OP_TREE']) -> Iterator['cirq.Moment']: + for m in moments: + if isinstance(m, Moment): + yield m + elif isinstance(m, AbstractCircuit): + yield Moment(m.freeze().to_op()) + else: + yield Moment(m) @classmethod @abc.abstractmethod diff --git a/cirq-core/cirq/circuits/circuit_test.py b/cirq-core/cirq/circuits/circuit_test.py index 222ebe414c4..b8d20ecfc08 100644 --- a/cirq-core/cirq/circuits/circuit_test.py +++ b/cirq-core/cirq/circuits/circuit_test.py @@ -23,7 +23,6 @@ import sympy import cirq -import cirq.testing from cirq import circuits from cirq import ops from cirq.testing.devices import ValidatingTestDevice @@ -72,19 +71,32 @@ def validate_moment(self, moment): def test_from_moments(): a, b, c, d = cirq.LineQubit.range(4) - assert cirq.Circuit.from_moments( + moment = cirq.Moment(cirq.Z(a), cirq.Z(b)) + subcircuit = cirq.FrozenCircuit.from_moments(cirq.X(c), cirq.Y(d)) + circuit = cirq.Circuit.from_moments( + moment, + subcircuit, [cirq.X(a), cirq.Y(b)], [cirq.X(c)], [], cirq.Z(d), [cirq.measure(a, b, key='ab'), cirq.measure(c, d, key='cd')], - ) == cirq.Circuit( + ) + assert circuit == cirq.Circuit( + cirq.Moment(cirq.Z(a), cirq.Z(b)), + cirq.Moment( + cirq.CircuitOperation( + cirq.FrozenCircuit(cirq.Moment(cirq.X(c)), cirq.Moment(cirq.Y(d))) + ) + ), cirq.Moment(cirq.X(a), cirq.Y(b)), cirq.Moment(cirq.X(c)), cirq.Moment(), cirq.Moment(cirq.Z(d)), cirq.Moment(cirq.measure(a, b, key='ab'), cirq.measure(c, d, key='cd')), ) + assert circuit[0] is moment + assert circuit[1].operations[0].circuit is subcircuit def test_alignment(): diff --git a/cirq-core/cirq/circuits/frozen_circuit_test.py b/cirq-core/cirq/circuits/frozen_circuit_test.py index d41fcb1b0e5..660fac319ca 100644 --- a/cirq-core/cirq/circuits/frozen_circuit_test.py +++ b/cirq-core/cirq/circuits/frozen_circuit_test.py @@ -24,19 +24,32 @@ def test_from_moments(): a, b, c, d = cirq.LineQubit.range(4) - assert cirq.FrozenCircuit.from_moments( + moment = cirq.Moment(cirq.Z(a), cirq.Z(b)) + subcircuit = cirq.FrozenCircuit.from_moments(cirq.X(c), cirq.Y(d)) + circuit = cirq.FrozenCircuit.from_moments( + moment, + subcircuit, [cirq.X(a), cirq.Y(b)], [cirq.X(c)], [], cirq.Z(d), [cirq.measure(a, b, key='ab'), cirq.measure(c, d, key='cd')], - ) == cirq.FrozenCircuit( + ) + assert circuit == cirq.FrozenCircuit( + cirq.Moment(cirq.Z(a), cirq.Z(b)), + cirq.Moment( + cirq.CircuitOperation( + cirq.FrozenCircuit(cirq.Moment(cirq.X(c)), cirq.Moment(cirq.Y(d))) + ) + ), cirq.Moment(cirq.X(a), cirq.Y(b)), cirq.Moment(cirq.X(c)), cirq.Moment(), cirq.Moment(cirq.Z(d)), cirq.Moment(cirq.measure(a, b, key='ab'), cirq.measure(c, d, key='cd')), ) + assert circuit[0] is moment + assert circuit[1].operations[0].circuit is subcircuit def test_freeze_and_unfreeze():