Skip to content

Commit

Permalink
Improved Read Tool,PAUSED/TERMINATED working, Resources updating in DB
Browse files Browse the repository at this point in the history
  • Loading branch information
luciferlinx101 committed May 31, 2023
1 parent bf87229 commit 30a6a8a
Show file tree
Hide file tree
Showing 9 changed files with 76 additions and 70 deletions.
3 changes: 2 additions & 1 deletion config_template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ SEARCH_ENGINE_ID: YOUR_SEARCH_ENIGNE_ID
# IF YOU DONT HAVE GOOGLE SERACH KEY, USE THIS
SERP_API_KEY: YOUR_SERP_API_KEY

#FILE DIRECTORIES
RESOURCES_OUTPUT_ROOT_DIR: workspace/output
RESOURCES_INPUT_ROOT_DIR: workspace/output
RESOURCES_INPUT_ROOT_DIR: workspace/input

#ENTER YOUR EMAIL CREDENTIALS TO ACCESS EMAIL TOOL
EMAIL_ADDRESS: YOUR_EMAIL_ADDRESS
Expand Down
10 changes: 5 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,10 @@ def load_module_from_file(file_path):
# Function to process the files and extract class information
def process_files(folder_path):
existing_tools = session.query(Tool).all()
print("Exisiting Tool")
# print("Exisiting Tool")
existing_tools = [Tool(id=None, name=tool.name, folder_name=tool.folder_name, class_name=tool.class_name) for tool
in existing_tools]
print(existing_tools)
# print(existing_tools)

new_tools = []
# Iterate over all subfolders
Expand All @@ -205,13 +205,13 @@ def process_files(folder_path):
# filtered_classes = [clazz for clazz in classes if
# clazz["class_name"].endswith("Tool") and clazz["class_name"] != "BaseTool"]
for clazz in classes:
print("Class : ", clazz)
# print("Class : ", clazz)
new_tool = Tool(class_name=clazz["class_name"], folder_name=folder_name, file_name=file_name,
name=clazz["class_attribute"])
new_tools.append(new_tool)

print(existing_tools)
print(new_tools)
# print(existing_tools)
# print(new_tools)

