diff --git a/tests/unit_tests/explainer/test_smart_plotter.py b/tests/unit_tests/explainer/test_smart_plotter.py index 57f279a0..9a7133e3 100644 --- a/tests/unit_tests/explainer/test_smart_plotter.py +++ b/tests/unit_tests/explainer/test_smart_plotter.py @@ -108,6 +108,7 @@ def setUp(self): self.smart_explainer._case, self.smart_explainer._classes = check_model(model) self.smart_explainer.state = MultiDecorator(SmartState()) self.smart_explainer.y_pred = None + self.smart_explainer.proba_values = None self.smart_explainer.features_desc = dict(self.x_init.nunique()) self.smart_explainer.features_compacity = self.features_compacity @@ -863,7 +864,7 @@ def test_contribution_plot_8(self): xpl.model = model np_hv = [f"Id: {x}
Predict: {y}" for x, y in zip(xpl.x_init.index, xpl.y_pred.iloc[:, 0].tolist())] np_hv.sort() - output = xpl.plot.contribution_plot(col) + output = xpl.plot.contribution_plot(col, proba=False) annot_list = [] for data_plot in output.data: annot_list.extend(data_plot.hovertext.tolist()) @@ -895,7 +896,7 @@ def test_contribution_plot_9(self): model = lambda: None model.classes_ = np.array([0, 1]) xpl.model = model - output = xpl.plot.contribution_plot(col, max_points=39) + output = xpl.plot.contribution_plot(col, max_points=39, proba=False) assert len(output.data) == 4 for elem in output.data: assert elem.type == "violin"