Skip to content

Commit

Permalink
feat(agents-api): Set/get steps based on workflow state
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Singh Tomer <diwank.singh@gmail.com>
  • Loading branch information
creatorrr committed Sep 3, 2024
1 parent cb8ce3e commit 511842f
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 27 deletions.
1 change: 1 addition & 0 deletions agents-api/agents_api/clients/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,5 @@ async def run_task_execution_workflow(
task_queue=temporal_task_queue,
id=str(job_id),
run_timeout=timedelta(days=31),
# TODO: Should add search_attributes for queryability
)
2 changes: 1 addition & 1 deletion agents-api/agents_api/common/utils/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
simple_jinja_regex = re.compile(r"{{|{%.+}}|%}", re.DOTALL)


# FIXME: This does not work for some reason
# TODO: This does not work for some reason
def is_simple_jinja(template_string: str) -> bool:
return simple_jinja_regex.search(template_string) is None

Expand Down
73 changes: 48 additions & 25 deletions agents-api/agents_api/workflows/task_execution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@
# EvaluateStep # ✅
# | ToolCallStep # ❌
# | PromptStep # 🟡
# | GetStep #
# | SetStep #
# | GetStep #
# | SetStep #
# | LogStep # ✅
# | EmbedStep # ❌
# | SearchStep # ❌
Expand Down Expand Up @@ -142,13 +142,32 @@ async def transition(
raise ApplicationError(f"Error in transition: {e}") from e


async def continue_as_child(
execution_input: ExecutionInput,
start: TransitionTarget,
previous_inputs: list[Any],
user_state: dict[str, Any] = {},
) -> Any:
return await workflow.execute_child_workflow(
TaskExecutionWorkflow.run,
args=[
execution_input,
start,
previous_inputs,
user_state,
],
# TODO: Should add search_attributes for queryability
)


@workflow.defn
class TaskExecutionWorkflow:
user_state: dict[str, Any] = {}

def __init__(self) -> None:
self.user_state = {}

# TODO: Add endpoints for getting and setting user state for an execution
@workflow.query
def get_user_state(self) -> dict[str, Any]:
return self.user_state
Expand All @@ -171,7 +190,11 @@ async def run(
execution_input: ExecutionInput,
start: TransitionTarget = TransitionTarget(workflow="main", step=0),
previous_inputs: list[Any] = [],
user_state: dict[str, Any] = {},
) -> Any:
# Set the initial user state
self.user_state = user_state

workflow.logger.info(
f"TaskExecutionWorkflow for task {execution_input.task.id}"
f" [LOC {start.workflow}.{start.step}]"
Expand Down Expand Up @@ -297,9 +320,9 @@ async def run(
]

# Execute the chosen branch and come back here
result = await workflow.execute_child_workflow(
TaskExecutionWorkflow.run,
args=case_args,
result = await continue_as_child(
*case_args,
user_state=self.user_state,
)

state = PartialTransition(output=result)
Expand Down Expand Up @@ -342,9 +365,9 @@ async def run(
]

# Execute the chosen branch and come back here
result = await workflow.execute_child_workflow(
TaskExecutionWorkflow.run,
args=if_else_args,
result = await continue_as_child(
*if_else_args,
user_state=self.user_state,
)

state = PartialTransition(output=result)
Expand Down Expand Up @@ -376,9 +399,9 @@ async def run(
]

# Execute the chosen branch and come back here
result = await workflow.execute_child_workflow(
TaskExecutionWorkflow.run,
args=foreach_args,
result = await continue_as_child(
*foreach_args,
user_state=self.user_state,
)

state = PartialTransition(output=result)
Expand Down Expand Up @@ -417,9 +440,9 @@ async def run(

# TODO: We should parallelize this
# Execute the chosen branch and come back here
output = await workflow.execute_child_workflow(
TaskExecutionWorkflow.run,
args=map_reduce_args,
output = await continue_as_child(
*map_reduce_args,
user_state=self.user_state,
)

# Reduce the result with the initial value
Expand Down Expand Up @@ -483,9 +506,11 @@ async def run(
next=yield_next_target,
)

result = await workflow.execute_child_workflow(
TaskExecutionWorkflow.run,
args=[execution_input, yield_next_target, [output]],
result = await continue_as_child(
execution_input=execution_input,
start=yield_next_target,
previous_inputs=[output],
user_state=self.user_state,
)

state = PartialTransition(output=result)
Expand Down Expand Up @@ -555,7 +580,7 @@ async def run(
raise ApplicationError("Not implemented")

case _:
# FIXME: Add steps that are not yet supported
# TODO: Add steps that are not yet supported
workflow.logger.error(
f"Unhandled step type: {type(context.current_step).__name__}"
)
Expand Down Expand Up @@ -585,11 +610,9 @@ async def run(
)

# TODO: Should use a continue_as_new workflow if history grows too large
return await workflow.execute_child_workflow(
TaskExecutionWorkflow.run,
args=[
execution_input,
final_state.next,
previous_inputs + [final_state.output],
],
return await continue_as_child(
execution_input=execution_input,
start=final_state.next,
previous_inputs=previous_inputs + [final_state.output],
user_state=self.user_state,
)
2 changes: 1 addition & 1 deletion agents-api/tests/test_entry_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def _(client=cozo_client, developer_id=test_developer_id, session=test_session):
content="test entry content",
)

# FIXME: We should make sessions.updated_at also a updated_at_ms field to avoid this sleep
# TODO: We should make sessions.updated_at also a updated_at_ms field to avoid this sleep
time.sleep(1)

create_entries(
Expand Down

0 comments on commit 511842f

Please sign in to comment.