Skip to content

Commit

Permalink
Optimize qid comparisons (#6386)
Browse files Browse the repository at this point in the history
  • Loading branch information
maffoo authored Dec 18, 2023
1 parent 2f3c1e2 commit e3fbd98
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 23 deletions.
42 changes: 34 additions & 8 deletions cirq-core/cirq/devices/grid_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,15 @@ class _BaseGridQid(ops.Qid):
_row: int
_col: int
_dimension: int
_comp_key: Optional[Tuple[int, int]] = None
_hash: Optional[int] = None

def __hash__(self) -> int:
if self._hash is None:
self._hash = hash((self._row, self._col, self._dimension))
return self._hash

def __eq__(self, other):
def __eq__(self, other) -> bool:
# Explicitly implemented for performance (vs delegating to Qid).
if isinstance(other, _BaseGridQid):
return self is other or (
Expand All @@ -50,7 +51,7 @@ def __eq__(self, other):
)
return NotImplemented

def __ne__(self, other):
def __ne__(self, other) -> bool:
# Explicitly implemented for performance (vs delegating to Qid).
if isinstance(other, _BaseGridQid):
return self is not other and (
Expand All @@ -60,8 +61,38 @@ def __ne__(self, other):
)
return NotImplemented

def __lt__(self, other) -> bool:
# Explicitly implemented for performance (vs delegating to Qid).
if isinstance(other, _BaseGridQid):
k0, k1 = self._comparison_key(), other._comparison_key()
return k0 < k1 or (k0 == k1 and self._dimension < other._dimension)
return super().__lt__(other)

def __le__(self, other) -> bool:
# Explicitly implemented for performance (vs delegating to Qid).
if isinstance(other, _BaseGridQid):
k0, k1 = self._comparison_key(), other._comparison_key()
return k0 < k1 or (k0 == k1 and self._dimension <= other._dimension)
return super().__le__(other)

def __ge__(self, other) -> bool:
# Explicitly implemented for performance (vs delegating to Qid).
if isinstance(other, _BaseGridQid):
k0, k1 = self._comparison_key(), other._comparison_key()
return k0 > k1 or (k0 == k1 and self._dimension >= other._dimension)
return super().__ge__(other)

def __gt__(self, other) -> bool:
# Explicitly implemented for performance (vs delegating to Qid).
if isinstance(other, _BaseGridQid):
k0, k1 = self._comparison_key(), other._comparison_key()
return k0 > k1 or (k0 == k1 and self._dimension > other._dimension)
return super().__gt__(other)

def _comparison_key(self):
return self._row, self._col
if self._comp_key is None:
self._comp_key = self._row, self._col
return self._comp_key

@property
def row(self) -> int:
Expand Down Expand Up @@ -359,11 +390,6 @@ def __getnewargs__(self):
def _with_row_col(self, row: int, col: int) -> 'GridQubit':
return GridQubit(row, col)

def _cmp_tuple(self):
cls = GridQid if type(self) is GridQubit else type(self)
# Must be same as Qid._cmp_tuple but with cls in place of type(self).
return (cls.__name__, repr(cls), self._comparison_key(), self.dimension)

@staticmethod
def square(diameter: int, top: int = 0, left: int = 0) -> List['GridQubit']:
"""Returns a square of GridQubits.
Expand Down
42 changes: 34 additions & 8 deletions cirq-core/cirq/devices/line_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,52 @@ def __hash__(self) -> int:
self._hash = hash((self._x, self._dimension))
return self._hash

def __eq__(self, other):
def __eq__(self, other) -> bool:
# Explicitly implemented for performance (vs delegating to Qid).
if isinstance(other, _BaseLineQid):
return self is other or (self._x == other._x and self._dimension == other._dimension)
return NotImplemented

def __ne__(self, other):
def __ne__(self, other) -> bool:
# Explicitly implemented for performance (vs delegating to Qid).
if isinstance(other, _BaseLineQid):
return self is not other and (
self._x != other._x or self._dimension != other._dimension
)
return NotImplemented

def __lt__(self, other) -> bool:
# Explicitly implemented for performance (vs delegating to Qid).
if isinstance(other, _BaseLineQid):
return self._x < other._x or (
self._x == other._x and self._dimension < other._dimension
)
return super().__lt__(other)

def __le__(self, other) -> bool:
# Explicitly implemented for performance (vs delegating to Qid).
if isinstance(other, _BaseLineQid):
return self._x < other._x or (
self._x == other._x and self._dimension <= other._dimension
)
return super().__le__(other)

def __ge__(self, other) -> bool:
# Explicitly implemented for performance (vs delegating to Qid).
if isinstance(other, _BaseLineQid):
return self._x > other._x or (
self._x == other._x and self._dimension >= other._dimension
)
return super().__ge__(other)

def __gt__(self, other) -> bool:
# Explicitly implemented for performance (vs delegating to Qid).
if isinstance(other, _BaseLineQid):
return self._x > other._x or (
self._x == other._x and self._dimension > other._dimension
)
return super().__gt__(other)

def _comparison_key(self):
return self._x

Expand Down Expand Up @@ -279,12 +311,6 @@ def __getnewargs__(self):
def _with_x(self, x: int) -> 'LineQubit':
return LineQubit(x)

def _cmp_tuple(self):
cls = LineQid if type(self) is LineQubit else type(self)
# Must be the same as Qid._cmp_tuple but with cls in place of
# type(self).
return (cls.__name__, repr(cls), self._comparison_key(), self._dimension)

@staticmethod
def range(*range_args) -> List['LineQubit']:
"""Returns a range of line qubits.
Expand Down
37 changes: 30 additions & 7 deletions cirq-core/cirq/ops/named_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,22 +37,50 @@ def __hash__(self) -> int:
self._hash = hash((self._name, self._dimension))
return self._hash

def __eq__(self, other):
def __eq__(self, other) -> bool:
# Explicitly implemented for performance (vs delegating to Qid).
if isinstance(other, _BaseNamedQid):
return self is other or (
self._name == other._name and self._dimension == other._dimension
)
return NotImplemented

def __ne__(self, other):
def __ne__(self, other) -> bool:
# Explicitly implemented for performance (vs delegating to Qid).
if isinstance(other, _BaseNamedQid):
return self is not other and (
self._name != other._name or self._dimension != other._dimension
)
return NotImplemented

def __lt__(self, other) -> bool:
# Explicitly implemented for performance (vs delegating to Qid).
if isinstance(other, _BaseNamedQid):
k0, k1 = self._comparison_key(), other._comparison_key()
return k0 < k1 or (k0 == k1 and self._dimension < other._dimension)
return super().__lt__(other)

def __le__(self, other) -> bool:
# Explicitly implemented for performance (vs delegating to Qid).
if isinstance(other, _BaseNamedQid):
k0, k1 = self._comparison_key(), other._comparison_key()
return k0 < k1 or (k0 == k1 and self._dimension <= other._dimension)
return super().__le__(other)

def __ge__(self, other) -> bool:
# Explicitly implemented for performance (vs delegating to Qid).
if isinstance(other, _BaseNamedQid):
k0, k1 = self._comparison_key(), other._comparison_key()
return k0 > k1 or (k0 == k1 and self._dimension >= other._dimension)
return super().__ge__(other)

def __gt__(self, other) -> bool:
# Explicitly implemented for performance (vs delegating to Qid).
if isinstance(other, _BaseNamedQid):
k0, k1 = self._comparison_key(), other._comparison_key()
return k0 > k1 or (k0 == k1 and self._dimension > other._dimension)
return super().__gt__(other)

def _comparison_key(self):
if self._comp_key is None:
self._comp_key = _pad_digits(self._name)
Expand Down Expand Up @@ -174,11 +202,6 @@ def __getnewargs__(self):
"""Returns a tuple of args to pass to __new__ when unpickling."""
return (self._name,)

def _cmp_tuple(self):
cls = NamedQid if type(self) is NamedQubit else type(self)
# Must be same as Qid._cmp_tuple but with cls in place of type(self).
return (cls.__name__, repr(cls), self._comparison_key(), self._dimension)

def __str__(self) -> str:
return self._name

Expand Down

0 comments on commit e3fbd98

Please sign in to comment.