Skip to content

Commit

Permalink
Merge pull request #2903 from plotly/feat/on-error
Browse files Browse the repository at this point in the history
Add callback on_error handler
  • Loading branch information
T4rk1n authored Jul 11, 2024
2 parents 501b715 + db28caf commit 351a81f
Show file tree
Hide file tree
Showing 9 changed files with 197 additions and 24 deletions.
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ max-bool-expr=5
max-branches=15

# Maximum number of locals for function / method body
max-locals=20
max-locals=25

# Maximum number of parents for a class (see R0901).
max-parents=7
Expand Down
73 changes: 57 additions & 16 deletions dash/_callback.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import collections
import hashlib
from functools import wraps
from typing import Callable, Optional, Any

import flask

Expand Down Expand Up @@ -67,6 +68,7 @@ def callback(
cancel=None,
manager=None,
cache_args_to_ignore=None,
on_error: Optional[Callable[[Exception], Any]] = None,
**_kwargs,
):
"""
Expand Down Expand Up @@ -137,6 +139,10 @@ def callback(
this should be a list of argument indices as integers.
:param interval:
Time to wait between the long callback update requests.
:param on_error:
Function to call when the callback raises an exception. Receives the
exception object as first argument. The callback_context can be used
to access the original callback inputs, states and output.
"""

long_spec = None
Expand Down Expand Up @@ -186,6 +192,7 @@ def callback(
long=long_spec,
manager=manager,
running=running,
on_error=on_error,
)


