Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Instance GateFamily check for equality ignoring global phase #4542

Merged
merged 8 commits into from
Oct 13, 2021
Prev Previous commit
Next Next commit
Rename accept_global_phase to accept_global_phase_op
  • Loading branch information
tanujkhattar committed Oct 12, 2021
commit b03c0e45755409518a9811b5d435af67bffb8247
6 changes: 3 additions & 3 deletions cirq-core/cirq/neutral_atoms/neutral_atom_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def neutral_atom_gateset(max_parallel_z=None, max_parallel_xy=None):
ops.MeasurementGate,
ops.IdentityGate,
unroll_circuit_op=False,
accept_global_phase=False,
accept_global_phase_op=False,
)


Expand Down Expand Up @@ -100,15 +100,15 @@ def __init__(
ops.ParallelGateFamily(ops.YPowGate),
ops.ParallelGateFamily(ops.PhasedXPowGate),
unroll_circuit_op=False,
accept_global_phase=False,
accept_global_phase_op=False,
)
self.controlled_gateset = ops.Gateset(
ops.AnyIntegerPowerGateFamily(ops.CNotPowGate),
ops.AnyIntegerPowerGateFamily(ops.CCNotPowGate),
ops.AnyIntegerPowerGateFamily(ops.CZPowGate),
ops.AnyIntegerPowerGateFamily(ops.CCZPowGate),
unroll_circuit_op=False,
accept_global_phase=False,
accept_global_phase_op=False,
)
self.gateset = neutral_atom_gateset(max_parallel_z, max_parallel_xy)
for q in qubits:
Expand Down
22 changes: 11 additions & 11 deletions cirq-core/cirq/ops/gateset.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def __init__(
*gates: Union[Type[raw_types.Gate], raw_types.Gate, GateFamily],
name: Optional[str] = None,
unroll_circuit_op: bool = True,
accept_global_phase: bool = True,
accept_global_phase_op: bool = True,
) -> None:
"""Init Gateset.

Expand All @@ -200,14 +200,14 @@ def __init__(
name: (Optional) Name for the Gateset. Useful for description.
unroll_circuit_op: If True, `cirq.CircuitOperation` is recursively
validated by validating the underlying `cirq.Circuit`.
accept_global_phase: If True, `cirq.GlobalPhaseOperation` is accepted.
accept_global_phase_op: If True, `cirq.GlobalPhaseOperation` is accepted.
"""
self._name = name
self._gates = frozenset(
g if isinstance(g, GateFamily) else GateFamily(gate=g) for g in gates
)
self._unroll_circuit_op = unroll_circuit_op
self._accept_global_phase = accept_global_phase
self._accept_global_phase_op = accept_global_phase_op
self._instance_gate_families: Dict[raw_types.Gate, GateFamily] = {}
self._type_gate_families: Dict[Type[raw_types.Gate], GateFamily] = {}
for g in self._gates:
Expand All @@ -230,7 +230,7 @@ def with_params(
*,
name: Optional[str] = None,
unroll_circuit_op: Optional[bool] = None,
accept_global_phase: Optional[bool] = None,
accept_global_phase_op: Optional[bool] = None,
) -> 'Gateset':
"""Returns a copy of this Gateset with identical gates and new values for named arguments.

Expand All @@ -240,7 +240,7 @@ def with_params(
name: New name for the Gateset.
unroll_circuit_op: If True, new Gateset will recursively validate
`cirq.CircuitOperation` by validating the underlying `cirq.Circuit`.
accept_global_phase: If True, new Gateset will accept `cirq.GlobalPhaseOperation`.
accept_global_phase_op: If True, new Gateset will accept `cirq.GlobalPhaseOperation`.

Returns:
`self` if all new values are None or identical to the values of current Gateset.
Expand All @@ -252,18 +252,18 @@ def val_if_none(var: Any, val: Any) -> Any:

name = val_if_none(name, self._name)
unroll_circuit_op = val_if_none(unroll_circuit_op, self._unroll_circuit_op)
accept_global_phase = val_if_none(accept_global_phase, self._accept_global_phase)
accept_global_phase_op = val_if_none(accept_global_phase_op, self._accept_global_phase_op)
if (
name == self._name
and unroll_circuit_op == self._unroll_circuit_op
and accept_global_phase == self._accept_global_phase
and accept_global_phase_op == self._accept_global_phase_op
):
return self
return Gateset(
*self.gates,
name=name,
unroll_circuit_op=cast(bool, unroll_circuit_op),
accept_global_phase=cast(bool, accept_global_phase),
accept_global_phase_op=cast(bool, accept_global_phase_op),
)

def __contains__(self, item: Union[raw_types.Gate, raw_types.Operation]) -> bool:
Expand Down Expand Up @@ -366,7 +366,7 @@ def _validate_operation(self, op: raw_types.Operation) -> bool:
)
return self.validate(op_circuit)
elif isinstance(op, global_phase_op.GlobalPhaseOperation):
return self._accept_global_phase
return self._accept_global_phase_op
else:
return False

Expand All @@ -375,7 +375,7 @@ def _value_equality_values_(self) -> Any:
frozenset(self.gates),
self.name,
self._unroll_circuit_op,
self._accept_global_phase,
self._accept_global_phase_op,
)

def __repr__(self) -> str:
Expand All @@ -384,7 +384,7 @@ def __repr__(self) -> str:
f'{",".join([repr(g) for g in self.gates])},'
f'name = "{self.name}",'
f'unroll_circuit_op = {self._unroll_circuit_op},'
f'accept_global_phase = {self._accept_global_phase})'
f'accept_global_phase_op = {self._accept_global_phase_op})'
)

def __str__(self) -> str:
Expand Down
14 changes: 8 additions & 6 deletions cirq-core/cirq/ops/gateset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def assert_validate_and_contains_consistent(gateset, op_tree, result):
assert_validate_and_contains_consistent(
gateset.with_params(
unroll_circuit_op=use_circuit_op,
accept_global_phase=use_global_phase,
accept_global_phase_op=use_global_phase,
),
op_tree,
True,
Expand All @@ -267,7 +267,7 @@ def assert_validate_and_contains_consistent(gateset, op_tree, result):
assert_validate_and_contains_consistent(
gateset.with_params(
unroll_circuit_op=False,
accept_global_phase=False,
accept_global_phase_op=False,
),
op_tree,
False,
Expand All @@ -280,16 +280,16 @@ def test_with_params():
gateset.with_params(
name=gateset.name,
unroll_circuit_op=gateset._unroll_circuit_op,
accept_global_phase=gateset._accept_global_phase,
accept_global_phase_op=gateset._accept_global_phase_op,
)
is gateset
)
gateset_with_params = gateset.with_params(
name='new name', unroll_circuit_op=False, accept_global_phase=False
name='new name', unroll_circuit_op=False, accept_global_phase_op=False
)
assert gateset_with_params.name == 'new name'
assert gateset_with_params._unroll_circuit_op is False
assert gateset_with_params._accept_global_phase is False
assert gateset_with_params._accept_global_phase_op is False


def test_gateset_eq():
Expand All @@ -298,7 +298,9 @@ def test_gateset_eq():
eq.add_equality_group(cirq.Gateset(CustomX ** 3))
eq.add_equality_group(cirq.Gateset(CustomX, name='Custom Gateset'))
eq.add_equality_group(cirq.Gateset(CustomX, name='Custom Gateset', unroll_circuit_op=False))
eq.add_equality_group(cirq.Gateset(CustomX, name='Custom Gateset', accept_global_phase=False))
eq.add_equality_group(
cirq.Gateset(CustomX, name='Custom Gateset', accept_global_phase_op=False)
)
eq.add_equality_group(
cirq.Gateset(
cirq.GateFamily(CustomX, name='custom_name', description='custom_description')
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/optimizers/merge_interactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def __init__(
self.gateset = ops.Gateset(
ops.CZPowGate if allow_partial_czs else ops.CZ,
unroll_circuit_op=False,
accept_global_phase=True,
accept_global_phase_op=True,
)

def _may_keep_old_op(self, old_op: 'cirq.Operation') -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(
self.gateset = ops.Gateset(
ops.SQRT_ISWAP_INV if use_sqrt_iswap_inv else ops.SQRT_ISWAP,
unroll_circuit_op=False,
accept_global_phase=True,
accept_global_phase_op=True,
)

def _may_keep_old_op(self, old_op: 'cirq.Operation') -> bool:
Expand Down
2 changes: 1 addition & 1 deletion cirq-ionq/cirq_ionq/ionq_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(self, qubits: Union[Sequence[cirq.LineQubit], int], atol=1e-8):
cirq.ZZPowGate,
cirq.MeasurementGate,
unroll_circuit_op=False,
accept_global_phase=False,
accept_global_phase_op=False,
)

def qubit_set(self) -> AbstractSet['cirq.Qid']:
Expand Down
2 changes: 1 addition & 1 deletion cirq-pasqal/cirq_pasqal/pasqal_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(self, qubits: Sequence[cirq.ops.Qid]) -> None:
cirq.IdentityGate,
cirq.MeasurementGate,
unroll_circuit_op=False,
accept_global_phase=False,
accept_global_phase_op=False,
)
self.qubits = qubits

Expand Down