Skip to content

Commit

Permalink
Code refatoring and agent level resource
Browse files Browse the repository at this point in the history
  • Loading branch information
luciferlinx101 committed Jun 2, 2023
1 parent 57d652a commit 7578f15
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 51 deletions.
30 changes: 30 additions & 0 deletions migrations/versions/2f97c068fab9_resource_modified.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Resource Modified
Revision ID: 2f97c068fab9
Revises: a91808a89623
Create Date: 2023-06-02 13:13:21.670935
"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = '2f97c068fab9'
down_revision = 'a91808a89623'
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('resources', sa.Column('agent_id', sa.Integer(), nullable=True))
op.drop_column('resources', 'project_id')
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('resources', sa.Column('project_id', sa.INTEGER(), autoincrement=False, nullable=True))
op.drop_column('resources', 'agent_id')
# ### end Alembic commands ###
39 changes: 6 additions & 33 deletions superagi/agent/super_agi.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,32 +45,6 @@
session = Session()


def make_written_file_resource(file_name: str,project_id:int):
path = get_config("RESOURCES_OUTPUT_ROOT_DIR")
storage_type = get_config("STORAGE_TYPE")
file_type = "application/txt"

root_dir = get_config('RESOURCES_OUTPUT_ROOT_DIR')

if root_dir is not None:
root_dir = root_dir if root_dir.startswith("/") else os.getcwd() + "/" + root_dir
root_dir = root_dir if root_dir.endswith("/") else root_dir + "/"
final_path = root_dir + file_name
else:
final_path = os.getcwd() + "/" + file_name
file_size = os.path.getsize(final_path)
resource = None
if storage_type == FILE:
# Save Resource to Database
resource = Resource(name=file_name, path=path + "/" + file_name, storage_type=storage_type, size=file_size,
type=file_type,
channel="OUTPUT",
project_id=project_id)
elif storage_type == S3:
pass
return resource


