Skip to content

Commit

Permalink
Fix matplotlib typing (#6290)
Browse files Browse the repository at this point in the history
* Fix matplotlib typing

matplotlib 3.8.0 was released this week and included typing hints.
This fixes the resulting CI breakages.

* Fix issues.

* formatting

* Change to seaborn v0_8
  • Loading branch information
dstrain115 authored Sep 18, 2023
1 parent f715527 commit b630298
Show file tree
Hide file tree
Showing 13 changed files with 56 additions and 32 deletions.
4 changes: 3 additions & 1 deletion cirq-core/cirq/contrib/svg/svg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
from typing import TYPE_CHECKING, List, Tuple, cast, Dict

import matplotlib.textpath
import matplotlib.font_manager


if TYPE_CHECKING:
import cirq

QBLUE = '#1967d2'
FONT = "Arial"
FONT = matplotlib.font_manager.FontProperties(family="Arial")
EMPTY_MOMENT_COLWIDTH = float(21) # assumed default column width


Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/devices/named_topologies.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _node_and_coordinates(


def draw_gridlike(
graph: nx.Graph, ax: plt.Axes = None, tilted: bool = True, **kwargs
graph: nx.Graph, ax: Optional[plt.Axes] = None, tilted: bool = True, **kwargs
) -> Dict[Any, Tuple[int, int]]:
"""Draw a grid-like graph using Matplotlib.
Expand Down
11 changes: 6 additions & 5 deletions cirq-core/cirq/experiments/qubit_characterizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
import dataclasses
import itertools

from typing import Any, Iterator, List, Optional, Sequence, Tuple, TYPE_CHECKING
from typing import Any, cast, Iterator, List, Optional, Sequence, Tuple, TYPE_CHECKING
import numpy as np

from matplotlib import pyplot as plt

# this is for older systems with matplotlib <3.2 otherwise 3d projections fail
from mpl_toolkits import mplot3d # pylint: disable=unused-import
from mpl_toolkits import mplot3d
from cirq import circuits, ops, protocols

if TYPE_CHECKING:
Expand Down Expand Up @@ -89,8 +89,9 @@ def plot(self, ax: Optional[plt.Axes] = None, **plot_kwargs: Any) -> plt.Axes:
"""
show_plot = not ax
if not ax:
fig, ax = plt.subplots(1, 1, figsize=(8, 8))
ax.set_ylim([0, 1])
fig, ax = plt.subplots(1, 1, figsize=(8, 8)) # pragma: no cover
ax = cast(plt.Axes, ax) # pragma: no cover
ax.set_ylim((0.0, 1.0)) # pragma: no cover
ax.plot(self._num_cfds_seq, self._gnd_state_probs, 'ro-', **plot_kwargs)
ax.set_xlabel(r"Number of Cliffords")
ax.set_ylabel('Ground State Probability')
Expand Down Expand Up @@ -541,7 +542,7 @@ def _find_inv_matrix(mat: np.ndarray, mat_sequence: np.ndarray) -> int:
def _matrix_bar_plot(
mat: np.ndarray,
z_label: str,
ax: plt.Axes,
ax: mplot3d.axes3d.Axes3D,
kets: Optional[Sequence[str]] = None,
title: Optional[str] = None,
ylim: Tuple[int, int] = (-1, 1),
Expand Down
11 changes: 6 additions & 5 deletions cirq-core/cirq/linalg/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import (
Any,
Callable,
cast,
Iterable,
List,
Optional,
Expand All @@ -33,7 +34,7 @@
import matplotlib.pyplot as plt

# this is for older systems with matplotlib <3.2 otherwise 3d projections fail
from mpl_toolkits import mplot3d # pylint: disable=unused-import
from mpl_toolkits import mplot3d
import numpy as np

from cirq import value, protocols
Expand Down Expand Up @@ -554,7 +555,7 @@ def scatter_plot_normalized_kak_interaction_coefficients(
interactions: Iterable[Union[np.ndarray, 'cirq.SupportsUnitary', 'KakDecomposition']],
*,
include_frame: bool = True,
ax: Optional[plt.Axes] = None,
ax: Optional[mplot3d.axes3d.Axes3D] = None,
**kwargs,
):
r"""Plots the interaction coefficients of many two-qubit operations.
Expand Down Expand Up @@ -633,13 +634,13 @@ def scatter_plot_normalized_kak_interaction_coefficients(
show_plot = not ax
if not ax:
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1, projection='3d')
ax = cast(mplot3d.axes3d.Axes3D, fig.add_subplot(1, 1, 1, projection='3d'))

def coord_transform(
pts: Union[List[Tuple[int, int, int]], np.ndarray]
) -> Tuple[Iterable[float], Iterable[float], Iterable[float]]:
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
if len(pts) == 0:
return [], [], []
return np.array([]), np.array([]), np.array([])
xs, ys, zs = np.transpose(pts)
return xs, zs, ys

Expand Down
11 changes: 7 additions & 4 deletions cirq-core/cirq/vis/heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from dataclasses import astuple, dataclass
from typing import (
Any,
cast,
Dict,
List,
Mapping,
Expand Down Expand Up @@ -217,7 +218,7 @@ def _plot_colorbar(
)
position = self._config['colorbar_position']
orien = 'vertical' if position in ('left', 'right') else 'horizontal'
colorbar = ax.figure.colorbar(
colorbar = cast(plt.Figure, ax.figure).colorbar(
mappable, colorbar_ax, ax, orientation=orien, **self._config.get("colorbar_options", {})
)
colorbar_ax.tick_params(axis='y', direction='out')
Expand All @@ -230,15 +231,15 @@ def _write_annotations(
ax: plt.Axes,
) -> None:
"""Writes annotations to the center of cells. Internal."""
for (center, annotation), facecolor in zip(centers_and_annot, collection.get_facecolors()):
for (center, annotation), facecolor in zip(centers_and_annot, collection.get_facecolor()):
# Calculate the center of the cell, assuming that it is a square
# centered at (x=col, y=row).
if not annotation:
continue
x, y = center
face_luminance = vis_utils.relative_luminance(facecolor)
face_luminance = vis_utils.relative_luminance(facecolor) # type: ignore
text_color = 'black' if face_luminance > 0.4 else 'white'
text_kwargs = dict(color=text_color, ha="center", va="center")
text_kwargs: Dict[str, Any] = dict(color=text_color, ha="center", va="center")
text_kwargs.update(self._config.get('annotation_text_kwargs', {}))
ax.text(x, y, annotation, **text_kwargs)

Expand Down Expand Up @@ -295,6 +296,7 @@ def plot(
show_plot = not ax
if not ax:
fig, ax = plt.subplots(figsize=(8, 8))
ax = cast(plt.Axes, ax)
original_config = copy.deepcopy(self._config)
self.update_config(**kwargs)
collection = self._plot_on_axis(ax)
Expand Down Expand Up @@ -381,6 +383,7 @@ def plot(
show_plot = not ax
if not ax:
fig, ax = plt.subplots(figsize=(8, 8))
ax = cast(plt.Axes, ax)
original_config = copy.deepcopy(self._config)
self.update_config(**kwargs)
qubits = set([q for qubits in self._value_map.keys() for q in qubits])
Expand Down
10 changes: 10 additions & 0 deletions cirq-core/cirq/vis/heatmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ def ax():
return figure.add_subplot(111)


def test_default_ax():
row_col_list = ((0, 5), (8, 1), (7, 0), (13, 5), (1, 6), (3, 2), (2, 8))
test_value_map = {
grid_qubit.GridQubit(row, col): np.random.random() for (row, col) in row_col_list
}
_, _ = heatmap.Heatmap(test_value_map).plot()


@pytest.mark.parametrize('tuple_keys', [True, False])
def test_cells_positions(ax, tuple_keys):
row_col_list = ((0, 5), (8, 1), (7, 0), (13, 5), (1, 6), (3, 2), (2, 8))
Expand Down Expand Up @@ -61,6 +69,8 @@ def test_two_qubit_heatmap(ax):
title = "Two Qubit Interaction Heatmap"
heatmap.TwoQubitInteractionHeatmap(value_map, title=title).plot(ax)
assert ax.get_title() == title
# Test default axis
heatmap.TwoQubitInteractionHeatmap(value_map, title=title).plot()


def test_invalid_args():
Expand Down
8 changes: 4 additions & 4 deletions cirq-core/cirq/vis/histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ def integrated_histogram(
plot_options.update(kwargs)

if cdf_on_x:
ax.step(bin_values, parameter_values, **plot_options)
ax.step(bin_values, parameter_values, **plot_options) # type: ignore
else:
ax.step(parameter_values, bin_values, **plot_options)
ax.step(parameter_values, bin_values, **plot_options) # type: ignore

set_semilog = ax.semilogy if cdf_on_x else ax.semilogx
set_lim = ax.set_xlim if cdf_on_x else ax.set_ylim
Expand All @@ -128,15 +128,15 @@ def integrated_histogram(

if median_line:
set_line(
np.median(float_data),
float(np.median(float_data)),
linestyle='--',
color=plot_options['color'],
alpha=0.5,
label=median_label,
)
if mean_line:
set_line(
np.mean(float_data),
float(np.mean(float_data)),
linestyle='-.',
color=plot_options['color'],
alpha=0.5,
Expand Down
16 changes: 10 additions & 6 deletions cirq-core/cirq/vis/state_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Tool to visualize the results of a study."""

from typing import Union, Optional, Sequence, SupportsFloat
from typing import cast, Optional, Sequence, SupportsFloat, Union
import collections
import numpy as np
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -51,13 +51,13 @@ def get_state_histogram(result: 'result.Result') -> np.ndarray:

def plot_state_histogram(
data: Union['result.Result', collections.Counter, Sequence[SupportsFloat]],
ax: Optional['plt.Axis'] = None,
ax: Optional[plt.Axes] = None,
*,
tick_label: Optional[Sequence[str]] = None,
xlabel: Optional[str] = 'qubit state',
ylabel: Optional[str] = 'result count',
title: Optional[str] = 'Result State Histogram',
) -> 'plt.Axis':
) -> plt.Axes:
"""Plot the state histogram from either a single result with repetitions or
a histogram computed using `result.histogram()` or a flattened histogram
of measurement results computed using `get_state_histogram`.
Expand Down Expand Up @@ -87,6 +87,7 @@ def plot_state_histogram(
show_fig = not ax
if not ax:
fig, ax = plt.subplots(1, 1)
ax = cast(plt.Axes, ax)
if isinstance(data, result.Result):
values = get_state_histogram(data)
elif isinstance(data, collections.Counter):
Expand All @@ -96,9 +97,12 @@ def plot_state_histogram(
if tick_label is None:
tick_label = [str(i) for i in range(len(values))]
ax.bar(np.arange(len(values)), values, tick_label=tick_label)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_title(title)
if xlabel:
ax.set_xlabel(xlabel)
if ylabel:
ax.set_ylabel(ylabel)
if title:
ax.set_title(title)
if show_fig:
fig.show()
return ax
2 changes: 2 additions & 0 deletions cirq-core/cirq/vis/state_histogram_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ def test_plot_state_histogram_result():
for r1, r2 in zip(ax1.get_children(), ax2.get_children()):
if isinstance(r1, mpl.patches.Rectangle) and isinstance(r2, mpl.patches.Rectangle):
assert str(r1) == str(r2)
# Test default axis
state_histogram.plot_state_histogram(expected_values)


@pytest.mark.usefixtures('closefigures')
Expand Down
5 changes: 3 additions & 2 deletions cirq-google/cirq_google/engine/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from collections import abc, defaultdict
import datetime
from itertools import cycle
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Sequence
from typing import Any, cast, Dict, Iterator, List, Optional, Tuple, Union, Sequence

import matplotlib as mpl
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -277,6 +277,7 @@ def plot_histograms(
show_plot = not ax
if not ax:
fig, ax = plt.subplots(1, 1)
ax = cast(plt.Axes, ax)

if isinstance(keys, str):
keys = [keys]
Expand Down Expand Up @@ -322,7 +323,7 @@ def plot(
show_plot = not fig
if not fig:
fig = plt.figure()
axs = fig.subplots(1, 2)
axs = cast(List[plt.Axes], fig.subplots(1, 2))
self.heatmap(key).plot(axs[0])
self.plot_histograms(key, axs[1])
if show_plot:
Expand Down
2 changes: 1 addition & 1 deletion docs/experiments/textbook_algorithms.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1010,7 +1010,7 @@
"outputs": [],
"source": [
"\"\"\"Plot the results.\"\"\"\n",
"plt.style.use(\"seaborn-whitegrid\")\n",
"plt.style.use(\"seaborn-v0_8-whitegrid\")\n",
"\n",
"plt.plot(nvals, estimates, \"--o\", label=\"Phase estimation\")\n",
"plt.axhline(theta, label=\"True value\", color=\"black\")\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/start/intro.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1453,7 +1453,7 @@
" probs.append(prob[0])\n",
"\n",
"# Plot the probability of the ground state at each simulation step.\n",
"plt.style.use('seaborn-whitegrid')\n",
"plt.style.use('seaborn-v0_8-whitegrid')\n",
"plt.plot(probs, 'o')\n",
"plt.xlabel(\"Step\")\n",
"plt.ylabel(\"Probability of ground state\");"
Expand Down Expand Up @@ -1490,7 +1490,7 @@
"\n",
"\n",
"# Plot the probability of the ground state at each simulation step.\n",
"plt.style.use('seaborn-whitegrid')\n",
"plt.style.use('seaborn-v0_8-whitegrid')\n",
"plt.plot(sampled_probs, 'o')\n",
"plt.xlabel(\"Step\")\n",
"plt.ylabel(\"Probability of ground state\");"
Expand Down
2 changes: 1 addition & 1 deletion examples/two_qubit_gate_compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def main(samples: int = 1000, max_infidelity: float = 0.01):
print(f'Maximum infidelity of "failed" compilation: {np.max(failed_infidelities_arr)}')

plt.figure()
plt.hist(infidelities_arr, bins=25, range=[0, max_infidelity * 1.1])
plt.hist(infidelities_arr, bins=25, range=(0.0, max_infidelity * 1.1)) # pragma: no cover
ylim = plt.ylim()
plt.plot([max_infidelity] * 2, ylim, '--', label='Maximum tabulation infidelity')
plt.xlabel('Compiled gate infidelity vs target')
Expand Down

0 comments on commit b630298

Please sign in to comment.