From 41a4b7ac61795c6470514a270c449e6328d9c049 Mon Sep 17 00:00:00 2001 From: rakesh-krishna-a-s Date: Tue, 20 Jun 2023 18:18:17 +0530 Subject: [PATCH 001/241] pydantic models --- superagi/controllers/agent.py | 5 +- superagi/types/db.py | 195 ++++++++++++++++++++++++++++++++++ 2 files changed, 198 insertions(+), 2 deletions(-) create mode 100644 superagi/types/db.py diff --git a/superagi/controllers/agent.py b/superagi/controllers/agent.py index 158104ac5..5897ccaf6 100644 --- a/superagi/controllers/agent.py +++ b/superagi/controllers/agent.py @@ -20,13 +20,14 @@ import json from sqlalchemy import func from superagi.helper.auth import check_auth, get_user_organisation +from superagi.types import db router = APIRouter() # CRUD Operations -@router.post("/add", response_model=sqlalchemy_to_pydantic(Agent), status_code=201) -def create_agent(agent: sqlalchemy_to_pydantic(Agent, exclude=["id"]), +@router.post("/add", response_model=db.Agent, status_code=201) +def create_agent(agent: db.Agent, Authorize: AuthJWT = Depends(check_auth)): """ Creates a new Agent diff --git a/superagi/types/db.py b/superagi/types/db.py new file mode 100644 index 000000000..7335dfd3d --- /dev/null +++ b/superagi/types/db.py @@ -0,0 +1,195 @@ +from datetime import datetime + +from pydantic.main import BaseModel + + +class Agent(BaseModel): + name: str + project_id: int + description: str + + class Config: + orm_mode = True + + +class AgentConfiguration(BaseModel): + id: int + agent_id: int + key: str + value: str + + class Config: + orm_mode = True + + +class AgentExecution(BaseModel): + id: int + status: str + name: str + agent_id: int + last_execution_time: datetime + num_of_calls: int + num_of_tokens: int + current_step_id: int + permission_id: int + + class Config: + orm_mode = True + + +class AgentExecutionFeed(BaseModel): + id: int + agent_execution_id: int + agent_id: int + feed: str + role: str + extra_info: str + + class Config: + orm_mode = True + + +class AgentExecutionPermission(BaseModel): + id: int + agent_execution_id: int + agent_id: int + status: str + tool_name: str + user_feedback: str + assistant_reply: str + + class Config: + orm_mode = True + + +class AgentTemplate(BaseModel): + id: int + organisation_id: int + agent_workflow_id: int + name: str + description: str + marketplace_template_id: int + + class Config: + orm_mode = True + + +class AgentTemplateConfig(BaseModel): + id: int + agent_template_id: int + key: str + value: str + + class Config: + orm_mode = True + + +class AgentWorkflow(BaseModel): + id: int + name: str + description: str + + class Config: + orm_mode = True + + +class AgentWorkflowStep(BaseModel): + id: int + agent_workflow_id: int + unique_id: str + prompt: str + variables: str + output_type: str + step_type: str + next_step_id: int + history_enabled: bool + completion_prompt: str + + class Config: + orm_mode = True + + +class Budget(BaseModel): + id: int + budget: float + cycle: str + + class Config: + orm_mode = True + + +class Configuration(BaseModel): + id: int + organisation_id: int + key: str + value: str + + class Config: + orm_mode = True + + +class Organisation(BaseModel): + id: int + name: str + description: str + + class Config: + orm_mode = True + + +class Project(BaseModel): + id: int + name: str + organisation_id: int + description: str + + class Config: + orm_mode = True + + +class Resource(BaseModel): + id: int + name: str + storage_type: str + path: str + size: int + type: str + channel: str + agent_id: int + + class Config: + orm_mode = True + + +class Tool(BaseModel): + id: int + name: str + folder_name: str + class_name: str + file_name: str + + class Config: + orm_mode = True + + +class ToolConfig(BaseModel): + id: int + name: str + key: str + value: str + agent_id: int + + class Config: + orm_mode = True + + +class User(BaseModel): + id: int + name: str + email: str + password: str + organisation_id: int + + class Config: + orm_mode = True + From 1d67676ef3ddf3653e75f83932407a99b75727a9 Mon Sep 17 00:00:00 2001 From: Leon Date: Tue, 20 Jun 2023 15:45:18 -0700 Subject: [PATCH 002/241] untested lance implrementation --- gui/pages/Content/Agents/AgentCreate.js | 8 +-- superagi/jobs/agent_executor.py | 8 +-- superagi/vector_store/lancedb.py | 92 +++++++++++++++++++++++++ superagi/vector_store/vector_factory.py | 22 +++++- 4 files changed, 120 insertions(+), 10 deletions(-) create mode 100644 superagi/vector_store/lancedb.py diff --git a/gui/pages/Content/Agents/AgentCreate.js b/gui/pages/Content/Agents/AgentCreate.js index 75952212a..a1deea743 100644 --- a/gui/pages/Content/Agents/AgentCreate.js +++ b/gui/pages/Content/Agents/AgentCreate.js @@ -57,8 +57,8 @@ export default function AgentCreate({sendAgentData, selectedProjectId, fetchAgen const rollingRef = useRef(null); const [rollingDropdown, setRollingDropdown] = useState(false); - const databases = ["Pinecone"] - const [database, setDatabase] = useState(databases[0]); + const databases = ["Pinecone", "LanceDB"] + const [database, setDatabase] = useState(databases[1]); const databaseRef = useRef(null); const [databaseDropdown, setDatabaseDropdown] = useState(false); @@ -172,7 +172,7 @@ export default function AgentCreate({sendAgentData, selectedProjectId, fetchAgen setToolNames((prevArray) => [...prevArray, tool.name]); } }; - + const removeTool = (indexToDelete) => { setMyTools((prevArray) => { const newArray = [...prevArray]; @@ -339,7 +339,7 @@ export default function AgentCreate({sendAgentData, selectedProjectId, fetchAgen setCreateClickable(false); - // if permission has word restricted change the permission to + // if permission has word restricted change the permission to let permission_type = permission; if (permission.includes("RESTRICTED")) { permission_type = "RESTRICTED"; diff --git a/superagi/jobs/agent_executor.py b/superagi/jobs/agent_executor.py index 636be0aea..198d3e35a 100644 --- a/superagi/jobs/agent_executor.py +++ b/superagi/jobs/agent_executor.py @@ -149,11 +149,11 @@ def execute_next_action(self, agent_execution_id): if parsed_config["LTM_DB"] == "Pinecone": memory = VectorFactory.get_vector_storage("PineCone", "super-agent-index1", OpenAiEmbedding(model_api_key)) - else: - memory = VectorFactory.get_vector_storage("PineCone", "super-agent-index1", + elif parsed_config["LTM_DB"] == "LanceDB": + memory = VectorFactory.get_vector_storage("LanceDB", "super-agent-index1", OpenAiEmbedding(model_api_key)) except: - logger.info("Unable to setup the pinecone connection...") + logger.info("Unable to setup the connection...") memory = None user_tools = session.query(Tool).filter(Tool.id.in_(parsed_config["tools"])).all() @@ -165,7 +165,7 @@ def execute_next_action(self, agent_execution_id): tools = self.set_default_params_tools(tools, parsed_config, agent_execution.agent_id, model_api_key=model_api_key) - + spawned_agent = SuperAgi(ai_name=parsed_config["name"], ai_role=parsed_config["description"], diff --git a/superagi/vector_store/lancedb.py b/superagi/vector_store/lancedb.py new file mode 100644 index 000000000..784cec679 --- /dev/null +++ b/superagi/vector_store/lancedb.py @@ -0,0 +1,92 @@ +import uuid + +from superagi.vector_store.document import Document +from superagi.vector_store.base import VectorStore +from typing import Any, Callable, Optional, Iterable, List + +from superagi.vector_store.embedding.openai import BaseEmbedding + + +class LanceDB(VectorStore): + """ + LanceDB vector store. + + Attributes: + tbl : The LanceDB table. + embedding_model : The embedding model. + text_field : The text field is the name of the field where the corresponding text for an embedding is stored. + table_name : Name for the table in the vector database + """ + def __init__( + self, + tbl: Any, + embedding_model: BaseEmbedding, + text_field: str, + table_name : str, + ): + try: + import lancedb + except ImportError: + raise ValueError("Please install LanceDB to use this vector store.") + + self.tbl = tbl + self.embedding_model = embedding_model + self.text_field = text_field + self.table_name = table_name + + def add_texts( + self, + texts: Iterable[str], + metadatas: Optional['list[dict]'] = None, + ) -> 'list[str]': + """ + Add texts to the vector store. + + Args: + texts : The texts to add. + fields: Additional fields to add. + Returns: + The list of ids vectors stored in LanceDB. + """ + vectors = [] + ids = ids or [str(uuid.uuid4()) for _ in texts] + if len(ids) < len(texts): + raise ValueError("Number of ids must match number of texts.") + + for text, id in zip(texts, ids): + metadata = metadatas.pop(0) if metadatas else {} + metadata[self.text_field] = text + vectors.append((id, self.embedding_model.get_embedding(text), metadata)) + + self.tbl.add(vectors) + return ids + + def get_matching_text(self, query: str, top_k: int = 5, **kwargs: Any) -> List[Document]: + """ + Return docs most similar to query using specified search type. + + Args: + query : The query to search. + top_k : The top k to search. + **kwargs : The keyword arguments to search. + + Returns: + The list of documents most similar to the query + """ + namespace = kwargs.get("namespace", self.namespace) + + embed_text = self.embedding_model.get_embedding(query) + res = self.tbl.search(embed_text).limit(top_k).to_df() + + documents = [] + + for doc in res['matches']: + documents.append( + Document( + text_content=doc.metadata[self.text_field], + metadata=doc.metadata, + ) + ) + + return documents + diff --git a/superagi/vector_store/vector_factory.py b/superagi/vector_store/vector_factory.py index 528885694..b16b8f487 100644 --- a/superagi/vector_store/vector_factory.py +++ b/superagi/vector_store/vector_factory.py @@ -1,9 +1,11 @@ import os import pinecone +import lancedb from pinecone import UnauthorizedException from superagi.vector_store.pinecone import Pinecone +from superagi.vector_store.lancedb import LanceDB from superagi.vector_store import weaviate from superagi.config.config import get_config @@ -44,9 +46,25 @@ def get_vector_storage(cls, vector_store, index_name, embedding_model): return Pinecone(index, embedding_model, 'text') except UnauthorizedException: raise ValueError("PineCone API key not found") - + + if vector_store == "LanceDB": + try: + # connect lancedb to local directory /lancedb/ + uri = "/lancedb" + db = lancedb.connect(uri) + + # create table if does not exist + try: + tbl = db.createTable(index_name) + except: + tbl = db.openTable(index_name) + + return LanceDB(tbl, embedding_model, 'text', index_name) + except: + raise ValueError("VectorStore setup for LanceDB failed") + if vector_store == "Weaviate": - + use_embedded = get_config("WEAVIATE_USE_EMBEDDED") url = get_config("WEAVIATE_URL") api_key = get_config("WEAVIATE_API_KEY") From 109d3a89d02531fe17cf6c94515b60a39e6d91df Mon Sep 17 00:00:00 2001 From: Leon Date: Tue, 20 Jun 2023 15:50:37 -0700 Subject: [PATCH 003/241] Update lancedb.py --- superagi/vector_store/lancedb.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/superagi/vector_store/lancedb.py b/superagi/vector_store/lancedb.py index 784cec679..6d8cc9570 100644 --- a/superagi/vector_store/lancedb.py +++ b/superagi/vector_store/lancedb.py @@ -80,7 +80,7 @@ def get_matching_text(self, query: str, top_k: int = 5, **kwargs: Any) -> List[D documents = [] - for doc in res['matches']: + for doc in res['vector']: documents.append( Document( text_content=doc.metadata[self.text_field], @@ -89,4 +89,3 @@ def get_matching_text(self, query: str, top_k: int = 5, **kwargs: Any) -> List[D ) return documents - From 6e6109700a2ebfa2a31a67b0f2d7c0c1d8a2c462 Mon Sep 17 00:00:00 2001 From: rakesh-krishna-a-s Date: Wed, 21 Jun 2023 11:56:15 +0530 Subject: [PATCH 004/241] temp change --- superagi/controllers/agent.py | 10 +++---- superagi/types/db.py | 49 +++++++++++++++++++++++------------ 2 files changed, 37 insertions(+), 22 deletions(-) diff --git a/superagi/controllers/agent.py b/superagi/controllers/agent.py index 5897ccaf6..7ee13b017 100644 --- a/superagi/controllers/agent.py +++ b/superagi/controllers/agent.py @@ -20,14 +20,14 @@ import json from sqlalchemy import func from superagi.helper.auth import check_auth, get_user_organisation -from superagi.types import db +from superagi.types import db as db_types router = APIRouter() # CRUD Operations -@router.post("/add", response_model=db.Agent, status_code=201) -def create_agent(agent: db.Agent, +@router.post("/add", response_model=db_types.Agent, status_code=201) +def create_agent(agent: db_types.Agent, Authorize: AuthJWT = Depends(check_auth)): """ Creates a new Agent @@ -80,8 +80,8 @@ def get_agent(agent_id: int, return db_agent -@router.put("/update/{agent_id}", response_model=sqlalchemy_to_pydantic(Agent)) -def update_agent(agent_id: int, agent: sqlalchemy_to_pydantic(Agent, exclude=["id"]), +@router.put("/update/{agent_id}", response_model=db_types.Agent) +def update_agent(agent_id: int, agent: db_types.AgentWithoutID, Authorize: AuthJWT = Depends(check_auth)): """ Update an existing Agent diff --git a/superagi/types/db.py b/superagi/types/db.py index 7335dfd3d..3a810084e 100644 --- a/superagi/types/db.py +++ b/superagi/types/db.py @@ -3,7 +3,16 @@ from pydantic.main import BaseModel -class Agent(BaseModel): +class DBModel(BaseModel): + created_at: datetime + updated_at: datetime + + class Config: + orm_mode = True + + +class Agent(DBModel): + id: int name: str project_id: int description: str @@ -12,7 +21,13 @@ class Config: orm_mode = True -class AgentConfiguration(BaseModel): +class AgentWithoutID(Agent): + class Config: + orm_mode = True + exclude = ['id', 'created_at', 'updated_at'] + + +class AgentConfiguration(DBModel): id: int agent_id: int key: str @@ -22,7 +37,7 @@ class Config: orm_mode = True -class AgentExecution(BaseModel): +class AgentExecution(DBModel): id: int status: str name: str @@ -37,7 +52,7 @@ class Config: orm_mode = True -class AgentExecutionFeed(BaseModel): +class AgentExecutionFeed(DBModel): id: int agent_execution_id: int agent_id: int @@ -49,7 +64,7 @@ class Config: orm_mode = True -class AgentExecutionPermission(BaseModel): +class AgentExecutionPermission(DBModel): id: int agent_execution_id: int agent_id: int @@ -62,7 +77,7 @@ class Config: orm_mode = True -class AgentTemplate(BaseModel): +class AgentTemplate(DBModel): id: int organisation_id: int agent_workflow_id: int @@ -74,7 +89,7 @@ class Config: orm_mode = True -class AgentTemplateConfig(BaseModel): +class AgentTemplateConfig(DBModel): id: int agent_template_id: int key: str @@ -84,7 +99,7 @@ class Config: orm_mode = True -class AgentWorkflow(BaseModel): +class AgentWorkflow(DBModel): id: int name: str description: str @@ -93,7 +108,7 @@ class Config: orm_mode = True -class AgentWorkflowStep(BaseModel): +class AgentWorkflowStep(DBModel): id: int agent_workflow_id: int unique_id: str @@ -109,7 +124,7 @@ class Config: orm_mode = True -class Budget(BaseModel): +class Budget(DBModel): id: int budget: float cycle: str @@ -118,7 +133,7 @@ class Config: orm_mode = True -class Configuration(BaseModel): +class Configuration(DBModel): id: int organisation_id: int key: str @@ -128,7 +143,7 @@ class Config: orm_mode = True -class Organisation(BaseModel): +class Organisation(DBModel): id: int name: str description: str @@ -137,7 +152,7 @@ class Config: orm_mode = True -class Project(BaseModel): +class Project(DBModel): id: int name: str organisation_id: int @@ -147,7 +162,7 @@ class Config: orm_mode = True -class Resource(BaseModel): +class Resource(DBModel): id: int name: str storage_type: str @@ -161,7 +176,7 @@ class Config: orm_mode = True -class Tool(BaseModel): +class Tool(DBModel): id: int name: str folder_name: str @@ -172,7 +187,7 @@ class Config: orm_mode = True -class ToolConfig(BaseModel): +class ToolConfig(DBModel): id: int name: str key: str @@ -183,7 +198,7 @@ class Config: orm_mode = True -class User(BaseModel): +class User(DBModel): id: int name: str email: str From ab068ea653363631dbc418f9792d932b7159f5f9 Mon Sep 17 00:00:00 2001 From: rakesh-krishna-a-s Date: Wed, 21 Jun 2023 12:48:30 +0530 Subject: [PATCH 005/241] agent with db types --- superagi/controllers/agent.py | 18 ++- superagi/types/db.py | 204 ++++++++++++++++++++++++++++++---- 2 files changed, 192 insertions(+), 30 deletions(-) diff --git a/superagi/controllers/agent.py b/superagi/controllers/agent.py index 7ee13b017..1b411768e 100644 --- a/superagi/controllers/agent.py +++ b/superagi/controllers/agent.py @@ -1,33 +1,29 @@ from fastapi_sqlalchemy import db -from fastapi import HTTPException, Depends, Request +from fastapi import HTTPException, Depends from fastapi_jwt_auth import AuthJWT from superagi.models.agent import Agent from superagi.models.agent_template import AgentTemplate -from superagi.models.agent_template_config import AgentTemplateConfig from superagi.models.project import Project from fastapi import APIRouter -from pydantic_sqlalchemy import sqlalchemy_to_pydantic - from superagi.models.agent_workflow import AgentWorkflow from superagi.models.types.agent_with_config import AgentWithConfig from superagi.models.agent_config import AgentConfiguration from superagi.models.agent_execution import AgentExecution -from superagi.models.agent_execution_feed import AgentExecutionFeed from superagi.models.tool import Tool from jsonmerge import merge from superagi.worker import execute_agent from datetime import datetime import json from sqlalchemy import func -from superagi.helper.auth import check_auth, get_user_organisation +from superagi.helper.auth import check_auth from superagi.types import db as db_types router = APIRouter() # CRUD Operations -@router.post("/add", response_model=db_types.Agent, status_code=201) -def create_agent(agent: db_types.Agent, +@router.post("/add", response_model=db_types.AgentOut, status_code=201) +def create_agent(agent: db_types.AgentIn, Authorize: AuthJWT = Depends(check_auth)): """ Creates a new Agent @@ -58,7 +54,7 @@ def create_agent(agent: db_types.Agent, return db_agent -@router.get("/get/{agent_id}", response_model=sqlalchemy_to_pydantic(Agent)) +@router.get("/get/{agent_id}", response_model=db_types.AgentOut) def get_agent(agent_id: int, Authorize: AuthJWT = Depends(check_auth)): """ @@ -80,8 +76,8 @@ def get_agent(agent_id: int, return db_agent -@router.put("/update/{agent_id}", response_model=db_types.Agent) -def update_agent(agent_id: int, agent: db_types.AgentWithoutID, +@router.put("/update/{agent_id}", response_model=db_types.AgentOut) +def update_agent(agent_id: int, agent: db_types.AgentIn, Authorize: AuthJWT = Depends(check_auth)): """ Update an existing Agent diff --git a/superagi/types/db.py b/superagi/types/db.py index 3a810084e..89d4d757d 100644 --- a/superagi/types/db.py +++ b/superagi/types/db.py @@ -11,7 +11,7 @@ class Config: orm_mode = True -class Agent(DBModel): +class AgentOut(DBModel): id: int name: str project_id: int @@ -21,13 +21,16 @@ class Config: orm_mode = True -class AgentWithoutID(Agent): +class AgentIn(BaseModel): + name: str + project_id: int + description: str + class Config: orm_mode = True - exclude = ['id', 'created_at', 'updated_at'] -class AgentConfiguration(DBModel): +class AgentConfigurationOut(DBModel): id: int agent_id: int key: str @@ -37,7 +40,16 @@ class Config: orm_mode = True -class AgentExecution(DBModel): +class AgentConfigurationIn(BaseModel): + agent_id: int + key: str + value: str + + class Config: + orm_mode = True + + +class AgentExecutionOut(DBModel): id: int status: str name: str @@ -52,7 +64,32 @@ class Config: orm_mode = True -class AgentExecutionFeed(DBModel): +class AgentExecutionIn(BaseModel): + status: str + name: str + agent_id: int + last_execution_time: datetime + num_of_calls: int + num_of_tokens: int + current_step_id: int + permission_id: int + + class Config: + orm_mode = True + +class AgentExecutionFeedOut(DBModel): + id: int + agent_execution_id: int + agent_id: int + feed: str + role: str + extra_info: str + + class Config: + orm_mode = True + + +class AgentExecutionFeedIn(BaseModel): id: int agent_execution_id: int agent_id: int @@ -64,7 +101,7 @@ class Config: orm_mode = True -class AgentExecutionPermission(DBModel): +class AgentExecutionPermissionOut(DBModel): id: int agent_execution_id: int agent_id: int @@ -77,7 +114,19 @@ class Config: orm_mode = True -class AgentTemplate(DBModel): +class AgentExecutionPermissionIn(BaseModel): + agent_execution_id: int + agent_id: int + status: str + tool_name: str + user_feedback: str + assistant_reply: str + + class Config: + orm_mode = True + + +class AgentTemplateOut(DBModel): id: int organisation_id: int agent_workflow_id: int @@ -89,7 +138,18 @@ class Config: orm_mode = True -class AgentTemplateConfig(DBModel): +class AgentTemplateIn(BaseModel): + organisation_id: int + agent_workflow_id: int + name: str + description: str + marketplace_template_id: int + + class Config: + orm_mode = True + + +class AgentTemplateConfigOut(DBModel): id: int agent_template_id: int key: str @@ -99,7 +159,25 @@ class Config: orm_mode = True -class AgentWorkflow(DBModel): +class AgentTemplateConfigIn(BaseModel): + agent_template_id: int + key: str + value: str + + class Config: + orm_mode = True + + +class AgentWorkflowOut(DBModel): + id: int + name: str + description: str + + class Config: + orm_mode = True + + +class AgentWorkflowIn(BaseModel): id: int name: str description: str @@ -108,7 +186,7 @@ class Config: orm_mode = True -class AgentWorkflowStep(DBModel): +class AgentWorkflowStepOut(DBModel): id: int agent_workflow_id: int unique_id: str @@ -124,7 +202,22 @@ class Config: orm_mode = True -class Budget(DBModel): +class AgentWorkflowStepIn(BaseModel): + id: int + agent_workflow_id: int + unique_id: str + prompt: str + variables: str + output_type: str + step_type: str + next_step_id: int + history_enabled: bool + completion_prompt: str + + class Config: + orm_mode = True + +class BudgetOut(DBModel): id: int budget: float cycle: str @@ -132,8 +225,16 @@ class Budget(DBModel): class Config: orm_mode = True +class BudgetIn(BaseModel): + budget: float + cycle: str -class Configuration(DBModel): + class Config: + orm_mode = True + + + +class ConfigurationOut(DBModel): id: int organisation_id: int key: str @@ -143,7 +244,16 @@ class Config: orm_mode = True -class Organisation(DBModel): +class ConfigurationIn(BaseModel): + id: int + organisation_id: int + key: str + value: str + + class Config: + orm_mode = True + +class OrganisationOut(DBModel): id: int name: str description: str @@ -152,7 +262,14 @@ class Config: orm_mode = True -class Project(DBModel): +class OrganisationIn(BaseModel): + name: str + description: str + + class Config: + orm_mode = True + +class ProjectOut(DBModel): id: int name: str organisation_id: int @@ -162,7 +279,15 @@ class Config: orm_mode = True -class Resource(DBModel): +class ProjectIn(BaseModel): + name: str + organisation_id: int + description: str + + class Config: + orm_mode = True + +class ResourceOut(DBModel): id: int name: str storage_type: str @@ -175,8 +300,20 @@ class Resource(DBModel): class Config: orm_mode = True +class ResourceIn(BaseModel): + name: str + storage_type: str + path: str + size: int + type: str + channel: str + agent_id: int -class Tool(DBModel): + class Config: + orm_mode = True + + +class ToolOut(DBModel): id: int name: str folder_name: str @@ -187,7 +324,17 @@ class Config: orm_mode = True -class ToolConfig(DBModel): +class ToolIn(BaseModel): + name: str + folder_name: str + class_name: str + file_name: str + + class Config: + orm_mode = True + + +class ToolConfigOut(DBModel): id: int name: str key: str @@ -198,7 +345,17 @@ class Config: orm_mode = True -class User(DBModel): +class ToolConfigIn(BaseModel): + name: str + key: str + value: str + agent_id: int + + class Config: + orm_mode = True + + +class UserOut(DBModel): id: int name: str email: str @@ -208,3 +365,12 @@ class User(DBModel): class Config: orm_mode = True + +class UserIn(BaseModel): + name: str + email: str + password: str + organisation_id: int + + class Config: + orm_mode = True \ No newline at end of file From c4c599303fd4d6dcf4c31c55cb4b30163788398e Mon Sep 17 00:00:00 2001 From: rakesh-krishna-a-s Date: Wed, 21 Jun 2023 14:18:50 +0530 Subject: [PATCH 006/241] replaced with pydantic models --- superagi/controllers/agent.py | 12 ++++++------ superagi/controllers/agent_config.py | 10 +++++----- superagi/controllers/agent_execution.py | 11 ++++++----- superagi/controllers/agent_execution_feed.py | 11 ++++++----- .../controllers/agent_execution_permission.py | 10 +++++----- superagi/controllers/agent_template.py | 9 ++++----- superagi/controllers/budget.py | 11 ++++++----- superagi/controllers/config.py | 5 +++-- superagi/controllers/organisation.py | 13 +++++++------ superagi/controllers/project.py | 16 ++++++++-------- superagi/controllers/tool.py | 16 ++++++++-------- superagi/controllers/user.py | 15 ++++++++------- 12 files changed, 72 insertions(+), 67 deletions(-) diff --git a/superagi/controllers/agent.py b/superagi/controllers/agent.py index 1b411768e..987dbb961 100644 --- a/superagi/controllers/agent.py +++ b/superagi/controllers/agent.py @@ -16,14 +16,14 @@ import json from sqlalchemy import func from superagi.helper.auth import check_auth -from superagi.types import db as db_types +from superagi.types.db import AgentOut, AgentIn router = APIRouter() # CRUD Operations -@router.post("/add", response_model=db_types.AgentOut, status_code=201) -def create_agent(agent: db_types.AgentIn, +@router.post("/add", response_model=AgentOut, status_code=201) +def create_agent(agent: AgentIn, Authorize: AuthJWT = Depends(check_auth)): """ Creates a new Agent @@ -54,7 +54,7 @@ def create_agent(agent: db_types.AgentIn, return db_agent -@router.get("/get/{agent_id}", response_model=db_types.AgentOut) +@router.get("/get/{agent_id}", response_model=AgentOut) def get_agent(agent_id: int, Authorize: AuthJWT = Depends(check_auth)): """ @@ -76,8 +76,8 @@ def get_agent(agent_id: int, return db_agent -@router.put("/update/{agent_id}", response_model=db_types.AgentOut) -def update_agent(agent_id: int, agent: db_types.AgentIn, +@router.put("/update/{agent_id}", response_model=AgentOut) +def update_agent(agent_id: int, agent: AgentIn, Authorize: AuthJWT = Depends(check_auth)): """ Update an existing Agent diff --git a/superagi/controllers/agent_config.py b/superagi/controllers/agent_config.py index a923acea8..cf98c3e27 100644 --- a/superagi/controllers/agent_config.py +++ b/superagi/controllers/agent_config.py @@ -2,18 +2,18 @@ from fastapi import HTTPException, Depends from fastapi_jwt_auth import AuthJWT from fastapi_sqlalchemy import db -from pydantic_sqlalchemy import sqlalchemy_to_pydantic from superagi.helper.auth import check_auth from superagi.models.agent import Agent from superagi.models.agent_config import AgentConfiguration from superagi.models.types.agent_config import AgentConfig +from superagi.types.db import AgentConfigurationIn, AgentConfigurationOut router = APIRouter() # CRUD Operations -@router.post("/add", response_model=sqlalchemy_to_pydantic(AgentConfiguration), status_code=201) -def create_agent_config(agent_config: sqlalchemy_to_pydantic(AgentConfiguration, exclude=["id"], ), +@router.post("/add", response_model=AgentConfigurationOut, status_code=201) +def create_agent_config(agent_config: AgentConfigurationIn, Authorize: AuthJWT = Depends(check_auth)): """ Create a new agent configuration by setting a new key and value related to the agent. @@ -39,7 +39,7 @@ def create_agent_config(agent_config: sqlalchemy_to_pydantic(AgentConfiguration, return db_agent_config -@router.get("/get/{agent_config_id}", response_model=sqlalchemy_to_pydantic(AgentConfiguration)) +@router.get("/get/{agent_config_id}", response_model=AgentConfigurationOut) def get_agent(agent_config_id: int, Authorize: AuthJWT = Depends(check_auth)): """ @@ -61,7 +61,7 @@ def get_agent(agent_config_id: int, return db_agent_config -@router.put("/update", response_model=sqlalchemy_to_pydantic(AgentConfiguration)) +@router.put("/update", response_model=AgentConfigurationOut) def update_agent(agent_config: AgentConfig, Authorize: AuthJWT = Depends(check_auth)): """ diff --git a/superagi/controllers/agent_execution.py b/superagi/controllers/agent_execution.py index ea09a7abd..3da24e3a7 100644 --- a/superagi/controllers/agent_execution.py +++ b/superagi/controllers/agent_execution.py @@ -11,13 +11,14 @@ from pydantic_sqlalchemy import sqlalchemy_to_pydantic from sqlalchemy import desc from superagi.helper.auth import check_auth +from superagi.types.db import AgentExecutionOut, AgentExecutionIn router = APIRouter() # CRUD Operations -@router.post("/add", response_model=sqlalchemy_to_pydantic(AgentExecution), status_code=201) -def create_agent_execution(agent_execution: sqlalchemy_to_pydantic(AgentExecution, exclude=["id"]), +@router.post("/add", response_model=AgentExecutionOut, status_code=201) +def create_agent_execution(agent_execution: AgentExecutionIn, Authorize: AuthJWT = Depends(check_auth)): """ Create a new agent execution/run. @@ -49,7 +50,7 @@ def create_agent_execution(agent_execution: sqlalchemy_to_pydantic(AgentExecutio return db_agent_execution -@router.get("/get/{agent_execution_id}", response_model=sqlalchemy_to_pydantic(AgentExecution)) +@router.get("/get/{agent_execution_id}", response_model=AgentExecutionOut) def get_agent_execution(agent_execution_id: int, Authorize: AuthJWT = Depends(check_auth)): """ @@ -71,9 +72,9 @@ def get_agent_execution(agent_execution_id: int, return db_agent_execution -@router.put("/update/{agent_execution_id}", response_model=sqlalchemy_to_pydantic(AgentExecution)) +@router.put("/update/{agent_execution_id}", response_model=AgentExecutionOut) def update_agent_execution(agent_execution_id: int, - agent_execution: sqlalchemy_to_pydantic(AgentExecution, exclude=["id"]), + agent_execution: AgentExecutionIn, Authorize: AuthJWT = Depends(check_auth)): """Update details of particular agent_execution by agent_execution_id""" diff --git a/superagi/controllers/agent_execution_feed.py b/superagi/controllers/agent_execution_feed.py index 758d0352d..5bd248522 100644 --- a/superagi/controllers/agent_execution_feed.py +++ b/superagi/controllers/agent_execution_feed.py @@ -11,13 +11,14 @@ from superagi.helper.feed_parser import parse_feed from superagi.models.agent_execution import AgentExecution from superagi.models.agent_execution_feed import AgentExecutionFeed +from superagi.types.db import AgentExecutionFeedOut, AgentExecutionFeedIn router = APIRouter() # CRUD Operations -@router.post("/add", response_model=sqlalchemy_to_pydantic(AgentExecutionFeed), status_code=201) -def create_agent_execution_feed(agent_execution_feed: sqlalchemy_to_pydantic(AgentExecutionFeed, exclude=["id"]), +@router.post("/add", response_model=AgentExecutionFeedOut, status_code=201) +def create_agent_execution_feed(agent_execution_feed: AgentExecutionFeedIn, Authorize: AuthJWT = Depends(check_auth)): """ Add a new agent execution feed. @@ -45,7 +46,7 @@ def create_agent_execution_feed(agent_execution_feed: sqlalchemy_to_pydantic(Age return db_agent_execution_feed -@router.get("/get/{agent_execution_feed_id}", response_model=sqlalchemy_to_pydantic(AgentExecutionFeed)) +@router.get("/get/{agent_execution_feed_id}", response_model=AgentExecutionFeedOut) def get_agent_execution_feed(agent_execution_feed_id: int, Authorize: AuthJWT = Depends(check_auth)): """ @@ -68,9 +69,9 @@ def get_agent_execution_feed(agent_execution_feed_id: int, return db_agent_execution_feed -@router.put("/update/{agent_execution_feed_id}", response_model=sqlalchemy_to_pydantic(AgentExecutionFeed)) +@router.put("/update/{agent_execution_feed_id}", response_model=AgentExecutionFeedOut) def update_agent_execution_feed(agent_execution_feed_id: int, - agent_execution_feed: sqlalchemy_to_pydantic(AgentExecutionFeed, exclude=["id"]), + agent_execution_feed: AgentExecutionFeedIn, Authorize: AuthJWT = Depends(check_auth)): """ Update a particular agent execution feed. diff --git a/superagi/controllers/agent_execution_permission.py b/superagi/controllers/agent_execution_permission.py index e7c4f9e3b..bbc7eb1d7 100644 --- a/superagi/controllers/agent_execution_permission.py +++ b/superagi/controllers/agent_execution_permission.py @@ -10,6 +10,7 @@ from fastapi import APIRouter from pydantic_sqlalchemy import sqlalchemy_to_pydantic from superagi.helper.auth import check_auth +from superagi.types.db import AgentExecutionPermissionOut, AgentExecutionPermissionIn router = APIRouter() @@ -37,9 +38,9 @@ def get_agent_execution_permission(agent_execution_permission_id: int, return db_agent_execution_permission -@router.post("/add", response_model=sqlalchemy_to_pydantic(AgentExecutionPermission)) +@router.post("/add", response_model=AgentExecutionPermissionOut) def create_agent_execution_permission( - agent_execution_permission: sqlalchemy_to_pydantic(AgentExecutionPermission, exclude=["id"]) + agent_execution_permission: AgentExecutionPermissionIn , Authorize: AuthJWT = Depends(check_auth)): """ Create a new agent execution permission. @@ -58,10 +59,9 @@ def create_agent_execution_permission( @router.patch("/update/{agent_execution_permission_id}", - response_model=sqlalchemy_to_pydantic(AgentExecutionPermission, exclude=["id"])) + response_model=AgentExecutionPermissionIn) def update_agent_execution_permission(agent_execution_permission_id: int, - agent_execution_permission: sqlalchemy_to_pydantic(AgentExecutionPermission, - exclude=["id"]), + agent_execution_permission: AgentExecutionPermissionIn, Authorize: AuthJWT = Depends(check_auth)): """ Update an AgentExecutionPermission in the database. diff --git a/superagi/controllers/agent_template.py b/superagi/controllers/agent_template.py index 8839cdb3f..e993fe9af 100644 --- a/superagi/controllers/agent_template.py +++ b/superagi/controllers/agent_template.py @@ -1,8 +1,6 @@ from fastapi import APIRouter from fastapi import HTTPException, Depends from fastapi_sqlalchemy import db -from pydantic_sqlalchemy import sqlalchemy_to_pydantic - from main import get_config from superagi.helper.auth import get_user_organisation from superagi.models.agent import Agent @@ -11,12 +9,13 @@ from superagi.models.agent_template_config import AgentTemplateConfig from superagi.models.agent_workflow import AgentWorkflow from superagi.models.tool import Tool +from superagi.types.db import AgentTemplateIn, AgentTemplateOut router = APIRouter() -@router.post("/create", status_code=201, response_model=sqlalchemy_to_pydantic(AgentTemplate)) -def create_agent_template(agent_template: sqlalchemy_to_pydantic(AgentTemplate, exclude=["id"]), +@router.post("/create", status_code=201, response_model=AgentTemplateOut) +def create_agent_template(agent_template: AgentTemplateIn, organisation=Depends(get_user_organisation)): """ Create an agent template. @@ -81,7 +80,7 @@ def get_agent_template(template_source, agent_template_id: int, organisation=Dep return template -@router.post("/update_details/{agent_template_id}", response_model=sqlalchemy_to_pydantic(AgentTemplate)) +@router.post("/update_details/{agent_template_id}", response_model=AgentTemplateOut) def update_agent_template(agent_template_id: int, agent_configs: dict, organisation=Depends(get_user_organisation)): diff --git a/superagi/controllers/budget.py b/superagi/controllers/budget.py index 41452e25a..efbaad1f9 100644 --- a/superagi/controllers/budget.py +++ b/superagi/controllers/budget.py @@ -6,12 +6,13 @@ from superagi.helper.auth import check_auth from superagi.models.budget import Budget +from superagi.types.db import BudgetIn, BudgetOut router = APIRouter() -@router.post("/add", response_model=sqlalchemy_to_pydantic(Budget), status_code=201) -def create_budget(budget: sqlalchemy_to_pydantic(Budget, exclude=["id"]), +@router.post("/add", response_model=BudgetOut, status_code=201) +def create_budget(budget: BudgetIn, Authorize: AuthJWT = Depends(check_auth)): """ Create a new budget. @@ -34,7 +35,7 @@ def create_budget(budget: sqlalchemy_to_pydantic(Budget, exclude=["id"]), return new_budget -@router.get("/get/{budget_id}", response_model=sqlalchemy_to_pydantic(Budget)) +@router.get("/get/{budget_id}", response_model=BudgetOut) def get_budget(budget_id: int, Authorize: AuthJWT = Depends(check_auth)): """ @@ -54,8 +55,8 @@ def get_budget(budget_id: int, return db_budget -@router.put("/update/{budget_id}", response_model=sqlalchemy_to_pydantic(Budget)) -def update_budget(budget_id: int, budget: sqlalchemy_to_pydantic(Budget, exclude=["id"]), +@router.put("/update/{budget_id}", response_model=BudgetOut) +def update_budget(budget_id: int, budget: BudgetIn, Authorize: AuthJWT = Depends(check_auth)): """ Update budget details by budget_id. diff --git a/superagi/controllers/config.py b/superagi/controllers/config.py index 0c16841f8..2130d569c 100644 --- a/superagi/controllers/config.py +++ b/superagi/controllers/config.py @@ -9,14 +9,15 @@ from fastapi_jwt_auth import AuthJWT from superagi.helper.encyption_helper import encrypt_data,decrypt_data from superagi.lib.logger import logger +from superagi.types.db import ConfigurationIn, ConfigurationOut router = APIRouter() # CRUD Operations @router.post("/add/organisation/{organisation_id}", status_code=201, - response_model=sqlalchemy_to_pydantic(Configuration)) -def create_config(config: sqlalchemy_to_pydantic(Configuration, exclude=["id"]), organisation_id: int, + response_model=ConfigurationOut) +def create_config(config: ConfigurationIn, organisation_id: int, Authorize: AuthJWT = Depends(check_auth)): """ Creates a new Organisation level config. diff --git a/superagi/controllers/organisation.py b/superagi/controllers/organisation.py index 17af35318..07013d086 100644 --- a/superagi/controllers/organisation.py +++ b/superagi/controllers/organisation.py @@ -9,13 +9,14 @@ from superagi.models.project import Project from superagi.models.user import User from superagi.lib.logger import logger +from superagi.types.db import OrganisationIn, OrganisationOut router = APIRouter() # CRUD Operations -@router.post("/add", response_model=sqlalchemy_to_pydantic(Organisation), status_code=201) -def create_organisation(organisation: sqlalchemy_to_pydantic(Organisation, exclude=["id"]), +@router.post("/add", response_model=OrganisationOut, status_code=201) +def create_organisation(organisation: OrganisationIn, Authorize: AuthJWT = Depends(check_auth)): """ Create a new organisation. @@ -43,7 +44,7 @@ def create_organisation(organisation: sqlalchemy_to_pydantic(Organisation, exclu return new_organisation -@router.get("/get/{organisation_id}", response_model=sqlalchemy_to_pydantic(Organisation)) +@router.get("/get/{organisation_id}", response_model=OrganisationOut) def get_organisation(organisation_id: int, Authorize: AuthJWT = Depends(check_auth)): """ Get organisation details by organisation_id. @@ -65,8 +66,8 @@ def get_organisation(organisation_id: int, Authorize: AuthJWT = Depends(check_au return db_organisation -@router.put("/update/{organisation_id}", response_model=sqlalchemy_to_pydantic(Organisation)) -def update_organisation(organisation_id: int, organisation: sqlalchemy_to_pydantic(Organisation, exclude=["id"]), +@router.put("/update/{organisation_id}", response_model=OrganisationOut) +def update_organisation(organisation_id: int, organisation: OrganisationIn, Authorize: AuthJWT = Depends(check_auth)): """ Update organisation details by organisation_id. @@ -94,7 +95,7 @@ def update_organisation(organisation_id: int, organisation: sqlalchemy_to_pydant return db_organisation -@router.get("/get/user/{user_id}", response_model=sqlalchemy_to_pydantic(Organisation), status_code=201) +@router.get("/get/user/{user_id}", response_model=OrganisationOut, status_code=201) def get_organisations_by_user(user_id: int): """ Get organisations associated with a user. diff --git a/superagi/controllers/project.py b/superagi/controllers/project.py index 8894137cb..a84008173 100644 --- a/superagi/controllers/project.py +++ b/superagi/controllers/project.py @@ -1,19 +1,19 @@ -from fastapi_sqlalchemy import DBSessionMiddleware, db -from fastapi import HTTPException, Depends, Request +from fastapi_sqlalchemy import db +from fastapi import HTTPException, Depends from fastapi_jwt_auth import AuthJWT from superagi.models.project import Project from superagi.models.organisation import Organisation from fastapi import APIRouter -from pydantic_sqlalchemy import sqlalchemy_to_pydantic from superagi.helper.auth import check_auth from superagi.lib.logger import logger +from superagi.types.db import ProjectIn, ProjectOut router = APIRouter() # CRUD Operations -@router.post("/add", response_model=sqlalchemy_to_pydantic(Project), status_code=201) -def create_project(project: sqlalchemy_to_pydantic(Project, exclude=["id"]), +@router.post("/add", response_model=ProjectOut, status_code=201) +def create_project(project: ProjectIn, Authorize: AuthJWT = Depends(check_auth)): """ Create a new project. @@ -47,7 +47,7 @@ def create_project(project: sqlalchemy_to_pydantic(Project, exclude=["id"]), return project -@router.get("/get/{project_id}", response_model=sqlalchemy_to_pydantic(Project)) +@router.get("/get/{project_id}", response_model=ProjectOut) def get_project(project_id: int, Authorize: AuthJWT = Depends(check_auth)): """ Get project details by project_id. @@ -69,8 +69,8 @@ def get_project(project_id: int, Authorize: AuthJWT = Depends(check_auth)): return db_project -@router.put("/update/{project_id}", response_model=sqlalchemy_to_pydantic(Project)) -def update_project(project_id: int, project: sqlalchemy_to_pydantic(Project, exclude=["id"]), +@router.put("/update/{project_id}", response_model=ProjectOut) +def update_project(project_id: int, project: ProjectIn, Authorize: AuthJWT = Depends(check_auth)): """ Update a project detail by project_id. diff --git a/superagi/controllers/tool.py b/superagi/controllers/tool.py index ed2106a9c..28ce8588f 100644 --- a/superagi/controllers/tool.py +++ b/superagi/controllers/tool.py @@ -2,25 +2,25 @@ from fastapi import HTTPException, Depends from fastapi_jwt_auth import AuthJWT from fastapi_sqlalchemy import db -from pydantic_sqlalchemy import sqlalchemy_to_pydantic from superagi.helper.auth import check_auth from superagi.models.tool import Tool +from superagi.types.db import ToolIn, ToolOut router = APIRouter() # CRUD Operations -@router.post("/add", response_model=sqlalchemy_to_pydantic(Tool), status_code=201) +@router.post("/add", response_model=ToolOut, status_code=201) def create_tool( - tool: sqlalchemy_to_pydantic(Tool, exclude=["id"]), + tool: ToolIn, Authorize: AuthJWT = Depends(check_auth), ): """ Create a new tool. Args: - tool (sqlalchemy_to_pydantic(Tool, exclude=["id"])): Tool data. + tool (ToolIn): Tool data. Returns: Tool: The created tool. @@ -41,7 +41,7 @@ def create_tool( return db_tool -@router.get("/get/{tool_id}", response_model=sqlalchemy_to_pydantic(Tool)) +@router.get("/get/{tool_id}", response_model=ToolOut) def get_tool( tool_id: int, Authorize: AuthJWT = Depends(check_auth), @@ -66,10 +66,10 @@ def get_tool( return db_tool -@router.put("/update/{tool_id}", response_model=sqlalchemy_to_pydantic(Tool)) +@router.put("/update/{tool_id}", response_model=ToolOut) def update_tool( tool_id: int, - tool: sqlalchemy_to_pydantic(Tool, exclude=["id"]), + tool: ToolIn, Authorize: AuthJWT = Depends(check_auth), ): """ @@ -77,7 +77,7 @@ def update_tool( Args: tool_id (int): ID of the tool. - tool (sqlalchemy_to_pydantic(Tool, exclude=["id"])): Updated tool data. + tool (ToolIn): Updated tool data. Returns: Tool: The updated tool details. diff --git a/superagi/controllers/user.py b/superagi/controllers/user.py index c1c278b96..297b4bfb0 100644 --- a/superagi/controllers/user.py +++ b/superagi/controllers/user.py @@ -9,19 +9,20 @@ from pydantic_sqlalchemy import sqlalchemy_to_pydantic from superagi.helper.auth import check_auth from superagi.lib.logger import logger +from superagi.types.db import UserIn, UserOut router = APIRouter() # CRUD Operations -@router.post("/add", response_model=sqlalchemy_to_pydantic(User), status_code=201) -def create_user(user: sqlalchemy_to_pydantic(User, exclude=["id"]), +@router.post("/add", response_model=UserOut, status_code=201) +def create_user(user: UserIn, Authorize: AuthJWT = Depends(check_auth)): """ Create a new user. Args: - user (sqlalchemy_to_pydantic(User, exclude=["id"])): User data. + user (UserIn): User data. Returns: User: The created user. @@ -44,7 +45,7 @@ def create_user(user: sqlalchemy_to_pydantic(User, exclude=["id"]), return db_user -@router.get("/get/{user_id}", response_model=sqlalchemy_to_pydantic(User)) +@router.get("/get/{user_id}", response_model=UserOut) def get_user(user_id: int, Authorize: AuthJWT = Depends(check_auth)): """ @@ -68,16 +69,16 @@ def get_user(user_id: int, return db_user -@router.put("/update/{user_id}", response_model=sqlalchemy_to_pydantic(User)) +@router.put("/update/{user_id}", response_model=UserOut) def update_user(user_id: int, - user: sqlalchemy_to_pydantic(User, exclude=["id"]), + user: UserIn, Authorize: AuthJWT = Depends(check_auth)): """ Update a particular user. Args: user_id (int): ID of the user. - user (sqlalchemy_to_pydantic(User, exclude=["id"])): Updated user data. + user (UserIn): Updated user data. Returns: User: The updated user details. From a26bd4faee31c9dd973d21b9f080c13be120142a Mon Sep 17 00:00:00 2001 From: rakesh-krishna-a-s Date: Wed, 21 Jun 2023 16:44:41 +0530 Subject: [PATCH 007/241] pydantic models --- superagi/controllers/agent_config.py | 2 +- superagi/controllers/user.py | 4 ++-- superagi/types/db.py | 30 ++++++++++++++++++---------- 3 files changed, 22 insertions(+), 14 deletions(-) diff --git a/superagi/controllers/agent_config.py b/superagi/controllers/agent_config.py index cf98c3e27..931cc82f8 100644 --- a/superagi/controllers/agent_config.py +++ b/superagi/controllers/agent_config.py @@ -62,7 +62,7 @@ def get_agent(agent_config_id: int, @router.put("/update", response_model=AgentConfigurationOut) -def update_agent(agent_config: AgentConfig, +def update_agent(agent_config: AgentConfigurationIn, Authorize: AuthJWT = Depends(check_auth)): """ Update a particular agent configuration value for the given agent_id and agent_config key. diff --git a/superagi/controllers/user.py b/superagi/controllers/user.py index 297b4bfb0..9562e2a2b 100644 --- a/superagi/controllers/user.py +++ b/superagi/controllers/user.py @@ -9,7 +9,7 @@ from pydantic_sqlalchemy import sqlalchemy_to_pydantic from superagi.helper.auth import check_auth from superagi.lib.logger import logger -from superagi.types.db import UserIn, UserOut +from superagi.types.db import UserBase, UserIn, UserOut router = APIRouter() @@ -71,7 +71,7 @@ def get_user(user_id: int, @router.put("/update/{user_id}", response_model=UserOut) def update_user(user_id: int, - user: UserIn, + user: UserBase, Authorize: AuthJWT = Depends(check_auth)): """ Update a particular user. diff --git a/superagi/types/db.py b/superagi/types/db.py index 89d4d757d..295f85a84 100644 --- a/superagi/types/db.py +++ b/superagi/types/db.py @@ -1,4 +1,5 @@ from datetime import datetime +from typing import Optional from pydantic.main import BaseModel @@ -58,7 +59,7 @@ class AgentExecutionOut(DBModel): num_of_calls: int num_of_tokens: int current_step_id: int - permission_id: int + permission_id: Optional[int] class Config: orm_mode = True @@ -77,13 +78,14 @@ class AgentExecutionIn(BaseModel): class Config: orm_mode = True + class AgentExecutionFeedOut(DBModel): id: int agent_execution_id: int agent_id: int feed: str role: str - extra_info: str + extra_info: Optional[str] class Config: orm_mode = True @@ -203,7 +205,6 @@ class Config: class AgentWorkflowStepIn(BaseModel): - id: int agent_workflow_id: int unique_id: str prompt: str @@ -217,6 +218,7 @@ class AgentWorkflowStepIn(BaseModel): class Config: orm_mode = True + class BudgetOut(DBModel): id: int budget: float @@ -225,6 +227,7 @@ class BudgetOut(DBModel): class Config: orm_mode = True + class BudgetIn(BaseModel): budget: float cycle: str @@ -233,7 +236,6 @@ class Config: orm_mode = True - class ConfigurationOut(DBModel): id: int organisation_id: int @@ -245,7 +247,6 @@ class Config: class ConfigurationIn(BaseModel): - id: int organisation_id: int key: str value: str @@ -253,6 +254,7 @@ class ConfigurationIn(BaseModel): class Config: orm_mode = True + class OrganisationOut(DBModel): id: int name: str @@ -269,6 +271,7 @@ class OrganisationIn(BaseModel): class Config: orm_mode = True + class ProjectOut(DBModel): id: int name: str @@ -287,6 +290,7 @@ class ProjectIn(BaseModel): class Config: orm_mode = True + class ResourceOut(DBModel): id: int name: str @@ -300,6 +304,7 @@ class ResourceOut(DBModel): class Config: orm_mode = True + class ResourceIn(BaseModel): name: str storage_type: str @@ -355,21 +360,24 @@ class Config: orm_mode = True -class UserOut(DBModel): - id: int +class UserBase(BaseModel): name: str email: str password: str + + class Config: + orm_mode = True + + +class UserOut(UserBase, DBModel): + id: int organisation_id: int class Config: orm_mode = True -class UserIn(BaseModel): - name: str - email: str - password: str +class UserIn(UserBase): organisation_id: int class Config: From c48bbeee3bfc4002d5276ce015a3527521d3093b Mon Sep 17 00:00:00 2001 From: rakesh-krishna-a-s Date: Wed, 21 Jun 2023 18:36:57 +0530 Subject: [PATCH 008/241] remove pydantic_sqlalchemy import --- requirements.txt | 4 +--- superagi/controllers/agent_execution.py | 1 - superagi/controllers/agent_execution_feed.py | 2 +- superagi/controllers/agent_execution_permission.py | 2 +- superagi/controllers/budget.py | 2 +- superagi/controllers/config.py | 2 +- superagi/controllers/organisation.py | 2 +- superagi/controllers/resources.py | 2 +- superagi/controllers/user.py | 2 +- 9 files changed, 8 insertions(+), 11 deletions(-) diff --git a/requirements.txt b/requirements.txt index 71d7a6dd8..5c5625d40 100644 --- a/requirements.txt +++ b/requirements.txt @@ -57,6 +57,7 @@ json5==0.9.14 jsonmerge==1.9.0 jsonschema==4.17.3 kombu==5.2.4 +llama-index==0.6.29 log-symbols==0.0.14 loguru==0.7.0 lxml==4.9.2 @@ -82,7 +83,6 @@ prompt-toolkit==3.0.38 psycopg2==2.9.6 pycparser==2.21 pydantic==1.10.8 -pydantic-sqlalchemy==0.0.9 PyJWT==1.7.1 PyPDF2==3.0.1 pyquery==2.0.0 @@ -106,7 +106,6 @@ six==1.16.0 sniffio==1.3.0 soupsieve==2.4.1 spinners==0.0.24 -SQLAlchemy==1.4.48 starlette==0.27.0 tenacity==8.2.2 termcolor==2.3.0 @@ -116,7 +115,6 @@ tldextract==3.4.4 tqdm==4.65.0 tweepy==4.14.0 typing-inspect==0.8.0 -typing_extensions==4.6.2 ujson==5.7.0 urllib3==1.26.16 uvicorn==0.22.0 diff --git a/superagi/controllers/agent_execution.py b/superagi/controllers/agent_execution.py index 3da24e3a7..9df960709 100644 --- a/superagi/controllers/agent_execution.py +++ b/superagi/controllers/agent_execution.py @@ -8,7 +8,6 @@ from superagi.models.agent_execution import AgentExecution from superagi.models.agent import Agent from fastapi import APIRouter -from pydantic_sqlalchemy import sqlalchemy_to_pydantic from sqlalchemy import desc from superagi.helper.auth import check_auth from superagi.types.db import AgentExecutionOut, AgentExecutionIn diff --git a/superagi/controllers/agent_execution_feed.py b/superagi/controllers/agent_execution_feed.py index 5bd248522..afa11ecfc 100644 --- a/superagi/controllers/agent_execution_feed.py +++ b/superagi/controllers/agent_execution_feed.py @@ -2,7 +2,7 @@ from fastapi import HTTPException, Depends from fastapi_jwt_auth import AuthJWT from fastapi_sqlalchemy import db -from pydantic_sqlalchemy import sqlalchemy_to_pydantic + from sqlalchemy.sql import asc from superagi.agent.task_queue import TaskQueue diff --git a/superagi/controllers/agent_execution_permission.py b/superagi/controllers/agent_execution_permission.py index bbc7eb1d7..dbf959425 100644 --- a/superagi/controllers/agent_execution_permission.py +++ b/superagi/controllers/agent_execution_permission.py @@ -8,7 +8,7 @@ from superagi.models.agent_execution_permission import AgentExecutionPermission from superagi.worker import execute_agent from fastapi import APIRouter -from pydantic_sqlalchemy import sqlalchemy_to_pydantic + from superagi.helper.auth import check_auth from superagi.types.db import AgentExecutionPermissionOut, AgentExecutionPermissionIn diff --git a/superagi/controllers/budget.py b/superagi/controllers/budget.py index efbaad1f9..95beb68cb 100644 --- a/superagi/controllers/budget.py +++ b/superagi/controllers/budget.py @@ -2,7 +2,7 @@ from fastapi import HTTPException, Depends from fastapi_jwt_auth import AuthJWT from fastapi_sqlalchemy import db -from pydantic_sqlalchemy import sqlalchemy_to_pydantic + from superagi.helper.auth import check_auth from superagi.models.budget import Budget diff --git a/superagi/controllers/config.py b/superagi/controllers/config.py index 2130d569c..30fecb8d9 100644 --- a/superagi/controllers/config.py +++ b/superagi/controllers/config.py @@ -1,5 +1,5 @@ from fastapi import APIRouter -from pydantic_sqlalchemy import sqlalchemy_to_pydantic + from superagi.models.configuration import Configuration from superagi.models.organisation import Organisation from fastapi_sqlalchemy import db diff --git a/superagi/controllers/organisation.py b/superagi/controllers/organisation.py index 07013d086..ebbe69c09 100644 --- a/superagi/controllers/organisation.py +++ b/superagi/controllers/organisation.py @@ -2,7 +2,7 @@ from fastapi import HTTPException, Depends from fastapi_jwt_auth import AuthJWT from fastapi_sqlalchemy import db -from pydantic_sqlalchemy import sqlalchemy_to_pydantic + from superagi.helper.auth import check_auth from superagi.models.organisation import Organisation diff --git a/superagi/controllers/resources.py b/superagi/controllers/resources.py index 56f65b316..b39b3202c 100644 --- a/superagi/controllers/resources.py +++ b/superagi/controllers/resources.py @@ -4,7 +4,7 @@ from fastapi_jwt_auth.exceptions import AuthJWTException from superagi.models.budget import Budget from fastapi import APIRouter, UploadFile -from pydantic_sqlalchemy import sqlalchemy_to_pydantic + import os from fastapi import FastAPI, File, Form, UploadFile from typing import Annotated diff --git a/superagi/controllers/user.py b/superagi/controllers/user.py index 9562e2a2b..71c2ba39e 100644 --- a/superagi/controllers/user.py +++ b/superagi/controllers/user.py @@ -6,7 +6,7 @@ from superagi.models.project import Project from superagi.models.user import User from fastapi import APIRouter -from pydantic_sqlalchemy import sqlalchemy_to_pydantic + from superagi.helper.auth import check_auth from superagi.lib.logger import logger from superagi.types.db import UserBase, UserIn, UserOut From 21a21ece8a46919ec9570fce10184519da17fc2d Mon Sep 17 00:00:00 2001 From: rakesh-krishna-a-s Date: Wed, 21 Jun 2023 23:27:33 +0530 Subject: [PATCH 009/241] add optional --- entrypoint.sh | 2 +- superagi/types/db.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/entrypoint.sh b/entrypoint.sh index d3c8581f5..55b09c9cc 100755 --- a/entrypoint.sh +++ b/entrypoint.sh @@ -2,4 +2,4 @@ alembic upgrade head # Start the app -exec uvicorn main:app --host 0.0.0.0 --port 8001 --reload +exec uvicorn main:app --host 0.0.0.0 --port 8001 --reload --root-path /api diff --git a/superagi/types/db.py b/superagi/types/db.py index 295f85a84..5f7661358 100644 --- a/superagi/types/db.py +++ b/superagi/types/db.py @@ -378,7 +378,7 @@ class Config: class UserIn(UserBase): - organisation_id: int + organisation_id: Optional[int] class Config: orm_mode = True \ No newline at end of file From fa12dc1495aea2a176512bf0fbcbd798eb51dbe8 Mon Sep 17 00:00:00 2001 From: Leon Date: Wed, 21 Jun 2023 13:49:43 -0700 Subject: [PATCH 010/241] 2nd untested implementation --- superagi/vector_store/lancedb.py | 49 +++++++++++++++++++------ superagi/vector_store/vector_factory.py | 10 +---- 2 files changed, 39 insertions(+), 20 deletions(-) diff --git a/superagi/vector_store/lancedb.py b/superagi/vector_store/lancedb.py index 6d8cc9570..c2489868f 100644 --- a/superagi/vector_store/lancedb.py +++ b/superagi/vector_store/lancedb.py @@ -12,53 +12,68 @@ class LanceDB(VectorStore): LanceDB vector store. Attributes: - tbl : The LanceDB table. + db : The LanceDB connected database. embedding_model : The embedding model. text_field : The text field is the name of the field where the corresponding text for an embedding is stored. table_name : Name for the table in the vector database """ def __init__( self, - tbl: Any, + db: Any, embedding_model: BaseEmbedding, text_field: str, - table_name : str, ): try: import lancedb except ImportError: raise ValueError("Please install LanceDB to use this vector store.") - self.tbl = tbl + self.db = db self.embedding_model = embedding_model self.text_field = text_field - self.table_name = table_name def add_texts( self, texts: Iterable[str], metadatas: Optional['list[dict]'] = None, + ids: Optional['list[str]'] = None, + table_name: Optional[str] = None, ) -> 'list[str]': """ Add texts to the vector store. Args: texts : The texts to add. - fields: Additional fields to add. + metadatas: The metadatas to add. + ids : The ids to add. + table_name : The table to add. Returns: The list of ids vectors stored in LanceDB. """ + vectors = [] ids = ids or [str(uuid.uuid4()) for _ in texts] if len(ids) < len(texts): raise ValueError("Number of ids must match number of texts.") for text, id in zip(texts, ids): + vector = {} metadata = metadatas.pop(0) if metadatas else {} metadata[self.text_field] = text - vectors.append((id, self.embedding_model.get_embedding(text), metadata)) - self.tbl.add(vectors) + vector["id"] = id + vector["vector"] = self.embedding_model.get_embedding(text) + for key, value in metadata.items(): + vector[key] = value + + vectors.append(vector) + + try: + tbl = self.db.create_table(table_name, data=vectors) + except: + tbl = self.db.open_table(table_name) + tbl.add(vectors) + return ids def get_matching_text(self, query: str, top_k: int = 5, **kwargs: Any) -> List[Document]: @@ -75,16 +90,26 @@ def get_matching_text(self, query: str, top_k: int = 5, **kwargs: Any) -> List[D """ namespace = kwargs.get("namespace", self.namespace) + try: + tbl = self.db.open_table(namespace) + except: + raise ValueError("Table name was not found in LanceDB") + embed_text = self.embedding_model.get_embedding(query) - res = self.tbl.search(embed_text).limit(top_k).to_df() + res = tbl.search(embed_text).limit(top_k).to_df() documents = [] - for doc in res['vector']: + for i in range(len(res)): + meta = {} + for col in res: + if col != 'vector' and col != 'id': + meta[col] = res[col][i] + documents.append( Document( - text_content=doc.metadata[self.text_field], - metadata=doc.metadata, + text_content=res[self.text_field][i], + metadata=meta, ) ) diff --git a/superagi/vector_store/vector_factory.py b/superagi/vector_store/vector_factory.py index b16b8f487..5ca105c25 100644 --- a/superagi/vector_store/vector_factory.py +++ b/superagi/vector_store/vector_factory.py @@ -50,16 +50,10 @@ def get_vector_storage(cls, vector_store, index_name, embedding_model): if vector_store == "LanceDB": try: # connect lancedb to local directory /lancedb/ - uri = "/lancedb" + uri = "/lancedb/" + index_name db = lancedb.connect(uri) - # create table if does not exist - try: - tbl = db.createTable(index_name) - except: - tbl = db.openTable(index_name) - - return LanceDB(tbl, embedding_model, 'text', index_name) + return LanceDB(db, embedding_model, 'text') except: raise ValueError("VectorStore setup for LanceDB failed") From aa3916e6938f499311b68856483cc6b9b4ac5191 Mon Sep 17 00:00:00 2001 From: Leon Date: Wed, 21 Jun 2023 15:50:15 -0700 Subject: [PATCH 011/241] Update lancedb.py --- superagi/vector_store/lancedb.py | 1 - 1 file changed, 1 deletion(-) diff --git a/superagi/vector_store/lancedb.py b/superagi/vector_store/lancedb.py index c2489868f..14ac7193a 100644 --- a/superagi/vector_store/lancedb.py +++ b/superagi/vector_store/lancedb.py @@ -15,7 +15,6 @@ class LanceDB(VectorStore): db : The LanceDB connected database. embedding_model : The embedding model. text_field : The text field is the name of the field where the corresponding text for an embedding is stored. - table_name : Name for the table in the vector database """ def __init__( self, From d093f0e80d983a071e061db2424953c323c7e441 Mon Sep 17 00:00:00 2001 From: rakesh-krishna-a-s Date: Thu, 22 Jun 2023 10:48:35 +0530 Subject: [PATCH 012/241] add sqlalchemy requirements.txt --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 5c5625d40..6bc912736 100644 --- a/requirements.txt +++ b/requirements.txt @@ -107,6 +107,7 @@ sniffio==1.3.0 soupsieve==2.4.1 spinners==0.0.24 starlette==0.27.0 +SQLAlchemy==2.0.16 tenacity==8.2.2 termcolor==2.3.0 tiktoken==0.4.0 From d946a9cad1697edd4a03ec72be84843861cffeb9 Mon Sep 17 00:00:00 2001 From: rakesh-krishna-a-s Date: Thu, 22 Jun 2023 11:01:11 +0530 Subject: [PATCH 013/241] declarative_base import change --- superagi/models/base_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/superagi/models/base_model.py b/superagi/models/base_model.py index 8b872b11b..180110237 100644 --- a/superagi/models/base_model.py +++ b/superagi/models/base_model.py @@ -1,7 +1,7 @@ import json from sqlalchemy import Column, DateTime, INTEGER -from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import declarative_base from datetime import datetime Base = declarative_base() From 3ec6e5f37ffa8093a316c6689743bb1b33bda051 Mon Sep 17 00:00:00 2001 From: rakesh-krishna-a-s Date: Thu, 22 Jun 2023 13:43:55 +0530 Subject: [PATCH 014/241] temp save --- superagi/controllers/resources.py | 6 +++-- superagi/helper/file_to_index_parser.py | 33 +++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 2 deletions(-) create mode 100644 superagi/helper/file_to_index_parser.py diff --git a/superagi/controllers/resources.py b/superagi/controllers/resources.py index b39b3202c..cf933797c 100644 --- a/superagi/controllers/resources.py +++ b/superagi/controllers/resources.py @@ -4,7 +4,7 @@ from fastapi_jwt_auth.exceptions import AuthJWTException from superagi.models.budget import Budget from fastapi import APIRouter, UploadFile - +from superagi.helper.file_to_index_parser import create_document_index import os from fastapi import FastAPI, File, Form, UploadFile from typing import Annotated @@ -90,6 +90,8 @@ async def upload(agent_id: int, file: UploadFile = File(...), name=Form(...), si db.session.add(resource) db.session.commit() db.session.flush() + create_document_index(file_path) + logger.info(resource) return resource @@ -155,4 +157,4 @@ def download_file_by_id(resource_id: int, headers={ "Content-Disposition": f"attachment; filename={file_name}" } - ) + ) \ No newline at end of file diff --git a/superagi/helper/file_to_index_parser.py b/superagi/helper/file_to_index_parser.py new file mode 100644 index 000000000..a7ae1424f --- /dev/null +++ b/superagi/helper/file_to_index_parser.py @@ -0,0 +1,33 @@ +from llama_index import SimpleDirectoryReader +import os + +from superagi.jobs.agent_executor import AgentExecutor +from superagi.models.agent_execution import AgentExecution +from superagi.vector_store.embedding.openai import OpenAiEmbedding +from superagi.config.config import get_config + +def create_document_index(file_path: str, agent_id: int, session): + """ + Creates a document index from a given directory. + """ + documents = SimpleDirectoryReader(input_files=[file_path]).load_data() + + return documents + + +def llama_vector_store_factory(vector_store, index_name, embedding_model, session, agent_id): + """ + Creates a llama vector store. + """ + model_api_key = get_config("OPENAI_API_KEY") + # agent_execution = AgentExecution(agent_id=agent_id) + # agent_executor = AgentExecutor() + # model_api_key = agent_executor.get_model_api_key_from_execution(agent_execution, session) + from superagi.vector_store.vector_factory import VectorFactory + vector_store = VectorFactory.get_vector_storage("PineCone", "super-agent-index1", + OpenAiEmbedding(model_api_key)) + if vector_store is None: + raise ValueError("Vector store not found") + if vector_store == "PineCone": + from llama_index.vector_stores import PineconeVectorStore + return PineconeVectorStore(index_name, embedding_model, 'text') From f3601a808261c148b51f6fced5d7cd9b78d00e37 Mon Sep 17 00:00:00 2001 From: Leon Date: Thu, 22 Jun 2023 11:06:21 -0700 Subject: [PATCH 015/241] added open_table workaround --- gui/pages/Content/Agents/AgentCreate.js | 2 +- superagi/vector_store/lancedb.py | 8 ++++++-- superagi/vector_store/vector_factory.py | 2 +- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/gui/pages/Content/Agents/AgentCreate.js b/gui/pages/Content/Agents/AgentCreate.js index a1deea743..31230ef65 100644 --- a/gui/pages/Content/Agents/AgentCreate.js +++ b/gui/pages/Content/Agents/AgentCreate.js @@ -58,7 +58,7 @@ export default function AgentCreate({sendAgentData, selectedProjectId, fetchAgen const [rollingDropdown, setRollingDropdown] = useState(false); const databases = ["Pinecone", "LanceDB"] - const [database, setDatabase] = useState(databases[1]); + const [database, setDatabase] = useState(databases[0]); const databaseRef = useRef(null); const [databaseDropdown, setDatabaseDropdown] = useState(false); diff --git a/superagi/vector_store/lancedb.py b/superagi/vector_store/lancedb.py index 14ac7193a..af2ce2f12 100644 --- a/superagi/vector_store/lancedb.py +++ b/superagi/vector_store/lancedb.py @@ -89,10 +89,14 @@ def get_matching_text(self, query: str, top_k: int = 5, **kwargs: Any) -> List[D """ namespace = kwargs.get("namespace", self.namespace) + for table in self.db.table_names(): + if table == namespace: + tbl = self.db.open_table(table) + try: - tbl = self.db.open_table(namespace) + tbl except: - raise ValueError("Table name was not found in LanceDB") + raise ValueError(namespace + " Table not found in LanceDB.") embed_text = self.embedding_model.get_embedding(query) res = tbl.search(embed_text).limit(top_k).to_df() diff --git a/superagi/vector_store/vector_factory.py b/superagi/vector_store/vector_factory.py index 5ca105c25..dbfdfe2e8 100644 --- a/superagi/vector_store/vector_factory.py +++ b/superagi/vector_store/vector_factory.py @@ -49,7 +49,7 @@ def get_vector_storage(cls, vector_store, index_name, embedding_model): if vector_store == "LanceDB": try: - # connect lancedb to local directory /lancedb/ + # connect lancedb to local directory /lancedb/index_name uri = "/lancedb/" + index_name db = lancedb.connect(uri) From 536c51e30ae81a23983a66925cb5e0dd661e86be Mon Sep 17 00:00:00 2001 From: Leon Date: Thu, 22 Jun 2023 13:18:41 -0700 Subject: [PATCH 016/241] Update requirements.txt --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 71d7a6dd8..592638549 100644 --- a/requirements.txt +++ b/requirements.txt @@ -57,6 +57,7 @@ json5==0.9.14 jsonmerge==1.9.0 jsonschema==4.17.3 kombu==5.2.4 +lancedb==0.1.8 log-symbols==0.0.14 loguru==0.7.0 lxml==4.9.2 From 48c13ebb74742b599b37561e0e0b842cc005a2e3 Mon Sep 17 00:00:00 2001 From: Leon Yee <43097991+unkn-wn@users.noreply.github.com> Date: Mon, 26 Jun 2023 12:30:28 -0700 Subject: [PATCH 017/241] from branch to personal main (#2) * untested lance implrementation * Update lancedb.py * 2nd untested implementation * Update lancedb.py * added open_table workaround * Update requirements.txt --- gui/pages/Content/Agents/AgentCreate.js | 4 +- requirements.txt | 1 + superagi/jobs/agent_executor.py | 7 +- superagi/vector_store/lancedb.py | 119 ++++++++++++++++++++++++ superagi/vector_store/vector_factory.py | 16 +++- 5 files changed, 139 insertions(+), 8 deletions(-) create mode 100644 superagi/vector_store/lancedb.py diff --git a/gui/pages/Content/Agents/AgentCreate.js b/gui/pages/Content/Agents/AgentCreate.js index 6378e83a3..0e7fd63c3 100644 --- a/gui/pages/Content/Agents/AgentCreate.js +++ b/gui/pages/Content/Agents/AgentCreate.js @@ -59,7 +59,7 @@ export default function AgentCreate({sendAgentData, selectedProjectId, fetchAgen const rollingRef = useRef(null); const [rollingDropdown, setRollingDropdown] = useState(false); - const databases = ["Pinecone"] + const databases = ["Pinecone", "LanceDB"] const [database, setDatabase] = useState(databases[0]); const databaseRef = useRef(null); const [databaseDropdown, setDatabaseDropdown] = useState(false); @@ -364,7 +364,7 @@ export default function AgentCreate({sendAgentData, selectedProjectId, fetchAgen setCreateClickable(false); - // if permission has word restricted change the permission to + // if permission has word restricted change the permission to let permission_type = permission; if (permission.includes("RESTRICTED")) { permission_type = "RESTRICTED"; diff --git a/requirements.txt b/requirements.txt index 624a2d107..2ca15a9ec 100644 --- a/requirements.txt +++ b/requirements.txt @@ -63,6 +63,7 @@ json5==0.9.14 jsonmerge==1.9.0 jsonschema==4.17.3 kombu==5.2.4 +lancedb==0.1.8 log-symbols==0.0.14 loguru==0.7.0 lxml==4.9.2 diff --git a/superagi/jobs/agent_executor.py b/superagi/jobs/agent_executor.py index 53ecdc33a..b5d8cf60b 100644 --- a/superagi/jobs/agent_executor.py +++ b/superagi/jobs/agent_executor.py @@ -172,11 +172,11 @@ def execute_next_action(self, agent_execution_id): if parsed_config["LTM_DB"] == "Pinecone": memory = VectorFactory.get_vector_storage("PineCone", "super-agent-index1", OpenAiEmbedding(model_api_key)) - else: - memory = VectorFactory.get_vector_storage("PineCone", "super-agent-index1", + elif parsed_config["LTM_DB"] == "LanceDB": + memory = VectorFactory.get_vector_storage("LanceDB", "super-agent-index1", OpenAiEmbedding(model_api_key)) except: - logger.info("Unable to setup the pinecone connection...") + logger.info("Unable to setup the connection...") memory = None user_tools = session.query(Tool).filter(Tool.id.in_(parsed_config["tools"])).all() @@ -187,7 +187,6 @@ def execute_next_action(self, agent_execution_id): tools = self.set_default_params_tools(tools, parsed_config, agent_execution.agent_id, model_api_key=model_api_key, session=session) - spawned_agent = SuperAgi(ai_name=parsed_config["name"], ai_role=parsed_config["description"], llm=OpenAi(model=parsed_config["model"], api_key=model_api_key), tools=tools, memory=memory, diff --git a/superagi/vector_store/lancedb.py b/superagi/vector_store/lancedb.py new file mode 100644 index 000000000..af2ce2f12 --- /dev/null +++ b/superagi/vector_store/lancedb.py @@ -0,0 +1,119 @@ +import uuid + +from superagi.vector_store.document import Document +from superagi.vector_store.base import VectorStore +from typing import Any, Callable, Optional, Iterable, List + +from superagi.vector_store.embedding.openai import BaseEmbedding + + +class LanceDB(VectorStore): + """ + LanceDB vector store. + + Attributes: + db : The LanceDB connected database. + embedding_model : The embedding model. + text_field : The text field is the name of the field where the corresponding text for an embedding is stored. + """ + def __init__( + self, + db: Any, + embedding_model: BaseEmbedding, + text_field: str, + ): + try: + import lancedb + except ImportError: + raise ValueError("Please install LanceDB to use this vector store.") + + self.db = db + self.embedding_model = embedding_model + self.text_field = text_field + + def add_texts( + self, + texts: Iterable[str], + metadatas: Optional['list[dict]'] = None, + ids: Optional['list[str]'] = None, + table_name: Optional[str] = None, + ) -> 'list[str]': + """ + Add texts to the vector store. + + Args: + texts : The texts to add. + metadatas: The metadatas to add. + ids : The ids to add. + table_name : The table to add. + Returns: + The list of ids vectors stored in LanceDB. + """ + + vectors = [] + ids = ids or [str(uuid.uuid4()) for _ in texts] + if len(ids) < len(texts): + raise ValueError("Number of ids must match number of texts.") + + for text, id in zip(texts, ids): + vector = {} + metadata = metadatas.pop(0) if metadatas else {} + metadata[self.text_field] = text + + vector["id"] = id + vector["vector"] = self.embedding_model.get_embedding(text) + for key, value in metadata.items(): + vector[key] = value + + vectors.append(vector) + + try: + tbl = self.db.create_table(table_name, data=vectors) + except: + tbl = self.db.open_table(table_name) + tbl.add(vectors) + + return ids + + def get_matching_text(self, query: str, top_k: int = 5, **kwargs: Any) -> List[Document]: + """ + Return docs most similar to query using specified search type. + + Args: + query : The query to search. + top_k : The top k to search. + **kwargs : The keyword arguments to search. + + Returns: + The list of documents most similar to the query + """ + namespace = kwargs.get("namespace", self.namespace) + + for table in self.db.table_names(): + if table == namespace: + tbl = self.db.open_table(table) + + try: + tbl + except: + raise ValueError(namespace + " Table not found in LanceDB.") + + embed_text = self.embedding_model.get_embedding(query) + res = tbl.search(embed_text).limit(top_k).to_df() + + documents = [] + + for i in range(len(res)): + meta = {} + for col in res: + if col != 'vector' and col != 'id': + meta[col] = res[col][i] + + documents.append( + Document( + text_content=res[self.text_field][i], + metadata=meta, + ) + ) + + return documents diff --git a/superagi/vector_store/vector_factory.py b/superagi/vector_store/vector_factory.py index 528885694..dbfdfe2e8 100644 --- a/superagi/vector_store/vector_factory.py +++ b/superagi/vector_store/vector_factory.py @@ -1,9 +1,11 @@ import os import pinecone +import lancedb from pinecone import UnauthorizedException from superagi.vector_store.pinecone import Pinecone +from superagi.vector_store.lancedb import LanceDB from superagi.vector_store import weaviate from superagi.config.config import get_config @@ -44,9 +46,19 @@ def get_vector_storage(cls, vector_store, index_name, embedding_model): return Pinecone(index, embedding_model, 'text') except UnauthorizedException: raise ValueError("PineCone API key not found") - + + if vector_store == "LanceDB": + try: + # connect lancedb to local directory /lancedb/index_name + uri = "/lancedb/" + index_name + db = lancedb.connect(uri) + + return LanceDB(db, embedding_model, 'text') + except: + raise ValueError("VectorStore setup for LanceDB failed") + if vector_store == "Weaviate": - + use_embedded = get_config("WEAVIATE_USE_EMBEDDED") url = get_config("WEAVIATE_URL") api_key = get_config("WEAVIATE_API_KEY") From 79fbd6383b168cf152694e7f9aaff3d0ba8d9bc1 Mon Sep 17 00:00:00 2001 From: Leon Yee <43097991+unkn-wn@users.noreply.github.com> Date: Mon, 26 Jun 2023 14:22:14 -0700 Subject: [PATCH 018/241] unit tests for lance (#3) * untested lance implrementation * Update lancedb.py * 2nd untested implementation * Update lancedb.py * added open_table workaround * Update requirements.txt * unit tests for lance --- superagi/vector_store/lancedb.py | 7 +- .../vector_store/test_lancedb.py | 106 ++++++++++++++++++ 2 files changed, 110 insertions(+), 3 deletions(-) create mode 100644 tests/integration_tests/vector_store/test_lancedb.py diff --git a/superagi/vector_store/lancedb.py b/superagi/vector_store/lancedb.py index af2ce2f12..7112d162c 100644 --- a/superagi/vector_store/lancedb.py +++ b/superagi/vector_store/lancedb.py @@ -87,7 +87,7 @@ def get_matching_text(self, query: str, top_k: int = 5, **kwargs: Any) -> List[D Returns: The list of documents most similar to the query """ - namespace = kwargs.get("namespace", self.namespace) + namespace = kwargs.get("namespace", "None") for table in self.db.table_names(): if table == namespace: @@ -96,17 +96,18 @@ def get_matching_text(self, query: str, top_k: int = 5, **kwargs: Any) -> List[D try: tbl except: - raise ValueError(namespace + " Table not found in LanceDB.") + raise ValueError(namespace + " Table not found in LanceDB. Please call this function with a valid table name.") embed_text = self.embedding_model.get_embedding(query) res = tbl.search(embed_text).limit(top_k).to_df() + print(res) documents = [] for i in range(len(res)): meta = {} for col in res: - if col != 'vector' and col != 'id': + if col != 'vector' and col != 'id' and col != 'score': meta[col] = res[col][i] documents.append( diff --git a/tests/integration_tests/vector_store/test_lancedb.py b/tests/integration_tests/vector_store/test_lancedb.py new file mode 100644 index 000000000..73719add5 --- /dev/null +++ b/tests/integration_tests/vector_store/test_lancedb.py @@ -0,0 +1,106 @@ +import numpy as np +import shutil +import pytest + +import lancedb +from superagi.vector_store.lancedb import LanceDB +from superagi.vector_store.document import Document +from superagi.vector_store.embedding.openai import OpenAiEmbedding + + +@pytest.fixture +def client(): + db = lancedb.connect(".test_lancedb") + yield db + shutil.rmtree(".test_lancedb") + + +@pytest.fixture +def mock_openai_embedding(monkeypatch): + monkeypatch.setattr( + OpenAiEmbedding, + "get_embedding", + lambda self, text: np.random.random(3).tolist(), + ) + + +@pytest.fixture +def store(client, mock_openai_embedding): + yield LanceDB(client, OpenAiEmbedding(api_key="test_api_key"), "text") + + +@pytest.fixture +def dataset(): + book_titles = [ + "The Great Gatsby", + "To Kill a Mockingbird", + "1984", + "Pride and Prejudice", + "The Catcher in the Rye", + ] + + documents = [] + for i, title in enumerate(book_titles): + author = f"Author {i}" + description = f"A summary of {title}" + text_content = f"This is the text for {title}" + metadata = {"author": author, "description": description} + document = Document(text_content=text_content, metadata=metadata) + + documents.append(document) + + return documents + + +@pytest.fixture +def dataset_no_metadata(): + book_titles = [ + "The Lord of the Rings", + "The Hobbit", + "The Chronicles of Narnia", + ] + + documents = [] + for title in book_titles: + text_content = f"This is the text for {title}" + document = Document(text_content=text_content) + documents.append(document) + + return documents + + +@pytest.mark.parametrize( + "data, results, table_name", + [ + ("dataset", (5, 2), "test_table"), + ("dataset_no_metadata", (3, 0), "test_table_no_metadata"), + ], +) +def test_add_texts(store, client, data, results, table_name, request): + dataset = request.getfixturevalue(data) + count, meta_count = results + ids = store.add_documents(dataset, table_name=table_name) + assert len(ids) == count + + tbl = client.open_table(table_name) + assert len(tbl.to_pandas().columns) - 3 == meta_count + # Subtracting 3 because of the id, vector, and text columns. The rest + # should be metadata columns. + + +@pytest.mark.parametrize( + "data, search_text, table_name, index", + [ + ("dataset", "The Great Gatsby", "test_table", 0), + ("dataset", "1984", "test_table2", 2), + ("dataset_no_metadata", "The Hobbit", "test_table_no_metadata", 1), + ], +) +def test_get_matching_text(store, data, search_text, table_name, index, request): + print("SEARCHING FOR " + search_text) + dataset = request.getfixturevalue(data) + store.add_documents(dataset, table_name=table_name) + results = store.get_matching_text(search_text, top_k=2, namespace=table_name) + print(results[0]) + assert len(results) == 2 + assert results[0] == dataset[index] From 538fbf0383c72d4b461266245629e129468c1e25 Mon Sep 17 00:00:00 2001 From: Leon Date: Mon, 26 Jun 2023 14:59:15 -0700 Subject: [PATCH 019/241] unit test fix --- superagi/vector_store/lancedb.py | 1 - .../vector_store/test_lancedb.py | 15 ++++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/superagi/vector_store/lancedb.py b/superagi/vector_store/lancedb.py index 7112d162c..cebc2329e 100644 --- a/superagi/vector_store/lancedb.py +++ b/superagi/vector_store/lancedb.py @@ -100,7 +100,6 @@ def get_matching_text(self, query: str, top_k: int = 5, **kwargs: Any) -> List[D embed_text = self.embedding_model.get_embedding(query) res = tbl.search(embed_text).limit(top_k).to_df() - print(res) documents = [] diff --git a/tests/integration_tests/vector_store/test_lancedb.py b/tests/integration_tests/vector_store/test_lancedb.py index 73719add5..93f236b3a 100644 --- a/tests/integration_tests/vector_store/test_lancedb.py +++ b/tests/integration_tests/vector_store/test_lancedb.py @@ -89,18 +89,19 @@ def test_add_texts(store, client, data, results, table_name, request): @pytest.mark.parametrize( - "data, search_text, table_name, index", + "data, search_text, table_name, meta_num", [ - ("dataset", "The Great Gatsby", "test_table", 0), - ("dataset", "1984", "test_table2", 2), + ("dataset", "The Great Gatsby", "test_table", 3), + ("dataset", "1984", "test_table2", 3), ("dataset_no_metadata", "The Hobbit", "test_table_no_metadata", 1), ], ) -def test_get_matching_text(store, data, search_text, table_name, index, request): - print("SEARCHING FOR " + search_text) +def test_get_matching_text(store, data, search_text, table_name, meta_num, request): dataset = request.getfixturevalue(data) store.add_documents(dataset, table_name=table_name) results = store.get_matching_text(search_text, top_k=2, namespace=table_name) - print(results[0]) + assert len(results) == 2 - assert results[0] == dataset[index] + assert len(results[0].metadata) == meta_num + # Metadata for dataset with metadata should be 3 (author, desc, text_content) + # Metadata for dataset without metadata should be 1 (text_content) \ No newline at end of file From 34e6c9af4b69203e7884cb9b4b360f9432310dc5 Mon Sep 17 00:00:00 2001 From: rakesh-krishna-a-s Date: Tue, 27 Jun 2023 10:27:25 +0530 Subject: [PATCH 020/241] resource added to vector store --- docker-compose.yaml | 2 + requirements.txt | 1 + superagi/controllers/resources.py | 32 +++++++++- superagi/helper/file_to_index_parser.py | 73 ++++++++++++++++++++--- superagi/tools/resource/query_resource.py | 62 +++++++++++++++++++ superagi/vector_store/embedding/openai.py | 4 +- superagi/vector_store/vector_factory.py | 5 +- 7 files changed, 164 insertions(+), 15 deletions(-) create mode 100644 superagi/tools/resource/query_resource.py diff --git a/docker-compose.yaml b/docker-compose.yaml index 5eaf7c295..975e98fdf 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -44,6 +44,8 @@ services: - superagi_postgres_data:/var/lib/postgresql/data/ networks: - super_network + ports: + - "5432:5432" proxy: image: nginx:stable-alpine diff --git a/requirements.txt b/requirements.txt index 6bc912736..55279a848 100644 --- a/requirements.txt +++ b/requirements.txt @@ -131,3 +131,4 @@ tiktoken==0.4.0 psycopg2==2.9.6 slack-sdk==3.21.3 pytest==7.3.2 +transformers==4.30.2 \ No newline at end of file diff --git a/superagi/controllers/resources.py b/superagi/controllers/resources.py index cf933797c..922637816 100644 --- a/superagi/controllers/resources.py +++ b/superagi/controllers/resources.py @@ -1,10 +1,13 @@ +import openai from fastapi_sqlalchemy import DBSessionMiddleware, db from fastapi import HTTPException, Depends, Request from fastapi_jwt_auth import AuthJWT from fastapi_jwt_auth.exceptions import AuthJWTException + from superagi.models.budget import Budget from fastapi import APIRouter, UploadFile -from superagi.helper.file_to_index_parser import create_document_index +from superagi.helper.file_to_index_parser import create_llama_document, llama_vector_store_factory, \ + save_file_to_vector_store import os from fastapi import FastAPI, File, Form, UploadFile from typing import Annotated @@ -21,6 +24,7 @@ import tempfile import requests from superagi.lib.logger import logger +from superagi.vector_store.vector_factory import VectorFactory router = APIRouter() @@ -87,11 +91,33 @@ async def upload(agent_id: int, file: UploadFile = File(...), name=Form(...), si resource = Resource(name=name, path=path, storage_type=storage_type, size=size, type=type, channel="INPUT", agent_id=agent.id) + + # from llama_index import VectorStoreIndex + # from superagi.vector_store.embedding.openai import OpenAiEmbedding + # from llama_index import StorageContext + # from llama_index import SimpleDirectoryReader + # model_api_key = get_config("OPENAI_API_KEY") + # documents = SimpleDirectoryReader(input_files=[file_path]).load_data() + # for docs in documents: + # if docs.extra_info is None: + # docs.extra_info = {"agent_id": agent_id} + # else: + # docs.extra_info["agent_id"] = agent_id + # os.environ["OPENAI_API_KEY"] = get_config("OPENAI_API_KEY") + # vector_store = llama_vector_store_factory('PineCone', 'super-agent-index1', OpenAiEmbedding(model_api_key)) + # storage_context = StorageContext.from_defaults(vector_store=vector_store) + # if vector_store is None: + # storage_context = StorageContext.from_defaults(persist_dir="workspace/index") + # openai.api_key = get_config("OPENAI_API_KEY") + # index = VectorStoreIndex.from_documents(documents, storage_context=storage_context) + # index.set_index_id(f'Agent {agent_id}') + # if vector_store is None: + # index.storage_context.persist() + db.session.add(resource) db.session.commit() db.session.flush() - create_document_index(file_path) - + save_file_to_vector_store(file_path, agent_id, resource.id) logger.info(resource) return resource diff --git a/superagi/helper/file_to_index_parser.py b/superagi/helper/file_to_index_parser.py index a7ae1424f..84012bbde 100644 --- a/superagi/helper/file_to_index_parser.py +++ b/superagi/helper/file_to_index_parser.py @@ -1,3 +1,6 @@ +import logging + +import llama_index from llama_index import SimpleDirectoryReader import os @@ -6,28 +9,80 @@ from superagi.vector_store.embedding.openai import OpenAiEmbedding from superagi.config.config import get_config -def create_document_index(file_path: str, agent_id: int, session): + +def create_llama_document(file_path: str): """ Creates a document index from a given directory. """ documents = SimpleDirectoryReader(input_files=[file_path]).load_data() + return documents -def llama_vector_store_factory(vector_store, index_name, embedding_model, session, agent_id): +def llama_vector_store_factory(vector_store_name,index_name,embedding_model): """ Creates a llama vector store. """ model_api_key = get_config("OPENAI_API_KEY") - # agent_execution = AgentExecution(agent_id=agent_id) - # agent_executor = AgentExecutor() - # model_api_key = agent_executor.get_model_api_key_from_execution(agent_execution, session) from superagi.vector_store.vector_factory import VectorFactory - vector_store = VectorFactory.get_vector_storage("PineCone", "super-agent-index1", - OpenAiEmbedding(model_api_key)) + vector_store = VectorFactory.get_vector_storage(vector_store_name, index_name, + embedding_model) if vector_store is None: raise ValueError("Vector store not found") - if vector_store == "PineCone": + if vector_store_name == "PineCone": from llama_index.vector_stores import PineconeVectorStore - return PineconeVectorStore(index_name, embedding_model, 'text') + return PineconeVectorStore(vector_store.index) + if vector_store_name == "Weaviate": + from llama_index.vector_stores import WeaviateVectorStore + return WeaviateVectorStore(vector_store.client) + + +def save_file_to_vector_store(file_path: str, agent_id: int, resource_id: str): + from llama_index import VectorStoreIndex + import openai + from superagi.vector_store.embedding.openai import OpenAiEmbedding + from llama_index import StorageContext + from llama_index import SimpleDirectoryReader + model_api_key = get_config("OPENAI_API_KEY") + documents = SimpleDirectoryReader(input_files=[file_path]).load_data() + for docs in documents: + if docs.extra_info is None: + docs.extra_info = {"agent_id": agent_id, "resource_id": resource_id} + else: + docs.extra_info["agent_id"] = agent_id + docs.extra_info["resource_id"] = resource_id + os.environ["OPENAI_API_KEY"] = get_config("OPENAI_API_KEY") + vector_store = None + storage_context = None + try: + vector_store = llama_vector_store_factory('PineCone', 'super-agent-index1', OpenAiEmbedding(model_api_key)) + storage_context = StorageContext.from_defaults(vector_store=vector_store) + except ValueError: + logging.error("Vector store not found") + vector_store = None + # vector_store = llama_vector_store_factory('Weaviate', 'super-agent-index1', OpenAiEmbedding(model_api_key)) + print(vector_store) + # storage_context = StorageContext.from_defaults(persist_dir="workspace/index") + openai.api_key = get_config("OPENAI_API_KEY") + index = VectorStoreIndex.from_documents(documents, storage_context=storage_context) + index.set_index_id(f'Agent {agent_id}') + if vector_store is None: + index.storage_context.persist(persist_dir="workspace/index") + + +def generate_summary_of_document(documents: list[llama_index.Document]): + from llama_index import LLMPredictor + from llama_index import ServiceContext + from langchain.chat_models import ChatOpenAI + from llama_index import ResponseSynthesizer + from llama_index import DocumentSummaryIndex + llm_predictor_chatgpt = LLMPredictor(llm=ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo")) + service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor_chatgpt, chunk_size=1024) + response_synthesizer = ResponseSynthesizer.from_args(response_mode="tree_summarize", use_async=True) + doc_summary_index = DocumentSummaryIndex.from_documents( + documents=documents, + service_context=service_context, + response_synthesizer=response_synthesizer + ) + return doc_summary_index.get_document_summary(documents[0].doc_id) diff --git a/superagi/tools/resource/query_resource.py b/superagi/tools/resource/query_resource.py new file mode 100644 index 000000000..e26709d42 --- /dev/null +++ b/superagi/tools/resource/query_resource.py @@ -0,0 +1,62 @@ +import os +from typing import Type + +from pydantic import BaseModel, Field + +from superagi.tools.base_tool import BaseTool +from superagi.config.config import get_config + + +class QueryResource(BaseModel): + """Input for QueryResource tool.""" + query: str = Field(..., description="Description of the information to be queried") + + +class QueryResourceTool(BaseTool): + """ + Read File tool + + Attributes: + name : The name. + description : The description. + args_schema : The args schema. + """ + name: str = "Query Resource" + args_schema: Type[BaseModel] = QueryResource + description: str = "Has the ability to get information from a resource" + + def _execute(self, file_name: str): + """ + Execute the read file tool. + + Args: + file_name : The name of the file to read. + + Returns: + The file content + """ + input_root_dir = get_config('RESOURCES_INPUT_ROOT_DIR') + output_root_dir = get_config('RESOURCES_OUTPUT_ROOT_DIR') + final_path = None + + if input_root_dir is not None: + input_root_dir = input_root_dir if input_root_dir.startswith("/") else os.getcwd() + "/" + input_root_dir + input_root_dir = input_root_dir if input_root_dir.endswith("/") else input_root_dir + "/" + final_path = input_root_dir + file_name + + if final_path is None or not os.path.exists(final_path): + if output_root_dir is not None: + output_root_dir = output_root_dir if output_root_dir.startswith( + "/") else os.getcwd() + "/" + output_root_dir + output_root_dir = output_root_dir if output_root_dir.endswith("/") else output_root_dir + "/" + final_path = output_root_dir + file_name + + if final_path is None or not os.path.exists(final_path): + raise FileNotFoundError(f"File '{file_name}' not found.") + + directory = os.path.dirname(final_path) + os.makedirs(directory, exist_ok=True) + + with open(final_path, 'r') as file: + file_content = file.read() + return file_content[:1500] diff --git a/superagi/vector_store/embedding/openai.py b/superagi/vector_store/embedding/openai.py index d6318a1de..7cef2a382 100644 --- a/superagi/vector_store/embedding/openai.py +++ b/superagi/vector_store/embedding/openai.py @@ -21,8 +21,9 @@ def __init__(self, api_key, model="text-embedding-ada-002"): async def get_embedding_async(self, text): try: # openai.api_key = get_config("OPENAI_API_KEY") - openai.api_key = self.api_key + # openai.api_key = self.api_key response = await openai.Embedding.create( + api_key=self.api_key, input=[text], engine=self.model ) @@ -34,6 +35,7 @@ def get_embedding(self, text): try: # openai.api_key = get_config("OPENAI_API_KEY") response = openai.Embedding.create( + api_key=self.api_key, input=[text], engine=self.model ) diff --git a/superagi/vector_store/vector_factory.py b/superagi/vector_store/vector_factory.py index 528885694..63e471552 100644 --- a/superagi/vector_store/vector_factory.py +++ b/superagi/vector_store/vector_factory.py @@ -33,6 +33,8 @@ def get_vector_storage(cls, vector_store, index_name, embedding_model): if index_name not in pinecone.list_indexes(): sample_embedding = embedding_model.get_embedding("sample") + if "error" in sample_embedding: + print("Error in embedding model", sample_embedding) # if does not exist, create index pinecone.create_index( @@ -58,5 +60,4 @@ def get_vector_storage(cls, vector_store, index_name, embedding_model): ) return weaviate.Weaviate(client, embedding_model, index_name, 'text') - else: - raise Exception("Vector store not supported") + raise ValueError(f"Vector store {vector_store} not supported") From 357c33b3326b577ddb74c75bdd2b036ae6fdcd2c Mon Sep 17 00:00:00 2001 From: rakesh-krishna-a-s Date: Tue, 27 Jun 2023 13:12:56 +0530 Subject: [PATCH 021/241] resource query tool working --- .../c02f3d759bf3_add_summary_to_resource.py | 28 ++++++++ requirements.txt | 3 +- superagi/controllers/resources.py | 9 ++- superagi/helper/file_to_index_parser.py | 6 +- superagi/models/resource.py | 1 + superagi/tools/resource/query_resource.py | 64 ++++++++----------- 6 files changed, 68 insertions(+), 43 deletions(-) create mode 100644 migrations/versions/c02f3d759bf3_add_summary_to_resource.py diff --git a/migrations/versions/c02f3d759bf3_add_summary_to_resource.py b/migrations/versions/c02f3d759bf3_add_summary_to_resource.py new file mode 100644 index 000000000..8ed9cd5f6 --- /dev/null +++ b/migrations/versions/c02f3d759bf3_add_summary_to_resource.py @@ -0,0 +1,28 @@ +"""add summary to resource + +Revision ID: c02f3d759bf3 +Revises: 1d54db311055 +Create Date: 2023-06-27 05:07:29.016704 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'c02f3d759bf3' +down_revision = '1d54db311055' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ## + op.add_column('resources', sa.Column('summary', sa.String(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('resources', 'summary') + # ### end Alembic commands ### diff --git a/requirements.txt b/requirements.txt index 55279a848..4bf78ce39 100644 --- a/requirements.txt +++ b/requirements.txt @@ -131,4 +131,5 @@ tiktoken==0.4.0 psycopg2==2.9.6 slack-sdk==3.21.3 pytest==7.3.2 -transformers==4.30.2 \ No newline at end of file +transformers==4.30.2 +pypdf==3.11.0 \ No newline at end of file diff --git a/superagi/controllers/resources.py b/superagi/controllers/resources.py index 922637816..722bb0d7c 100644 --- a/superagi/controllers/resources.py +++ b/superagi/controllers/resources.py @@ -7,7 +7,7 @@ from superagi.models.budget import Budget from fastapi import APIRouter, UploadFile from superagi.helper.file_to_index_parser import create_llama_document, llama_vector_store_factory, \ - save_file_to_vector_store + save_file_to_vector_store, generate_summary_of_document import os from fastapi import FastAPI, File, Form, UploadFile from typing import Annotated @@ -117,7 +117,12 @@ async def upload(agent_id: int, file: UploadFile = File(...), name=Form(...), si db.session.add(resource) db.session.commit() db.session.flush() - save_file_to_vector_store(file_path, agent_id, resource.id) + save_file_to_vector_store(file_path, str(agent_id), resource.id) + documents = create_llama_document(file_path) + summary = generate_summary_of_document(documents) + resource.summary = summary + print(summary) + db.session.commit() logger.info(resource) return resource diff --git a/superagi/helper/file_to_index_parser.py b/superagi/helper/file_to_index_parser.py index 84012bbde..bfcd4539c 100644 --- a/superagi/helper/file_to_index_parser.py +++ b/superagi/helper/file_to_index_parser.py @@ -38,7 +38,7 @@ def llama_vector_store_factory(vector_store_name,index_name,embedding_model): return WeaviateVectorStore(vector_store.client) -def save_file_to_vector_store(file_path: str, agent_id: int, resource_id: str): +def save_file_to_vector_store(file_path: str, agent_id: str, resource_id: str): from llama_index import VectorStoreIndex import openai from superagi.vector_store.embedding.openai import OpenAiEmbedding @@ -60,9 +60,9 @@ def save_file_to_vector_store(file_path: str, agent_id: int, resource_id: str): storage_context = StorageContext.from_defaults(vector_store=vector_store) except ValueError: logging.error("Vector store not found") - vector_store = None + # vector_store = None # vector_store = llama_vector_store_factory('Weaviate', 'super-agent-index1', OpenAiEmbedding(model_api_key)) - print(vector_store) + # print(vector_store) # storage_context = StorageContext.from_defaults(persist_dir="workspace/index") openai.api_key = get_config("OPENAI_API_KEY") index = VectorStoreIndex.from_documents(documents, storage_context=storage_context) diff --git a/superagi/models/resource.py b/superagi/models/resource.py index 53091fc1a..7908feaef 100644 --- a/superagi/models/resource.py +++ b/superagi/models/resource.py @@ -28,6 +28,7 @@ class Resource(DBBaseModel): type = Column(String) # application/pdf etc channel = Column(String) # INPUT,OUTPUT agent_id = Column(Integer) + summary = Column(String) def __repr__(self): """ diff --git a/superagi/tools/resource/query_resource.py b/superagi/tools/resource/query_resource.py index e26709d42..a2e818a73 100644 --- a/superagi/tools/resource/query_resource.py +++ b/superagi/tools/resource/query_resource.py @@ -1,10 +1,15 @@ import os from typing import Type +from langchain.chat_models import ChatOpenAI from pydantic import BaseModel, Field - from superagi.tools.base_tool import BaseTool from superagi.config.config import get_config +import openai +from llama_index import VectorStoreIndex, LLMPredictor, ServiceContext +from superagi.helper.file_to_index_parser import llama_vector_store_factory +from superagi.vector_store.embedding.openai import OpenAiEmbedding +from llama_index.vector_stores.types import ExactMatchFilter, MetadataFilters class QueryResource(BaseModel): @@ -24,39 +29,24 @@ class QueryResourceTool(BaseTool): name: str = "Query Resource" args_schema: Type[BaseModel] = QueryResource description: str = "Has the ability to get information from a resource" - - def _execute(self, file_name: str): - """ - Execute the read file tool. - - Args: - file_name : The name of the file to read. - - Returns: - The file content - """ - input_root_dir = get_config('RESOURCES_INPUT_ROOT_DIR') - output_root_dir = get_config('RESOURCES_OUTPUT_ROOT_DIR') - final_path = None - - if input_root_dir is not None: - input_root_dir = input_root_dir if input_root_dir.startswith("/") else os.getcwd() + "/" + input_root_dir - input_root_dir = input_root_dir if input_root_dir.endswith("/") else input_root_dir + "/" - final_path = input_root_dir + file_name - - if final_path is None or not os.path.exists(final_path): - if output_root_dir is not None: - output_root_dir = output_root_dir if output_root_dir.startswith( - "/") else os.getcwd() + "/" + output_root_dir - output_root_dir = output_root_dir if output_root_dir.endswith("/") else output_root_dir + "/" - final_path = output_root_dir + file_name - - if final_path is None or not os.path.exists(final_path): - raise FileNotFoundError(f"File '{file_name}' not found.") - - directory = os.path.dirname(final_path) - os.makedirs(directory, exist_ok=True) - - with open(final_path, 'r') as file: - file_content = file.read() - return file_content[:1500] + agent_id: int = None + + def _execute(self, query: str): + model_api_key = get_config("OPENAI_API_KEY") + openai.api_key = model_api_key + llm_predictor_chatgpt = LLMPredictor(llm=ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo",openai_api_key=get_config("OPENAI_API_KEY"))) + service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor_chatgpt) + vector_store = llama_vector_store_factory('PineCone', 'super-agent-index1', OpenAiEmbedding(model_api_key)) + index = VectorStoreIndex.from_vector_store(vector_store=vector_store,service_context=service_context) + query_engine = index.as_query_engine( + filters=MetadataFilters( + filters=[ + ExactMatchFilter( + key="agent_id", + value=str(self.agent_id) + ) + ] + ) + ) + response = query_engine.query(query) + return response \ No newline at end of file From 77b266a96f18b79dd176c5bda2c5797dfa9271cc Mon Sep 17 00:00:00 2001 From: Tarraann Date: Tue, 27 Jun 2023 17:25:27 +0530 Subject: [PATCH 022/241] Added oauth button for twitter --- gui/pages/Content/Toolkits/ToolkitWorkspace.js | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/gui/pages/Content/Toolkits/ToolkitWorkspace.js b/gui/pages/Content/Toolkits/ToolkitWorkspace.js index f55047177..f04bd0a72 100644 --- a/gui/pages/Content/Toolkits/ToolkitWorkspace.js +++ b/gui/pages/Content/Toolkits/ToolkitWorkspace.js @@ -25,6 +25,13 @@ export default function ToolkitWorkspace({toolkitDetails}){ window.location.href = `https://accounts.google.com/o/oauth2/v2/auth?client_id=${client_id}&redirect_uri=${redirect_uri}&access_type=offline&response_type=code&scope=${scope}`; } + function getTwitterToken(oauth_data){ + const oauth_token = oauth_data.oauth_token + const oauth_token_secret = oauth_data.oauth_token_secret + const authUrl = `https://api.twitter.com/oauth/authenticate?oauth_token=${oauth_token}` + window.location.href = authUrl + } + useEffect(() => { if(toolkitDetails !== null) { if (toolkitDetails.tools) { @@ -71,6 +78,16 @@ export default function ToolkitWorkspace({toolkitDetails}){ }); }; + const handleTwitterAuthClick = async () => { + authenticateTwitterCred(toolDetails.id) + .then((response) => { + getTwitterToken(response.data); + }) + .catch((error) => { + console.error('Error fetching data: ', error); + }); + }; + return (<>
@@ -116,6 +133,7 @@ export default function ToolkitWorkspace({toolkitDetails}){
{toolkitDetails.name === 'Google Calendar Toolkit' && } + {toolDetails.name === 'Twitter Toolkit' && }
From ececcc0354b5bf73775e8ac99adbb201b50de01b Mon Sep 17 00:00:00 2001 From: Tarraann Date: Tue, 27 Jun 2023 17:26:58 +0530 Subject: [PATCH 023/241] Added api --- gui/pages/api/DashboardService.js | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/gui/pages/api/DashboardService.js b/gui/pages/api/DashboardService.js index cc473ce8b..401e5a823 100644 --- a/gui/pages/api/DashboardService.js +++ b/gui/pages/api/DashboardService.js @@ -128,6 +128,10 @@ export const authenticateGoogleCred = (toolKitId) => { return api.get(`/google/get_google_creds/toolkit_id/${toolKitId}`); } +export const authenticateTwitterCred = (toolKitId) => { + return api.get(`/twitter/get_twitter_creds/toolkit_id/${toolKitId}`); +} + export const fetchToolTemplateList = () => { return api.get(`/toolkits/get/list?page=0`); } From d7b7150c2b1224060a0b73e4e1ea33d947449082 Mon Sep 17 00:00:00 2001 From: Tarraann Date: Tue, 27 Jun 2023 17:27:42 +0530 Subject: [PATCH 024/241] Get tokens --- main.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/main.py b/main.py index 222e9dd4c..fd262b8c7 100644 --- a/main.py +++ b/main.py @@ -411,6 +411,16 @@ def get_google_calendar_tool_configs(toolkit_id: int): "client_id": google_calendar_config.value } +@app.get("/twitter/get_twitter_creds/toolkit_id/{toolkit_id}") +def get_twitter_tool_configs(toolkit_id: int): + twitter_config_key = db.session.query(ToolConfig).filter(ToolConfig.toolkit_id == toolkit_id,ToolConfig.key == "TWITTER_API_KEY").first() + twitter_config_secret = db.session.query(ToolConfig).filter(ToolConfig.toolkit_id == toolkit_id,ToolConfig.key == "TWITTER_API_SECRET").first() + api_data = { + "api_key": twitter_config_key.value, + "api_secret": twitter_config_secret.value + } + response = TwitterTokens().get_request_token(api_data) + return response @app.get("/validate-open-ai-key/{open_ai_key}") async def root(open_ai_key: str, Authorize: AuthJWT = Depends()): From f8cb05eb36f228bd9fcbb458e568ac2a38552d6a Mon Sep 17 00:00:00 2001 From: Tarraann Date: Tue, 27 Jun 2023 17:28:27 +0530 Subject: [PATCH 025/241] Oauth Flow --- main.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/main.py b/main.py index fd262b8c7..0ee309d45 100644 --- a/main.py +++ b/main.py @@ -15,6 +15,10 @@ from sqlalchemy.orm import sessionmaker import superagi +import urllib.parse +import http.client as http_client +from superagi.helper.twitter_tokens import TwitterTokens +from datetime import datetime, timedelta from superagi.agent.agent_prompt_builder import AgentPromptBuilder from superagi.config.config import get_config from superagi.controllers.agent import router as agent_router @@ -320,6 +324,31 @@ async def google_auth_calendar(code: str = Query(...), Authorize: AuthJWT = Depe frontend_url = superagi.config.config.get_config("FRONTEND_URL", "http://localhost:3000") return RedirectResponse(frontend_url) +@app.get('/oauth-twitter') +async def twitter_oauth(oauth_token: str = Query(...),oauth_verifier: str = Query(...), Authorize: AuthJWT = Depends()): + token_uri = f'https://api.twitter.com/oauth/access_token?oauth_verifier={oauth_verifier}&oauth_token={oauth_token}' + conn = http_client.HTTPSConnection("api.twitter.com") + conn.request("POST", token_uri, "") + res = conn.getresponse() + response_data = res.read().decode('utf-8') + conn.close() + response = dict(urllib.parse.parse_qsl(response_data)) + root_dir = superagi.config.config.get_config('RESOURCES_OUTPUT_ROOT_DIR') + file_name = "twitter_credentials.pickle" + final_path = file_name + if root_dir is not None: + root_dir = root_dir if root_dir.startswith("/") else os.getcwd() + "/" + root_dir + root_dir = root_dir if root_dir.endswith("/") else root_dir + "/" + final_path = root_dir + file_name + else: + final_path = os.getcwd() + "/" + file_name + try: + with open(final_path, mode="wb") as file: + pickle.dump(response, file) + except Exception as err: + return f"Error: {err}" + frontend_url = superagi.config.config.get_config("FRONTEND_URL", "http://localhost:3000") + return RedirectResponse(frontend_url) @app.get('/github-login') def github_login(): From 1d56496e70f21c7479f0d21e85a85231678f21f0 Mon Sep 17 00:00:00 2001 From: Tarraann Date: Tue, 27 Jun 2023 17:30:29 +0530 Subject: [PATCH 026/241] Helper file for twitter tokens --- superagi/helper/twitter_tokens.py | 93 +++++++++++++++++++++++++++++++ 1 file changed, 93 insertions(+) create mode 100644 superagi/helper/twitter_tokens.py diff --git a/superagi/helper/twitter_tokens.py b/superagi/helper/twitter_tokens.py new file mode 100644 index 000000000..4fe9642be --- /dev/null +++ b/superagi/helper/twitter_tokens.py @@ -0,0 +1,93 @@ +import os +import pickle +import json +import hmac +import time +import random +import base64 +import hashlib +import urllib.parse +import http.client as http_client +from superagi.config.config import get_config +from sqlalchemy.orm import sessionmaker +from superagi.models.db import connect_db +from superagi.models.tool_config import ToolConfig +from superagi.resource_manager.manager import ResourceManager + +class TwitterTokens: + + def get_request_token(self,api_data): + api_key = api_data["api_key"] + api_secret_key = api_data["api_secret"] + http_method = 'POST' + base_url = 'https://api.twitter.com/oauth/request_token' + + params = { + 'oauth_callback': 'http://localhost:3000/api/oauth-twitter', + 'oauth_consumer_key': api_key, + 'oauth_nonce': self.gen_nonce(), + 'oauth_signature_method': 'HMAC-SHA1', + 'oauth_timestamp': int(time.time()), + 'oauth_version': '1.0' + } + + params_sorted = sorted(params.items()) + params_qs = '&'.join([f'{k}={self.percent_encode(str(v))}' for k, v in params_sorted]) + + base_string = f'{http_method}&{self.percent_encode(base_url)}&{self.percent_encode(params_qs)}' + + signing_key = f'{self.percent_encode(api_secret_key)}&' + signature = hmac.new(signing_key.encode(), base_string.encode(), hashlib.sha1) + params['oauth_signature'] = base64.b64encode(signature.digest()).decode() + + auth_header = 'OAuth ' + ', '.join([f'{k}="{self.percent_encode(str(v))}"' for k, v in params.items()]) + + headers = { + 'Content-Type': 'application/x-www-form-urlencoded', + 'Authorization': auth_header + } + conn = http_client.HTTPSConnection("api.twitter.com") + conn.request("POST", "/oauth/request_token", "", headers) + res = conn.getresponse() + response_data = res.read().decode('utf-8') + conn.close() + request_token_resp = dict(urllib.parse.parse_qsl(response_data)) + return request_token_resp + + def percent_encode(self, val): + return urllib.parse.quote(val, safe='') + + def gen_nonce(self): + nonce = ''.join([str(random.randint(0, 9)) for i in range(32)]) + return nonce + + def get_twitter_creds(self, toolkit_id): + file_name = "twitter_credentials.pickle" + root_dir = get_config('RESOURCES_OUTPUT_ROOT_DIR') + file_path = file_name + if root_dir is not None: + root_dir = root_dir if root_dir.startswith("/") else os.getcwd() + "/" + root_dir + root_dir = root_dir if root_dir.endswith("/") else root_dir + "/" + file_path = root_dir + file_name + else: + file_path = os.getcwd() + "/" + file_name + if os.path.exists(file_path): + with open(file_path,'rb') as file: + creds = pickle.load(file) + if isinstance(creds, str): + creds = json.loads(creds) + engine = connect_db() + Session = sessionmaker(bind=engine) + session = Session() + twitter_creds = session.query(ToolConfig).filter(ToolConfig.toolkit_id == toolkit_id).all() + api_key = "" + api_key_secret = "" + for credentials in twitter_creds: + credentials = credentials.__dict__ + if credentials["key"] == "TWITTER_API_KEY": + api_key = credentials["value"] + if credentials["key"] == "TWITTER_API_SECRET": + api_key_secret = credentials["value"] + creds["api_key"] = api_key + creds["api_key_secret"] = api_key_secret + return creds \ No newline at end of file From a91a77a7b616dff7d48afcaf190cd2602cf6dd04 Mon Sep 17 00:00:00 2001 From: Tarraann Date: Tue, 27 Jun 2023 17:33:00 +0530 Subject: [PATCH 027/241] Toolkit for twitter --- superagi/tools/twitter/twitter_toolkit.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 superagi/tools/twitter/twitter_toolkit.py diff --git a/superagi/tools/twitter/twitter_toolkit.py b/superagi/tools/twitter/twitter_toolkit.py new file mode 100644 index 000000000..2adb282a3 --- /dev/null +++ b/superagi/tools/twitter/twitter_toolkit.py @@ -0,0 +1,15 @@ +from abc import ABC +from superagi.tools.base_tool import BaseToolkit, BaseTool +from typing import Type, List +from superagi.tools.twitter.send_tweets import SendTweetsTool + + +class TwitterToolKit(BaseToolkit, ABC): + name: str = "Twitter Toolkit" + description: str = "Twitter Tool kit contains all tools related to Twitter" + + def get_tools(self) -> List[BaseTool]: + return [SendTweetsTool()] + + def get_env_keys(self) -> List[str]: + return ["TWITTER_API_KEY", "TWITTER_API_SECRET"] \ No newline at end of file From c789df45fac819079bbde8fccf972998531d8790 Mon Sep 17 00:00:00 2001 From: Tarraann Date: Tue, 27 Jun 2023 17:33:22 +0530 Subject: [PATCH 028/241] Tool to send tweets with multimediaenabled --- superagi/tools/twitter/send_tweets.py | 87 +++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) create mode 100644 superagi/tools/twitter/send_tweets.py diff --git a/superagi/tools/twitter/send_tweets.py b/superagi/tools/twitter/send_tweets.py new file mode 100644 index 000000000..e3d3ffec1 --- /dev/null +++ b/superagi/tools/twitter/send_tweets.py @@ -0,0 +1,87 @@ +import os +import json +import base64 +import requests +from typing import Any, Type +from pydantic import BaseModel, Field +from superagi.tools.base_tool import BaseTool +from superagi.helper.twitter_tokens import TwitterTokens +from requests_oauthlib import OAuth1 +from requests_oauthlib import OAuth1Session +from superagi.helper.resource_helper import ResourceHelper + +class SendTweetsInput(BaseModel): + tweet_text: str = Field(..., description="Tweet text to be posted from twitter handle, if no value is given keep the default value as 'None'") + is_media: bool = Field(..., description="'True' if there is any media to be posted with Tweet else 'False'.") + media_num: int = Field(..., description="Integer value for the number of media files to be uploaded, default value is 0") + media_files: list = Field(..., description="Name of the media files to be uploaded.") + +class SendTweetsTool(BaseTool): + name: str = "Send Tweets Tool" + args_schema: Type[BaseModel] = SendTweetsInput + description: str = "Send and Schedule Tweets for your Twitter Handle" + agent_id: int = None + + def _execute(self, is_media: bool, tweet_text: str = 'None', media_num: int = 0, media_files: list = []): + toolkit_id = self.toolkit_config.toolkit_id + creds = TwitterTokens().get_twitter_creds(toolkit_id) + params = {} + if is_media: + media_ids = self.get_media_ids(media_files, creds) + params["media"] = {"media_ids": media_ids} + if tweet_text is not None: + params["text"] = tweet_text + tweet_response = self.send_tweets(params, creds) + if tweet_response.status_code == 201: + return "Tweet posted successfully!!" + else: + return "Error posting tweet. (Status code: {})".format(tweet_response.status_code) + + + def get_media_ids(self, media_files, creds): + media_ids = [] + oauth = OAuth1(creds["api_key"], + client_secret=creds["api_key_secret"], + resource_owner_key=creds["oauth_token"], + resource_owner_secret=creds["oauth_token_secret"]) + for file in media_files: + file_path = self.get_file_path(file) + image_data = open(file_path, 'rb').read() + b64_image = base64.b64encode(image_data) + upload_endpoint = 'https://upload.twitter.com/1.1/media/upload.json' + headers = {'Authorization': 'application/octet-stream'} + response = requests.post(upload_endpoint, headers=headers, + data={'media_data': b64_image}, + auth=oauth) + ids = json.loads(response.text)['media_id'] + media_ids.append(str(ids)) + + return media_ids + + def get_file_path(self, file_name): + output_root_dir = ResourceHelper.get_root_output_dir() + + final_path = ResourceHelper.get_root_input_dir() + file_name + if "{agent_id}" in final_path: + final_path = final_path.replace("{agent_id}", str(self.agent_id)) + + if final_path is None or not os.path.exists(final_path): + if output_root_dir is not None: + final_path = ResourceHelper.get_root_output_dir() + file_name + if "{agent_id}" in final_path: + final_path = final_path.replace("{agent_id}", str(self.agent_id)) + + if final_path is None or not os.path.exists(final_path): + raise FileNotFoundError(f"File '{file_name}' not found.") + + return final_path + + def send_tweets(self, params, creds): + tweet_endpoint = "https://api.twitter.com/2/tweets" + oauth = OAuth1Session(creds["api_key"], + client_secret=creds["api_key_secret"], + resource_owner_key=creds["oauth_token"], + resource_owner_secret=creds["oauth_token_secret"]) + + response = oauth.post(tweet_endpoint,json=params) + return response \ No newline at end of file From b05d440f961b8b9422aecb035646ca7a9bac35c2 Mon Sep 17 00:00:00 2001 From: Tarraann Date: Tue, 27 Jun 2023 17:41:57 +0530 Subject: [PATCH 029/241] Minor changes --- gui/pages/Content/Toolkits/ToolkitWorkspace.js | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/gui/pages/Content/Toolkits/ToolkitWorkspace.js b/gui/pages/Content/Toolkits/ToolkitWorkspace.js index f04bd0a72..c42ba99a8 100644 --- a/gui/pages/Content/Toolkits/ToolkitWorkspace.js +++ b/gui/pages/Content/Toolkits/ToolkitWorkspace.js @@ -1,7 +1,7 @@ import React, {useEffect, useState} from 'react'; import Image from 'next/image'; import {ToastContainer, toast} from 'react-toastify'; -import {updateToolConfig, getToolConfig, authenticateGoogleCred} from "@/pages/api/DashboardService"; +import {updateToolConfig, getToolConfig, authenticateGoogleCred, authenticateTwitterCred} from "@/pages/api/DashboardService"; import styles from './Tool.module.css'; import {EventBus} from "@/utils/eventBus"; @@ -31,7 +31,7 @@ export default function ToolkitWorkspace({toolkitDetails}){ const authUrl = `https://api.twitter.com/oauth/authenticate?oauth_token=${oauth_token}` window.location.href = authUrl } - + useEffect(() => { if(toolkitDetails !== null) { if (toolkitDetails.tools) { @@ -43,7 +43,7 @@ export default function ToolkitWorkspace({toolkitDetails}){ const apiConfigs = response.data || []; setApiConfigs(apiConfigs); }) - .catch((error) => { + .catch((errPor) => { console.log('Error fetching API data:', error); }) .finally(() => { @@ -79,7 +79,7 @@ export default function ToolkitWorkspace({toolkitDetails}){ }; const handleTwitterAuthClick = async () => { - authenticateTwitterCred(toolDetails.id) + authenticateTwitterCred(toolkitDetails.id) .then((response) => { getTwitterToken(response.data); }) @@ -133,7 +133,7 @@ export default function ToolkitWorkspace({toolkitDetails}){
{toolkitDetails.name === 'Google Calendar Toolkit' && } - {toolDetails.name === 'Twitter Toolkit' && } + {toolkitDetails.name === 'Twitter Toolkit' && }
From ed15dc6ddf8dfb859a517531a832bcfe48acc337 Mon Sep 17 00:00:00 2001 From: NishantBorthakur Date: Tue, 27 Jun 2023 17:59:15 +0530 Subject: [PATCH 030/241] unique id added to create agent tab --- gui/pages/Content/Agents/AgentCreate.js | 2 +- .../Content/Agents/AgentTemplatesList.js | 4 +- gui/pages/Content/Agents/Agents.js | 54 +++++++++---------- gui/pages/Dashboard/Content.js | 19 ++++--- gui/utils/utils.js | 37 +++++++++++++ 5 files changed, 78 insertions(+), 38 deletions(-) diff --git a/gui/pages/Content/Agents/AgentCreate.js b/gui/pages/Content/Agents/AgentCreate.js index 86fdee027..6dc7301fa 100644 --- a/gui/pages/Content/Agents/AgentCreate.js +++ b/gui/pages/Content/Agents/AgentCreate.js @@ -7,7 +7,7 @@ import {createAgent, fetchAgentTemplateConfigLocal, getOrganisationConfig, uploa import {formatBytes, openNewTab, removeTab} from "@/utils/utils"; import {EventBus} from "@/utils/eventBus"; -export default function AgentCreate({sendAgentData, selectedProjectId, fetchAgents, toolkits, organisationId, template}) { +export default function AgentCreate({sendAgentData, selectedProjectId, fetchAgents, toolkits, organisationId, template, internalId}) { const [advancedOptions, setAdvancedOptions] = useState(false); const [agentName, setAgentName] = useState(""); const [agentDescription, setAgentDescription] = useState(""); diff --git a/gui/pages/Content/Agents/AgentTemplatesList.js b/gui/pages/Content/Agents/AgentTemplatesList.js index 9ea6eb146..e3093cf0c 100644 --- a/gui/pages/Content/Agents/AgentTemplatesList.js +++ b/gui/pages/Content/Agents/AgentTemplatesList.js @@ -5,7 +5,7 @@ import {fetchAgentTemplateListLocal} from "@/pages/api/DashboardService"; import AgentCreate from "@/pages/Content/Agents/AgentCreate"; import {EventBus} from "@/utils/eventBus"; -export default function AgentTemplatesList({sendAgentData, selectedProjectId, fetchAgents, toolkits, organisationId}){ +export default function AgentTemplatesList({sendAgentData, selectedProjectId, fetchAgents, toolkits, organisationId, internalId}){ const [agentTemplates, setAgentTemplates] = useState([]) const [createAgentClicked, setCreateAgentClicked] = useState(false) const [sendTemplate, setSendTemplate] = useState(null) @@ -86,7 +86,7 @@ export default function AgentTemplatesList({sendAgentData, selectedProjectId, fe
}
-
: } +
: }
) }; diff --git a/gui/pages/Content/Agents/Agents.js b/gui/pages/Content/Agents/Agents.js index e0c779652..4e232b7e1 100644 --- a/gui/pages/Content/Agents/Agents.js +++ b/gui/pages/Content/Agents/Agents.js @@ -1,38 +1,36 @@ import React from 'react'; import Image from "next/image"; import styles from './Agents.module.css'; -import {ToastContainer, toast} from 'react-toastify'; import 'react-toastify/dist/ReactToastify.css'; +import {createInternalId} from "@/utils/utils"; export default function Agents({sendAgentData, agents}) { - return ( - <> -
-
-

Agents

-
-
- -
- - {agents && agents.length > 0 ?
- {agents.map((agent, index) => ( -
-
sendAgentData(agent)}> - {agent.status &&
active-icon
} -
{agent.name}
-
-
- ))} -
:
- No Agents found -
} + return (<> +
+
+

Agents

+
+
+
- + + {agents && agents.length > 0 ?
+ {agents.map((agent, index) => ( +
+
sendAgentData(agent)}> + {agent.status &&
active-icon
} +
{agent.name}
+
+
+ ))} +
:
+ No Agents found +
} +
); } diff --git a/gui/pages/Dashboard/Content.js b/gui/pages/Dashboard/Content.js index f7a04f3ae..cab7e1b1d 100644 --- a/gui/pages/Dashboard/Content.js +++ b/gui/pages/Dashboard/Content.js @@ -10,6 +10,7 @@ import { EventBus } from "@/utils/eventBus"; import {getAgents, getToolKit, getLastActiveAgent} from "@/pages/api/DashboardService"; import Market from "../Content/Marketplace/Market"; import AgentTemplatesList from '../Content/Agents/AgentTemplatesList'; +import {createInternalId, removeInternalId} from "@/utils/utils"; export default function Content({env, selectedView, selectedProjectId, organisationId}) { const [tabs, setTabs] = useState([]); @@ -53,12 +54,12 @@ export default function Content({env, selectedView, selectedProjectId, organisat fetchToolkits(); }, [selectedProjectId]) - const closeTab = (e, index) => { + const closeTab = (e, index, contentType, internalId) => { e.stopPropagation(); - cancelTab(index); + cancelTab(index, contentType, internalId); }; - const cancelTab = (index) => { + const cancelTab = (index, contentType, internalId) => { let updatedTabs = [...tabs]; if (selectedTab === index) { @@ -78,6 +79,10 @@ export default function Content({env, selectedView, selectedProjectId, organisat updatedTabs.splice(index, 1); } + if(contentType === 'Create_Agent' && typeof window !== 'undefined') { + removeInternalId(internalId); + } + setTabs(updatedTabs); }; @@ -136,7 +141,7 @@ export default function Content({env, selectedView, selectedProjectId, organisat const newAgentTabIndex = tabs.findIndex( (tab) => tab.id === eventData.id && tab.name === eventData.name && tab.contentType === eventData.contentType ); - cancelTab(newAgentTabIndex); + cancelTab(newAgentTabIndex, eventData.contentType, eventData.contentType === 'Create_Agent' ? eventData.internalId : 0); }; EventBus.on('openNewTab', openNewTab); @@ -172,7 +177,7 @@ export default function Content({env, selectedView, selectedProjectId, organisat
empty-state
- +
{agents && agents.length > 0 &&
@@ -191,7 +196,7 @@ export default function Content({env, selectedView, selectedProjectId, organisat {tab.contentType === 'Marketplace' &&
marketplace-icon
}
{tab.name}
-
closeTab(e, index)} className={styles.tab_active} style={{order:'1'}}>close-icon
+
closeTab(e, index, tab.contentType, tab.contentType === 'Create_Agent' ? tab.internalId : 0)} className={styles.tab_active} style={{order:'1'}}>close-icon
))}
@@ -205,7 +210,7 @@ export default function Content({env, selectedView, selectedProjectId, organisat {tab.contentType === 'Toolkits' && } {tab.contentType === 'Settings' && } {tab.contentType === 'Marketplace' && } - {tab.contentType === 'Create_Agent' && } + {tab.contentType === 'Create_Agent' && } } ))} diff --git a/gui/utils/utils.js b/gui/utils/utils.js index 4a0ccd940..baddb51ff 100644 --- a/gui/utils/utils.js +++ b/gui/utils/utils.js @@ -104,4 +104,41 @@ export const removeTab = (id, name, contentType) => { EventBus.emit('removeTab', { element: {id: id, name: name, contentType: contentType} }); +} + +export const removeInternalId = (internalId) => { + const internal_ids = localStorage.getItem("agi_internal_ids"); + let idsArray = internal_ids ? internal_ids.split(",").map(Number) : []; + + if(idsArray.length <= 0) { + return; + } + + const internalIdIndex = idsArray.indexOf(internalId); + if (internalIdIndex !== -1) { + idsArray.splice(internalIdIndex, 1); + localStorage.setItem('agi_internal_ids', idsArray.join(',')); + } +} + +export const createInternalId = () => { + let newId = 1; + + if (typeof window !== 'undefined') { + const internal_ids = localStorage.getItem("agi_internal_ids"); + let idsArray = internal_ids ? internal_ids.split(",").map(Number) : []; + let found = false; + + for (let i = 1; !found; i++) { + if (!idsArray.includes(i)) { + newId = i; + found = true; + } + } + + idsArray.push(newId); + localStorage.setItem('agi_internal_ids', idsArray.join(',')); + } + + return newId; } \ No newline at end of file From 7e026007a5cd217883d408729ee68b5ef76195bd Mon Sep 17 00:00:00 2001 From: rakesh-krishna-a-s Date: Tue, 27 Jun 2023 18:46:47 +0530 Subject: [PATCH 031/241] temp commit --- docker-compose.yaml | 1 - entrypoint.sh | 2 +- gui/pages/api/DashboardService.js | 2 +- superagi/helper/file_to_index_parser.py | 14 ++++++++-- superagi/jobs/agent_executor.py | 34 ++++++++++++++++++++--- superagi/models/agent.py | 2 +- superagi/tools/resource/query_resource.py | 5 ++-- superagi/types/db.py | 20 ++++++------- 8 files changed, 58 insertions(+), 22 deletions(-) diff --git a/docker-compose.yaml b/docker-compose.yaml index 975e98fdf..c155b3938 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -59,7 +59,6 @@ services: volumes: - ./nginx/default.conf:/etc/nginx/conf.d/default.conf - networks: super_network: driver: bridge diff --git a/entrypoint.sh b/entrypoint.sh index 55b09c9cc..d3c8581f5 100755 --- a/entrypoint.sh +++ b/entrypoint.sh @@ -2,4 +2,4 @@ alembic upgrade head # Start the app -exec uvicorn main:app --host 0.0.0.0 --port 8001 --reload --root-path /api +exec uvicorn main:app --host 0.0.0.0 --port 8001 --reload diff --git a/gui/pages/api/DashboardService.js b/gui/pages/api/DashboardService.js index 8761d35a4..b581205f1 100644 --- a/gui/pages/api/DashboardService.js +++ b/gui/pages/api/DashboardService.js @@ -41,7 +41,7 @@ export const createAgent = (agentData) => { }; export const updateAgents = (agentData) => { - return api.put(`/agentconfigs/update/`, agentData); + return api.put(`/agentconfigs/update`, agentData); }; export const updateExecution = (executionId, executionData) => { diff --git a/superagi/helper/file_to_index_parser.py b/superagi/helper/file_to_index_parser.py index bfcd4539c..1218d6030 100644 --- a/superagi/helper/file_to_index_parser.py +++ b/superagi/helper/file_to_index_parser.py @@ -71,13 +71,17 @@ def save_file_to_vector_store(file_path: str, agent_id: str, resource_id: str): index.storage_context.persist(persist_dir="workspace/index") -def generate_summary_of_document(documents: list[llama_index.Document]): +def generate_summary_of_document(documents: list[llama_index.Document], openai_api_key: str = None): + openai_api_key = openai_api_key or get_config("OPENAI_API_KEY") from llama_index import LLMPredictor from llama_index import ServiceContext from langchain.chat_models import ChatOpenAI from llama_index import ResponseSynthesizer from llama_index import DocumentSummaryIndex - llm_predictor_chatgpt = LLMPredictor(llm=ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo")) + print('aaaaaaaaaaaaaaaaa', openai_api_key) + os.environ["OPENAI_API_KEY"] = openai_api_key + llm_predictor_chatgpt = LLMPredictor(llm=ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo", + openai_api_key=openai_api_key)) service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor_chatgpt, chunk_size=1024) response_synthesizer = ResponseSynthesizer.from_args(response_mode="tree_summarize", use_async=True) doc_summary_index = DocumentSummaryIndex.from_documents( @@ -86,3 +90,9 @@ def generate_summary_of_document(documents: list[llama_index.Document]): response_synthesizer=response_synthesizer ) return doc_summary_index.get_document_summary(documents[0].doc_id) + + +def generate_summary_of_texts(texts: list[str], openai_api_key: str): + from llama_index import Document + documents = [Document(doc_id=f"doc_id_{i}", text=text) for i, text in enumerate(texts)] + return generate_summary_of_document(documents, openai_api_key) diff --git a/superagi/jobs/agent_executor.py b/superagi/jobs/agent_executor.py index 636be0aea..b3788102e 100644 --- a/superagi/jobs/agent_executor.py +++ b/superagi/jobs/agent_executor.py @@ -163,10 +163,12 @@ def execute_next_action(self, agent_execution_id): tools.append(tool) print(user_tools) + print('////////////', model_api_key) + resource_summary = self.generate_resource_summary(agent.id, session, model_api_key) + resource_summary = resource_summary or parsed_config.get("resource_summary") tools = self.set_default_params_tools(tools, parsed_config, agent_execution.agent_id, - model_api_key=model_api_key) - - + model_api_key=model_api_key, + resource_description=resource_summary) spawned_agent = SuperAgi(ai_name=parsed_config["name"], ai_role=parsed_config["description"], llm=OpenAi(model=parsed_config["model"], api_key=model_api_key), tools=tools, @@ -205,7 +207,7 @@ def execute_next_action(self, agent_execution_id): # finally: engine.dispose() - def set_default_params_tools(self, tools, parsed_config, agent_id, model_api_key): + def set_default_params_tools(self, tools, parsed_config, agent_id, model_api_key, resource_description=None): """ Set the default parameters for the tools. @@ -214,6 +216,7 @@ def set_default_params_tools(self, tools, parsed_config, agent_id, model_api_key parsed_config (dict): The parsed configuration. agent_id (int): The ID of the agent. model_api_key (str): The API key of the model. + resource_description (str): The description of the resource. Returns: list: The list of tools with default parameters. @@ -232,6 +235,8 @@ def set_default_params_tools(self, tools, parsed_config, agent_id, model_api_key tool.image_llm = OpenAi(model=parsed_config["model"], api_key=model_api_key) if hasattr(tool, 'agent_id'): tool.agent_id = agent_id + if tool.name == "Query Resource" and resource_description: + tool.description = resource_description new_tools.append(tool) return tools @@ -267,3 +272,24 @@ def handle_wait_for_permission(self, agent_execution, spawned_agent, session): session.add(agent_execution_feed) agent_execution.status = "RUNNING" session.commit() + + def generate_resource_summary(self,agent_id: int, session: Session, openai_api_key: str): + from superagi.models.resource import Resource + from superagi.models.agent_config import AgentConfiguration + resources = session.query(Resource).filter(Resource.agent_id == agent_id).all() + # get last resource from agent config + last_resource = session.query(AgentConfiguration).filter(AgentConfiguration.agent_id == agent_id, + AgentConfiguration.key == "last_resource").first() + if last_resource is not None and last_resource.value == resources[-1].id: + return + texts = [resource.summary for resource in resources if resource.summary is not None] + if len(texts) == 0: + return + from superagi.helper.file_to_index_parser import generate_summary_of_texts + resource_summary = generate_summary_of_texts(texts, openai_api_key) + agent_resource_config = AgentConfiguration(agent_id=agent_id, key="resource_summary", value=resource_summary) + agent_last_resource = AgentConfiguration(agent_id=agent_id, key="last_resource", value=resources[-1].id) + session.add(agent_resource_config) + session.add(agent_last_resource) + session.commit() + return resource_summary diff --git a/superagi/models/agent.py b/superagi/models/agent.py index 2786300b7..de836c331 100644 --- a/superagi/models/agent.py +++ b/superagi/models/agent.py @@ -100,7 +100,7 @@ def eval_agent_config(cls, key, value): """ - if key in ["name", "description", "agent_type", "exit", "model", "permission_type", "LTM_DB"]: + if key in ["name", "description", "agent_type", "exit", "model", "permission_type", "LTM_DB", "resource_summary"]: return value elif key in ["project_id", "memory_window", "max_iterations", "iteration_interval"]: return int(value) diff --git a/superagi/tools/resource/query_resource.py b/superagi/tools/resource/query_resource.py index a2e818a73..e85094586 100644 --- a/superagi/tools/resource/query_resource.py +++ b/superagi/tools/resource/query_resource.py @@ -34,10 +34,11 @@ class QueryResourceTool(BaseTool): def _execute(self, query: str): model_api_key = get_config("OPENAI_API_KEY") openai.api_key = model_api_key - llm_predictor_chatgpt = LLMPredictor(llm=ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo",openai_api_key=get_config("OPENAI_API_KEY"))) + llm_predictor_chatgpt = LLMPredictor(llm=ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo", + openai_api_key=get_config("OPENAI_API_KEY"))) service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor_chatgpt) vector_store = llama_vector_store_factory('PineCone', 'super-agent-index1', OpenAiEmbedding(model_api_key)) - index = VectorStoreIndex.from_vector_store(vector_store=vector_store,service_context=service_context) + index = VectorStoreIndex.from_vector_store(vector_store=vector_store, service_context=service_context) query_engine = index.as_query_engine( filters=MetadataFilters( filters=[ diff --git a/superagi/types/db.py b/superagi/types/db.py index 5f7661358..1f4d3a21b 100644 --- a/superagi/types/db.py +++ b/superagi/types/db.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Optional +from typing import Optional, Union, List from pydantic.main import BaseModel @@ -44,7 +44,7 @@ class Config: class AgentConfigurationIn(BaseModel): agent_id: int key: str - value: str + value: Union[str, List[str]] class Config: orm_mode = True @@ -66,14 +66,14 @@ class Config: class AgentExecutionIn(BaseModel): - status: str - name: str - agent_id: int - last_execution_time: datetime - num_of_calls: int - num_of_tokens: int - current_step_id: int - permission_id: int + status: Optional[str] + name: Optional[str] + agent_id: Optional[int] + last_execution_time: Optional[datetime] + num_of_calls: Optional[int] + num_of_tokens: Optional[int] + current_step_id: Optional[int] + permission_id: Optional[int] class Config: orm_mode = True From d58713181261d9c9e29dfc25b1cca318e06cdffa Mon Sep 17 00:00:00 2001 From: Tarraann Date: Tue, 27 Jun 2023 20:22:52 +0530 Subject: [PATCH 032/241] Added unit tests --- .../tools/twitter/test_send_tweets.py | 75 +++++++++++++++++++ 1 file changed, 75 insertions(+) create mode 100644 tests/unit_tests/tools/twitter/test_send_tweets.py diff --git a/tests/unit_tests/tools/twitter/test_send_tweets.py b/tests/unit_tests/tools/twitter/test_send_tweets.py new file mode 100644 index 000000000..62ca86eb8 --- /dev/null +++ b/tests/unit_tests/tools/twitter/test_send_tweets.py @@ -0,0 +1,75 @@ +import unittest +import os +from unittest.mock import MagicMock, patch, +from superagi.tools.twitter.send_tweets import SendTweetsTool, SendTweetsInput +from superagi.helper.twitter_tokens import TwitterTokens + +class TestSendTweets(unittest.TestCase): + + def setUp(self): + self.test_input = { + "tweet_text": "Hello, world!", + "is_media": False, + "media_num": 0, + "media_files": [] + } + self.send_tweets_instance = SendTweetsTool() + + def test_execute_success(self): + with patch.object(SendTweetsTool, 'send_tweets', return_value=MagicMock(status_code=201)) as mock_send_tweets, \ + patch.object(TwitterTokens, 'get_twitter_creds', return_value={}) as mock_get_twitter_creds, \ + patch('send_tweets_tool.SendTweetsTool.toolkit_config', new_callable=PropertyMock, return_value=MagicMock(toolkit_id=1)): + send_tweets_tool = SendTweetsTool() + response = send_tweets_tool._execute(False, tweet_text='Test tweet') + self.assertEqual(response, "Tweet posted successfully!!") + mock_send_tweets.assert_called() + + def test_execute_error(self): + with patch.object(SendTweetsTool, 'send_tweets', return_value=MagicMock(status_code=400)) as mock_send_tweets, \ + patch.object(TwitterTokens, 'get_twitter_creds', return_value={}) as mock_get_twitter_creds: + send_tweets_tool = SendTweetsTool() + response = send_tweets_tool._execute(False, tweet_text='Test tweet') + self.assertEqual(response, "Error posting tweet. (Status code: 400)") + mock_send_tweets.assert_called() + + def test_get_media_ids(self): + test_creds = { + "api_key": "test_key", + "api_key_secret": "test_key_secret", + "oauth_token": "test_token", + "oauth_token_secret": "test_token_secret" + } + + with patch('requests.post') as mock_request: + mock_request.return_value.text = '{"media_id": "1234567890"}' + media_ids = self.send_tweets_instance.get_media_ids(["downloads.png"], test_creds) + self.assertEqual(media_ids, ["1234567890"]) + + with patch('requests.post') as mock_request: + mock_request.return_value.text = '{"media_id": "0987654321"}' + media_ids = self.send_tweets_instance.get_media_ids(["testing.png"], test_creds) + self.assertEqual(media_ids, ["0987654321"]) + + def test_send_tweets(self): + test_params = { + "text": "Hello, world!" + } + test_creds = { + "api_key": "test_key", + "api_key_secret": "test_key_secret", + "oauth_token": "test_token", + "oauth_token_secret": "test_token_secret" + } + + with patch('requests_oauthlib.OAuth1Session.post') as mock_oauth_request: + mock_oauth_request.return_value.status_code = 201 + response = self.send_tweets_instance.send_tweets(test_params, test_creds) + self.assertEqual(response.status_code, 201) + + with patch('requests_oauthlib.OAuth1Session.post') as mock_oauth_request: + mock_oauth_request.return_value.status_code = 400 + response = self.send_tweets_instance.send_tweets(test_params, test_creds) + self.assertEqual(response.status_code, 400) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From 541236a61a226a8dd7b4e2adbc69e743e8043bb4 Mon Sep 17 00:00:00 2001 From: Tarraann Date: Tue, 27 Jun 2023 20:25:42 +0530 Subject: [PATCH 033/241] Minor fix --- tests/unit_tests/tools/twitter/test_send_tweets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit_tests/tools/twitter/test_send_tweets.py b/tests/unit_tests/tools/twitter/test_send_tweets.py index 62ca86eb8..8091ba5ac 100644 --- a/tests/unit_tests/tools/twitter/test_send_tweets.py +++ b/tests/unit_tests/tools/twitter/test_send_tweets.py @@ -1,6 +1,6 @@ import unittest import os -from unittest.mock import MagicMock, patch, +from unittest.mock import MagicMock, patch from superagi.tools.twitter.send_tweets import SendTweetsTool, SendTweetsInput from superagi.helper.twitter_tokens import TwitterTokens From bc9c7f7582f4101de4e18aad7514ac0a7a107859 Mon Sep 17 00:00:00 2001 From: Tarraann Date: Tue, 27 Jun 2023 20:29:29 +0530 Subject: [PATCH 034/241] Minor fix --- tests/unit_tests/tools/twitter/test_send_tweets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit_tests/tools/twitter/test_send_tweets.py b/tests/unit_tests/tools/twitter/test_send_tweets.py index 8091ba5ac..28a9e011f 100644 --- a/tests/unit_tests/tools/twitter/test_send_tweets.py +++ b/tests/unit_tests/tools/twitter/test_send_tweets.py @@ -1,6 +1,6 @@ import unittest import os -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, patch, PropertyMock from superagi.tools.twitter.send_tweets import SendTweetsTool, SendTweetsInput from superagi.helper.twitter_tokens import TwitterTokens From 26a235826add7b81f40bbb57dba557e1fbfdc9e5 Mon Sep 17 00:00:00 2001 From: rakesh-krishna-a-s Date: Wed, 28 Jun 2023 00:38:55 +0530 Subject: [PATCH 035/241] poc done --- superagi/jobs/agent_executor.py | 2 +- superagi/models/resource.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/superagi/jobs/agent_executor.py b/superagi/jobs/agent_executor.py index b3788102e..7fcddb7f9 100644 --- a/superagi/jobs/agent_executor.py +++ b/superagi/jobs/agent_executor.py @@ -288,7 +288,7 @@ def generate_resource_summary(self,agent_id: int, session: Session, openai_api_k from superagi.helper.file_to_index_parser import generate_summary_of_texts resource_summary = generate_summary_of_texts(texts, openai_api_key) agent_resource_config = AgentConfiguration(agent_id=agent_id, key="resource_summary", value=resource_summary) - agent_last_resource = AgentConfiguration(agent_id=agent_id, key="last_resource", value=resources[-1].id) + agent_last_resource = AgentConfiguration(agent_id=agent_id, key="last_resource", value=str(resources[-1].id)) session.add(agent_resource_config) session.add(agent_last_resource) session.commit() diff --git a/superagi/models/resource.py b/superagi/models/resource.py index 7908feaef..15edb3e59 100644 --- a/superagi/models/resource.py +++ b/superagi/models/resource.py @@ -1,4 +1,4 @@ -from sqlalchemy import Column, Integer, String, Float +from sqlalchemy import Column, Integer, String, Float, Text from superagi.models.base_model import DBBaseModel from sqlalchemy.orm import sessionmaker @@ -28,7 +28,7 @@ class Resource(DBBaseModel): type = Column(String) # application/pdf etc channel = Column(String) # INPUT,OUTPUT agent_id = Column(Integer) - summary = Column(String) + summary = Column(Text) def __repr__(self): """ From eda16c214e92f9df3c14906a1d0d9a5dead47183 Mon Sep 17 00:00:00 2001 From: rakesh-krishna-a-s Date: Wed, 28 Jun 2023 10:57:50 +0530 Subject: [PATCH 036/241] added summary to background task --- superagi/controllers/resources.py | 39 ++++++++++++++++++++++++------- superagi/jobs/agent_executor.py | 7 +++++- 2 files changed, 36 insertions(+), 10 deletions(-) diff --git a/superagi/controllers/resources.py b/superagi/controllers/resources.py index 722bb0d7c..1dfdf2115 100644 --- a/superagi/controllers/resources.py +++ b/superagi/controllers/resources.py @@ -1,6 +1,6 @@ import openai from fastapi_sqlalchemy import DBSessionMiddleware, db -from fastapi import HTTPException, Depends, Request +from fastapi import HTTPException, Depends, Request, BackgroundTasks from fastapi_jwt_auth import AuthJWT from fastapi_jwt_auth.exceptions import AuthJWTException @@ -38,7 +38,8 @@ @router.post("/add/{agent_id}", status_code=201) async def upload(agent_id: int, file: UploadFile = File(...), name=Form(...), size=Form(...), type=Form(...), - Authorize: AuthJWT = Depends(check_auth)): + background_tasks: BackgroundTasks = None, + Authorize: AuthJWT = Depends(check_auth)): """ Upload a file as a resource for an agent. @@ -117,12 +118,13 @@ async def upload(agent_id: int, file: UploadFile = File(...), name=Form(...), si db.session.add(resource) db.session.commit() db.session.flush() - save_file_to_vector_store(file_path, str(agent_id), resource.id) - documents = create_llama_document(file_path) - summary = generate_summary_of_document(documents) - resource.summary = summary - print(summary) - db.session.commit() + background_tasks.add_task(add_to_vector_store_and_create_summary, file_path, agent_id, resource.id) + # save_file_to_vector_store(file_path, str(agent_id), resource.id) + # documents = create_llama_document(file_path) + # summary = generate_summary_of_document(documents) + # resource.summary = summary + # print(summary) + # db.session.commit() logger.info(resource) return resource @@ -188,4 +190,23 @@ def download_file_by_id(resource_id: int, headers={ "Content-Disposition": f"attachment; filename={file_name}" } - ) \ No newline at end of file + ) + + +def add_to_vector_store_and_create_summary(file_path: str, agent_id: int, resource_id: int): + """ + Add a file to the vector store and generate a summary for it. + + Args: + file_path (str): Path of the file. + agent_id (str): ID of the agent. + resource_id (int): ID of the resource. + + """ + + save_file_to_vector_store(file_path, str(agent_id), str(resource_id)) + documents = create_llama_document(file_path) + summary = generate_summary_of_document(documents) + resource = db.session.query(Resource).filter(Resource.id == resource_id).first() + resource.summary = summary + db.session.commit() \ No newline at end of file diff --git a/superagi/jobs/agent_executor.py b/superagi/jobs/agent_executor.py index 7fcddb7f9..3097af28b 100644 --- a/superagi/jobs/agent_executor.py +++ b/superagi/jobs/agent_executor.py @@ -277,6 +277,8 @@ def generate_resource_summary(self,agent_id: int, session: Session, openai_api_k from superagi.models.resource import Resource from superagi.models.agent_config import AgentConfiguration resources = session.query(Resource).filter(Resource.agent_id == agent_id).all() + if len(resources) == 0: + return # get last resource from agent config last_resource = session.query(AgentConfiguration).filter(AgentConfiguration.agent_id == agent_id, AgentConfiguration.key == "last_resource").first() @@ -286,7 +288,10 @@ def generate_resource_summary(self,agent_id: int, session: Session, openai_api_k if len(texts) == 0: return from superagi.helper.file_to_index_parser import generate_summary_of_texts - resource_summary = generate_summary_of_texts(texts, openai_api_key) + if len(texts) > 1: + resource_summary = generate_summary_of_texts(texts, openai_api_key) + else: + resource_summary = texts[0] agent_resource_config = AgentConfiguration(agent_id=agent_id, key="resource_summary", value=resource_summary) agent_last_resource = AgentConfiguration(agent_id=agent_id, key="last_resource", value=str(resources[-1].id)) session.add(agent_resource_config) From c8a5216a987f9bf6b886ffe335222725bda60c76 Mon Sep 17 00:00:00 2001 From: Autocop-Agent <129729746+Autocop-Agent@users.noreply.github.com> Date: Tue, 27 Jun 2023 17:03:49 +0530 Subject: [PATCH 037/241] Added fix for supercoder (#521) * Added logic for requirements and dependencies * Added generate logic prompt * Refactored pytest for unittest --------- Co-authored-by: COLONAYUSH --- superagi/tools/code/prompts/generate_logic.txt | 11 +++++++++++ superagi/tools/code/write_code.py | 2 +- superagi/tools/code/write_test.py | 4 ++-- 3 files changed, 14 insertions(+), 3 deletions(-) create mode 100644 superagi/tools/code/prompts/generate_logic.txt diff --git a/superagi/tools/code/prompts/generate_logic.txt b/superagi/tools/code/prompts/generate_logic.txt new file mode 100644 index 000000000..710cd94e4 --- /dev/null +++ b/superagi/tools/code/prompts/generate_logic.txt @@ -0,0 +1,11 @@ +You typically always place distinct classes in separate files. +Always create a run.sh file which act as the entrypoint of the program, create it intellligently after analyzing the file types +For Python, always generate a suitable requirements.txt file. +For NodeJS, consistently produce an appropriate package.json file. +Always include a brief comment that describes the purpose of the function definition. +Attempt to provide comments that explain complicated logic. +Consistently adhere to best practices for the specified languages, ensuring code is defined as a package or project. + +Preferred Python toolbelt: +- pytest +- dataclasses \ No newline at end of file diff --git a/superagi/tools/code/write_code.py b/superagi/tools/code/write_code.py index 14a23e9f3..97b1bb12e 100644 --- a/superagi/tools/code/write_code.py +++ b/superagi/tools/code/write_code.py @@ -61,7 +61,7 @@ def _execute(self, code_description: str) -> str: Returns: Generated codes files or error message. """ - prompt = PromptReader.read_tools_prompt(__file__, "write_code.txt") + prompt = PromptReader.read_tools_prompt(__file__, "write_code.txt") + "\nUseful to know:\n" + PromptReader.read_tools_prompt(__file__, "generate_logic.txt") prompt = prompt.replace("{goals}", AgentPromptBuilder.add_list_items_to_string(self.goals)) prompt = prompt.replace("{code_description}", code_description) spec_response = self.tool_response_manager.get_last_response("WriteSpecTool") diff --git a/superagi/tools/code/write_test.py b/superagi/tools/code/write_test.py index b37079dcc..4d889575a 100644 --- a/superagi/tools/code/write_test.py +++ b/superagi/tools/code/write_test.py @@ -26,7 +26,7 @@ class WriteTestSchema(BaseModel): class WriteTestTool(BaseTool): """ - Used to generate pytest unit tests based on the specification. + Used to generate unit tests based on the specification. Attributes: llm: LLM used for test generation. @@ -62,7 +62,7 @@ def _execute(self, test_description: str, test_file_name: str) -> str: test_file_name: The name of the file where the generated tests will be saved. Returns: - Generated pytest unit tests or error message. + Generated unit tests or error message. """ prompt = PromptReader.read_tools_prompt(__file__, "write_test.txt") prompt = prompt.replace("{goals}", AgentPromptBuilder.add_list_items_to_string(self.goals)) From 96b83fdfcaa2d17198c29fc0566637875afa9a10 Mon Sep 17 00:00:00 2001 From: abhijeet Date: Tue, 27 Jun 2023 08:16:56 +0530 Subject: [PATCH 038/241] Updated Marketplace URL --- superagi/models/toolkit.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/superagi/models/toolkit.py b/superagi/models/toolkit.py index 39f942ff2..00c576691 100644 --- a/superagi/models/toolkit.py +++ b/superagi/models/toolkit.py @@ -5,10 +5,9 @@ from superagi.models.base_model import DBBaseModel -# marketplace_url = "https://app.superagi.com/api/" - -marketplace_url = "http://localhost:8001" +marketplace_url = "https://app.superagi.com/api" +# marketplace_url = "http://localhost:8001" class Toolkit(DBBaseModel): From 94e057bdb17a2985fa3e08ae789f596fc425587f Mon Sep 17 00:00:00 2001 From: abhijeet Date: Tue, 27 Jun 2023 09:43:04 +0530 Subject: [PATCH 039/241] added bug fix backend --- superagi/helper/feed_parser.py | 9 +++++- superagi/helper/time_helper.py | 33 +++++++++++++++++++++ tests/unit_tests/helper/test_time_helper.py | 15 ++++++++++ 3 files changed, 56 insertions(+), 1 deletion(-) create mode 100644 superagi/helper/time_helper.py create mode 100644 tests/unit_tests/helper/test_time_helper.py diff --git a/superagi/helper/feed_parser.py b/superagi/helper/feed_parser.py index 12828a89c..8cdad5469 100644 --- a/superagi/helper/feed_parser.py +++ b/superagi/helper/feed_parser.py @@ -1,4 +1,7 @@ import json +from datetime import datetime + +from superagi.helper.time_helper import get_time_difference def parse_feed(feed): @@ -13,6 +16,9 @@ def parse_feed(feed): If parsing fails, the original feed is returned. """ + # Get the current time + feed.time_difference = get_time_difference(feed.updated_at, str(datetime.now())) + # Check if the feed belongs to an assistant role if feed.role == "assistant": try: @@ -31,7 +37,8 @@ def parse_feed(feed): if "command" in parsed: final_output += "Tool: " + parsed["command"]["name"] + "\n" - return {"role": "assistant", "feed": final_output, "updated_at": feed.updated_at} + return {"role": "assistant", "feed": final_output, "updated_at": feed.updated_at, + "time_difference": feed.time_difference} except Exception: return feed if feed.role == "system": diff --git a/superagi/helper/time_helper.py b/superagi/helper/time_helper.py new file mode 100644 index 000000000..ad3fdf11f --- /dev/null +++ b/superagi/helper/time_helper.py @@ -0,0 +1,33 @@ +from datetime import datetime + + +def get_time_difference(timestamp1, timestamp2): + time_format = "%Y-%m-%d %H:%M:%S.%f" + + # Parse the given timestamp + parsed_timestamp1 = datetime.strptime(str(timestamp1), time_format) + parsed_timestamp2 = datetime.strptime(timestamp2, time_format) + + # Calculate the time difference + time_difference = parsed_timestamp2 - parsed_timestamp1 + + # Convert time difference to total seconds + total_seconds = int(time_difference.total_seconds()) + + # Calculate years, months, days, hours, and minutes + years, seconds_remainder = divmod(total_seconds, 31536000) # 1 year = 365 days * 24 hours * 60 minutes * 60 seconds + months, seconds_remainder = divmod(seconds_remainder, + 2592000) # 1 month = 30 days * 24 hours * 60 minutes * 60 seconds + days, seconds_remainder = divmod(seconds_remainder, 86400) # 1 day = 24 hours * 60 minutes * 60 seconds + hours, seconds_remainder = divmod(seconds_remainder, 3600) # 1 hour = 60 minutes * 60 seconds + minutes, _ = divmod(seconds_remainder, 60) # 1 minute = 60 seconds + + # Create a dictionary to store the time difference + time_difference_dict = { + "years": years, + "months": months, + "days": days, + "hours": hours, + "minutes": minutes + } + return time_difference_dict \ No newline at end of file diff --git a/tests/unit_tests/helper/test_time_helper.py b/tests/unit_tests/helper/test_time_helper.py new file mode 100644 index 000000000..c6092ab98 --- /dev/null +++ b/tests/unit_tests/helper/test_time_helper.py @@ -0,0 +1,15 @@ +from superagi.helper.time_helper import get_time_difference + + +def test_get_time_difference(): + # Test case 1: Same timestamp, expect all time components to be zero + timestamp1 = "2023-06-26 17:31:08.884322" + timestamp2 = "2023-06-27 03:57:42.038497" + expected_result = { + "years": 0, + "months": 0, + "days": 0, + "hours": 10, + "minutes": 26 + } + assert get_time_difference(timestamp1, timestamp2) == expected_result From cf136af3e570ef40b3c6f76f5a2fcc4d34076154 Mon Sep 17 00:00:00 2001 From: abhijeet Date: Tue, 27 Jun 2023 09:43:24 +0530 Subject: [PATCH 040/241] updated test --- tests/unit_tests/helper/test_time_helper.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit_tests/helper/test_time_helper.py b/tests/unit_tests/helper/test_time_helper.py index c6092ab98..79e24a500 100644 --- a/tests/unit_tests/helper/test_time_helper.py +++ b/tests/unit_tests/helper/test_time_helper.py @@ -2,7 +2,6 @@ def test_get_time_difference(): - # Test case 1: Same timestamp, expect all time components to be zero timestamp1 = "2023-06-26 17:31:08.884322" timestamp2 = "2023-06-27 03:57:42.038497" expected_result = { From ea3b806c5bb6feefc12b4d2a1e75d1de2f29949c Mon Sep 17 00:00:00 2001 From: abhijeet Date: Tue, 27 Jun 2023 09:43:50 +0530 Subject: [PATCH 041/241] Updated --- superagi/helper/time_helper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/superagi/helper/time_helper.py b/superagi/helper/time_helper.py index ad3fdf11f..3986ef0e5 100644 --- a/superagi/helper/time_helper.py +++ b/superagi/helper/time_helper.py @@ -30,4 +30,4 @@ def get_time_difference(timestamp1, timestamp2): "hours": hours, "minutes": minutes } - return time_difference_dict \ No newline at end of file + return time_difference_dict From 949e5f98589a2309ba34fb91b6786a2ea919fdf5 Mon Sep 17 00:00:00 2001 From: abhijeet Date: Tue, 27 Jun 2023 11:44:53 +0530 Subject: [PATCH 042/241] Updated --- superagi/models/toolkit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/superagi/models/toolkit.py b/superagi/models/toolkit.py index 00c576691..522e69c43 100644 --- a/superagi/models/toolkit.py +++ b/superagi/models/toolkit.py @@ -6,8 +6,8 @@ from superagi.models.base_model import DBBaseModel -marketplace_url = "https://app.superagi.com/api" -# marketplace_url = "http://localhost:8001" +# marketplace_url = "https://app.superagi.com/api" +marketplace_url = "http://localhost:8001" class Toolkit(DBBaseModel): From 7d64384dc8d4799b0a1c56ee12990cc975ca415d Mon Sep 17 00:00:00 2001 From: NishantBorthakur Date: Tue, 27 Jun 2023 12:32:48 +0530 Subject: [PATCH 043/241] time difference issue fixed, added node module in gitignore in main superagi folder --- .gitignore | 1 + gui/pages/Content/Agents/ActionConsole.js | 6 +++--- gui/pages/Content/Agents/ActivityFeed.js | 7 +++---- gui/pages/Content/Agents/RunHistory.js | 4 ++-- gui/utils/utils.js | 25 +++++++++++------------ 5 files changed, 21 insertions(+), 22 deletions(-) diff --git a/.gitignore b/.gitignore index 3c4d9032a..ea17fbcbe 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ superagi/controllers/__pycache__ **agent_dictvenv **/__gitpycache__/ gui/node_modules +node_modules gui/.next .DS_Store .DS_Store? diff --git a/gui/pages/Content/Agents/ActionConsole.js b/gui/pages/Content/Agents/ActionConsole.js index 9081f81f8..c28772707 100644 --- a/gui/pages/Content/Agents/ActionConsole.js +++ b/gui/pages/Content/Agents/ActionConsole.js @@ -2,7 +2,7 @@ import React, { useState, useEffect } from 'react'; import styles from './Agents.module.css'; import Image from 'next/image'; import { updatePermissions } from '@/pages/api/DashboardService'; -import { formatTime } from '@/utils/utils'; +import {formatTimeDifference} from '@/utils/utils'; function ActionBox({ action, index, denied, reasons, handleDeny, handleSelection, setReasons }) { const isDenied = denied[index]; @@ -49,7 +49,7 @@ function ActionBox({ action, index, denied, reasons, handleDeny, handleSelection
schedule-icon
-
{formatTime(action.created_at)}
+
{formatTimeDifference(action.time_difference)}
); @@ -81,7 +81,7 @@ function HistoryBox({ action }){
schedule-icon
-
{formatTime(action.created_at)}
+
{formatTimeDifference(action.time_difference)}
diff --git a/gui/pages/Content/Agents/ActivityFeed.js b/gui/pages/Content/Agents/ActivityFeed.js index bc9c7f29e..2c33b4bf7 100644 --- a/gui/pages/Content/Agents/ActivityFeed.js +++ b/gui/pages/Content/Agents/ActivityFeed.js @@ -2,7 +2,7 @@ import React, {useEffect, useRef, useState} from 'react'; import styles from './Agents.module.css'; import {getExecutionFeeds} from "@/pages/api/DashboardService"; import Image from "next/image"; -import {formatTime, loadingTextEffect} from "@/utils/utils"; +import {loadingTextEffect, formatTimeDifference} from "@/utils/utils"; import {EventBus} from "@/utils/eventBus"; export default function ActivityFeed({selectedRunId, selectedView, setFetchedData }) { @@ -59,7 +59,6 @@ export default function ActivityFeed({selectedRunId, selectedView, setFetchedDat const data = response.data; setFeeds(data.feeds); setRunStatus(data.status); - console.log(data.permissions) setFetchedData(data.permissions); }) .catch((error) => { @@ -92,13 +91,13 @@ export default function ActivityFeed({selectedRunId, selectedView, setFetchedDat
{f?.feed || ''}
- {f.updated_at && formatTime(f.updated_at) !== 'Invalid Time' &&
+ {f.time_difference && formatTimeDifference(f.time_difference) !== 'Invalid Time' &&
schedule-icon
- {formatTime(f.updated_at)} + {formatTimeDifference(f.time_difference)}
} diff --git a/gui/pages/Content/Agents/RunHistory.js b/gui/pages/Content/Agents/RunHistory.js index 95c3e471b..a41d00fa4 100644 --- a/gui/pages/Content/Agents/RunHistory.js +++ b/gui/pages/Content/Agents/RunHistory.js @@ -1,7 +1,7 @@ import React from 'react'; import styles from './Agents.module.css'; import Image from "next/image"; -import {formatTime, formatNumber} from "@/utils/utils"; +import {formatNumber, formatTimeDifference} from "@/utils/utils"; export default function RunHistory({runs, setHistory, selectedRunId, setSelectedRun}) { return (<> @@ -44,7 +44,7 @@ export default function RunHistory({runs, setHistory, selectedRunId, setSelected schedule-icon
- {formatTime(run.last_execution_time)} + {formatTimeDifference(run.time_difference)}
diff --git a/gui/utils/utils.js b/gui/utils/utils.js index baddb51ff..513b47ef8 100644 --- a/gui/utils/utils.js +++ b/gui/utils/utils.js @@ -1,21 +1,20 @@ -import { formatDistanceToNow, parseISO } from 'date-fns'; import {baseUrl} from "@/pages/api/apiConfig"; import {EventBus} from "@/utils/eventBus"; -export const formatTime = (lastExecutionTime) => { - try { - const parsedTime = parseISO(lastExecutionTime); - if (isNaN(parsedTime.getTime())) { - throw new Error('Invalid time value'); +export const formatTimeDifference = (timeDifference) => { + const units = ['years', 'months', 'days', 'hours', 'minutes']; + + for (const unit of units) { + if (timeDifference[unit] !== 0) { + if (unit === 'minutes') { + return `${timeDifference[unit]} minutes ago`; + } else { + return `${timeDifference[unit]} ${unit} ago`; + } } - return formatDistanceToNow(parsedTime, { - addSuffix: true, - includeSeconds: true, - }).replace(/about\s/, ''); - } catch (error) { - console.error('Error formatting time:', error); - return 'Invalid Time'; } + + return 'Just now'; }; export const formatNumber = (number) => { From b51949f3f7b4e4c38b970246284c8592ca753d81 Mon Sep 17 00:00:00 2001 From: abhijeet Date: Tue, 27 Jun 2023 12:49:35 +0530 Subject: [PATCH 044/241] Fixed Time API --- superagi/controllers/agent_execution.py | 3 +++ superagi/controllers/agent_execution_feed.py | 6 +++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/superagi/controllers/agent_execution.py b/superagi/controllers/agent_execution.py index ea09a7abd..16373660b 100644 --- a/superagi/controllers/agent_execution.py +++ b/superagi/controllers/agent_execution.py @@ -3,6 +3,7 @@ from fastapi import HTTPException, Depends from fastapi_jwt_auth import AuthJWT +from superagi.helper.time_helper import get_time_difference from superagi.models.agent_workflow import AgentWorkflow from superagi.worker import execute_agent from superagi.models.agent_execution import AgentExecution @@ -120,6 +121,8 @@ def list_running_agents(agent_id: str, executions = db.session.query(AgentExecution).filter(AgentExecution.agent_id == agent_id).order_by( desc(AgentExecution.status == 'RUNNING'), desc(AgentExecution.last_execution_time)).all() + for execution in executions: + execution.time_difference = get_time_difference(execution.last_execution_time,str(datetime.now())) return executions diff --git a/superagi/controllers/agent_execution_feed.py b/superagi/controllers/agent_execution_feed.py index 6f721b032..81e1f3e35 100644 --- a/superagi/controllers/agent_execution_feed.py +++ b/superagi/controllers/agent_execution_feed.py @@ -1,3 +1,5 @@ +from datetime import datetime + from fastapi import APIRouter from fastapi import HTTPException, Depends from fastapi_jwt_auth import AuthJWT @@ -7,6 +9,7 @@ from superagi.agent.task_queue import TaskQueue from superagi.helper.auth import check_auth +from superagi.helper.time_helper import get_time_difference from superagi.models.agent_execution_permission import AgentExecutionPermission from superagi.helper.feed_parser import parse_feed from superagi.models.agent_execution import AgentExecution @@ -147,7 +150,8 @@ def get_agent_execution_feed(agent_execution_id: int, "response": permission.user_feedback, "status": permission.status, "tool_name": permission.tool_name, - "user_feedback": permission.user_feedback + "user_feedback": permission.user_feedback, + "time_difference":get_time_difference(permission.created_at,str(datetime.now())) } for permission in execution_permissions ] return { From 2de9b62693d3099a6ef387807b237b62ae0bb438 Mon Sep 17 00:00:00 2001 From: NishantBorthakur Date: Tue, 27 Jun 2023 12:46:13 +0530 Subject: [PATCH 045/241] minor fix --- gui/pages/Content/Agents/ActivityFeed.js | 2 +- gui/utils/utils.js | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/gui/pages/Content/Agents/ActivityFeed.js b/gui/pages/Content/Agents/ActivityFeed.js index 2c33b4bf7..7418e9358 100644 --- a/gui/pages/Content/Agents/ActivityFeed.js +++ b/gui/pages/Content/Agents/ActivityFeed.js @@ -91,7 +91,7 @@ export default function ActivityFeed({selectedRunId, selectedView, setFetchedDat
{f?.feed || ''}
- {f.time_difference && formatTimeDifference(f.time_difference) !== 'Invalid Time' &&
+ {f.time_difference &&
schedule-icon diff --git a/gui/utils/utils.js b/gui/utils/utils.js index 513b47ef8..6f5fab3e8 100644 --- a/gui/utils/utils.js +++ b/gui/utils/utils.js @@ -34,7 +34,6 @@ export const formatNumber = (number) => { return scaledNumber.toFixed(1) + suffix; }; - export const formatBytes = (bytes, decimals = 2) => { if (bytes === 0) { return '0 Bytes'; From a8ca92b1a27be83546343a436d4bccac33b0769f Mon Sep 17 00:00:00 2001 From: abhijeet Date: Tue, 27 Jun 2023 16:21:24 +0530 Subject: [PATCH 046/241] Added Test --- superagi/helper/time_helper.py | 8 ++++---- tests/unit_tests/helper/test_feed_parser.py | 22 +++++++++++++++++++++ 2 files changed, 26 insertions(+), 4 deletions(-) create mode 100644 tests/unit_tests/helper/test_feed_parser.py diff --git a/superagi/helper/time_helper.py b/superagi/helper/time_helper.py index 3986ef0e5..e0970ca75 100644 --- a/superagi/helper/time_helper.py +++ b/superagi/helper/time_helper.py @@ -15,11 +15,11 @@ def get_time_difference(timestamp1, timestamp2): total_seconds = int(time_difference.total_seconds()) # Calculate years, months, days, hours, and minutes - years, seconds_remainder = divmod(total_seconds, 31536000) # 1 year = 365 days * 24 hours * 60 minutes * 60 seconds + years, seconds_remainder = divmod(total_seconds, (365 * 24 * 60 * 60)) # 1 year = 365 days * 24 hours * 60 minutes * 60 seconds months, seconds_remainder = divmod(seconds_remainder, - 2592000) # 1 month = 30 days * 24 hours * 60 minutes * 60 seconds - days, seconds_remainder = divmod(seconds_remainder, 86400) # 1 day = 24 hours * 60 minutes * 60 seconds - hours, seconds_remainder = divmod(seconds_remainder, 3600) # 1 hour = 60 minutes * 60 seconds + (30 * 24 * 60 * 60)) # 1 month = 30 days * 24 hours * 60 minutes * 60 seconds + days, seconds_remainder = divmod(seconds_remainder, 24 * 60 * 60) # 1 day = 24 hours * 60 minutes * 60 seconds + hours, seconds_remainder = divmod(seconds_remainder, 60 * 60) # 1 hour = 60 minutes * 60 seconds minutes, _ = divmod(seconds_remainder, 60) # 1 minute = 60 seconds # Create a dictionary to store the time difference diff --git a/tests/unit_tests/helper/test_feed_parser.py b/tests/unit_tests/helper/test_feed_parser.py new file mode 100644 index 000000000..3380610ee --- /dev/null +++ b/tests/unit_tests/helper/test_feed_parser.py @@ -0,0 +1,22 @@ +import unittest +from datetime import datetime + +from superagi.helper.feed_parser import parse_feed +from superagi.models.agent_execution_feed import AgentExecutionFeed + + +class TestParseFeed(unittest.TestCase): + + def test_parse_feed_system(self): + current_time = datetime.now() + + # Create a sample AgentExecutionFeed object with a system role + sample_feed = AgentExecutionFeed(id=2, agent_execution_id=100, agent_id=200, role="system", + feed='System message', + updated_at=current_time) + + # Call the parse_feed function with the sample_feed object + result = parse_feed(sample_feed) + + # In this test case, we only ensure that the parse_feed function doesn't modify the given feed + self.assertEqual(result, sample_feed, "Incorrect output from parse_feed function for system role") From 75079b5d2c7e93b3f6f111e041812eebc77865f3 Mon Sep 17 00:00:00 2001 From: TransformerOptimus Date: Tue, 27 Jun 2023 19:00:02 +0530 Subject: [PATCH 047/241] fixing docker compose issue --- docker-compose.yaml | 1 - gui/.dockerignore | 2 ++ gui/Dockerfile | 5 +++-- package-lock.json | 2 +- 4 files changed, 6 insertions(+), 4 deletions(-) create mode 100644 gui/.dockerignore diff --git a/docker-compose.yaml b/docker-compose.yaml index bbeaa1d75..69b8fb1e6 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -29,7 +29,6 @@ services: - super_network volumes: - ./gui:/app - - /app/node_modules - /app/.next super__redis: image: "docker.io/library/redis:latest" diff --git a/gui/.dockerignore b/gui/.dockerignore new file mode 100644 index 000000000..220520759 --- /dev/null +++ b/gui/.dockerignore @@ -0,0 +1,2 @@ +node_modules/ +.next/ \ No newline at end of file diff --git a/gui/Dockerfile b/gui/Dockerfile index 89a648969..48365d6b9 100644 --- a/gui/Dockerfile +++ b/gui/Dockerfile @@ -1,9 +1,10 @@ -FROM node:lts +FROM node:lts AS deps WORKDIR /app COPY package*.json ./ -COPY . . RUN npm ci +COPY . . + CMD ["npm", "run", "dev"] \ No newline at end of file diff --git a/package-lock.json b/package-lock.json index be1d0fb5d..de33ede32 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,5 +1,5 @@ { - "name": "SuperAGI", + "name": "super_agi", "lockfileVersion": 2, "requires": true, "packages": { From 9166f46cfb16e64f4c9fa35a423428c6c591fb5f Mon Sep 17 00:00:00 2001 From: NishantBorthakur Date: Wed, 28 Jun 2023 12:18:29 +0530 Subject: [PATCH 048/241] localstorage for agent creation --- gui/pages/Content/Agents/AgentCreate.js | 101 ++++++++++-------- .../Content/Agents/AgentTemplatesList.js | 12 ++- gui/pages/_app.js | 14 +++ gui/utils/utils.js | 17 +++ 4 files changed, 97 insertions(+), 47 deletions(-) diff --git a/gui/pages/Content/Agents/AgentCreate.js b/gui/pages/Content/Agents/AgentCreate.js index 6dc7301fa..f3fe94bd5 100644 --- a/gui/pages/Content/Agents/AgentCreate.js +++ b/gui/pages/Content/Agents/AgentCreate.js @@ -4,15 +4,13 @@ import {ToastContainer, toast} from 'react-toastify'; import 'react-toastify/dist/ReactToastify.css'; import styles from './Agents.module.css'; import {createAgent, fetchAgentTemplateConfigLocal, getOrganisationConfig, uploadFile} from "@/pages/api/DashboardService"; -import {formatBytes, openNewTab, removeTab} from "@/utils/utils"; +import {formatBytes, openNewTab, removeTab, setLocalStorageValue, setLocalStorageArray} from "@/utils/utils"; import {EventBus} from "@/utils/eventBus"; export default function AgentCreate({sendAgentData, selectedProjectId, fetchAgents, toolkits, organisationId, template, internalId}) { const [advancedOptions, setAdvancedOptions] = useState(false); const [agentName, setAgentName] = useState(""); const [agentDescription, setAgentDescription] = useState(""); - const [selfEvaluation, setSelfEvaluation] = useState(''); - const [basePrompt, setBasePrompt] = useState(''); const [longTermMemory, setLongTermMemory] = useState(true); const [addResources, setAddResources] = useState(true); const [input, setInput] = useState([]); @@ -109,23 +107,24 @@ export default function AgentCreate({sendAgentData, selectedProjectId, fetchAgen useEffect(() => { if(template !== null) { - setAgentName(template.name) - setAgentDescription(template.description) - setAdvancedOptions(true) + setLocalStorageValue("agent_name_" + String(internalId), template.name, setAgentName); + setLocalStorageValue("agent_description_" + String(internalId), template.description, setAgentDescription); + setAdvancedOptions(true); fetchAgentTemplateConfigLocal(template.id) .then((response) => { const data = response.data || []; - setGoals(data.goal) - setAgentType(data.agent_type) - setConstraints(data.constraints) - setIterations(data.max_iterations) - setRollingWindow(data.memory_window) - setPermission(data.permission_type) - setStepTime(data.iteration_interval) - setInstructions(data.instruction) - setDatabase(data.LTM_DB) - setModel(data.model) + setLocalStorageArray("agent_goals_" + String(internalId), data.goal, setGoals); + setAgentType(data.agent_type); + setLocalStorageArray("agent_constraints_" + String(internalId), data.constraints, setConstraints); + setIterations(data.max_iterations); + setRollingWindow(data.memory_window); + setPermission(data.permission_type); + setStepTime(data.iteration_interval); + setLocalStorageArray("agent_instructions_" + String(internalId), data.instruction, setInstructions); + setDatabase(data.LTM_DB); + setModel(data.model); + data.tools.forEach((item) => { toolkitList.forEach((toolkit) => { toolkit.tools.forEach((tool) => { @@ -249,63 +248,57 @@ export default function AgentCreate({sendAgentData, selectedProjectId, fetchAgen const handleGoalChange = (index, newValue) => { const updatedGoals = [...goals]; updatedGoals[index] = newValue; - setGoals(updatedGoals); + setLocalStorageArray("agent_goals_" + String(internalId), updatedGoals, setGoals); }; + const handleInstructionChange = (index, newValue) => { const updatedInstructions = [...instructions]; updatedInstructions[index] = newValue; - setInstructions(updatedInstructions); + setLocalStorageArray("agent_instructions_" + String(internalId), updatedInstructions, setInstructions); }; const handleConstraintChange = (index, newValue) => { const updatedConstraints = [...constraints]; updatedConstraints[index] = newValue; - setConstraints(updatedConstraints); + setLocalStorageArray("agent_constraints_" + String(internalId), updatedConstraints, setConstraints); }; const handleGoalDelete = (index) => { const updatedGoals = [...goals]; updatedGoals.splice(index, 1); - setGoals(updatedGoals); + setLocalStorageArray("agent_goals_" + String(internalId), updatedGoals, setGoals); }; const handleInstructionDelete = (index) => { const updatedInstructions = [...instructions]; updatedInstructions.splice(index, 1); - setInstructions(updatedInstructions); + setLocalStorageArray("agent_instructions_" + String(internalId), updatedInstructions, setInstructions); }; const handleConstraintDelete = (index) => { const updatedConstraints = [...constraints]; updatedConstraints.splice(index, 1); - setConstraints(updatedConstraints); + setLocalStorageArray("agent_constraints_" + String(internalId), updatedConstraints, setConstraints); }; const addGoal = () => { - setGoals((prevArray) => [...prevArray, 'new goal']); + setLocalStorageArray("agent_goals_" + String(internalId), (prevArray) => [...prevArray, 'new goal'], setGoals); }; + const addInstruction = () => { - setInstructions((prevArray) => [...prevArray, 'new instructions']); + setLocalStorageArray("agent_instructions_" + String(internalId), (prevArray) => [...prevArray, 'new instructions'], setInstructions); }; const addConstraint = () => { - setConstraints((prevArray) => [...prevArray, 'new constraint']); + setLocalStorageArray("agent_constraints_" + String(internalId), (prevArray) => [...prevArray, 'new constraint'], setConstraints); }; const handleNameChange = (event) => { - setAgentName(event.target.value); + setLocalStorageValue("agent_name_" + String(internalId), event.target.value, setAgentName); }; const handleDescriptionChange = (event) => { - setAgentDescription(event.target.value); - }; - - const handleSelfEvaluationChange = (event) => { - setSelfEvaluation(event.target.value); - }; - - const handleBasePromptChange = (event) => { - setBasePrompt(event.target.value); + setLocalStorageValue("agent_description_" + String(internalId), event.target.value, setAgentDescription); }; const preventDefault = (e) => { @@ -518,6 +511,34 @@ export default function AgentCreate({sendAgentData, selectedProjectId, fetchAgen
); + useEffect(() => { + const agent_name = localStorage.getItem("agent_name_" + String(internalId)); + const agent_description = localStorage.getItem("agent_description_" + String(internalId)); + const agent_goals = localStorage.getItem("agent_goals_" + String(internalId)); + const agent_instructions = localStorage.getItem("agent_instructions_" + String(internalId)); + const agent_constraints = localStorage.getItem("agent_constraints_" + String(internalId)); + + if(agent_name) { + setAgentName(agent_name); + } + + if(agent_description) { + setAgentDescription(agent_description); + } + + if(agent_goals) { + setGoals(JSON.parse(agent_goals)); + } + + if(agent_instructions) { + setInstructions(JSON.parse(agent_instructions)); + } + + if(agent_constraints) { + setConstraints(JSON.parse(agent_constraints)); + } + }, [internalId]) + return (<>
@@ -643,16 +664,6 @@ export default function AgentCreate({sendAgentData, selectedProjectId, fetchAgen
- {/*
*/} - {/*
*/} - {/*

This will defined the agent role definitely and reduces hallucination. This will defined the agent role definitely and reduces hallucination.

*/} - {/*