Skip to content

Commit

Permalink
Refactoring the resource manager structure
Browse files Browse the repository at this point in the history
  • Loading branch information
TransformerOptimus committed Jul 2, 2023
1 parent 2587b52 commit 09cf26d
Show file tree
Hide file tree
Showing 23 changed files with 294 additions and 132 deletions.
1 change: 1 addition & 0 deletions config_template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ OPENAI_API_BASE: https://api.openai.com/v1

# "gpt-3.5-turbo-0301": 4032, "gpt-4-0314": 8092, "gpt-3.5-turbo": 4032, "gpt-4": 8092, "gpt-4-32k": 32768, "gpt-4-32k-0314": 32768, "llama":2048, "mpt-7b-storywriter":45000
MODEL_NAME: "gpt-3.5-turbo-0301"
RESOURCES_EMBEDDING_MODEL: "text-davinci-003"
MAX_TOOL_TOKEN_LIMIT: 800
MAX_MODEL_TOKEN_LIMIT: 4032 # set to 2048 for llama

Expand Down
7 changes: 4 additions & 3 deletions superagi/controllers/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,10 @@ async def upload(agent_id: int, file: UploadFile = File(...), name=Form(...), si
db.session.commit()
db.session.flush()

file_path = file_path if storage_type == StorageTypes.FILE else None
file_object = file if storage_type == StorageTypes.S3 else None
documents = await ResourceManager.create_llama_document(file_path, file_object)
if storage_type == StorageTypes.S3:
documents = await ResourceManager(agent.id).create_llama_document_s3(file)
else:
documents = await ResourceManager(agent.id).create_llama_document(file_path)
summarize_resource.delay(agent_id, resource.id, documents)
logger.info(resource)

Expand Down
2 changes: 1 addition & 1 deletion superagi/helper/google_calendar_creds.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from sqlalchemy.orm import sessionmaker
from superagi.models.db import connect_db
from superagi.models.tool_config import ToolConfig
from superagi.resource_manager.manager import FileManager
from superagi.resource_manager.file_manager import FileManager

class GoogleCalendarCreds:

Expand Down
5 changes: 3 additions & 2 deletions superagi/jobs/agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
from superagi.models.project import Project
from superagi.models.tool import Tool
from superagi.models.tool_config import ToolConfig
from superagi.resource_manager.llama_document_summary import LlamaDocumentSummary
from superagi.tools.base_tool import BaseToolkitConfiguration
from superagi.resource_manager.manager import FileManager
from superagi.resource_manager.file_manager import FileManager
from superagi.resource_manager.resource_manager import ResourceManager
from superagi.tools.thinking.tools import ThinkingTool
from superagi.tools.tool_response_query_manager import ToolResponseQueryManager
Expand Down Expand Up @@ -321,7 +322,7 @@ def generate_resource_summary(self, agent_id: int, session: Session, openai_api_
if len(texts) == 0:
return
if len(texts) > 1:
resource_summary = ResourceManager.generate_summary_of_texts(texts, openai_api_key)
resource_summary = LlamaDocumentSummary().generate_summary_of_texts(texts, openai_api_key)
else:
resource_summary = texts[0]

Expand Down
8 changes: 4 additions & 4 deletions superagi/jobs/resource_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from superagi.models.agent_config import AgentConfiguration
from superagi.models.db import connect_db
from superagi.models.resource import Resource
from superagi.resource_manager.llama_document_summary import LlamaDocumentSummary
from superagi.resource_manager.resource_manager import ResourceManager

engine = connect_db()
Expand All @@ -25,12 +26,12 @@ def add_to_vector_store_and_create_summary(cls, agent_id: int, resource_id: int,
"""
db = session()
try:
ResourceManager.save_document_to_vector_store(documents, str(agent_id), str(resource_id))
ResourceManager(agent_id).save_document_to_vector_store(documents, str(resource_id))
except Exception as e:
logger.error(e)
summary = None
try:
summary = ResourceManager.generate_summary_of_document(documents)
summary = LlamaDocumentSummary().generate_summary_of_document(documents)
except Exception as e:
logger.error(e)
resource = db.query(Resource).filter(Resource.id == resource_id).first()
Expand All @@ -44,8 +45,7 @@ def add_to_vector_store_and_create_summary(cls, agent_id: int, resource_id: int,
if len(summary_texts) == 1:
resource_summary = summary_texts[0]
else:
openai_api_key = get_config("OPENAI_API_KEY")
resource_summary = ResourceManager.generate_summary_of_texts(summary_texts, openai_api_key)
resource_summary = LlamaDocumentSummary().generate_summary_of_texts(summary_texts)

agent_config_resource_summary = db.query(AgentConfiguration). \
filter(AgentConfiguration.agent_id == agent_id,
Expand Down
File renamed without changes.
39 changes: 39 additions & 0 deletions superagi/resource_manager/llama_document_summary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from llama_index.indices.response import ResponseMode
from llama_index.schema import Document

from superagi.config.config import get_config


class LlamaDocumentSummary:
def __init__(self, model_name=get_config("RESOURCES_EMBEDDING_MODEL_NAME", "text-davinci-003")):
self.model_name = model_name

def generate_summary_of_document(self, documents: list[Document]):
from llama_index import LLMPredictor, ServiceContext, ResponseSynthesizer, DocumentSummaryIndex

llm_predictor_chatgpt = LLMPredictor(llm=self._build_llm())
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor_chatgpt, chunk_size=1024)
response_synthesizer = ResponseSynthesizer.from_args(response_mode=ResponseMode.TREE_SUMMARIZE, use_async=True)
doc_summary_index = DocumentSummaryIndex.from_documents(
documents=documents,
service_context=service_context,
response_synthesizer=response_synthesizer
)

return doc_summary_index.get_document_summary(documents[0].doc_id)

def generate_summary_of_texts(self, texts: list[str]):
from llama_index import Document
documents = [Document(doc_id=f"doc_id_{i}", text=text) for i, text in enumerate(texts)]
return self.generate_summary_of_document(documents)

def _build_llm(self):
open_ai_models = ['gpt-4', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k', 'gpt-4-32k']
if self.model_name in open_ai_models:
from langchain.chat_models import ChatOpenAI

openai_api_key = get_config("OPENAI_API_KEY")
return ChatOpenAI(temperature=0, model_name=self.model_name,
openai_api_key=openai_api_key)

raise Exception(f"Model name {self.model_name} not supported for document summary")
46 changes: 46 additions & 0 deletions superagi/resource_manager/llama_vector_store_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from llama_index.vector_stores.types import VectorStore

from superagi.config.config import get_config
from superagi.types.vector_store_types import VectorStoreType


class LlamaVectorStoreFactory:
def __init__(self, vector_store_name: VectorStoreType, index_name: str):
self.vector_store_name = vector_store_name
self.index_name = index_name

def get_vector_store(self) -> VectorStore:
if self.vector_store_name == VectorStoreType.PINECONE:
from llama_index.vector_stores import PineconeVectorStore
return PineconeVectorStore(self.index_name)

if self.vector_store_name == VectorStoreType.REDIS:
redis_url = get_config("REDIS_VECTOR_STORE_URL") or "redis://super__redis:6379"
from llama_index.vector_stores import RedisVectorStore
return RedisVectorStore(
index_name=self.index_name,
redis_url=redis_url,
metadata_fields=["agent_id", "resource_id"]
)

if self.vector_store_name == VectorStoreType.CHROMA:
from llama_index.vector_stores import ChromaVectorStore
import chromadb
from chromadb.config import Settings
chroma_host_name = get_config("CHROMA_HOST_NAME") or "localhost"
chroma_port = get_config("CHROMA_PORT") or 8000
chroma_client = chromadb.Client(
Settings(chroma_api_impl="rest", chroma_server_host=chroma_host_name,
chroma_server_http_port=chroma_port))
chroma_collection = chroma_client.get_or_create_collection(self.index_name)
return ChromaVectorStore(chroma_collection)

if self.vector_store_name == VectorStoreType.QDRANT:
from llama_index.vector_stores import QdrantVectorStore
qdrant_host_name = get_config("QDRANT_HOST_NAME") or "localhost"
qdrant_port = get_config("QDRANT_PORT") or 6333
from qdrant_client import QdrantClient
qdrant_client = QdrantClient(host=qdrant_host_name, port=qdrant_port)
return QdrantVectorStore(client=qdrant_client, collection_name=self.index_name)

raise ValueError(str(self.vector_store_name) + " vector store is not supported yet.")
135 changes: 29 additions & 106 deletions superagi/resource_manager/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,144 +4,67 @@

from superagi.config.config import get_config
from superagi.helper.resource_helper import ResourceHelper
from llama_index.schema import Document

from superagi.lib.logger import logger
from superagi.resource_manager.llama_vector_store_factory import LlamaVectorStoreFactory
from superagi.types.vector_store_types import VectorStoreType
from llama_index.indices.response import ResponseMode


class ResourceManager:
@classmethod
async def create_llama_document(cls, file_path: str = None, file_object=None):
def __init__(self, agent_id: str = None):
self.agent_id = agent_id

async def create_llama_document(self, file_path: str):
"""
Creates a document index from a given directory.
"""

if file_path is None and file_object is None:
raise Exception("Either file_path or file_object must be provided")

if file_path is not None and file_object is not None:
raise Exception("Only one of file_path or file_object must be provided")

save_directory = ResourceHelper.get_root_input_dir() + "/"

if file_object is not None:
file_path = save_directory + file_object.filename
with open(file_path, "wb") as f:
contents = await file_object.read()
f.write(contents)
file_object.file.close()

if file_path is None:
raise Exception("Either file_path must be provided")
documents = SimpleDirectoryReader(input_files=[file_path]).load_data()

if file_object is not None:
os.remove(file_path)

return documents

@classmethod
def generate_summary_of_document(cls, documents: list[Document], openai_api_key: str = None):
openai_api_key = openai_api_key or get_config("OPENAI_API_KEY")
from llama_index import LLMPredictor, ServiceContext, ResponseSynthesizer, DocumentSummaryIndex
from langchain.chat_models import ChatOpenAI
os.environ["OPENAI_API_KEY"] = openai_api_key
llm_predictor_chatgpt = LLMPredictor(llm=ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo",
openai_api_key=openai_api_key))
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor_chatgpt, chunk_size=1024)
response_synthesizer = ResponseSynthesizer.from_args(response_mode=ResponseMode.TREE_SUMMARIZE, use_async=True)
doc_summary_index = DocumentSummaryIndex.from_documents(
documents=documents,
service_context=service_context,
response_synthesizer=response_synthesizer
)

return doc_summary_index.get_document_summary(documents[0].doc_id)

@classmethod
def generate_summary_of_texts(cls, texts: list[str], openai_api_key: str):
from llama_index import Document
documents = [Document(doc_id=f"doc_id_{i}", text=text) for i, text in enumerate(texts)]
return cls.generate_summary_of_document(documents, openai_api_key)

@classmethod
def llama_vector_store_factory(cls, vector_store_name: VectorStoreType, index_name, embedding_model):
async def create_llama_document_s3(self, file_object):
"""
Creates a llama vector store.
Creates a document index from a given directory.
"""
from superagi.vector_store.vector_factory import VectorFactory

vector_factory_support = [VectorStoreType.PINECONE, VectorStoreType.WEAVIATE]
if vector_store_name in vector_factory_support:
vector_store = VectorFactory.get_vector_storage(vector_store_name, index_name,
embedding_model)
if vector_store_name == VectorStoreType.PINECONE:
from llama_index.vector_stores import PineconeVectorStore
return PineconeVectorStore(vector_store.index)

if vector_store_name == VectorStoreType.WEAVIATE:
raise ValueError("Weaviate vector store is not supported yet.")

if vector_store_name == VectorStoreType.REDIS:
redis_url = get_config("REDIS_VECTOR_STORE_URL") or "redis://super__redis:6379"
from llama_index.vector_stores import RedisVectorStore
return RedisVectorStore(
index_name=index_name,
redis_url=redis_url,
metadata_fields=["agent_id", "resource_id"]
)
if file_object is None:
raise Exception("Either file_path or file_object must be provided")

if vector_store_name == VectorStoreType.CHROMA:
from llama_index.vector_stores import ChromaVectorStore
import chromadb
from chromadb.config import Settings
chroma_host_name = get_config("CHROMA_HOST_NAME") or "localhost"
chroma_port = get_config("CHROMA_PORT") or 8000
chroma_client = chromadb.Client(
Settings(chroma_api_impl="rest", chroma_server_host=chroma_host_name,
chroma_server_http_port=chroma_port))
chroma_collection = chroma_client.get_or_create_collection(index_name)
return ChromaVectorStore(chroma_collection), chroma_collection
save_directory = ResourceHelper.get_root_input_dir() + "/"
file_path = save_directory + file_object.filename
with open(file_path, "wb") as f:
contents = await file_object.read()
f.write(contents)
file_object.file.close()

if vector_store_name == VectorStoreType.QDRANT:
from llama_index.vector_stores import QdrantVectorStore
qdrant_host_name = get_config("QDRANT_HOST_NAME") or "localhost"
qdrant_port = get_config("QDRANT_PORT") or 6333
from qdrant_client import QdrantClient
qdrant_client = QdrantClient(host=qdrant_host_name, port=qdrant_port)
return QdrantVectorStore(client=qdrant_client, collection_name=index_name)
documents = SimpleDirectoryReader(input_files=[file_path]).load_data()
os.remove(file_path)
return documents

@classmethod
def save_document_to_vector_store(cls, documents: list, agent_id: str, resource_id: str):
def save_document_to_vector_store(self, documents: list, resource_id: str):
from llama_index import VectorStoreIndex, StorageContext
import openai
from superagi.vector_store.embedding.openai import OpenAiEmbedding
model_api_key = get_config("OPENAI_API_KEY")
openai.api_key = get_config("OPENAI_API_KEY")
for docs in documents:
if docs.metadata is None:
docs.metadata = {"agent_id": agent_id, "resource_id": resource_id}
else:
docs.metadata["agent_id"] = agent_id
docs.metadata["resource_id"] = resource_id
os.environ["OPENAI_API_KEY"] = get_config("OPENAI_API_KEY")
docs.metadata = {}
docs.metadata["agent_id"] = str(self.agent_id)
docs.metadata["resource_id"] = resource_id
vector_store = None
storage_context = None
vector_store_name = VectorStoreType.get_vector_store_type(get_config("RESOURCE_VECTOR_STORE") or "Redis")
vector_store_index_name = get_config("RESOURCE_VECTOR_STORE_INDEX_NAME") or "super-agent-index"
try:
print(vector_store_name, vector_store_index_name)
vector_store = cls.llama_vector_store_factory(vector_store_name, vector_store_index_name,
OpenAiEmbedding(model_api_key))
if vector_store_name == VectorStoreType.CHROMA:
vector_store, chroma_collection = vector_store
vector_store = LlamaVectorStoreFactory(vector_store_name, vector_store_index_name).get_vector_store()
storage_context = StorageContext.from_defaults(vector_store=vector_store)
except ValueError as e:
logger.error(f"Vector store not found{e}")
openai.api_key = get_config("OPENAI_API_KEY")
try:
index = VectorStoreIndex.from_documents(documents, storage_context=storage_context)
index.set_index_id(f'Agent {agent_id}')
index.set_index_id(f'Agent {self.agent_id}')
except Exception as e:
print(e)
# persisting the data in case of redis
if vector_store_name == VectorStoreType.REDIS:
vector_store.persist(persist_path="")
vector_store.persist(persist_path="")
2 changes: 1 addition & 1 deletion superagi/tools/code/write_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from superagi.helper.token_counter import TokenCounter
from superagi.lib.logger import logger
from superagi.llms.base_llm import BaseLlm
from superagi.resource_manager.manager import FileManager
from superagi.resource_manager.file_manager import FileManager
from superagi.tools.base_tool import BaseTool
from superagi.tools.tool_response_query_manager import ToolResponseQueryManager

Expand Down
2 changes: 1 addition & 1 deletion superagi/tools/code/write_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from superagi.helper.token_counter import TokenCounter
from superagi.lib.logger import logger
from superagi.llms.base_llm import BaseLlm
from superagi.resource_manager.manager import FileManager
from superagi.resource_manager.file_manager import FileManager
from superagi.tools.base_tool import BaseTool


Expand Down
2 changes: 1 addition & 1 deletion superagi/tools/code/write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from superagi.helper.token_counter import TokenCounter
from superagi.lib.logger import logger
from superagi.llms.base_llm import BaseLlm
from superagi.resource_manager.manager import FileManager
from superagi.resource_manager.file_manager import FileManager
from superagi.tools.base_tool import BaseTool
from superagi.tools.tool_response_query_manager import ToolResponseQueryManager

Expand Down
2 changes: 1 addition & 1 deletion superagi/tools/file/read_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pydantic import BaseModel, Field

from superagi.helper.resource_helper import ResourceHelper
from superagi.resource_manager.manager import FileManager
from superagi.resource_manager.file_manager import FileManager
from superagi.tools.base_tool import BaseTool


Expand Down
2 changes: 1 addition & 1 deletion superagi/tools/file/write_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pydantic import BaseModel, Field

# from superagi.helper.s3_helper import upload_to_s3
from superagi.resource_manager.manager import FileManager
from superagi.resource_manager.file_manager import FileManager
from superagi.tools.base_tool import BaseTool


Expand Down
2 changes: 1 addition & 1 deletion superagi/tools/google_calendar/list_calendar_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from superagi.tools.base_tool import BaseTool
from superagi.helper.google_calendar_creds import GoogleCalendarCreds
from superagi.helper.calendar_date import CalendarDate
from superagi.resource_manager.manager import FileManager
from superagi.resource_manager.file_manager import FileManager
from superagi.helper.s3_helper import S3Helper
from urllib.parse import urlparse, parse_qs
from sqlalchemy.orm import sessionmaker
Expand Down
Loading

0 comments on commit 09cf26d

Please sign in to comment.