class SuperAgi:
def __init__(self,
ai_name: str,
Expand Down Expand Up @@ -192,15 +166,14 @@ def execute(self, goals: List[str]):
if action.name in tools:
tool = tools[action.name]
try:
observation = tool.execute(action.args)
if hasattr(tool, 'agent_id'):
observation = tool.execute(action.args, agent_id=self.agent.id)
else:
observation = tool.execute(action.args)
print("Tool Observation : ")
print(observation)
if action.name == WRITE_FILE:
resource = make_written_file_resource(file_name=action.args.get('file_name'),
project_id=self.agent.project_id)
if resource is not None:
session.add(resource)
session.commit()
# if action.name == WRITE_FILE:

except ValidationError as e:
observation = (
f"Validation Error in args: {str(e)}, args: {action.args}"
Expand Down
20 changes: 10 additions & 10 deletions superagi/controllers/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,18 @@
from typing import Annotated
from superagi.models.resource import Resource
from superagi.config.config import get_config
from superagi.models.project import Project
from superagi.models.agent import Agent
from starlette.responses import FileResponse
from pathlib import Path
from fastapi.responses import StreamingResponse

router = APIRouter()

@router.post("/add/{project_id}", status_code=201)
async def upload(project_id: int, file: UploadFile = File(...), name=Form(...), size=Form(...), type=Form(...)):
project = db.session.query(Project).filter(Project.id == project_id).first()
if project is None:
raise HTTPException(status_code=400, detail="Project does not exists")
@router.post("/add/{agent_id}", status_code=201)
async def upload(agent_id: int, file: UploadFile = File(...), name=Form(...), size=Form(...), type=Form(...)):
agent = db.session.query(Agent).filter(Agent.id == agent_id).first()
if agent is None:
raise HTTPException(status_code=400, detail="Agent does not exists")

if not name.endswith(".txt") and not name.endswith(".pdf"):
raise HTTPException(status_code=400, detail="File type not supported!")
Expand All @@ -48,17 +48,17 @@ async def upload(project_id: int, file: UploadFile = File(...), name=Form(...),
pass

resource = Resource(name=name, path=path, storage_type=storage_type, size=size, type=type, channel="INPUT",
project_id=project.id)
agent_id=agent.id)
db.session.add(resource)
db.session.commit()
db.session.flush()
print(resource)
return resource


@router.get("/get/all/{project_id}", status_code=200)
def get_all_resources(project_id: int):
resources = db.session.query(Resource).filter(Resource.project_id == project_id).all()
@router.get("/get/all/{agent_id}", status_code=200)
def get_all_resources(agent_id: int):
resources = db.session.query(Resource).filter(Resource.agent_id == agent_id).all()
return resources


Expand Down
4 changes: 2 additions & 2 deletions superagi/models/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class Resource(DBBaseModel):
size = Column(Integer)
type = Column(String) # application/pdf etc
channel = Column(String) #INPUT,OUTPUT
project_id = Column(Integer)
agent_id = Column(Integer)

def __repr__(self):
return f"Resource(id={self.id}, name='{self.name}', storage_type='{self.storage_type}', path='{self.path}, size='{self.size}', type='{self.type}')"
return f"Resource(id={self.id}, name='{self.name}', storage_type='{self.storage_type}', path='{self.path}, size='{self.size}', type='{self.type}', channel={self.channel}, agent_id={self.agent_id})"
2 changes: 1 addition & 1 deletion superagi/models/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ class User(DBBaseModel):


def __repr__(self):
return f"User(id={self.id}, name='{self.name}', email='{self.email}', password='{self.password}', organisation={self.organisation})"
return f"User(id={self.id}, name='{self.name}', email='{self.email}', password='{self.password}', organisation_id={self.organisation_id})"
45 changes: 40 additions & 5 deletions superagi/tools/file/write_file.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,55 @@
import os
from typing import Type

from pydantic import BaseModel, Field

from superagi.tools.base_tool import BaseTool

from superagi.config.config import get_config
from superagi.models.resource import Resource
from sqlalchemy.orm import sessionmaker
from superagi.models.db import connectDB

engine = connectDB()
Session = sessionmaker(bind=engine)
session = Session()

def make_written_file_resource(file_name: str,agent_id:int):
path = get_config("RESOURCES_OUTPUT_ROOT_DIR")
storage_type = get_config("STORAGE_TYPE")
file_type = "application/txt"

root_dir = get_config('RESOURCES_OUTPUT_ROOT_DIR')

if root_dir is not None:
root_dir = root_dir if root_dir.startswith("/") else os.getcwd() + "/" + root_dir
root_dir = root_dir if root_dir.endswith("/") else root_dir + "/"
final_path = root_dir + file_name
else:
final_path = os.getcwd() + "/" + file_name
file_size = os.path.getsize(final_path)
resource = None
if storage_type == "FILE":
# Save Resource to Database
resource = Resource(name=file_name, path=path + "/" + file_name, storage_type=storage_type, size=file_size,
type=file_type,
channel="OUTPUT",
agent_id=agent_id)
elif storage_type == "S3":
pass
return resource

class WriteFileInput(BaseModel):
"""Input for CopyFileTool."""
file_name: str = Field(..., description="Name of the file to write")
content: str = Field(..., description="File content to write")
agent_id: int = Field(..., description="Agent ID associated with the File")


class WriteFileTool(BaseTool):
name: str = "Write File"
args_schema: Type[BaseModel] = WriteFileInput
description: str = "Writes text to a file"
agent_id: int = None

def _execute(self, file_name: str, content: str):
def _execute(self, file_name: str, content: str, agent_id: int):
final_path = file_name
root_dir = get_config('RESOURCES_OUTPUT_ROOT_DIR')
if root_dir is not None:
Expand All @@ -32,6 +62,11 @@ def _execute(self, file_name: str, content: str):
try:
with open(final_path, 'w', encoding="utf-8") as file:
file.write(content)
resource = make_written_file_resource(file_name=file_name,
agent_id=agent_id)
if resource is not None:
session.add(resource)
session.commit()
return f"File written to successfully - {file_name}"
except Exception as err:
return f"Error: {err}"
return f"Error: {err}"

0 comments on commit 7578f15

Please sign in to comment.