Skip to content

Commit

Permalink
feat(model): Support model cache and first version of Agentic Workflo…
Browse files Browse the repository at this point in the history
…w Expression Language(AWEL)
  • Loading branch information
fangyinc committed Nov 15, 2023
1 parent 8eaf369 commit 6db8c49
Show file tree
Hide file tree
Showing 43 changed files with 3,030 additions and 21 deletions.
57 changes: 57 additions & 0 deletions pilot/awel/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""Agentic Workflow Expression Language (AWEL)"""

from .dag.base import DAGContext, DAG

from .operator.base import BaseOperator, WorkflowRunner, initialize_awel
from .operator.common_operator import (
JoinOperator,
ReduceStreamOperator,
MapOperator,
BranchOperator,
InputOperator,
)

from .operator.stream_operator import (
StreamifyAbsOperator,
UnstreamifyAbsOperator,
TransformStreamAbsOperator,
)

from .task.base import TaskState, TaskOutput, TaskContext, InputContext, InputSource
from .task.task_impl import (
SimpleInputSource,
SimpleCallDataInputSource,
DefaultTaskContext,
DefaultInputContext,
SimpleTaskOutput,
SimpleStreamTaskOutput,
)
from .runner.local_runner import DefaultWorkflowRunner

__all__ = [
"initialize_awel",
"DAGContext",
"DAG",
"BaseOperator",
"JoinOperator",
"ReduceStreamOperator",
"MapOperator",
"BranchOperator",
"InputOperator",
"WorkflowRunner",
"TaskState",
"TaskOutput",
"TaskContext",
"InputContext",
"InputSource",
"DefaultWorkflowRunner",
"SimpleInputSource",
"SimpleCallDataInputSource",
"DefaultTaskContext",
"DefaultInputContext",
"SimpleTaskOutput",
"SimpleStreamTaskOutput",
"StreamifyAbsOperator",
"UnstreamifyAbsOperator",
"TransformStreamAbsOperator",
]
Empty file added pilot/awel/dag/__init__.py
Empty file.
252 changes: 252 additions & 0 deletions pilot/awel/dag/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
from abc import ABC, abstractmethod
from typing import Optional, Dict, List, Sequence, Union, Any
import uuid
import contextvars
import threading
import asyncio
from collections import deque

from ..resource.base import ResourceGroup
from ..task.base import TaskContext

DependencyType = Union["DependencyMixin", Sequence["DependencyMixin"]]


def _is_async_context():
try:
loop = asyncio.get_running_loop()
return asyncio.current_task(loop=loop) is not None
except RuntimeError:
return False


class DependencyMixin(ABC):
@abstractmethod
def set_upstream(self, nodes: DependencyType) -> "DependencyMixin":
"""Set one or more upstream nodes for this node.
Args:
nodes (DependencyType): Upstream nodes to be set to current node.
Returns:
DependencyMixin: Returns self to allow method chaining.
Raises:
ValueError: If no upstream nodes are provided or if an argument is not a DependencyMixin.
"""

@abstractmethod
def set_downstream(self, nodes: DependencyType) -> "DependencyMixin":
"""Set one or more downstream nodes for this node.
Args:
nodes (DependencyType): Downstream nodes to be set to current node.
Returns:
DependencyMixin: Returns self to allow method chaining.
Raises:
ValueError: If no downstream nodes are provided or if an argument is not a DependencyMixin.
"""

def __lshift__(self, nodes: DependencyType) -> DependencyType:
"""Implements self << nodes
Example:
.. code-block:: python
# means node.set_upstream(input_node)
node << input_node
# means node2.set_upstream([input_node])
node2 << [input_node]
"""
self.set_upstream(nodes)
return nodes

def __rshift__(self, nodes: DependencyType) -> DependencyType:
"""Implements self >> nodes
Example:
.. code-block:: python
# means node.set_downstream(next_node)
node >> next_node
# means node2.set_downstream([next_node])
node2 >> [next_node]
"""
self.set_downstream(nodes)
return nodes

def __rrshift__(self, nodes: DependencyType) -> "DependencyMixin":
"""Implements [node] >> self"""
self.__lshift__(nodes)
return self

def __rlshift__(self, nodes: DependencyType) -> "DependencyMixin":
"""Implements [node] << self"""
self.__rshift__(nodes)
return self


class DAGVar:
_thread_local = threading.local()
_async_local = contextvars.ContextVar("current_dag_stack", default=deque())

@classmethod
def enter_dag(cls, dag) -> None:
is_async = _is_async_context()
if is_async:
stack = cls._async_local.get()
stack.append(dag)
cls._async_local.set(stack)
else:
if not hasattr(cls._thread_local, "current_dag_stack"):
cls._thread_local.current_dag_stack = deque()
cls._thread_local.current_dag_stack.append(dag)

@classmethod
def exit_dag(cls) -> None:
is_async = _is_async_context()
if is_async:
stack = cls._async_local.get()
if stack:
stack.pop()
cls._async_local.set(stack)
else:
if (
hasattr(cls._thread_local, "current_dag_stack")
and cls._thread_local.current_dag_stack
):
cls._thread_local.current_dag_stack.pop()

@classmethod
def get_current_dag(cls) -> Optional["DAG"]:
is_async = _is_async_context()
if is_async:
stack = cls._async_local.get()
return stack[-1] if stack else None
else:
if (
hasattr(cls._thread_local, "current_dag_stack")
and cls._thread_local.current_dag_stack
):
return cls._thread_local.current_dag_stack[-1]
return None


class DAGNode(DependencyMixin, ABC):
resource_group: Optional[ResourceGroup] = None
"""The resource group of current DAGNode"""

def __init__(self, dag: Optional["DAG"] = None, node_id: str = None) -> None:
super().__init__()
self._upstream: List["DAGNode"] = []
self._downstream: List["DAGNode"] = []
self._dag: Optional["DAG"] = dag or DAGVar.get_current_dag()
if not node_id and self._dag:
node_id = self._dag._new_node_id()
self._node_id: str = node_id

@property
def node_id(self) -> str:
return self._node_id

def set_node_id(self, node_id: str) -> None:
self._node_id = node_id

@property
def dag(self) -> "DAGNode":
return self._dag

def set_upstream(self, nodes: DependencyType) -> "DAGNode":
self.set_dependency(nodes)

def set_downstream(self, nodes: DependencyType) -> "DAGNode":
self.set_dependency(nodes, is_upstream=False)

@property
def upstream(self) -> List["DAGNode"]:
return self._upstream

@property
def downstream(self) -> List["DAGNode"]:
return self._downstream

def set_dependency(self, nodes: DependencyType, is_upstream: bool = True) -> None:
if not isinstance(nodes, Sequence):
nodes = [nodes]
if not all(isinstance(node, DAGNode) for node in nodes):
raise ValueError(
"all nodes to set dependency to current node must be instance of 'DAGNode'"
)
nodes: Sequence[DAGNode] = nodes
dags = set([node.dag for node in nodes if node.dag])
if self.dag:
dags.add(self.dag)
if not dags:
raise ValueError("set dependency to current node must in a DAG context")
if len(dags) != 1:
raise ValueError(
"set dependency to current node just support in one DAG context"
)
dag = dags.pop()
self._dag = dag

dag._append_node(self)
for node in nodes:
if is_upstream and node not in self.upstream:
node._dag = dag
dag._append_node(node)

self._upstream.append(node)
node._downstream.append(self)
elif node not in self._downstream:
node._dag = dag
dag._append_node(node)

self._downstream.append(node)
node._upstream.append(self)


class DAGContext:
def __init__(self) -> None:
self._curr_task_ctx = None
self._share_data: Dict[str, Any] = {}

@property
def current_task_context(self) -> TaskContext:
return self._curr_task_ctx

def set_current_task_context(self, _curr_task_ctx: TaskContext) -> None:
self._curr_task_ctx = _curr_task_ctx

async def get_share_data(self, key: str) -> Any:
return self._share_data.get(key)

async def save_to_share_data(self, key: str, data: Any) -> None:
self._share_data[key] = data


class DAG:
def __init__(
self, dag_id: str, resource_group: Optional[ResourceGroup] = None
) -> None:
self.node_map: Dict[str, DAGNode] = {}

def _append_node(self, node: DAGNode) -> None:
self.node_map[node.node_id] = node

def _new_node_id(self) -> str:
return str(uuid.uuid4())

def __enter__(self):
DAGVar.enter_dag(self)
return self

def __exit__(self, exc_type, exc_val, exc_tb):
DAGVar.exit_dag()
Empty file.
51 changes: 51 additions & 0 deletions pilot/awel/dag/tests/test_dag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import pytest
import threading
import asyncio
from ..dag import DAG, DAGContext


def test_dag_context_sync():
dag1 = DAG("dag1")
dag2 = DAG("dag2")

with dag1:
assert DAGContext.get_current_dag() == dag1
with dag2:
assert DAGContext.get_current_dag() == dag2
assert DAGContext.get_current_dag() == dag1
assert DAGContext.get_current_dag() is None


def test_dag_context_threading():
def thread_function(dag):
DAGContext.enter_dag(dag)
assert DAGContext.get_current_dag() == dag
DAGContext.exit_dag()

dag1 = DAG("dag1")
dag2 = DAG("dag2")

thread1 = threading.Thread(target=thread_function, args=(dag1,))
thread2 = threading.Thread(target=thread_function, args=(dag2,))

thread1.start()
thread2.start()
thread1.join()
thread2.join()

assert DAGContext.get_current_dag() is None


@pytest.mark.asyncio
async def test_dag_context_async():
async def async_function(dag):
DAGContext.enter_dag(dag)
assert DAGContext.get_current_dag() == dag
DAGContext.exit_dag()

dag1 = DAG("dag1")
dag2 = DAG("dag2")

await asyncio.gather(async_function(dag1), async_function(dag2))

assert DAGContext.get_current_dag() is None
Empty file added pilot/awel/operator/__init__.py
Empty file.
Loading

0 comments on commit 6db8c49

Please sign in to comment.