Skip to content

Commit

Permalink
🧼 cleanup and move vars/config to settings
Browse files Browse the repository at this point in the history
+ openrouter
  • Loading branch information
shroominic committed Aug 28, 2023
1 parent 7ef01d8 commit 816bfab
Showing 1 changed file with 70 additions and 62 deletions.
132 changes: 70 additions & 62 deletions codeinterpreterapi/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import re
import traceback
from io import BytesIO
from os import getenv
from typing import Optional
from uuid import UUID, uuid4

Expand All @@ -24,7 +23,7 @@
RedisChatMessageHistory,
)
from langchain.prompts.chat import MessagesPlaceholder
from langchain.schema import BaseChatMessageHistory, SystemMessage
from langchain.schema import BaseChatMessageHistory
from langchain.tools import BaseTool, StructuredTool

from codeinterpreterapi.agents import OpenAIFunctionsAgent
Expand All @@ -37,7 +36,6 @@
from codeinterpreterapi.chat_history import CodeBoxChatMessageHistory
from codeinterpreterapi.config import settings
from codeinterpreterapi.parser import CodeAgentOutputParser, CodeChatAgentOutputParser
from codeinterpreterapi.prompts import code_interpreter_system_message
from codeinterpreterapi.schema import (
CodeInput,
CodeInterpreterResponse,
Expand All @@ -47,21 +45,25 @@
)


def _handle_deprecated_kwargs(kwargs: dict) -> None:
settings.MODEL = kwargs.get("max_retry", settings.MAX_RETRY)
settings.OPENAI_API_KEY = kwargs.get("openai_api_key", settings.OPENAI_API_KEY)
settings.SYSTEM_MESSAGE = kwargs.get("system_message", settings.SYSTEM_MESSAGE)
settings.MAX_ITERATIONS = kwargs.get("max_iterations", settings.MAX_ITERATIONS)


class CodeInterpreterSession:
def __init__(
self,
llm: Optional[BaseLanguageModel] = None,
system_message: SystemMessage = code_interpreter_system_message,
max_iterations: int = 9,
additional_tools: list[BaseTool] = [],
**kwargs,
) -> None:
self.codebox = CodeBox()
_handle_deprecated_kwargs(kwargs)
self.codebox = CodeBox(requirements=settings.CUSTOM_PACKAGES)
self.verbose = kwargs.get("verbose", settings.VERBOSE)
self.tools: list[BaseTool] = self._tools(additional_tools)
self.llm: BaseLanguageModel = llm or self._choose_llm(**kwargs)
self.max_iterations = max_iterations
self.system_message = system_message
self.agent_executor: Optional[AgentExecutor] = None
self.input_files: list[File] = []
self.output_files: list[File] = []
Expand All @@ -88,72 +90,82 @@ async def astart(self) -> SessionStatus:
self.agent_executor = self._agent_executor()
return status

async def ainstall_additional_packages(self) -> None:
for package in settings.CUSTOM_PACKAGES:
# check if already installed
if await self.codebox.arun(f"import {package}"):
continue
if settings.VERBOSE:
print(f"Installing {package}...")
await self.codebox.ainstall(package)

def _tools(self, additional_tools: list[BaseTool]) -> list[BaseTool]:
return additional_tools + [
StructuredTool(
name="python",
description="Input a string of code to a ipython interpreter. "
"Write the entire code in a single string. This string can "
"be really long, so you can use the `;` character to split lines. "
"Variables are preserved between runs. ",
"Variables are preserved between runs. "
+ (
f"You have access to all default python packages + {settings.CUSTOM_PACKAGES} "
)
if settings.CUSTOM_PACKAGES
else "", # TODO: or include this in the system message
func=self._run_handler,
coroutine=self._arun_handler,
args_schema=CodeInput, # type: ignore
),
]

