Skip to content

Commit

Permalink
fix: retrigger feature
Browse files Browse the repository at this point in the history
  • Loading branch information
gusye1234 committed Sep 15, 2024
1 parent e1cdcde commit 5f39472
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 15 deletions.
53 changes: 41 additions & 12 deletions drive_flow/core.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
import inspect
import asyncio
from typing import Callable, Optional, Union, Any, Tuple
from typing import Callable, Optional, Union, Any, Tuple, Literal
from .types import (
BaseEvent,
EventFunction,
EventGroup,
EventInput,
_SpecialEventReturn,
ReturnBehavior,
InvokeInterCache,
)
from .broker import BaseBroker
from .utils import (
logger,
string_to_md5_hash,
)
from .utils import logger, string_to_md5_hash, generate_uuid


class EventEngineCls:
Expand Down Expand Up @@ -41,7 +39,10 @@ def make_event(self, func: Union[EventFunction, BaseEvent]) -> BaseEvent:
return event

def listen_group(
self, group_markers: list[BaseEvent], group_name: Optional[str] = None
self,
group_markers: list[BaseEvent],
group_name: Optional[str] = None,
retrigger_type: Literal["all", "any"] = "all",
) -> Callable[[BaseEvent], BaseEvent]:
assert all(
[isinstance(m, BaseEvent) for m in group_markers]
Expand All @@ -60,7 +61,10 @@ def decorator(func: BaseEvent) -> BaseEvent:
this_group_name = group_name or f"{len(func.parent_groups)}"
this_group_hash = string_to_md5_hash(":".join(group_markers_in_dict.keys()))
new_group = EventGroup(
this_group_name, this_group_hash, group_markers_in_dict
this_group_name,
this_group_hash,
group_markers_in_dict,
retrigger_type=retrigger_type,
)
self.__max_group_size = max(
self.__max_group_size, len(group_markers_in_dict)
Expand All @@ -83,12 +87,15 @@ async def invoke_event(
global_ctx: Any = None,
max_async_events: Optional[int] = None,
) -> dict[str, Any]:
this_run_ctx = {}
this_run_ctx: dict[str, InvokeInterCache] = {}
queue: list[Tuple[BaseEvent, EventInput]] = [(event, event_input)]

async def run_event(current_event: BaseEvent, current_event_input: Any):
result = await current_event.solo_run(current_event_input, global_ctx)
this_run_ctx[current_event.id] = result
this_run_ctx[current_event.id] = {
"result": result,
"already_sent_to_event_group": set(),
}
if isinstance(result, _SpecialEventReturn):
if result.behavior == ReturnBehavior.GOTO:
group_markers, any_return = result.returns
Expand All @@ -107,13 +114,35 @@ async def run_event(current_event: BaseEvent, current_event_input: Any):
for cand_event in self.__event_maps.values():
cand_event_parents = cand_event.parent_groups
for group_hash, group in cand_event_parents.items():
if current_event.id in group.events and all(
if_current_event_trigger = current_event.id in group.events
if_ctx_cover = all(
[event_id in this_run_ctx for event_id in group.events]
):
)
event_group_id = f"{cand_event.id}:{group_hash}"
if if_current_event_trigger and if_ctx_cover:
if (
any(
[
event_group_id
in this_run_ctx[event_id][
"already_sent_to_event_group"
]
for event_id in group.events
]
)
and group.retrigger_type == "all"
):
# some events already dispatched to this event and group, skip
logger.debug(f"Skip {cand_event} for {current_event}")
continue
this_group_returns = {
event_id: this_run_ctx[event_id]
event_id: this_run_ctx[event_id]["result"]
for event_id in group.events
}
for event_id in group.events:
this_run_ctx[event_id][
"already_sent_to_event_group"
].add(event_group_id)
build_input = EventInput(
group_name=group.name, results=this_group_returns
)
Expand Down
11 changes: 9 additions & 2 deletions drive_flow/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from enum import Enum
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Awaitable, Optional, Union, Callable
from typing import Any, Awaitable, Optional, Union, Callable, TypedDict, Literal

from .utils import (
string_to_md5_hash,
generate_uuid,
function_or_method_to_string,
function_or_method_to_repr,
)
Expand All @@ -24,6 +25,11 @@ class TaskStatus(Enum):
PENDING = "pending"


class InvokeInterCache(TypedDict):
result: Any
already_sent_to_event_group: set[str]


GroupEventReturns = dict[str, Any]


Expand All @@ -36,7 +42,7 @@ class EventGroupInput:

@dataclass
class EventInput(EventGroupInput):
pass
task_id: str = field(default_factory=generate_uuid)


@dataclass
Expand All @@ -62,6 +68,7 @@ class EventGroup:
name: str
events_hash: str
events: dict[str, "BaseEvent"]
retrigger_type: Literal["all", "any"] = "all"

def hash(self) -> str:
return self.events_hash
Expand Down
3 changes: 2 additions & 1 deletion readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -214,4 +214,5 @@ asyncio.run(default_drive.invoke_event(a))
## TODO

- [x] fix: streaming event executation
- [ ] fix: an event never receive the listened events' results twice (de-duplication), unless the group is totally updated.
- [x] fix: an event never receive the listened events' results twice (de-duplication), unless the group is totally updated for `retrigger_type='all'`
- [ ] Add ReAct workflow example
1 change: 1 addition & 0 deletions tests/test_dynamic_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,5 @@ async def c(event: EventInput, global_ctx):
assert False, "should not be called"

result = await default_drive.invoke_event(a, None, {"test_ctx": 1})
assert call_a_count == 1
print(result)
43 changes: 43 additions & 0 deletions tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pytest
from drive_flow import default_drive, EventInput
from drive_flow.types import ReturnBehavior
from drive_flow.dynamic import abort_this


class DeliberateExcepion(Exception):
Expand Down Expand Up @@ -143,6 +144,7 @@ async def c(event: EventInput, global_ctx):

result = await default_drive.invoke_event(a, None, {"test_ctx": 1})
print(result)
assert call_c_count == 2


@pytest.mark.asyncio
Expand Down Expand Up @@ -174,3 +176,44 @@ async def c(event: EventInput, global_ctx):

with pytest.raises(DeliberateExcepion):
await default_drive.invoke_event(a, None, {"test_ctx": 1})
assert call_a_count == 1


@pytest.mark.asyncio
async def test_duplicate_events_not_send():
call_a_count = 0

@default_drive.make_event
async def start(event: EventInput, global_ctx):
pass

@default_drive.listen_group([start])
async def a(event: EventInput, global_ctx):
nonlocal call_a_count
if call_a_count <= 1:
pass
elif call_a_count == 2:
return abort_this()
call_a_count += 1
return 1

a = default_drive.listen_group([a])(a) # self loop

@default_drive.listen_group([start])
async def b(event: EventInput, global_ctx):
return 2

call_c_count = 0

@default_drive.listen_group([a, b])
async def c(event: EventInput, global_ctx):
nonlocal call_c_count
assert call_c_count < 1, "c should only be called once"
call_c_count += 1
print("Call C")
return 3

r = await default_drive.invoke_event(start, None, {"test_ctx": 1})
assert call_a_count == 2
assert call_c_count == 1
print({default_drive.get_event_from_id(k).repr_name: v for k, v in r.items()})

0 comments on commit 5f39472

Please sign in to comment.