Skip to content

Commit

Permalink
Merge serializable_forms and deserialized_forms (#6520)
Browse files Browse the repository at this point in the history
* Merge serializable_forms and deserialized_forms

* Documentation for adding gates

* Update documentation for adding CompilationTargetGatesets
  • Loading branch information
verult authored Mar 22, 2024
1 parent edda3a5 commit edd8393
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 59 deletions.
97 changes: 47 additions & 50 deletions cirq-google/cirq_google/devices/grid_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,31 @@
_SQRT_ISWAP_FSIM_GATE_FAMILY = ops.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP])
_SQRT_ISWAP_INV_FSIM_GATE_FAMILY = ops.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP_INV])
_CZ_FSIM_GATE_FAMILY = ops.FSimGateFamily(gates_to_accept=[cirq.CZ])
_SYC_GATE_FAMILY = cirq.GateFamily(ops.SYC)
_SQRT_ISWAP_GATE_FAMILY = cirq.GateFamily(cirq.SQRT_ISWAP)
_SQRT_ISWAP_INV_GATE_FAMILY = cirq.GateFamily(cirq.SQRT_ISWAP_INV)
_CZ_GATE_FAMILY = cirq.GateFamily(cirq.CZ)


# TODO(#5050) Add GlobalPhaseGate
# Target gates of `cirq_google.GoogleCZTargetGateset`.
_CZ_TARGET_GATES = [_CZ_FSIM_GATE_FAMILY, _PHASED_XZ_GATE_FAMILY, _MEASUREMENT_GATE_FAMILY]
_CZ_TARGET_GATES = [
_CZ_FSIM_GATE_FAMILY,
_CZ_GATE_FAMILY,
_PHASED_XZ_GATE_FAMILY,
_MEASUREMENT_GATE_FAMILY,
]
# Target gates of `cirq_google.SycamoreTargetGateset`.
_SYC_TARGET_GATES = [_SYC_FSIM_GATE_FAMILY, _PHASED_XZ_GATE_FAMILY, _MEASUREMENT_GATE_FAMILY]
_SYC_TARGET_GATES = [
_SYC_FSIM_GATE_FAMILY,
_SYC_GATE_FAMILY,
_PHASED_XZ_GATE_FAMILY,
_MEASUREMENT_GATE_FAMILY,
]
# Target gates of `cirq.SqrtIswapTargetGateset`
_SQRT_ISWAP_TARGET_GATES = [
_SQRT_ISWAP_FSIM_GATE_FAMILY,
_SQRT_ISWAP_GATE_FAMILY,
_PHASED_XZ_GATE_FAMILY,
_MEASUREMENT_GATE_FAMILY,
]
Expand All @@ -77,51 +92,44 @@ class _GateRepresentations:
Attributes:
gate_spec_name: The name of gate type in `GateSpecification`.
deserialized_forms: Gate representations to be included when the corresponding
`GateSpecification` gate type is deserialized into gatesets and gate durations.
serializable_forms: GateFamilies used to check whether a given gate can be serialized to the
gate type in this _GateRepresentation.
supported_gates: A list of gates that can be serialized into the `GateSpecification` with
the matching name.
"""

gate_spec_name: str
deserialized_forms: List[GateOrFamily]
serializable_forms: List[cirq.GateFamily]
supported_gates: List[cirq.GateFamily]


# Gates recognized by the GridDevice class. This controls the (de)serialization between
# `DeviceSpecification.valid_gates` and `cirq.Gateset`.

"""Valid gates for a GridDevice."""
# This is a superset of valid gates for a given `GridDevice` instance. The specific gateset depends
# on the underlying device.

# Edit this list to add support for new gates. If a new `_GateRepresentations` is added, add a new
# `GateSpecification` message in cirq-google/cirq_google/api/v2/device.proto.

# Update `_build_compilation_target_gatesets()` if the gate you are updating affects an existing
# CompilationTargetGateset there, or if you'd like to add another `CompilationTargetGateset` to
# allow users to transform their circuits that include your gate.
_GATES: List[_GateRepresentations] = [
_GateRepresentations(
gate_spec_name='syc',
deserialized_forms=[_SYC_FSIM_GATE_FAMILY],
serializable_forms=[_SYC_FSIM_GATE_FAMILY, cirq.GateFamily(ops.SYC)],
gate_spec_name='syc', supported_gates=[_SYC_FSIM_GATE_FAMILY, _SYC_GATE_FAMILY]
),
_GateRepresentations(
gate_spec_name='sqrt_iswap',
deserialized_forms=[_SQRT_ISWAP_FSIM_GATE_FAMILY],
serializable_forms=[_SQRT_ISWAP_FSIM_GATE_FAMILY, cirq.GateFamily(cirq.SQRT_ISWAP)],
supported_gates=[_SQRT_ISWAP_FSIM_GATE_FAMILY, _SQRT_ISWAP_GATE_FAMILY],
),
_GateRepresentations(
gate_spec_name='sqrt_iswap_inv',
deserialized_forms=[_SQRT_ISWAP_INV_FSIM_GATE_FAMILY],
serializable_forms=[_SQRT_ISWAP_INV_FSIM_GATE_FAMILY, cirq.GateFamily(cirq.SQRT_ISWAP_INV)],
supported_gates=[_SQRT_ISWAP_INV_FSIM_GATE_FAMILY, _SQRT_ISWAP_INV_GATE_FAMILY],
),
_GateRepresentations(
gate_spec_name='cz',
deserialized_forms=[_CZ_FSIM_GATE_FAMILY],
serializable_forms=[_CZ_FSIM_GATE_FAMILY, cirq.GateFamily(cirq.CZ)],
gate_spec_name='cz', supported_gates=[_CZ_FSIM_GATE_FAMILY, _CZ_GATE_FAMILY]
),
_GateRepresentations(
gate_spec_name='phased_xz',
deserialized_forms=[
cirq.PhasedXZGate,
cirq.XPowGate,
cirq.YPowGate,
cirq.PhasedXPowGate,
cirq.HPowGate,
cirq.GateFamily(cirq.I),
cirq.ops.SingleQubitCliffordGate,
],
serializable_forms=[
supported_gates=[
# TODO: Extend support to cirq.IdentityGate.
cirq.GateFamily(cirq.I),
cirq.GateFamily(cirq.PhasedXZGate),
Expand All @@ -134,29 +142,20 @@ class _GateRepresentations:
),
_GateRepresentations(
gate_spec_name='virtual_zpow',
deserialized_forms=[cirq.GateFamily(cirq.ZPowGate, tags_to_ignore=[ops.PhysicalZTag()])],
serializable_forms=[cirq.GateFamily(cirq.ZPowGate, tags_to_ignore=[ops.PhysicalZTag()])],
supported_gates=[cirq.GateFamily(cirq.ZPowGate, tags_to_ignore=[ops.PhysicalZTag()])],
),
_GateRepresentations(
gate_spec_name='physical_zpow',
deserialized_forms=[cirq.GateFamily(cirq.ZPowGate, tags_to_accept=[ops.PhysicalZTag()])],
serializable_forms=[cirq.GateFamily(cirq.ZPowGate, tags_to_accept=[ops.PhysicalZTag()])],
supported_gates=[cirq.GateFamily(cirq.ZPowGate, tags_to_accept=[ops.PhysicalZTag()])],
),
_GateRepresentations(
gate_spec_name='coupler_pulse',
deserialized_forms=[experimental_ops.CouplerPulse],
serializable_forms=[cirq.GateFamily(experimental_ops.CouplerPulse)],
),
_GateRepresentations(
gate_spec_name='meas',
deserialized_forms=[cirq.MeasurementGate],
serializable_forms=[cirq.GateFamily(cirq.MeasurementGate)],
supported_gates=[cirq.GateFamily(experimental_ops.CouplerPulse)],
),
_GateRepresentations(
gate_spec_name='wait',
deserialized_forms=[cirq.WaitGate],
serializable_forms=[cirq.GateFamily(cirq.WaitGate)],
gate_spec_name='meas', supported_gates=[cirq.GateFamily(cirq.MeasurementGate)]
),
_GateRepresentations(gate_spec_name='wait', supported_gates=[cirq.GateFamily(cirq.WaitGate)]),
]


Expand Down Expand Up @@ -216,7 +215,7 @@ def _serialize_gateset_and_gate_durations(
for gate_family in gateset.gates:
gate_spec = v2.device_pb2.GateSpecification()
gate_rep = next(
(gr for gr in _GATES for gf in gr.serializable_forms if gf == gate_family), None
(gr for gr in _GATES for gf in gr.supported_gates if gf == gate_family), None
)
if gate_rep is None:
raise ValueError(f'Unrecognized gate: {gate_family}.')
Expand All @@ -228,13 +227,13 @@ def _serialize_gateset_and_gate_durations(
# Set gate duration
gate_durations_picos = {
int(gate_durations[gf].total_picos())
for gf in gate_rep.serializable_forms
for gf in gate_rep.supported_gates
if gf in gate_durations
}
if len(gate_durations_picos) > 1:
raise ValueError(
'Multiple gate families in the following list exist in the gate duration dict, and '
f'they are expected to have the same duration value: {gate_rep.serializable_forms}'
f'they are expected to have the same duration value: {gate_rep.supported_gates}'
)
elif len(gate_durations_picos) == 1:
gate_spec.gate_duration_picos = gate_durations_picos.pop()
Expand Down Expand Up @@ -269,10 +268,8 @@ def _deserialize_gateset_and_gate_durations(
)
continue

gates_list.extend(gate_rep.deserialized_forms)
for g in gate_rep.deserialized_forms:
if not isinstance(g, cirq.GateFamily):
g = cirq.GateFamily(g)
gates_list.extend(gate_rep.supported_gates)
for g in gate_rep.supported_gates:
gate_durations[g] = cirq.Duration(picos=gate_spec.gate_duration_picos)

# TODO(#5050) Add GlobalPhaseGate support
Expand Down
32 changes: 23 additions & 9 deletions cirq-google/cirq_google/devices/grid_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,22 +89,26 @@ def _create_device_spec_with_horizontal_couplings():
cirq_google.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP]),
cirq_google.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP_INV]),
cirq_google.FSimGateFamily(gates_to_accept=[cirq.CZ]),
cirq.ops.phased_x_z_gate.PhasedXZGate,
cirq.ops.common_gates.XPowGate,
cirq.ops.common_gates.YPowGate,
cirq.GateFamily(cirq_google.SYC),
cirq.GateFamily(cirq.SQRT_ISWAP),
cirq.GateFamily(cirq.SQRT_ISWAP_INV),
cirq.GateFamily(cirq.CZ),
cirq.GateFamily(cirq.ops.phased_x_z_gate.PhasedXZGate),
cirq.GateFamily(cirq.ops.common_gates.XPowGate),
cirq.GateFamily(cirq.ops.common_gates.YPowGate),
cirq.GateFamily(cirq.I),
cirq.ops.SingleQubitCliffordGate,
cirq.ops.HPowGate,
cirq.ops.phased_x_gate.PhasedXPowGate,
cirq.GateFamily(cirq.ops.SingleQubitCliffordGate),
cirq.GateFamily(cirq.ops.HPowGate),
cirq.GateFamily(cirq.ops.phased_x_gate.PhasedXPowGate),
cirq.GateFamily(
cirq.ops.common_gates.ZPowGate, tags_to_ignore=[cirq_google.PhysicalZTag()]
),
cirq.GateFamily(
cirq.ops.common_gates.ZPowGate, tags_to_accept=[cirq_google.PhysicalZTag()]
),
cirq_google.experimental.ops.coupler_pulse.CouplerPulse,
cirq.ops.measurement_gate.MeasurementGate,
cirq.ops.wait_gate.WaitGate,
cirq.GateFamily(cirq_google.experimental.ops.coupler_pulse.CouplerPulse),
cirq.GateFamily(cirq.ops.measurement_gate.MeasurementGate),
cirq.GateFamily(cirq.ops.wait_gate.WaitGate),
)

base_duration = cirq.Duration(picos=1_000)
Expand All @@ -113,6 +117,10 @@ def _create_device_spec_with_horizontal_couplings():
cirq_google.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP]): base_duration * 1,
cirq_google.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP_INV]): base_duration * 2,
cirq_google.FSimGateFamily(gates_to_accept=[cirq.CZ]): base_duration * 3,
cirq.GateFamily(cirq_google.SYC): base_duration * 0,
cirq.GateFamily(cirq.SQRT_ISWAP): base_duration * 1,
cirq.GateFamily(cirq.SQRT_ISWAP_INV): base_duration * 2,
cirq.GateFamily(cirq.CZ): base_duration * 3,
cirq.GateFamily(cirq.ops.phased_x_z_gate.PhasedXZGate): base_duration * 4,
cirq.GateFamily(cirq.ops.common_gates.XPowGate): base_duration * 4,
cirq.GateFamily(cirq.ops.common_gates.YPowGate): base_duration * 4,
Expand All @@ -139,6 +147,9 @@ def _create_device_spec_with_horizontal_couplings():
cirq_google.FSimGateFamily(gates_to_accept=[cirq_google.SYC]),
cirq_google.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP]),
cirq_google.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP_INV]),
cirq.GateFamily(cirq_google.SYC),
cirq.GateFamily(cirq.SQRT_ISWAP),
cirq.GateFamily(cirq.SQRT_ISWAP_INV),
cirq.ops.common_gates.XPowGate,
cirq.ops.common_gates.YPowGate,
cirq.ops.common_gates.HPowGate,
Expand All @@ -161,6 +172,9 @@ def _create_device_spec_with_horizontal_couplings():
cirq_google.FSimGateFamily(gates_to_accept=[cirq_google.SYC]),
cirq_google.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP_INV]),
cirq_google.FSimGateFamily(gates_to_accept=[cirq.CZ]),
cirq.GateFamily(cirq_google.SYC),
cirq.GateFamily(cirq.SQRT_ISWAP_INV),
cirq.GateFamily(cirq.CZ),
cirq.ops.common_gates.XPowGate,
cirq.ops.common_gates.YPowGate,
cirq.ops.common_gates.HPowGate,
Expand Down

0 comments on commit edd8393

Please sign in to comment.