Skip to content

Commit

Permalink
🎞️ chat history backends + session_id management
Browse files Browse the repository at this point in the history
  • Loading branch information
shroominic committed Aug 9, 2023
1 parent 5e86938 commit f25d814
Showing 1 changed file with 46 additions and 31 deletions.
77 changes: 46 additions & 31 deletions codeinterpreterapi/session.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import base64
import os
import re
import traceback
import uuid
from io import BytesIO
from os import getenv
from typing import Optional
from uuid import UUID, uuid4

from codeboxapi import CodeBox # type: ignore
from codeboxapi.schema import CodeBoxOutput # type: ignore
Expand All @@ -19,11 +18,12 @@
from langchain.chat_models.base import BaseChatModel
from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_message_histories import (
FileChatMessageHistory,
ChatMessageHistory,
PostgresChatMessageHistory,
RedisChatMessageHistory,
)
from langchain.prompts.chat import MessagesPlaceholder
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema import BaseChatMessageHistory, BaseLanguageModel
from langchain.tools import BaseTool, StructuredTool

from codeinterpreterapi.agents import OpenAIFunctionsAgent
Expand All @@ -33,6 +33,7 @@
get_file_modifications,
remove_download_link,
)
from codeinterpreterapi.chat_history import CodeBoxChatMessageHistory
from codeinterpreterapi.config import settings
from codeinterpreterapi.parser import CodeAgentOutputParser, CodeChatAgentOutputParser
from codeinterpreterapi.prompts import code_interpreter_system_message
Expand All @@ -48,35 +49,38 @@
class CodeInterpreterSession:
def __init__(
self,
session_id=None,
llm: Optional[BaseLanguageModel] = None,
additional_tools: list[BaseTool] = [],
**kwargs,
) -> None:
if not session_id:
self.session_id = str(uuid.uuid4())
else:
self.session_id = session_id
self.history_storage = os.environ["STORAGE_TYPE"]
if self.history_storage == "file" and not os.path.exists(
os.environ["SESSION_FOLDER_STORAGE"]
):
os.mkdir(os.environ["SESSION_FOLDER_STORAGE"])

self.codebox = CodeBox()
self.verbose = kwargs.get("verbose", settings.VERBOSE)
self.tools: list[BaseTool] = self._tools(additional_tools)
self.llm: BaseLanguageModel = llm or self._choose_llm(**kwargs)
self.agent_executor: AgentExecutor = self._agent_executor()
self.agent_executor: AgentExecutor | None = None
self.input_files: list[File] = []
self.output_files: list[File] = []
self.code_log: list[tuple[str, str]] = []

@classmethod
def from_id(cls, session_id: UUID) -> "CodeInterpreterSession":
session = cls()
session.codebox = CodeBox.from_id(session_id)
return session

@property
def session_id(self) -> UUID | None:
return self.codebox.session_id

def start(self) -> SessionStatus:
return SessionStatus.from_codebox_status(self.codebox.start())
status = SessionStatus.from_codebox_status(self.codebox.start())
self.agent_executor = self._agent_executor()
return status

async def astart(self) -> SessionStatus:
return SessionStatus.from_codebox_status(await self.codebox.astart())
status = SessionStatus.from_codebox_status(await self.codebox.astart())
self.agent_executor = self._agent_executor()
return status

def _tools(self, additional_tools: list[BaseTool]) -> list[BaseTool]:
return additional_tools + [
Expand Down Expand Up @@ -124,15 +128,15 @@ def _choose_llm(
openai_api_key=openai_api_key,
max_retries=3,
request_timeout=60 * 3,
)
) # type: ignore
else:
return ChatOpenAI(
temperature=0.03,
model=model,
openai_api_key=openai_api_key,
max_retries=3,
request_timeout=60 * 3,
)
) # type: ignore
elif "claude" in model:
return ChatAnthropic(model=model)
else:
Expand Down Expand Up @@ -164,24 +168,33 @@ def _choose_agent(self) -> BaseSingleActionAgent:
)
)

def _agent_executor(self) -> AgentExecutor:
if self.history_storage == "file":
history = FileChatMessageHistory(
f"{os.environ['SESSION_FOLDER_STORAGE']}/{self.session_id}.json"
def _history_backend(self) -> BaseChatMessageHistory:
return (
CodeBoxChatMessageHistory(codebox=self.codebox)
if settings.HISTORY_BACKEND == "codebox"
else RedisChatMessageHistory(
session_id=str(self.session_id),
url=settings.REDIS_URL,
)
elif self.history_storage == "redis":
history = RedisChatMessageHistory(
session_id=self.session_id,
url=f"redis://{os.environ['RD_HOST']}:{os.environ['RD_PORT']}",
if settings.HISTORY_BACKEND == "redis"
else PostgresChatMessageHistory(
session_id=str(self.session_id),
connection_string=settings.POSTGRES_URL,
)
if settings.HISTORY_BACKEND == "postgres"
else ChatMessageHistory()
)

def _agent_executor(self) -> AgentExecutor:
return AgentExecutor.from_agent_and_tools(
agent=self._choose_agent(),
max_iterations=9,
tools=self.tools,
verbose=self.verbose,
memory=ConversationBufferMemory(
memory_key="chat_history", return_messages=True, chat_memory=history
memory_key="chat_history",
return_messages=True,
chat_memory=self._history_backend(),
),
)

Expand All @@ -204,7 +217,7 @@ def _run_handler(self, code: str):
raise TypeError("Expected output.content to be a string.")

if output.type == "image/png":
filename = f"image-{uuid.uuid4()}.png"
filename = f"image-{uuid4()}.png"
file_buffer = BytesIO(base64.b64decode(output.content))
file_buffer.name = filename
self.output_files.append(File(name=filename, content=file_buffer.read()))
Expand Down Expand Up @@ -251,7 +264,7 @@ async def _arun_handler(self, code: str):
raise TypeError("Expected output.content to be a string.")

if output.type == "image/png":
filename = f"image-{uuid.uuid4()}.png"
filename = f"image-{uuid4()}.png"
file_buffer = BytesIO(base64.b64decode(output.content))
file_buffer.name = filename
self.output_files.append(File(name=filename, content=file_buffer.read()))
Expand Down Expand Up @@ -375,6 +388,7 @@ def generate_response_sync(
user_request = UserRequest(content=user_msg, files=files)
try:
self._input_handler(user_request)
assert self.agent_executor, "Session not initialized."
response = self.agent_executor.run(input=user_request.content)
return self._output_handler(response)
except Exception as e:
Expand Down Expand Up @@ -418,6 +432,7 @@ async def agenerate_response(
user_request = UserRequest(content=user_msg, files=files)
try:
await self._ainput_handler(user_request)
assert self.agent_executor, "Session not initialized."
response = await self.agent_executor.arun(input=user_request.content)
return await self._aoutput_handler(response)
except Exception as e:
Expand Down

0 comments on commit f25d814

Please sign in to comment.