diff --git a/CHANGELOG.md b/CHANGELOG.md index 0c1299c8cd..2d89d5e1b5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ * Fix `c` argument in `plot_khat` ([1592](https://github.com/arviz-devs/arviz/pull/1592)) * Fix `ax` argument in `plot_elpd` ([1593](https://github.com/arviz-devs/arviz/pull/1593)) * Remove warning in `stats.py` compare function ([1607](https://github.com/arviz-devs/arviz/pull/1607)) +* Fix `ess/rhat` plots in `plot_forest` ([1606](https://github.com/arviz-devs/arviz/pull/1606)) * Fix `from_numpyro` crash when importing model with `thinning=x` for `x > 1` ([1619](https://github.com/arviz-devs/arviz/pull/1619)) ### Deprecation diff --git a/arviz/plots/backends/bokeh/forestplot.py b/arviz/plots/backends/bokeh/forestplot.py index e11dff3103..d6724abeee 100644 --- a/arviz/plots/backends/bokeh/forestplot.py +++ b/arviz/plots/backends/bokeh/forestplot.py @@ -109,15 +109,15 @@ def plot_forest( for i, width_r in zip(range(ncols), width_ratios): backend_kwargs_i = backend_kwargs.copy() - backend_kwargs_i.setdefault("width", int(figsize[0] * dpi)) + backend_kwargs_i.setdefault("height", int(figsize[1] * dpi)) backend_kwargs_i.setdefault( - "height", int(figsize[1] * (width_r / sum(width_ratios)) * dpi * 1.25) + "width", int(figsize[0] * (width_r / sum(width_ratios)) * dpi * 1.25) ) if i == 0: ax = bkp.figure( **backend_kwargs_i, ) - backend_kwargs_i.setdefault("y_range", ax.y_range) + backend_kwargs.setdefault("y_range", ax.y_range) else: ax = bkp.figure(**backend_kwargs_i) axes.append(ax) @@ -172,6 +172,11 @@ def plot_forest( plot_handler.legend(axes[0, idx], plotted_r_hat) idx += 1 + all_plotters = list(plot_handler.plotters.values()) + y_max = plot_handler.y_max() - all_plotters[-1].group_offset + if kind == "ridgeplot": # space at the top + y_max += ridgeplot_overlap + for i, ax_ in enumerate(axes.ravel()): if kind == "ridgeplot": ax_.xgrid.grid_line_color = None @@ -186,24 +191,17 @@ def plot_forest( ax_.x_range = DataRange1d(bounds=backend_config["bounds_x_range"], min_interval=1) ax_.y_range = DataRange1d(bounds=backend_config["bounds_y_range"], min_interval=2) + ax_.y_range._property_values["start"] = -all_plotters[ # pylint: disable=protected-access + 0 + ].group_offset + ax_.y_range._property_values["end"] = y_max # pylint: disable=protected-access + labels, ticks = plot_handler.labels_and_ticks() ticks = [int(tick) if (tick).is_integer() else tick for tick in ticks] axes[0, 0].yaxis.ticker = FixedTicker(ticks=ticks) axes[0, 0].yaxis.major_label_overrides = dict(zip(map(str, ticks), map(str, labels))) - all_plotters = list(plot_handler.plotters.values()) - y_max = plot_handler.y_max() - all_plotters[-1].group_offset - if kind == "ridgeplot": # space at the top - y_max += ridgeplot_overlap - - axes[0, 0].y_range._property_values[ - "start" - ] = -all_plotters[ # pylint: disable=protected-access - 0 - ].group_offset - axes[0, 0].y_range._property_values["end"] = y_max # pylint: disable=protected-access - if legend: plot_handler.legend(axes[0, 0], plotted) show_layout(axes, show) @@ -283,6 +281,14 @@ def label_idxs(): labels, idxs = [], [] for plotter in val: sub_labels, sub_idxs, _, _, _ = plotter.labels_ticks_and_vals() + labels_to_idxs = defaultdict(list) + for label, idx in zip(sub_labels, sub_idxs): + labels_to_idxs[label].append(idx) + sub_idxs = [] + sub_labels = [] + for label, all_idx in labels_to_idxs.items(): + sub_labels.append(label) + sub_idxs.append(np.mean([j for j in all_idx])) labels.append(sub_labels) idxs.append(sub_idxs) return np.concatenate(labels), np.concatenate(idxs) @@ -295,8 +301,8 @@ def legend(self, ax, plotted): for (model_name, glyphs) in plotted.items(): legend_it.append((model_name, glyphs)) - legend = Legend(items=legend_it) - ax.add_layout(legend, "right") + legend = Legend(items=legend_it, orientation="vertical", location="top_left") + ax.add_layout(legend, "above") ax.legend.click_policy = "hide" def display_multiple_ropes( @@ -675,12 +681,13 @@ def labels_ticks_and_vals(self): for y, label, model_name, _, _, vals, color in self.iterator(): y_ticks[label].append((y, vals, color, model_name)) labels, ticks, vals, colors, model_names = [], [], [], [], [] - for label, data in y_ticks.items(): - labels.append(label) - ticks.append(np.mean([j[0] for j in data])) - vals.append(np.vstack([j[1] for j in data])) - model_names.append(data[0][3]) - colors.append(data[0][2]) # the colors are all the same + for label, all_data in y_ticks.items(): + for data in all_data: + labels.append(label) + ticks.append(data[0]) + vals.append(np.array(data[1])) + model_names.append(data[3]) + colors.append(data[2]) # the colors are all the same return labels, ticks, vals, colors, model_names def treeplot(self, qlist, hdi_prob): diff --git a/arviz/plots/backends/matplotlib/forestplot.py b/arviz/plots/backends/matplotlib/forestplot.py index ba41ebbc43..88196b9b5f 100644 --- a/arviz/plots/backends/matplotlib/forestplot.py +++ b/arviz/plots/backends/matplotlib/forestplot.py @@ -238,6 +238,14 @@ def label_idxs(): labels, idxs = [], [] for plotter in val: sub_labels, sub_idxs, _, _ = plotter.labels_ticks_and_vals() + labels_to_idxs = defaultdict(list) + for label, idx in zip(sub_labels, sub_idxs): + labels_to_idxs[label].append(idx) + sub_idxs = [] + sub_labels = [] + for label, all_idx in labels_to_idxs.items(): + sub_labels.append(label) + sub_idxs.append(np.mean([j for j in all_idx])) labels.append(sub_labels) idxs.append(sub_idxs) return np.concatenate(labels), np.concatenate(idxs) @@ -567,11 +575,12 @@ def labels_ticks_and_vals(self): for y, label, _, _, vals, color in self.iterator(): y_ticks[label].append((y, vals, color)) labels, ticks, vals, colors = [], [], [], [] - for label, data in y_ticks.items(): - labels.append(label) - ticks.append(np.mean([j[0] for j in data])) - vals.append(np.vstack([j[1] for j in data])) - colors.append(data[0][2]) # the colors are all the same + for label, all_data in y_ticks.items(): + for data in all_data: + labels.append(label) + ticks.append(data[0]) + vals.append(np.array(data[1])) + colors.append(data[2]) # the colors are all the same return labels, ticks, vals, colors def treeplot(self, qlist, hdi_prob):