Skip to content

Commit

Permalink
fix(modal): fix timeout edge cases with custom_id reuse and long-ru…
Browse files Browse the repository at this point in the history
…nning callbacks (#914)
  • Loading branch information
shiftinv authored Dec 28, 2024
1 parent 4f6a371 commit df5e391
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 29 deletions.
1 change: 1 addition & 0 deletions changelog/914.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix :class:`ui.Modal` timeout issues with long-running callbacks, and multiple modals with the same user and ``custom_id``.
104 changes: 77 additions & 27 deletions disnake/ui/modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import os
import sys
import traceback
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, TypeVar, Union
from functools import partial
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, TypeVar, Union

from ..enums import TextInputStyle
from ..utils import MISSING
Expand Down Expand Up @@ -38,14 +39,32 @@ class Modal:
components: |components_type|
The components to display in the modal. Up to 5 action rows.
custom_id: :class:`str`
The custom ID of the modal.
The custom ID of the modal. This is usually not required.
If not given, then a unique one is generated for you.
.. note::
:class:`Modal`\\s are identified based on the user ID that triggered the
modal, and this ``custom_id``.
This can result in collisions when a user opens a modal with the same ``custom_id`` on
two separate devices, for example.
To avoid such issues, consider not specifying a ``custom_id`` to use an automatically generated one,
or include a unique value in the custom ID (e.g. the original interaction ID).
timeout: :class:`float`
The time to wait until the modal is removed from cache, if no interaction is made.
Modals without timeouts are not supported, since there's no event for when a modal is closed.
Defaults to 600 seconds.
"""

__slots__ = ("title", "custom_id", "components", "timeout")
__slots__ = (
"title",
"custom_id",
"components",
"timeout",
"__remove_callback",
"__timeout_handle",
)

def __init__(
self,
Expand All @@ -67,6 +86,11 @@ def __init__(
self.components: List[ActionRow] = rows
self.timeout: float = timeout

# function for the modal to remove itself from the store, if any
self.__remove_callback: Optional[Callable[[Modal], None]] = None
# timer handle for the scheduled timeout
self.__timeout_handle: Optional[asyncio.TimerHandle] = None

def __repr__(self) -> str:
return (
f"<Modal custom_id={self.custom_id!r} title={self.title!r} "
Expand Down Expand Up @@ -212,14 +236,46 @@ async def _scheduled_task(self, interaction: ModalInteraction) -> None:
except Exception as e:
await self.on_error(e, interaction)
finally:
# if the interaction was responded to (no matter if in the callback or error handler),
# the modal closed for the user and therefore can be removed from the store
if interaction.response._response_type is not None:
interaction._state._modal_store.remove_modal(
interaction.author.id, interaction.custom_id
)
if interaction.response._response_type is None:
# If the interaction was not successfully responded to, the modal didn't close for the user.
# Since the timeout was already stopped at this point, restart it.
self._start_listening(self.__remove_callback)
else:
# Otherwise, the modal closed for the user; remove it from the store.
self._stop_listening()

def _start_listening(self, remove_callback: Optional[Callable[[Modal], None]]) -> None:
self.__remove_callback = remove_callback

loop = asyncio.get_running_loop()
if self.__timeout_handle is not None:
# shouldn't get here, but handled just in case
self.__timeout_handle.cancel()

# start timeout
self.__timeout_handle = loop.call_later(self.timeout, self._dispatch_timeout)

def _stop_listening(self) -> None:
# cancel timeout
if self.__timeout_handle is not None:
self.__timeout_handle.cancel()
self.__timeout_handle = None

# remove modal from store
if self.__remove_callback is not None:
self.__remove_callback(self)
self.__remove_callback = None

def _dispatch_timeout(self) -> None:
self._stop_listening()
asyncio.create_task(self.on_timeout(), name=f"disnake-ui-modal-timeout-{self.custom_id}")

def dispatch(self, interaction: ModalInteraction) -> None:
# stop the timeout, but don't remove the modal from the store yet in case the
# response fails and the modal stays open
if self.__timeout_handle is not None:
self.__timeout_handle.cancel()

asyncio.create_task(
self._scheduled_task(interaction), name=f"disnake-ui-modal-dispatch-{self.custom_id}"
)
Expand All @@ -232,28 +288,22 @@ def __init__(self, state: ConnectionState) -> None:
self._modals: Dict[Tuple[int, str], Modal] = {}

def add_modal(self, user_id: int, modal: Modal) -> None:
loop = asyncio.get_running_loop()
self._modals[(user_id, modal.custom_id)] = modal
loop.create_task(self.handle_timeout(user_id, modal.custom_id, modal.timeout))
key = (user_id, modal.custom_id)

def remove_modal(self, user_id: int, modal_custom_id: str) -> Modal:
return self._modals.pop((user_id, modal_custom_id))
# if another modal with the same user+custom_id already exists,
# stop its timeout to avoid overlaps/collisions
if (existing := self._modals.get(key)) is not None:
existing._stop_listening()

async def handle_timeout(self, user_id: int, modal_custom_id: str, timeout: float) -> None:
# Waits for the timeout and then removes the modal from cache, this is done just in case
# the user closed the modal, as there isn't an event for that.
# start timeout, store modal
remove_callback = partial(self.remove_modal, user_id)
modal._start_listening(remove_callback)
self._modals[key] = modal

await asyncio.sleep(timeout)
try:
modal = self.remove_modal(user_id, modal_custom_id)
except KeyError:
# The modal has already been removed.
pass
else:
await modal.on_timeout()
def remove_modal(self, user_id: int, modal: Modal) -> None:
self._modals.pop((user_id, modal.custom_id), None)

def dispatch(self, interaction: ModalInteraction) -> None:
key = (interaction.author.id, interaction.custom_id)
modal = self._modals.get(key)
if modal is not None:
if (modal := self._modals.get(key)) is not None:
modal.dispatch(interaction)
2 changes: 1 addition & 1 deletion examples/interactions/modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self) -> None:
max_length=1024,
),
]
super().__init__(title="Create Tag", custom_id="create_tag", components=components)
super().__init__(title="Create Tag", components=components)

async def callback(self, inter: disnake.ModalInteraction) -> None:
tag_name = inter.text_values["name"]
Expand Down
2 changes: 1 addition & 1 deletion test_bot/cogs/modals.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self) -> None:
style=TextInputStyle.paragraph,
),
]
super().__init__(title="Create Tag", custom_id="create_tag", components=components)
super().__init__(title="Create Tag", components=components)

async def callback(self, inter: disnake.ModalInteraction[commands.Bot]) -> None:
embed = disnake.Embed(title="Tag Creation")
Expand Down

0 comments on commit df5e391

Please sign in to comment.