Skip to content

Commit

Permalink
feat: local runtime linked to AIOS (agiresearch#202)
Browse files Browse the repository at this point in the history
* chore: uvicorn solidary dep

* feat: local runtime linked to AIOS
  • Loading branch information
BRama10 authored Aug 7, 2024
1 parent 2f9dba6 commit da04e43
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 10 deletions.
12 changes: 6 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def main():
max_gpu_memory = args.max_gpu_memory
eval_device = args.eval_device
max_new_tokens = args.max_new_tokens
# scheduler_log_mode = args.scheduler_log_mode
# agent_log_mode = args.agent_log_mode
# llm_kernel_log_mode = args.llm_kernel_log_mode
scheduler_log_mode = args.scheduler_log_mode
agent_log_mode = args.agent_log_mode
llm_kernel_log_mode = args.llm_kernel_log_mode
use_backend = args.use_backend
load_dotenv()

Expand All @@ -44,7 +44,7 @@ def main():
max_gpu_memory=max_gpu_memory,
eval_device=eval_device,
max_new_tokens=max_new_tokens,
log_mode='console',
log_mode=llm_kernel_log_mode,
use_backend=use_backend
)

Expand All @@ -54,12 +54,12 @@ def main():

startScheduler, stopScheduler = useFIFOScheduler(
llm=llm,
log_mode='console',
log_mode=scheduler_log_mode,
get_queue_message=None
)

submitAgent, awaitAgentExecution = useFactory(
log_mode='console',
log_mode=agent_log_mode,
max_workers=500
)

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Pympler==1.0.1
click==8.1.7
ollama
fastapi
uvicorn
torch
transformers
accelerate
Expand Down
100 changes: 96 additions & 4 deletions runtime/server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from fastapi import FastAPI
from fastapi import Depends, FastAPI
from fastapi.middleware.cors import CORSMiddleware

from aios.hooks.llm import useFIFOScheduler, useFactory, useKernel
from aios.hooks.types.llm import AgentSubmitDeclaration, LLMParams

from state import useGlobalState

app = FastAPI()

app.add_middleware(
Expand All @@ -11,9 +16,96 @@
allow_headers=["*"],
)

@app.post("/run_agent")
async def run_agent(*args, **kwargs):
pass

getLLMState, setLLMState, setLLMCallback = useGlobalState()
getScheduler, setScheduler, setSchedulerCallback = useGlobalState()
getFactory, setFactory, setFactoryCallback = useGlobalState()
isRunning, setIsRunning, setIsRunningCallback = useGlobalState()

setIsRunning(False)

#initial
setLLMState(
useKernel(
llm_name='gpt-4o-mini',
max_gpu_memory=None,
eval_device=None,
max_new_tokens=256,
log_mode='console',
use_backend=None
)
)

startScheduler, stopScheduler = useFIFOScheduler(
llm=getLLMState(),
log_mode='console',
get_queue_message=None
)

setScheduler({
'start': startScheduler,
'stop': stopScheduler
})

submitAgent, awaitAgentExecution = useFactory(
log_mode='console',
max_workers=500
)

setFactory({
'submit': submitAgent,
'execute': awaitAgentExecution
})

@app.post("/set_kernel")
async def set_kernel(req: LLMParams):
setLLMState(
useKernel(**req)
)

@app.post("/add_agent")
async def add_agent(
req: AgentSubmitDeclaration,
factory: dict = Depends(getFactory),
is_running: bool = Depends(isRunning),
scheduler: dict = Depends(getScheduler),
):
if not is_running:
scheduler.get('start')()

try:
submit_agent = factory.get('submit')
submit_agent(**req)

return {
'success': True,
'agent': req.agent_name
}
except Exception:
return {
'success': False
}

@app.get("/execute_agents")
async def execute_agents(
factory: dict = Depends(getFactory),
# is_running: bool = Depends(isRunning),
scheduler: dict = Depends(getScheduler),
):
try:
response = factory.get('execute')()
scheduler.get('stop')()
setIsRunning(False)

return {
'success': True,
'response': response
}
except Exception:
return {
'success': False
}


@app.get("/get_all_agents")
async def get_all_agents(*args, **kwargs):
Expand Down
33 changes: 33 additions & 0 deletions runtime/state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import Any, Callable

class GlobalState:
def __init__(self):
self._attribute = None
self._callback = None

@property
def attribute(self):
return self._attribute

@attribute.setter
def attribute(self, value):
self._attribute = value
if self._callback:
self._callback(value)

def set_callback(self, callback_func):
self._callback = callback_func

def useGlobalState():
state = GlobalState()

def getGlobalState():
return state.attribute

def setCallback(cb: Callable[[Any],Any]):
state.set_callback(cb)

def setGlobalState(value: Any):
state.attribute = value

return getGlobalState, setGlobalState, setCallback

0 comments on commit da04e43

Please sign in to comment.