Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Creates fake grid device for testing qubit connectivity in routing #5830

Prev Previous commit
Next Next commit
addressed comments
  • Loading branch information
ammareltigani committed Aug 19, 2022
commit 8827b91a79c1fc325498db948c4409978b493585
45 changes: 20 additions & 25 deletions cirq-core/cirq/testing/routing_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
"""Provides test devices that can validate circuits during a routing procedure."""

from typing import Hashable, Optional, Dict, TYPE_CHECKING
from typing import Dict, TYPE_CHECKING

import networkx as nx

Expand All @@ -24,44 +24,39 @@


class RoutingTestingDevice(devices.Device):
"""Testing device to be used only for testing qubit connectivity in routing procedures."""

def __init__(self, nx_graph: nx.Graph, qubit_type: str = 'NamedQubit') -> None:
relabeling_map: Dict[Hashable, 'cirq.Qid'] = {}
if qubit_type == 'GridQubit':
relabeling_map = {old: devices.GridQubit(*old) for old in nx_graph}
elif qubit_type == 'LineQubit':
relabeling_map = {old: devices.LineQubit(old) for old in nx_graph}
else:
relabeling_map = {old: ops.NamedQubit(str(old)) for old in nx_graph}
"""Testing device to be used for testing qubit connectivity in routing procedures."""

def __init__(self, nx_graph: nx.Graph) -> None:
relabeling_map = {
old: ops.q(old) if isinstance(old, (int, str)) else ops.q(*old) for old in nx_graph
}
# Relabel nodes in-place.
nx.relabel_nodes(nx_graph, relabeling_map, copy=False)

self._metadata = devices.DeviceMetadata(relabeling_map.values(), nx_graph)

@property
def metadata(self) -> Optional[devices.DeviceMetadata]:
def metadata(self) -> devices.DeviceMetadata:
return self._metadata

def validate_operation(self, operation: 'cirq.Operation') -> None:
for q in operation.qubits:
if q not in self._metadata.qubit_set:
raise ValueError(f'Qubit not on device: {q!r}.')
if not self._metadata.qubit_set.issuperset(operation.qubits):
raise ValueError(f'Qubits not on device: {operation.qubits!r}.')

if len(operation.qubits) > 1:
if len(operation.qubits) == 2:
if operation.qubits not in self._metadata.nx_graph.edges:
raise ValueError(f'Qubit pair is not valid on device: {operation.qubits!r}.')
ammareltigani marked this conversation as resolved.
Show resolved Hide resolved
return

if len(operation.qubits) == 2 and operation.qubits not in self._metadata.nx_graph.edges:
raise ValueError(f'Qubit pair is not valid on device: {operation.qubits!r}.')
if operation not in ops.GateFamily(ops.MeasurementGate):
ammareltigani marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(f'Gate {operation.gate!r} is not supported on more than 2 qubits.')
ammareltigani marked this conversation as resolved.
Show resolved Hide resolved


def construct_grid_device(m: int, n: int) -> RoutingTestingDevice:
return RoutingTestingDevice(nx.grid_2d_graph(m, n), qubit_type="GridQubit")
return RoutingTestingDevice(nx.grid_2d_graph(m, n))


def construct_ring_device(l: int, directed: bool = False) -> RoutingTestingDevice:
if directed:
# If create_using is directed, the direction is in increasing order.
nx_graph = nx.cycle_graph(l, create_using=nx.DiGraph)
else:
nx_graph = nx.cycle_graph(l)

return RoutingTestingDevice(nx_graph, qubit_type="LineQubit")
nx_graph = nx.cycle_graph(l, create_using=nx.DiGraph if directed else nx.Graph)
return RoutingTestingDevice(nx_graph)
35 changes: 27 additions & 8 deletions cirq-core/cirq/testing/routing_devices_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ def test_grid_device():
def test_grid_op_validation():
device = cirq.testing.construct_grid_device(5, 7)

