From 924c5773854f173d606e7a0acdd3eec0909bb9c7 Mon Sep 17 00:00:00 2001 From: Gus Date: Tue, 10 Sep 2024 18:06:35 +0800 Subject: [PATCH] feat: add naive broker --- drive_events/broker.py | 11 ++ drive_events/core.py | 92 ++++++++++++--- drive_events/types.py | 66 ++++++++--- drive_events/utils.py | 13 +-- tests/{test_basic.py => test_define.py} | 52 ++++----- tests/test_run.py | 147 ++++++++++++++++++++++++ tests/test_types.py | 4 +- 7 files changed, 305 insertions(+), 80 deletions(-) create mode 100644 drive_events/broker.py rename tests/{test_basic.py => test_define.py} (71%) create mode 100644 tests/test_run.py diff --git a/drive_events/broker.py b/drive_events/broker.py new file mode 100644 index 0000000..7abf1d5 --- /dev/null +++ b/drive_events/broker.py @@ -0,0 +1,11 @@ +from typing import Any +from .types import BaseEvent, EventInput, Task, GroupEventReturns +from .utils import generate_uuid + + +class BaseBroker: + async def append(self, event: BaseEvent, event_input: EventInput) -> Task: + raise NotImplementedError() + + async def callback_after_run_done(self) -> tuple[BaseEvent, Any]: + raise NotImplementedError() diff --git a/drive_events/core.py b/drive_events/core.py index f009f5a..fbdbd9c 100644 --- a/drive_events/core.py +++ b/drive_events/core.py @@ -1,18 +1,35 @@ import inspect -from typing import Callable, Optional -from .types import BaseEvent, EventFunction, EventGroup -from .utils import logger +import asyncio +from typing import Callable, Optional, Union, Any, Tuple +from .types import BaseEvent, EventFunction, EventGroup, EventInput +from .broker import BaseBroker +from .utils import ( + logger, + string_to_md5_hash, +) class EventEngineCls: - def __init__(self): - self.__function_maps: dict[str, EventFunction] = {} + def __init__(self, name="default", broker: Optional[BaseBroker] = None): + self.name = name + self.broker = broker or BaseBroker() self.__event_maps: dict[str, BaseEvent] = {} + self.__max_group_size = 0 def reset(self): - self.__function_maps = {} self.__event_maps = {} + def make_event(self, func: Union[EventFunction, BaseEvent]) -> BaseEvent: + if isinstance(func, BaseEvent): + self.__event_maps[func.id] = func + return func + assert inspect.iscoroutinefunction( + func + ), "Event function must be a coroutine function" + event = BaseEvent(func) + self.__event_maps[event.id] = event + return event + def listen_groups( self, group_markers: list[BaseEvent], group_name: Optional[str] = None ) -> Callable[[BaseEvent], BaseEvent]: @@ -21,14 +38,23 @@ def listen_groups( ), "group_markers must be a list of BaseEvent" assert all( [m.id in self.__event_maps for m in group_markers] - ), "group_markers must be registered in the same event engine" - group_markers = list(set(group_markers)) + ), f"group_markers must be registered in the same event engine, current event engine is {self.name}" + group_markers_in_dict = {event.id: event for event in group_markers} def decorator(func: BaseEvent) -> BaseEvent: if not isinstance(func, BaseEvent): func = self.make_event(func) + assert ( + func.id in self.__event_maps + ), f"Event function must be registered in the same event engine, current event engine is {self.name}" this_group_name = group_name or f"{len(func.parent_groups)}" - new_group = EventGroup(this_group_name, group_markers) + this_group_hash = string_to_md5_hash(":".join(group_markers_in_dict.keys())) + new_group = EventGroup( + this_group_name, this_group_hash, group_markers_in_dict + ) + self.__max_group_size = max( + self.__max_group_size, len(group_markers_in_dict) + ) if new_group.hash() in func.parent_groups: logger.warning(f"Group {group_markers} already listened by {func}") return func @@ -40,13 +66,41 @@ def decorator(func: BaseEvent) -> BaseEvent: def goto(self, group_markers: list[BaseEvent], *args): raise NotImplementedError() - def make_event(self, func: EventFunction) -> BaseEvent: - if isinstance(func, BaseEvent): - return func - assert inspect.iscoroutinefunction( - func - ), "Event function must be a coroutine function" - event = BaseEvent(func) - self.__function_maps[event.id] = func - self.__event_maps[event.id] = event - return event + async def invoke_event( + self, + event: BaseEvent, + event_input: Optional[EventInput] = None, + global_ctx: Any = None, + max_async_events: Optional[int] = None, + ) -> dict[str, Any]: + this_run_ctx = {} + queue: list[Tuple[BaseEvent, EventInput]] = [(event, event_input)] + + async def run_event(current_event, current_event_input): + result = await current_event.solo_run(current_event_input, global_ctx) + this_run_ctx[current_event.id] = result + for cand_event in self.__event_maps.values(): + cand_event_parents = cand_event.parent_groups + for group_hash, group in cand_event_parents.items(): + if current_event.id in group.events and all( + [event_id in this_run_ctx for event_id in group.events] + ): + this_group_returns = { + event_id: this_run_ctx[event_id] + for event_id in group.events + } + build_input = EventInput( + group_name=group.name, results=this_group_returns + ) + queue.append((cand_event, build_input)) + + while len(queue): + this_batch_events = queue[:max_async_events] if max_async_events else queue + queue = queue[max_async_events:] if max_async_events else [] + logger.debug( + f"Running a turn with {len(this_batch_events)} event tasks, left {len(queue)} event tasks in queue" + ) + await asyncio.gather( + *[run_event(*run_event_input) for run_event_input in this_batch_events] + ) + return this_run_ctx diff --git a/drive_events/types.py b/drive_events/types.py index b23f164..7e2248c 100644 --- a/drive_events/types.py +++ b/drive_events/types.py @@ -1,7 +1,8 @@ from copy import copy from enum import Enum -from dataclasses import dataclass -from typing import Callable, Any, Awaitable, Optional +from dataclasses import dataclass, field +from datetime import datetime +from typing import Callable, Any, Awaitable, Optional, TypeVar, Generic from .utils import ( string_to_md5_hash, @@ -9,23 +10,47 @@ function_or_method_to_repr, ) -GroupEventReturns = dict["BaseEvent", Any] -EventInput = tuple[str, GroupEventReturns] -EventFunction = Callable[[EventInput], Awaitable[Any]] +class ReturnBehavior(Enum): + DISPATCH = "dispatch" + GOTO = "goto" + ABORT = "abort" + + +class TaskStatus(Enum): + RUNNING = "running" + SUCCESS = "success" + FAILURE = "failure" + PENDING = "pending" + + +GroupEventReturns = dict[str, Any] + + +@dataclass +class EventGroupInput: + group_name: str + results: GroupEventReturns + behavior: ReturnBehavior = ReturnBehavior.DISPATCH + + +@dataclass +class EventInput(EventGroupInput): + pass + + +# (group_event_results, global ctx set by user) -> result +EventFunction = Callable[[Optional[EventInput], Optional[Any]], Awaitable[Any]] @dataclass class EventGroup: name: str - events: list["BaseEvent"] - - def __post_init__(self): - self.events = sorted(self.events, key=lambda e: e.id) - self._hash = string_to_md5_hash(":".join([e.id for e in self.events])) + events_hash: str + events: dict[str, "BaseEvent"] def hash(self) -> str: - return self._hash + return self.events_hash class BaseEvent: @@ -43,6 +68,7 @@ def __init__( self.func_inst = func_inst self.id = string_to_md5_hash(function_or_method_to_string(self.func_inst)) self.repr_name = function_or_method_to_repr(self.func_inst) + self.meta = {"func_body": function_or_method_to_string(self.func_inst)} def debug_string(self, exclude_events: Optional[set[str]] = None) -> str: exclude_events = exclude_events or set([self.id]) @@ -52,14 +78,18 @@ def debug_string(self, exclude_events: Optional[set[str]] = None) -> str: def __repr__(self) -> str: return f"Node(source={self.repr_name})" - async def solo_run(self, event_input: EventInput) -> Awaitable[Any]: - return await self.func_inst(event_input) + async def solo_run( + self, event_input: EventInput, global_ctx: Any = None + ) -> Awaitable[Any]: + return await self.func_inst(event_input, global_ctx) -class ReturnBehavior(Enum): - DISPATCH = "dispatch" - GOTO = "goto" - ABORT = "abort" +@dataclass +class Task: + task_id: str + status: TaskStatus = TaskStatus.PENDING + created_at: datetime = field(default_factory=datetime.now) + upated_at: datetime = field(default_factory=datetime.now) def format_parents(parents: dict[str, EventGroup], exclude_events: set[str], indent=""): @@ -70,7 +100,7 @@ def format_parents(parents: dict[str, EventGroup], exclude_events: set[str], ind is_last_group = i == len(parents) - 1 group_prefix = "└─ " if is_last_group else "├─ " result.append(indent + group_prefix + f"<{parent_group.name}>") - for j, parent in enumerate(parent_group.events): + for j, parent in enumerate(parent_group.events.values()): root_events = copy(exclude_events) is_last = j == len(parent_group.events) - 1 child_indent = indent + (" " if is_last_group else "│ ") diff --git a/drive_events/utils.py b/drive_events/utils.py index 217d1e0..1fa5a0f 100644 --- a/drive_events/utils.py +++ b/drive_events/utils.py @@ -1,20 +1,15 @@ +import uuid import logging import asyncio import inspect import hashlib -from typing import Callable, Union -from types import MethodType +from typing import Callable logger = logging.getLogger("drive-events") -def always_get_a_event_loop(): - try: - return asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - return loop +def generate_uuid() -> str: + return str(uuid.uuid4()) def function_or_method_to_repr(func_or_method: Callable) -> str: diff --git a/tests/test_basic.py b/tests/test_define.py similarity index 71% rename from tests/test_basic.py rename to tests/test_define.py index acea306..0c5a87f 100644 --- a/tests/test_basic.py +++ b/tests/test_define.py @@ -7,11 +7,11 @@ @pytest.mark.asyncio async def test_set_and_reset(): @default_drive.make_event - async def a(event: EventInput): + async def a(event: EventInput, global_ctx): return 1 @default_drive.listen_groups([a]) - async def b(event: EventInput): + async def b(event: EventInput, global_ctx): return 2 default_drive.reset() @@ -19,7 +19,7 @@ async def b(event: EventInput): with pytest.raises(AssertionError): @default_drive.listen_groups([a]) - async def b(event: EventInput): + async def b(event: EventInput, global_ctx): return 2 @@ -27,7 +27,7 @@ async def b(event: EventInput): async def test_duplicate_decorator(): @default_drive.make_event @default_drive.make_event - async def a(event: EventInput): + async def a(event: EventInput, global_ctx): return 1 assert isinstance(a, BaseEvent) @@ -36,15 +36,15 @@ async def a(event: EventInput): @pytest.mark.asyncio async def test_order(): @default_drive.make_event - async def a(event: EventInput): + async def a(event: EventInput, global_ctx): return 1 @default_drive.listen_groups([a]) - async def b(event: EventInput): + async def b(event: EventInput, global_ctx): return 2 @default_drive.listen_groups([b]) - async def c(event: EventInput): + async def c(event: EventInput, global_ctx): return 3 print(a.debug_string()) @@ -59,15 +59,15 @@ async def c(event: EventInput): @pytest.mark.asyncio async def test_multi_send(): @default_drive.make_event - async def a(event: EventInput): + async def a(event: EventInput, global_ctx): return 1 @default_drive.listen_groups([a]) - async def b(event: EventInput): + async def b(event: EventInput, global_ctx): return 2 @default_drive.listen_groups([a]) - async def c(event: EventInput): + async def c(event: EventInput, global_ctx): return 3 print(a.debug_string()) @@ -81,19 +81,19 @@ async def c(event: EventInput): @pytest.mark.asyncio async def test_multi_recv(): @default_drive.make_event - async def a(event: EventInput): + async def a(event: EventInput, global_ctx): return 1 @default_drive.listen_groups([a]) - async def a1(event: EventInput): + async def a1(event: EventInput, global_ctx): return 1 @default_drive.make_event - async def b(event: EventInput): + async def b(event: EventInput, global_ctx): return 2 @default_drive.listen_groups([a1, b]) - async def c(event: EventInput): + async def c(event: EventInput, global_ctx): return 3 print(a.debug_string()) @@ -107,48 +107,36 @@ async def c(event: EventInput): @pytest.mark.asyncio async def test_multi_groups(): @default_drive.make_event - async def a0(event: EventInput): + async def a0(event: EventInput, global_ctx): return 0 @default_drive.make_event - async def a1(event: EventInput): + async def a1(event: EventInput, global_ctx): return 0 @default_drive.listen_groups([a0, a1]) @default_drive.listen_groups([a0, a1]) @default_drive.listen_groups([a0, a1]) - async def a(event: EventInput): + async def a(event: EventInput, global_ctx): return 1 - @default_drive.make_event - async def b(event: EventInput): - return 2 - - @default_drive.listen_groups([a]) - @default_drive.listen_groups([b, a]) - async def c(event: EventInput): - return 3 - - print(c.debug_string()) assert await a.solo_run(None) == 1 - assert await b.solo_run(None) == 2 - assert await c.solo_run(None) == 3 @pytest.mark.asyncio async def test_loop(): @default_drive.make_event - async def a(event: EventInput): + async def a(event: EventInput, global_ctx): return 1 @default_drive.listen_groups([a]) - async def b(event: EventInput): + async def b(event: EventInput, global_ctx): return 2 a = default_drive.listen_groups([b])(a) @default_drive.listen_groups([a, b]) - async def c(event: EventInput): + async def c(event: EventInput, global_ctx): return 3 print(a.debug_string()) diff --git a/tests/test_run.py b/tests/test_run.py new file mode 100644 index 0000000..f252c8b --- /dev/null +++ b/tests/test_run.py @@ -0,0 +1,147 @@ +import pytest +from drive_events import default_drive, EventInput +from drive_events.types import ReturnBehavior + + +class DeliberateExcepion(Exception): + pass + + +@pytest.mark.asyncio +async def test_simple_order_run(): + @default_drive.make_event + async def a(event: EventInput, global_ctx): + assert global_ctx == {"test_ctx": 1} + return 1 + + @default_drive.listen_groups([a]) + async def b(event: EventInput, global_ctx): + assert global_ctx == {"test_ctx": 1} + assert event.group_name == "0" + assert event.behavior == ReturnBehavior.DISPATCH + assert event.results == {a.id: 1} + return 2 + + @default_drive.listen_groups([b]) + async def c(event: EventInput, global_ctx): + assert global_ctx == {"test_ctx": 1} + assert event.group_name == "0" + assert event.behavior == ReturnBehavior.DISPATCH + assert event.results == {b.id: 2} + return 3 + + result = await default_drive.invoke_event(a, None, {"test_ctx": 1}) + print(result) + + +@pytest.mark.asyncio +async def test_multi_send(): + @default_drive.make_event + async def a(event: EventInput, global_ctx): + return 1 + + @default_drive.listen_groups([a]) + async def b(event: EventInput, global_ctx): + assert event.group_name == "0" + assert event.behavior == ReturnBehavior.DISPATCH + assert event.results == {a.id: 1} + return 2 + + @default_drive.listen_groups([a]) + async def c(event: EventInput, global_ctx): + assert event.group_name == "0" + assert event.behavior == ReturnBehavior.DISPATCH + assert event.results == {a.id: 1} + return 3 + + result = await default_drive.invoke_event(a, None, {"test_ctx": 1}) + print(result) + + +@pytest.mark.asyncio +async def test_multi_recv(): + @default_drive.make_event + async def start(event: EventInput, global_ctx): + return None + + @default_drive.listen_groups([start]) + async def a(event: EventInput, global_ctx): + return 1 + + @default_drive.listen_groups([start]) + async def b(event: EventInput, global_ctx): + return 2 + + @default_drive.listen_groups([a, b]) + async def c(event: EventInput, global_ctx): + assert event.group_name == "0" + assert event.behavior == ReturnBehavior.DISPATCH + assert event.results == {a.id: 1, b.id: 2} + return 3 + + result = await default_drive.invoke_event(start, None, {"test_ctx": 1}) + print(result) + + +@pytest.mark.asyncio +async def test_multi_groups(): + @default_drive.make_event + async def a(event: EventInput, global_ctx): + return 1 + + @default_drive.listen_groups([a]) + async def b(event: EventInput, global_ctx): + return 2 + + call_c_count = 0 + + @default_drive.listen_groups([a]) + @default_drive.listen_groups([b, a]) + async def c(event: EventInput, global_ctx): + nonlocal call_c_count + if call_c_count == 0: + assert event.group_name == "1" + assert event.behavior == ReturnBehavior.DISPATCH + assert event.results == {a.id: 1} + elif call_c_count == 1: + assert event.group_name == "0" + assert event.behavior == ReturnBehavior.DISPATCH + assert event.results == {a.id: 1, b.id: 2} + else: + assert False, "c should only be called twice" + call_c_count += 1 + return 3 + + result = await default_drive.invoke_event(a, None, {"test_ctx": 1}) + print(result) + + +@pytest.mark.asyncio +async def test_loop(): + call_a_count = 0 + + @default_drive.make_event + async def a(event: EventInput, global_ctx): + nonlocal call_a_count + if call_a_count == 0: + pass + elif call_a_count == 1: + assert event.group_name == "0" + assert event.behavior == ReturnBehavior.DISPATCH + assert event.results == {b.id: 2} + raise DeliberateExcepion() + call_a_count += 1 + return 1 + + @default_drive.listen_groups([a]) + async def b(event: EventInput, global_ctx): + return 2 + + a = default_drive.listen_groups([b])(a) + + @default_drive.listen_groups([a, b]) + async def c(event: EventInput, global_ctx): + return 3 + + with pytest.raises(DeliberateExcepion): + await default_drive.invoke_event(a, None, {"test_ctx": 1}) diff --git a/tests/test_types.py b/tests/test_types.py index caff239..45b9d6d 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -23,9 +23,9 @@ def mock_b(): return 2 n1 = BaseEvent(mock_a) - g1 = EventGroup("1", [n1]) + g1 = EventGroup("1", "hash-xxxxx", {n1.id: n1}) n2 = BaseEvent(mock_a, parent_groups={g1.hash(): g1}) - g2 = EventGroup("2", [n1, n2]) + g2 = EventGroup("2", "hash-yyyy", {n1.id: n1, n2.id: n2}) n3 = BaseEvent(mock_b, parent_groups={g1.hash(): g1, g2.hash(): g2}) print(n1, n1.debug_string())