Skip to content

Commit

Permalink
Merge branch 'main' into docs/add_link_to_README
Browse files Browse the repository at this point in the history
  • Loading branch information
maxschulz-COL authored Sep 8, 2023
2 parents b492279 + 6bef620 commit 0789bd5
Show file tree
Hide file tree
Showing 6 changed files with 274 additions and 54 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
<!--
A new scriv changelog fragment.
Uncomment the section that is right (remove the HTML comment wrapper).
-->

<!--
### Removed
- A bullet item for the Removed category.
-->
<!--
### Added
- A bullet item for the Added category.
-->
<!--
### Changed
- A bullet item for the Changed category.
-->
<!--
### Deprecated
- A bullet item for the Deprecated category.
-->
<!--
### Fixed
- A bullet item for the Fixed category.
-->
<!--
### Security
- A bullet item for the Security category.
-->
32 changes: 17 additions & 15 deletions vizro-core/src/vizro/models/_controls/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,18 +66,35 @@ def check_duplicate_parameter_target(cls, targets):

@_log_call
def pre_build(self):
self._set_slider_values()
self._set_categorical_selectors_options()
self._set_selector()
self._set_actions()

@_log_call
def build(self):
return self.selector.build()

def _set_slider_values(self):
self.selector: SelectorType
if isinstance(self.selector, (Slider, RangeSlider)):
if self.selector.min is None or self.selector.max is None:
raise TypeError(
f"{self.selector.type} requires the arguments 'min' and 'max' when used within Parameter."
)

def _set_categorical_selectors_options(self):
self.selector: SelectorType
if isinstance(self.selector, (Checklist, Dropdown, RadioItems)) and not self.selector.options:
raise TypeError(f"{self.selector.type} requires the argument 'options' when used within Parameter.")

def _set_selector(self):
self.selector: SelectorType
if not self.selector.title:
self.selector.title = ", ".join({target.rsplit(".")[-1] for target in self.targets})

def _set_actions(self):
self.selector: SelectorType
if not self.selector.actions:
self.selector.actions = [
Action(
Expand All @@ -87,18 +104,3 @@ def pre_build(self):
),
)
]

@_log_call
def build(self):
return self.selector.build()


if __name__ == "__main__":
print( # noqa: T201
repr(
Parameter(
targets=["scatter.x"],
selector=Slider(min=0, max=1, value=0.8, title="Bubble opacity"),
)
)
)
153 changes: 135 additions & 18 deletions vizro-core/tests/unit/vizro/models/_components/test_graph.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,156 @@
"""Unit tests for vizro.models.Graph."""
import json

import plotly
import plotly.graph_objects as go
import pytest
from dash import dcc
from pydantic import ValidationError

import vizro.models as vm
import vizro.plotly.express as px
from vizro.managers import data_manager
from vizro.models._action._action import Action
from vizro.models._components.graph import create_empty_fig


@pytest.fixture
def standard_go_chart(gapminder):
return go.Figure(data=go.Scatter(x=gapminder["gdpPercap"], y=gapminder["lifeExp"], mode="markers"))


def test_create_graph(standard_px_chart):
graph = vm.Graph(figure=standard_px_chart)
@pytest.fixture
def standard_px_chart_with_str_dataframe():
return px.scatter(
data_frame="gapminder",
x="gdpPercap",
y="lifeExp",
size="pop",
color="continent",
hover_name="country",
size_max=60,
)


assert hasattr(graph, "id")
assert graph.type == "graph"
assert graph.figure == standard_px_chart._captured_callable
assert graph.actions == []
@pytest.fixture
def expected_empty_chart():
figure = go.Figure()
figure.add_trace(go.Scatter(x=[None], y=[None], showlegend=False, hoverinfo="none"))
figure.update_layout(
xaxis={"visible": False},
yaxis={"visible": False},
annotations=[{"text": "NO DATA", "showarrow": False, "font": {"size": 16}}],
)
return figure


@pytest.fixture
def expected_graph():
return dcc.Loading(
dcc.Graph(
id="text_graph",
figure=create_empty_fig(""),
config={
"autosizable": True,
"frameMargins": 0,
"responsive": True,
},
className="chart_container",
),
color="grey",
parent_className="chart_container",
)