for tool in new_tools:
add_or_update_tool(session, tool_name=tool.name, file_name=tool.file_name, folder_name=tool.folder_name,
Expand Down
52 changes: 40 additions & 12 deletions superagi/agent/super_agi.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,15 @@
from superagi.tools.base_tool import BaseTool
from superagi.types.common import BaseMessage, HumanMessage, AIMessage, SystemMessage
from superagi.vector_store.base import VectorStore
from superagi.models.agent import Agent
from superagi.models.resource import Resource
from superagi.config.config import get_config
import os

FINISH = "finish"
WRITE_FILE = "write_file"
FILE = "FILE"
S3 = "S3"
# print("\033[91m\033[1m"
# + "\nA bit about me...."
# + "\033[0m\033[0m")
Expand All @@ -38,18 +45,30 @@
session = Session()


def checkExecution(execution_id):
try:
execution = session.query(AgentExecution).filter_by(id=execution_id).first()
if execution and execution.status in ['PAUSED', 'COMPLETED']:
return False
else:
return True
except SQLAlchemyError as e:
print("Error occurred during execution status check:", e)
return False
finally:
session.close()
def make_written_file_resource(file_name: str,project_id:int):
path = get_config("RESOURCES_OUTPUT_ROOT_DIR")
storage_type = get_config("STORAGE_TYPE")
file_type = "application/txt"

root_dir = get_config('RESOURCES_OUTPUT_ROOT_DIR')

if root_dir is not None:
root_dir = root_dir if root_dir.startswith("/") else os.getcwd() + "/" + root_dir
root_dir = root_dir if root_dir.endswith("/") else root_dir + "/"
final_path = root_dir + file_name
else:
final_path = os.getcwd() + "/" + file_name
file_size = os.path.getsize(final_path)
resource = None
if storage_type == FILE:
# Save Resource to Database
resource = Resource(name=file_name, path=path + file_name, storage_type=storage_type, size=file_size,
type=file_type,
channel="OUTPUT",
project_id=project_id)
elif storage_type == S3:
pass
return resource


class SuperAgi:
Expand All @@ -60,6 +79,7 @@ def __init__(self,
memory: VectorStore,
tools: List[BaseTool],
agent_config: Any,
agent: Agent,
output_parser: BaseOutputParser = AgentOutputParser(),
):
self.ai_name = ai_name
Expand All @@ -70,6 +90,7 @@ def __init__(self,
self.output_parser = output_parser
self.tools = tools
self.agent_config = agent_config
self.agent = agent
# Init Log
# print("\033[92m\033[1m" + "\nWelcome to SuperAGI - The future of AGI" + "\033[0m\033[0m")

Expand Down Expand Up @@ -117,6 +138,8 @@ def execute(self, goals: List[str]):

superagi_prompt = AgentPromptBuilder.get_superagi_prompt(self.ai_name, self.ai_role, goals, self.tools,
self.agent_config)
# print("BASE PROMPT")
# print(superagi_prompt)
messages = [{"role": "system", "content": superagi_prompt},
{"role": "system", "content": f"The current time and date is {time.strftime('%c')}"}]

Expand Down Expand Up @@ -173,6 +196,11 @@ def execute(self, goals: List[str]):
tool = tools[action.name]
try:
observation = tool.execute(action.args)
if action.name == WRITE_FILE and observation is not None:
resource = make_written_file_resource(file_name=action.args.get('file_name'),
project_id=self.agent.project_id)
if resource is not None:
session.add(resource)
except ValidationError as e:
observation = (
f"Validation Error in args: {str(e)}, args: {action.args}"
Expand Down
2 changes: 1 addition & 1 deletion superagi/controllers/agent_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def update_agent_execution(agent_execution_id: int,
if not agent:
raise HTTPException(status_code=404, detail="Agent not found")
db_agent_execution.agent_id = agent.id
if agent_execution.status != "CREATED" and agent_execution.status != "RUNNING" and agent_execution.status != "PAUSED" and agent_execution.status != "COMPLETED":
if agent_execution.status != "CREATED" and agent_execution.status != "RUNNING" and agent_execution.status != "PAUSED" and agent_execution.status != "COMPLETED" and agent_execution.status != "TERMINATED":
raise HTTPException(status_code=400, detail="Invalid Request")
db_agent_execution.status = agent_execution.status

Expand Down
32 changes: 10 additions & 22 deletions superagi/controllers/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
from pathlib import Path
from fastapi.responses import StreamingResponse



router = APIRouter()


Expand Down Expand Up @@ -61,15 +59,7 @@
# return "Hello!"

@router.post("/add/{project_id}", status_code=201)
async def upload(project_id:int, file: UploadFile = File(...), name=Form(...),size=Form(...),type=Form(...)):
# try:
print("Project!")
print(project_id)
# print(storage_type)
print(size)
print(type)
print(name)
# save_directory = "./input"
async def upload(project_id: int, file: UploadFile = File(...), name=Form(...), size=Form(...), type=Form(...)):
project = db.session.query(Project).filter(Project.id == project_id).first()
if project is None:
raise HTTPException(status_code=400, detail="Project does not exists")
Expand All @@ -78,16 +68,13 @@ async def upload(project_id:int, file: UploadFile = File(...), name=Form(...),si
raise HTTPException(status_code=400, detail="File type not supported!")

storage_type = get_config("STORAGE_TYPE")
save_directory = get_config("RESOURCES_DIR_INPUT")
save_directory = get_config("RESOURCES_INPUT_ROOT_DIR")

print(storage_type)
print(save_directory)

path = ""
if storage_type == "FILE":
# Create the save directory if it doesn't exist
# intput_directory = save_directory
# input_directory = os.path.join(save_directory, '/input')
os.makedirs(save_directory, exist_ok=True)
# Create the file path
file_path = os.path.join(save_directory, file.filename)
Expand All @@ -98,14 +85,14 @@ async def upload(project_id:int, file: UploadFile = File(...), name=Form(...),si
f.write(contents)
file.file.close()
elif storage_type == "S3":
#Logic for uploading to S3
# Logic for uploading to S3
bucket_name = get_config("BUCKET_NAME")
s3_key = get_config("S3_KEY")
# path to be added
pass


resource = Resource(name=name, path=path, storage_type=storage_type,size=size,type=type,channel="INPUT",project_id=project.id)
resource = Resource(name=name, path=path, storage_type=storage_type, size=size, type=type, channel="INPUT",
project_id=project.id)
db.session.add(resource)
db.session.commit()
db.session.flush()
Expand All @@ -114,24 +101,25 @@ async def upload(project_id:int, file: UploadFile = File(...), name=Form(...),si


@router.get("/get/all/{project_id}", status_code=200)
def get_all_resources(project_id:int):
def get_all_resources(project_id: int):
resources = db.session.query(Resource).filter(Resource.project_id == project_id).all()
return resources

@router.get("/get/{resource_id}",status_code=200)

@router.get("/get/{resource_id}", status_code=200)
def download_file_by_id(resource_id: int):
resource = db.session.query(Resource).filter(Resource.id == resource_id).first()
download_file_path = resource.path
file_name = resource.name

if not resource:
raise HTTPException(status_code=400,detail="Resource Not found!")
raise HTTPException(status_code=400, detail="Resource Not found!")

abs_file_path = Path(download_file_path).resolve()
if not abs_file_path.is_file():
raise HTTPException(status_code=404, detail="File not found")

print("Resource: ",resource_id)
print("Resource: ", resource_id)
# return FileResponse(str(abs_file_path), media_type="application/txt", filename=file_name)
return StreamingResponse(
open(str(abs_file_path), "rb"),
Expand Down
26 changes: 17 additions & 9 deletions superagi/jobs/agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,20 +61,29 @@ def execute_next_action(self, agent_execution_id):
session = Session()
agent_execution = session.query(AgentExecution).filter(AgentExecution.id == agent_execution_id).first()
agent = session.query(Agent).filter(Agent.id == agent_execution.agent_id).first()
if agent_execution.status == "PAUSED" or agent_execution.status == "TERMINATED" \
or agent_execution == "COMPLETED":
return

if not agent:
return "Agent Not found"
print("Agent Under Execution : ")
print(agent)
print("Agent Execution : ")
print(agent_execution)


tools = [
GoogleSearchTool(),
WriteFileTool(),
ReadFileTool(),
ReadEmailTool(),
SendEmailTool(),
SendEmailAttachmentTool(),
CreateIssueTool(),
SearchJiraTool(),
GetProjectsTool(),
EditIssueTool()
# ReadEmailTool(),
# SendEmailTool(),
# SendEmailAttachmentTool(),
# CreateIssueTool(),
# SearchJiraTool(),
# GetProjectsTool(),
# EditIssueTool()
]

parsed_config = self.fetch_agent_configuration(session, agent, agent_execution)
Expand All @@ -88,10 +97,9 @@ def execute_next_action(self, agent_execution_id):
for tool in user_tools:
tools.append(AgentExecutor.create_object(tool.class_name, tool.folder_name, tool.file_name))

# TODO: Generate tools array on fly
spawned_agent = SuperAgi(ai_name=parsed_config["name"], ai_role=parsed_config["description"],
llm=OpenAi(model=parsed_config["model"]), tools=tools, memory=memory,
agent_config=parsed_config)
agent_config=parsed_config, agent=agent)
response = spawned_agent.execute(parsed_config["goal"])

session.commit()
Expand Down
16 changes: 0 additions & 16 deletions superagi/tools/file/read_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,8 @@ class ReadFileTool(BaseTool):
description: str = "Reads the file content in a specified location"

def _execute(self, file_name: str):
# root_dir = get_config('RESOURCES_INPUT_ROOT_DIR')
# final_path = file_name
# if root_dir is not None:
# root_dir = root_dir if root_dir.startswith("/") else os.getcwd() + "/" + root_dir
# root_dir = root_dir if root_dir.endswith("/") else root_dir + "/"
# final_path = root_dir + file_name
# else:
# final_path = os.getcwd() + "/" + file_name
#
# directory = os.path.dirname(final_path)
# os.makedirs(directory, exist_ok=True)
#
# file = open(final_path, 'r')
# file_content = file.read()
# return file_content[:1500]
input_root_dir = get_config('RESOURCES_INPUT_ROOT_DIR')
output_root_dir = get_config('RESOURCES_OUTPUT_ROOT_DIR')

final_path = None

if input_root_dir is not None:
Expand Down
4 changes: 1 addition & 3 deletions superagi/tools/file/write_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ class WriteFileTool(BaseTool):

def _execute(self, file_name: str, content: str):
final_path = file_name
# root_dir = get_config('RESOURCES_OUTPUT_ROOT_DIR')
root_dir = get_config('RESOURCES_INTPUT_ROOT_DIR')

root_dir = get_config('RESOURCES_OUTPUT_ROOT_DIR')
if root_dir is not None:
root_dir = root_dir if root_dir.startswith("/") else os.getcwd() + "/" + root_dir
root_dir = root_dir if root_dir.endswith("/") else root_dir + "/"
Expand Down
1 change: 0 additions & 1 deletion workspace/random_topics.txt

This file was deleted.

0 comments on commit 30a6a8a

Please sign in to comment.