Expand Down Expand Up @@ -226,7 +233,7 @@ def insert_callback(
long=None,
manager=None,
running=None,
dynamic_creator=False,
dynamic_creator: Optional[bool] = False,
no_output=False,
):
if prevent_initial_call is None:
Expand Down Expand Up @@ -272,8 +279,16 @@ def insert_callback(
return callback_id


# pylint: disable=R0912, R0915
def register_callback( # pylint: disable=R0914
def _set_side_update(ctx, response) -> bool:
side_update = dict(ctx.updated_props)
if len(side_update) > 0:
response["sideUpdate"] = side_update
return True
return False


# pylint: disable=too-many-branches,too-many-statements
def register_callback(
callback_list, callback_map, config_prevent_initial_callbacks, *_args, **_kwargs
):
(
Expand All @@ -297,6 +312,7 @@ def register_callback( # pylint: disable=R0914
long = _kwargs.get("long")
manager = _kwargs.get("manager")
running = _kwargs.get("running")
on_error = _kwargs.get("on_error")
if running is not None:
if not isinstance(running[0], (list, tuple)):
running = [running]
Expand Down Expand Up @@ -342,6 +358,8 @@ def add_context(*args, **kwargs):
"callback_context", AttributeDict({"updated_props": {}})
)
callback_manager = long and long.get("manager", app_callback_manager)
error_handler = on_error or kwargs.pop("app_on_error", None)

if has_output:
_validate.validate_output_spec(insert_output, output_spec, Output)

Expand All @@ -351,7 +369,7 @@ def add_context(*args, **kwargs):
args, inputs_state_indices
)

response = {"multi": True}
response: dict = {"multi": True}
has_update = False

if long is not None:
Expand Down Expand Up @@ -440,10 +458,24 @@ def add_context(*args, **kwargs):
isinstance(output_value, dict)
and "long_callback_error" in output_value
):
error = output_value.get("long_callback_error")
raise LongCallbackError(
error = output_value.get("long_callback_error", {})
exc = LongCallbackError(
f"An error occurred inside a long callback: {error['msg']}\n{error['tb']}"
)
if error_handler:
output_value = error_handler(exc)

if output_value is None:
output_value = NoUpdate()
# set_props from the error handler uses the original ctx
# instead of manager.get_updated_props since it runs in the
# request process.
has_update = (
_set_side_update(callback_ctx, response)
or output_value is not None
)
else:
raise exc

if job_running and output_value is not callback_manager.UNDEFINED:
# cached results.
Expand All @@ -462,10 +494,22 @@ def add_context(*args, **kwargs):
if output_value is callback_manager.UNDEFINED:
return to_json(response)
else:
output_value = _invoke_callback(func, *func_args, **func_kwargs)

if NoUpdate.is_no_update(output_value):
raise PreventUpdate
try:
output_value = _invoke_callback(func, *func_args, **func_kwargs)
except PreventUpdate as err:
raise err
except Exception as err: # pylint: disable=broad-exception-caught
if error_handler:
output_value = error_handler(err)

# If the error returns nothing, automatically puts NoUpdate for response.
if output_value is None:
if not multi:
output_value = NoUpdate()
else:
output_value = [NoUpdate for _ in output_spec]
else:
raise err

component_ids = collections.defaultdict(dict)

Expand All @@ -487,12 +531,12 @@ def add_context(*args, **kwargs):
)

for val, spec in zip(flat_output_values, output_spec):
if isinstance(val, NoUpdate):
if NoUpdate.is_no_update(val):
continue
for vali, speci in (
zip(val, spec) if isinstance(spec, list) else [[val, spec]]
):
if not isinstance(vali, NoUpdate):
if not NoUpdate.is_no_update(vali):
has_update = True
id_str = stringify_id(speci["id"])
prop = clean_property_name(speci["property"])
Expand All @@ -506,10 +550,7 @@ def add_context(*args, **kwargs):
flat_output_values = []

if not long:
side_update = dict(callback_ctx.updated_props)
if len(side_update) > 0:
has_update = True
response["sideUpdate"] = side_update
has_update = _set_side_update(callback_ctx, response) or has_update

if not has_update:
raise PreventUpdate
Expand Down
9 changes: 8 additions & 1 deletion dash/dash.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import base64
import traceback
from urllib.parse import urlparse
from typing import Dict, Optional, Union
from typing import Any, Callable, Dict, Optional, Union

import flask

Expand Down Expand Up @@ -369,6 +369,10 @@ class Dash:
:param description: Sets a default description for meta tags on Dash pages (use_pages=True).
:param on_error: Global callback error handler to call when
an exception is raised. Receives the exception object as first argument.
The callback_context can be used to access the original callback inputs,
states and output.
"""

_plotlyjs_url: str
Expand Down Expand Up @@ -409,6 +413,7 @@ def __init__( # pylint: disable=too-many-statements
hooks: Union[RendererHooks, None] = None,
routing_callback_inputs: Optional[Dict[str, Union[Input, State]]] = None,
description=None,
on_error: Optional[Callable[[Exception], Any]] = None,
**obsolete,
):
_validate.check_obsolete(obsolete)
Expand Down Expand Up @@ -520,6 +525,7 @@ def __init__( # pylint: disable=too-many-statements
self._layout = None
self._layout_is_function = False
self.validation_layout = None
self._on_error = on_error
self._extra_components = []

self._setup_dev_tools()
Expand Down Expand Up @@ -1377,6 +1383,7 @@ def dispatch(self):
outputs_list=outputs_list,
long_callback_manager=self._background_manager,
callback_context=g,
app_on_error=self._on_error,
)
)
)
Expand Down
6 changes: 5 additions & 1 deletion dash/long_callback/managers/celery_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def run():
c.ignore_register_page = False
c.updated_props = ProxySetProps(_set_props)
context_value.set(c)
errored = False
try:
if isinstance(user_callback_args, dict):
user_callback_output = fn(*maybe_progress, **user_callback_args)
Expand All @@ -170,13 +171,15 @@ def run():
user_callback_output = fn(*maybe_progress, user_callback_args)
except PreventUpdate:
# Put NoUpdate dict directly to avoid circular imports.
errored = True
cache.set(
result_key,
json.dumps(
{"_dash_no_update": "_dash_no_update"}, cls=PlotlyJSONEncoder
),
)
except Exception as err: # pylint: disable=broad-except
errored = True
cache.set(
result_key,
json.dumps(
Expand All @@ -188,7 +191,8 @@ def run():
},
),
)
else:

if not errored:
cache.set(
result_key, json.dumps(user_callback_output, cls=PlotlyJSONEncoder)
)
Expand Down
9 changes: 7 additions & 2 deletions dash/long_callback/managers/diskcache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ def call_job_fn(self, key, job_fn, args, context):

# pylint: disable-next=not-callable
proc = Process(
target=job_fn, args=(key, self._make_progress_key(key), args, context)
target=job_fn,
args=(key, self._make_progress_key(key), args, context),
)
proc.start()
return proc.pid
Expand Down Expand Up @@ -187,6 +188,7 @@ def run():
c.ignore_register_page = False
c.updated_props = ProxySetProps(_set_props)
context_value.set(c)
errored = False
try:
if isinstance(user_callback_args, dict):
user_callback_output = fn(*maybe_progress, **user_callback_args)
Expand All @@ -195,8 +197,10 @@ def run():
else:
user_callback_output = fn(*maybe_progress, user_callback_args)
except PreventUpdate:
errored = True
cache.set(result_key, {"_dash_no_update": "_dash_no_update"})
except Exception as err: # pylint: disable=broad-except
errored = True
cache.set(
result_key,
{
Expand All @@ -206,7 +210,8 @@ def run():
}
},
)
else:

if not errored:
cache.set(result_key, user_callback_output)

ctx.run(run)
Expand Down
46 changes: 46 additions & 0 deletions tests/integration/callbacks/test_callback_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from dash import Dash, html, Input, Output, set_props


def test_cber001_error_handler(dash_duo):
def global_callback_error_handler(err):
set_props("output-global", {"children": f"global: {err}"})

app = Dash(on_error=global_callback_error_handler)

app.layout = [
html.Button("start", id="start-local"),
html.Button("start-global", id="start-global"),
html.Div(id="output"),
html.Div(id="output-global"),
html.Div(id="error-message"),
]

def on_callback_error(err):
set_props("error-message", {"children": f"message: {err}"})
return f"callback: {err}"

@app.callback(
Output("output", "children"),
Input("start-local", "n_clicks"),
on_error=on_callback_error,
prevent_initial_call=True,
)
def on_start(_):
raise Exception("local error")

@app.callback(
Output("output-global", "children"),
Input("start-global", "n_clicks"),
prevent_initial_call=True,
)
def on_start_global(_):
raise Exception("global error")

dash_duo.start_server(app)
dash_duo.find_element("#start-local").click()

dash_duo.wait_for_text_to_equal("#output", "callback: local error")
dash_duo.wait_for_text_to_equal("#error-message", "message: local error")

dash_duo.find_element("#start-global").click()
dash_duo.wait_for_text_to_equal("#output-global", "global: global error")
50 changes: 50 additions & 0 deletions tests/integration/long_callback/app_bg_on_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from dash import Dash, Input, Output, html, set_props
from tests.integration.long_callback.utils import get_long_callback_manager

long_callback_manager = get_long_callback_manager()
handle = long_callback_manager.handle


def global_error_handler(err):
set_props("global-output", {"children": f"global: {err}"})


app = Dash(
__name__, long_callback_manager=long_callback_manager, on_error=global_error_handler
)

app.layout = [
html.Button("callback on_error", id="start-cb-onerror"),
html.Div(id="cb-output"),
html.Button("global on_error", id="start-global-onerror"),
html.Div(id="global-output"),
]


def callback_on_error(err):
set_props("cb-output", {"children": f"callback: {err}"})


@app.callback(
Output("cb-output", "children"),
Input("start-cb-onerror", "n_clicks"),
prevent_initial_call=True,
background=True,
on_error=callback_on_error,
)
def on_click(_):
raise Exception("callback error")


@app.callback(
Output("global-output", "children"),
Input("start-global-onerror", "n_clicks"),
prevent_initial_call=True,
background=True,
)
def on_click_global(_):
raise Exception("global error")


if __name__ == "__main__":
app.run(debug=True)
Loading

0 comments on commit 351a81f

Please sign in to comment.