Skip to content

Commit

Permalink
Optimize circuit/moment resolution to reuse instances if possible (qu…
Browse files Browse the repository at this point in the history
  • Loading branch information
maffoo authored and rht committed May 1, 2023
1 parent afd49cf commit 88d6754
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 25 deletions.
34 changes: 14 additions & 20 deletions cirq-core/cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1284,6 +1284,20 @@ def _is_parameterized_(self) -> bool:
def _parameter_names_(self) -> AbstractSet[str]:
return {name for op in self.all_operations() for name in protocols.parameter_names(op)}

def _resolve_parameters_(
self: CIRCUIT_TYPE, resolver: 'cirq.ParamResolver', recursive: bool
) -> CIRCUIT_TYPE:
changed = False
resolved_moments: List['cirq.Moment'] = []
for moment in self:
resolved_moment = protocols.resolve_parameters(moment, resolver, recursive)
if resolved_moment is not moment:
changed = True
resolved_moments.append(resolved_moment)
if not changed:
return self
return self._from_moments(resolved_moments)

def _qasm_(self) -> str:
return self.to_qasm()

Expand Down Expand Up @@ -2377,17 +2391,6 @@ def clear_operations_touching(
if 0 <= k < len(self._moments):
self._moments[k] = self._moments[k].without_operations_touching(qubits)

def _resolve_parameters_(
self, resolver: 'cirq.ParamResolver', recursive: bool
) -> 'cirq.Circuit':
resolved_moments = []
for moment in self:
resolved_operations = _resolve_operations(moment.operations, resolver, recursive)
new_moment = Moment(resolved_operations)
resolved_moments.append(new_moment)

return Circuit(resolved_moments)

@property
def moments(self) -> Sequence['cirq.Moment']:
return self._moments
Expand Down Expand Up @@ -2441,15 +2444,6 @@ def _pick_inserted_ops_moment_indices(
return moment_indices, frontier


def _resolve_operations(
operations: Iterable['cirq.Operation'], param_resolver: 'cirq.ParamResolver', recursive: bool
) -> List['cirq.Operation']:
resolved_operations: List['cirq.Operation'] = []
for op in operations:
resolved_operations.append(protocols.resolve_parameters(op, param_resolver, recursive))
return resolved_operations


def _get_moment_annotations(moment: 'cirq.Moment') -> Iterator['cirq.Operation']:
for op in moment.operations:
if op.qubits:
Expand Down
17 changes: 17 additions & 0 deletions cirq-core/cirq/circuits/circuit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3039,6 +3039,23 @@ def test_resolve_parameters(circuit_cls, resolve_fn):
cirq.testing.assert_same_circuits(expected_circuit, resolved_circuit)


@pytest.mark.parametrize('circuit_cls', [cirq.Circuit, cirq.FrozenCircuit])
@pytest.mark.parametrize('resolve_fn', [cirq.resolve_parameters, cirq.resolve_parameters_once])
def test_resolve_parameters_no_change(circuit_cls, resolve_fn):
a, b = cirq.LineQubit.range(2)
circuit = circuit_cls(cirq.CZ(a, b), cirq.X(a), cirq.Y(b))
resolved_circuit = resolve_fn(circuit, cirq.ParamResolver({'u': 0.1, 'v': 0.3, 'w': 0.2}))
assert resolved_circuit is circuit

circuit = circuit_cls(
cirq.CZ(a, b) ** sympy.Symbol('u'),
cirq.X(a) ** sympy.Symbol('v'),
cirq.Y(b) ** sympy.Symbol('w'),
)
resolved_circuit = resolve_fn(circuit, cirq.ParamResolver({}))
assert resolved_circuit is circuit


@pytest.mark.parametrize('circuit_cls', [cirq.Circuit, cirq.FrozenCircuit])
@pytest.mark.parametrize('resolve_fn', [cirq.resolve_parameters, cirq.resolve_parameters_once])
def test_parameter_names(circuit_cls, resolve_fn):
Expand Down
5 changes: 0 additions & 5 deletions cirq-core/cirq/circuits/frozen_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,6 @@ def __pow__(self, other) -> 'cirq.FrozenCircuit':
except:
return NotImplemented

def _resolve_parameters_(
self, resolver: 'cirq.ParamResolver', recursive: bool
) -> 'cirq.FrozenCircuit':
return self.unfreeze()._resolve_parameters_(resolver, recursive).freeze()

def concat_ragged(
*circuits: 'cirq.AbstractCircuit', align: Union['cirq.Alignment', str] = Alignment.LEFT
) -> 'cirq.FrozenCircuit':
Expand Down
22 changes: 22 additions & 0 deletions cirq-core/cirq/circuits/moment.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@

import itertools
from typing import (
AbstractSet,
Any,
Callable,
Dict,
FrozenSet,
Iterable,
Iterator,
List,
Mapping,
overload,
Optional,
Expand Down Expand Up @@ -236,6 +238,26 @@ def without_operations_touching(self, qubits: Iterable['cirq.Qid']) -> 'cirq.Mom
if qubits.isdisjoint(frozenset(operation.qubits))
)

def _is_parameterized_(self) -> bool:
return any(protocols.is_parameterized(op) for op in self)

def _parameter_names_(self) -> AbstractSet[str]:
return {name for op in self for name in protocols.parameter_names(op)}

def _resolve_parameters_(
self, resolver: 'cirq.ParamResolver', recursive: bool
) -> 'cirq.Moment':
changed = False
resolved_ops: List['cirq.Operation'] = []
for op in self:
resolved_op = protocols.resolve_parameters(op, resolver, recursive)
if resolved_op != op:
changed = True
resolved_ops.append(resolved_op)
if not changed:
return self
return Moment(resolved_ops)

def _with_measurement_key_mapping_(self, key_map: Mapping[str, str]):
return Moment(
protocols.with_measurement_key_mapping(op, key_map)
Expand Down
33 changes: 33 additions & 0 deletions cirq-core/cirq/circuits/moment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import numpy as np
import pytest
import sympy

import cirq
import cirq.testing
Expand Down Expand Up @@ -274,6 +275,38 @@ def test_without_operations_touching():
)


def test_is_parameterized():
a, b = cirq.LineQubit.range(2)
moment = cirq.Moment(cirq.X(a) ** sympy.Symbol('v'), cirq.Y(b) ** sympy.Symbol('w'))
assert cirq.is_parameterized(moment)
assert not cirq.is_parameterized(cirq.Moment(cirq.X(a), cirq.Y(b)))


def test_resolve_parameters():
a, b = cirq.LineQubit.range(2)
moment = cirq.Moment(cirq.X(a) ** sympy.Symbol('v'), cirq.Y(b) ** sympy.Symbol('w'))
resolved_moment = cirq.resolve_parameters(moment, cirq.ParamResolver({'v': 0.1, 'w': 0.2}))
assert resolved_moment == cirq.Moment(cirq.X(a) ** 0.1, cirq.Y(b) ** 0.2)


def test_resolve_parameters_no_change():
a, b = cirq.LineQubit.range(2)
moment = cirq.Moment(cirq.X(a), cirq.Y(b))
resolved_moment = cirq.resolve_parameters(moment, cirq.ParamResolver({'v': 0.1, 'w': 0.2}))
assert resolved_moment is moment

moment = cirq.Moment(cirq.X(a) ** sympy.Symbol('v'), cirq.Y(b) ** sympy.Symbol('w'))
resolved_moment = cirq.resolve_parameters(moment, cirq.ParamResolver({}))
assert resolved_moment is moment


def test_parameter_names():
a, b = cirq.LineQubit.range(2)
moment = cirq.Moment(cirq.X(a) ** sympy.Symbol('v'), cirq.Y(b) ** sympy.Symbol('w'))
assert cirq.parameter_names(moment) == {'v', 'w'}
assert cirq.parameter_names(cirq.Moment(cirq.X(a), cirq.Y(b))) == set()


def test_with_measurement_keys():
a, b = cirq.LineQubit.range(2)
m = cirq.Moment(cirq.measure(a, key='m1'), cirq.measure(b, key='m2'))
Expand Down

0 comments on commit 88d6754

Please sign in to comment.