def _choose_llm(
self, model: str = "gpt-4", openai_api_key: Optional[str] = None, **kwargs
) -> BaseChatModel:
if "gpt" in model:
openai_api_key = (
openai_api_key
or settings.OPENAI_API_KEY
or getenv("OPENAI_API_KEY", None)
def _choose_llm(self) -> BaseChatModel:
if (
settings.AZURE_API_KEY
and settings.AZURE_API_BASE
and settings.AZURE_API_VERSION
and settings.AZURE_DEPLOYMENT_NAME
):
return AzureChatOpenAI(
temperature=0.03,
openai_api_base=settings.AZURE_API_BASE,
openai_api_version=settings.AZURE_API_VERSION,
deployment_name=settings.AZURE_DEPLOYMENT_NAME,
openai_api_key=settings.AZURE_API_KEY,
max_retries=settings.MAX_RETRY,
request_timeout=settings.REQUEST_TIMEOUT,
) # type: ignore
elif settings.OPENAI_API_KEY:
return ChatOpenAI(
model=settings.MODEL,
openai_api_key=settings.OPENAI_API_KEY,
request_timeout=settings.REQUEST_TIMEOUT,
temperature=settings.TEMPERATURE,
max_retries=settings.MAX_RETRY,
) # type: ignore
elif settings.ANTHROPIC_API_KEY:
if "claude" not in settings.MODEL:
print("Please set the claude model in the settings.")
return ChatAnthropic(
model=settings.MODEL,
temperature=settings.TEMPERATURE,
anthropic_api_key=settings.ANTHROPIC_API_KEY,
)
elif settings.OPENROUTER_API_KEY:
return ChatOpenAI(
model=settings.OPENROUTER_DEFAULT_CHAT_MODEL,
temperature=settings.TEMPERATURE,
openai_api_key=settings.OPENROUTER_API_KEY,
openai_api_base=settings.OPENROUTER_API_BASE,
)
if openai_api_key is None:
raise ValueError(
"OpenAI API key missing. Set OPENAI_API_KEY env variable "
"or pass `openai_api_key` to session."
)
openai_api_version = getenv("OPENAI_API_VERSION")
openai_api_base = getenv("OPENAI_API_BASE")
deployment_name = getenv("DEPLOYMENT_NAME")
openapi_type = getenv("OPENAI_API_TYPE")
if (
openapi_type == "azure"
and openai_api_version
and openai_api_base
and deployment_name
):
return AzureChatOpenAI(
temperature=0.03,
openai_api_base=openai_api_base,
openai_api_version=openai_api_version,
deployment_name=deployment_name,
openai_api_key=openai_api_key,
max_retries=3,
request_timeout=60 * 3,
) # type: ignore
else:
return ChatOpenAI(
temperature=0.03,
model=model,
openai_api_key=openai_api_key,
max_retries=3,
request_timeout=60 * 3,
) # type: ignore
elif "claude" in model:
return ChatAnthropic(model=model)
else:
raise ValueError(f"Unknown model: {model} (expected gpt or claude model)")
raise ValueError("Please set the API key for the LLM you want to use.")

def _choose_agent(self) -> BaseSingleActionAgent:
return (
OpenAIFunctionsAgent.from_llm_and_tools(
llm=self.llm,
tools=self.tools,
system_message=self.system_message,
system_message=settings.SYSTEM_MESSAGE,
extra_prompt_messages=[
MessagesPlaceholder(variable_name="chat_history")
],
Expand All @@ -162,14 +174,14 @@ def _choose_agent(self) -> BaseSingleActionAgent:
else ConversationalChatAgent.from_llm_and_tools(
llm=self.llm,
tools=self.tools,
system_message=code_interpreter_system_message.content,
system_message=settings.SYSTEM_MESSAGE.content,
output_parser=CodeChatAgentOutputParser(self.llm),
)
if isinstance(self.llm, BaseChatModel)
else ConversationalAgent.from_llm_and_tools(
llm=self.llm,
tools=self.tools,
prefix=code_interpreter_system_message.content,
prefix=settings.SYSTEM_MESSAGE.content,
output_parser=CodeAgentOutputParser(),
)
)
Expand All @@ -194,7 +206,7 @@ def _history_backend(self) -> BaseChatMessageHistory:
def _agent_executor(self) -> AgentExecutor:
return AgentExecutor.from_agent_and_tools(
agent=self._choose_agent(),
max_iterations=self.max_iterations,
max_iterations=settings.MAX_ITERATIONS,
tools=self.tools,
verbose=self.verbose,
memory=ConversationBufferMemory(
Expand Down Expand Up @@ -388,7 +400,6 @@ def generate_response_sync(
self,
user_msg: str,
files: list[File] = [],
detailed_error: bool = False,
) -> CodeInterpreterResponse:
"""Generate a Code Interpreter response based on the user's input."""
user_request = UserRequest(content=user_msg, files=files)
Expand All @@ -400,7 +411,7 @@ def generate_response_sync(
except Exception as e:
if self.verbose:
traceback.print_exc()
if detailed_error:
if settings.DETAILED_ERROR:
return CodeInterpreterResponse(
content="Error in CodeInterpreterSession: "
f"{e.__class__.__name__} - {e}"
Expand All @@ -415,7 +426,6 @@ async def generate_response(
self,
user_msg: str,
files: list[File] = [],
detailed_error: bool = False,
) -> CodeInterpreterResponse:
print(
"DEPRECATION WARNING: Use agenerate_response for async generation.\n"
Expand All @@ -425,14 +435,12 @@ async def generate_response(
return await self.agenerate_response(
user_msg=user_msg,
files=files,
detailed_error=detailed_error,
)

async def agenerate_response(
self,
user_msg: str,
files: list[File] = [],
detailed_error: bool = False,
) -> CodeInterpreterResponse:
"""Generate a Code Interpreter response based on the user's input."""
user_request = UserRequest(content=user_msg, files=files)
Expand All @@ -444,7 +452,7 @@ async def agenerate_response(
except Exception as e:
if self.verbose:
traceback.print_exc()
if detailed_error:
if settings.DETAILED_ERROR:
return CodeInterpreterResponse(
content="Error in CodeInterpreterSession: "
f"{e.__class__.__name__} - {e}"
Expand Down

0 comments on commit 816bfab

Please sign in to comment.