def test_failed_graph_with_wrong_figure(standard_go_chart):
with pytest.raises(ValidationError, match="must provide a valid CapturedCallable object"):
vm.Graph(
figure=standard_go_chart,
class TestDunderMethodsGraph:
def test_create_graph_mandatory_only(self, standard_px_chart):
graph = vm.Graph(figure=standard_px_chart)

assert hasattr(graph, "id")
assert graph.type == "graph"
assert graph.figure == standard_px_chart._captured_callable
assert graph.actions == []

@pytest.mark.parametrize("id", ["id_1", "id_2"])
def test_create_graph_mandatory_and_optional(self, standard_px_chart, id):
graph = vm.Graph(
figure=standard_px_chart,
id=id,
actions=[],
)

assert graph.id == id
assert graph.type == "graph"
assert graph.figure == standard_px_chart._captured_callable

def test_mandatory_figure_missing(self):
with pytest.raises(ValidationError, match="field required"):
vm.Graph()

def test_failed_graph_with_wrong_figure(self, standard_go_chart):
with pytest.raises(ValidationError, match="must provide a valid CapturedCallable object"):
vm.Graph(
figure=standard_go_chart,
)

def test_getitem_known_args(self, standard_px_chart):
graph = vm.Graph(figure=standard_px_chart)

assert graph["x"] == "gdpPercap"
assert graph["type"] == "graph"

@pytest.mark.parametrize("title, expected", [(None, 24), ("Test", None)])
def test_title_margin_adjustment(gapminder, title, expected):
figure = vm.Graph(figure=px.bar(data_frame=gapminder, x="year", y="pop", title=title)).__call__()
def test_getitem_unknown_args(self, standard_px_chart):
graph = vm.Graph(figure=standard_px_chart)
with pytest.raises(KeyError):
graph["unknown_args"]

@pytest.mark.parametrize("title, expected", [(None, 24), ("Test", None)])
def test_title_margin_adjustment(self, gapminder, title, expected):
figure = vm.Graph(figure=px.bar(data_frame=gapminder, x="year", y="pop", title=title)).__call__()

assert figure.layout.margin.t == expected
assert figure.layout.template.layout.margin.t == 64
assert figure.layout.template.layout.margin.l == 80
assert figure.layout.template.layout.margin.b == 64
assert figure.layout.template.layout.margin.r == 12

def test_set_action_via_validator(self, standard_px_chart, test_action_function):
graph = vm.Graph(figure=standard_px_chart, actions=[Action(function=test_action_function)])
actions_chain = graph.actions[0]
assert actions_chain.trigger.component_property == "clickData"


class TestProcessFigureDataFrame:
def test_process_figure_data_frame_str_df(self, standard_px_chart_with_str_dataframe, gapminder):
data_manager["gapminder"] = gapminder
graph_with_str_df = vm.Graph(
id="text_graph",
figure=standard_px_chart_with_str_dataframe,
)
assert data_manager._get_component_data("text_graph").equals(gapminder)
assert graph_with_str_df["data_frame"] == "gapminder"

def test_process_figure_data_frame_df(self, standard_px_chart, gapminder):
graph_with_df = vm.Graph(
id="text_graph",
figure=standard_px_chart,
)
assert data_manager._get_component_data("text_graph").equals(gapminder)
with pytest.raises(KeyError, match="'data_frame'"):
graph_with_df.figure["data_frame"]


class TestBuild:
def test_create_empty_fig(self, expected_empty_chart):
result = create_empty_fig("NO DATA")
assert result == expected_empty_chart

def test_graph_build(self, standard_px_chart, expected_graph):
graph = vm.Graph(
id="text_graph",
figure=standard_px_chart,
)

assert figure.layout.margin.t == expected
assert figure.layout.template.layout.margin.t == 64
assert figure.layout.template.layout.margin.l == 80
assert figure.layout.template.layout.margin.b == 64
assert figure.layout.template.layout.margin.r == 12
result = json.loads(json.dumps(graph.build(), cls=plotly.utils.PlotlyJSONEncoder))
expected = json.loads(json.dumps(expected_graph, cls=plotly.utils.PlotlyJSONEncoder))
assert result == expected
28 changes: 20 additions & 8 deletions vizro-core/tests/unit/vizro/models/_controls/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

import vizro.models as vm
from vizro.managers import model_manager
from vizro.models._action._actions_chain import ActionsChain
from vizro.models._controls.filter import Filter, _filter_between, _filter_isin
from vizro.models.types import CapturedCallable


class TestFilterFunctions:
Expand Down Expand Up @@ -75,15 +77,15 @@ def test_check_target_present_invalid(self):

@pytest.mark.usefixtures("managers_one_page_two_graphs")
class TestPreBuildMethod:
def test_target_auto_generation_valid(self):
def test_set_targets_valid(self):
# Core of tests is still interface level
filter = vm.Filter(column="country")
# Special case - need filter in the context of page in order to run filter.pre_build
model_manager["test_page"].controls = [filter]
filter.pre_build()
assert set(filter.targets) == {"scatter_chart", "bar_chart"}

def test_target_auto_generation_invalid(self):
def test_set_targets_invalid(self):
filter = vm.Filter(column="invalid_choice")
model_manager["test_page"].controls = [filter]

Expand All @@ -93,7 +95,7 @@ def test_target_auto_generation_invalid(self):
@pytest.mark.parametrize(
"test_input,expected", [("country", "categorical"), ("year", "numerical"), ("lifeExp", "numerical")]
)
def test_column_type_inference(self, test_input, expected):
def test_set_column_type(self, test_input, expected):
filter = vm.Filter(column=test_input)
model_manager["test_page"].controls = [filter]
filter.pre_build()
Expand All @@ -102,7 +104,7 @@ def test_column_type_inference(self, test_input, expected):
@pytest.mark.parametrize(
"test_input,expected", [("country", vm.Dropdown), ("year", vm.RangeSlider), ("lifeExp", vm.RangeSlider)]
)
def test_determine_selectors(self, test_input, expected):
def test_set_selector(self, test_input, expected):
filter = vm.Filter(column=test_input)
model_manager["test_page"].controls = [filter]
filter.pre_build()
Expand All @@ -119,23 +121,23 @@ def test_determine_slider_defaults_invalid_selector(self, test_input):
filter.pre_build()

@pytest.mark.parametrize("test_input", [vm.Slider(), vm.RangeSlider()])
def test_determine_slider_defaults_min_max_none(self, test_input, gapminder):
def test_set_slider_values_defaults_min_max_none(self, test_input, gapminder):
filter = vm.Filter(column="lifeExp", selector=test_input)
model_manager["test_page"].controls = [filter]
filter.pre_build()
assert filter.selector.min == gapminder.lifeExp.min()
assert filter.selector.max == gapminder.lifeExp.max()

@pytest.mark.parametrize("test_input", [vm.Slider(min=3, max=5), vm.RangeSlider(min=3, max=5)])
def test_determine_slider_defaults_min_max_fix(self, test_input):
def test_set_slider_values_defaults_min_max_fix(self, test_input):
filter = vm.Filter(column="lifeExp", selector=test_input)
model_manager["test_page"].controls = [filter]
filter.pre_build()
assert filter.selector.min == 3
assert filter.selector.max == 5

@pytest.mark.parametrize("test_input", [vm.Checklist(), vm.Dropdown(), vm.RadioItems()])
def test_determine_selector_defaults_options_none(self, test_input, gapminder):
def test_set_categorical_selectors_options_defaults_options_none(self, test_input, gapminder):
filter = vm.Filter(column="continent", selector=test_input)
model_manager["test_page"].controls = [filter]
filter.pre_build()
Expand All @@ -149,7 +151,7 @@ def test_determine_selector_defaults_options_none(self, test_input, gapminder):
vm.RadioItems(options=["Africa", "Europe"]),
],
)
def test_determine_selector_defaults_options_fix(self, test_input):
def test_set_categorical_selectors_options_defaults_options_fix(self, test_input):
filter = vm.Filter(column="continent", selector=test_input)
model_manager["test_page"].controls = [filter]
filter.pre_build()
Expand Down Expand Up @@ -182,3 +184,13 @@ def test_filter_build(self, test_column, test_selector):
result = str(filter.build())
expected = str(test_selector.build())
assert result == expected

@pytest.mark.parametrize("test_input", ["country", "year", "lifeExp"])
def test_set_actions(self, test_input):
filter = vm.Filter(column=test_input)
model_manager["test_page"].controls = [filter]
filter.pre_build()
default_action = filter.selector.actions[0]
assert isinstance(default_action, ActionsChain)
assert isinstance(default_action.actions[0].function, CapturedCallable)
assert default_action.actions[0].id == f"filter_action_{filter.id}"
Loading

0 comments on commit 0789bd5

Please sign in to comment.