with pytest.raises(ValueError, match="Qubit not on device"):
with pytest.raises(ValueError, match="Qubits not on device"):
device.validate_operation(cirq.X(cirq.NamedQubit("a")))
with pytest.raises(ValueError, match="Qubit not on device"):
with pytest.raises(ValueError, match="Qubits not on device"):
device.validate_operation(cirq.CNOT(cirq.NamedQubit("a"), cirq.GridQubit(0, 0)))
with pytest.raises(ValueError, match="Qubit not on device"):
with pytest.raises(ValueError, match="Qubits not on device"):
device.validate_operation(cirq.CNOT(cirq.GridQubit(5, 4), cirq.GridQubit(4, 4)))
with pytest.raises(ValueError, match="Qubit not on device"):
with pytest.raises(ValueError, match="Qubits not on device"):
device.validate_operation(cirq.CNOT(cirq.GridQubit(4, 7), cirq.GridQubit(4, 6)))

with pytest.raises(ValueError, match="Qubit pair is not valid on device"):
Expand All @@ -50,7 +50,9 @@ def test_grid_op_validation():
device.validate_operation(cirq.CNOT(cirq.GridQubit(2, 0), cirq.GridQubit(0, 0)))

device.validate_operation(cirq.CNOT(cirq.GridQubit(0, 0), cirq.GridQubit(0, 1)))
device.validate_operation(cirq.CNOT(cirq.GridQubit(0, 1), cirq.GridQubit(0, 0)))
device.validate_operation(cirq.CNOT(cirq.GridQubit(1, 0), cirq.GridQubit(0, 0)))
device.validate_operation(cirq.CNOT(cirq.GridQubit(0, 0), cirq.GridQubit(1, 0)))


def test_ring_device():
Expand All @@ -75,9 +77,9 @@ def test_ring_op_validation():
directed_device = cirq.testing.construct_ring_device(5, directed=True)
undirected_device = cirq.testing.construct_ring_device(5, directed=False)

with pytest.raises(ValueError, match="Qubit not on device"):
with pytest.raises(ValueError, match="Qubits not on device"):
directed_device.validate_operation(cirq.X(cirq.LineQubit(5)))
with pytest.raises(ValueError, match="Qubit not on device"):
with pytest.raises(ValueError, match="Qubits not on device"):
undirected_device.validate_operation(cirq.X(cirq.LineQubit(5)))

with pytest.raises(ValueError, match="Qubit pair is not valid on device"):
Expand All @@ -90,10 +92,27 @@ def test_ring_op_validation():
directed_device.validate_operation(cirq.CNOT(cirq.LineQubit(0), cirq.LineQubit(1)))


def test_allowed_multi_qubit_gates():
device = cirq.testing.construct_ring_device(5)

device.validate_operation(cirq.measure(cirq.LineQubit(0)))
device.validate_operation(cirq.measure(cirq.LineQubit.range(2)))
device.validate_operation(cirq.measure(cirq.LineQubit.range(3)))

with pytest.raises(
ValueError, match=f"Gate {cirq.CCNOT!r} is not supported on more than 2 qubits."
):
device.validate_operation(cirq.CCNOT(*cirq.LineQubit.range(3)))

device.validate_operation(cirq.CNOT(*cirq.LineQubit.range(2)))


def test_namedqubit_device():
nx_graph = nx.star_graph(10)
# 4-star graph
nx_graph = nx.Graph([("a", "b"), ("a", "c"), ("a", "d")])

device = cirq.testing.RoutingTestingDevice(nx_graph)
relabeled_graph = device.metadata.nx_graph
qubit_set = {cirq.NamedQubit(str(n)) for n in range(11)}
qubit_set = {cirq.NamedQubit(n) for n in "abcd"}
assert set(relabeled_graph.nodes) == qubit_set
assert nx.is_isomorphic(nx_graph, relabeled_graph)