Skip to content

Commit

Permalink
🔧 Refactor code interpreter API to use langchain_core and langchain_o…
Browse files Browse the repository at this point in the history
…penai libraries
  • Loading branch information
shroominic committed Apr 4, 2024
1 parent 0953050 commit ed61a2b
Show file tree
Hide file tree
Showing 12 changed files with 62 additions and 63 deletions.
5 changes: 2 additions & 3 deletions src/codeinterpreterapi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from . import _patch_parser # noqa

from codeinterpreterapi.config import settings
from codeinterpreterapi.schema import File
from codeinterpreterapi.session import CodeInterpreterSession

from ._patch_parser import patch

patch()

__all__ = [
"CodeInterpreterSession",
Expand Down
16 changes: 5 additions & 11 deletions src/codeinterpreterapi/_patch_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,13 @@
from json import JSONDecodeError
from typing import List, Union

from langchain.agents.agent import AgentOutputParser
from langchain.agents.openai_functions_agent import base
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import (
AIMessage,
BaseMessage,
)
from langchain_core.messages import AIMessage, BaseMessage
from langchain_core.outputs import ChatGeneration, Generation

from langchain.agents.agent import AgentOutputParser


class OpenAIFunctionsAgentOutputParser(AgentOutputParser):
"""Parses a message into agent action/finish.
Expand Down Expand Up @@ -102,8 +99,5 @@ def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
raise ValueError("Can only parse messages")


def patch() -> None:
"""Patch the parser."""
from langchain.agents import openai_functions_agent

openai_functions_agent.OpenAIFunctionsAgentOutputParser = OpenAIFunctionsAgentOutputParser # type: ignore
# patch
base.OpenAIFunctionsAgentOutputParser = OpenAIFunctionsAgentOutputParser # type: ignore
7 changes: 4 additions & 3 deletions src/codeinterpreterapi/chains/extract_code.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from langchain.base_language import BaseLanguageModel
from langchain.chat_models.anthropic import ChatAnthropic
from langchain_core.language_models import BaseLanguageModel


def extract_python_code(
Expand All @@ -19,7 +18,9 @@ async def aextract_python_code(


async def test() -> None:
llm = ChatAnthropic(model="claude-1.3") # type: ignore
from langchain_openai import ChatOpenAI

llm = ChatOpenAI()

code = """
import matplotlib.pyplot as plt
Expand Down
19 changes: 10 additions & 9 deletions src/codeinterpreterapi/chains/modifications_check.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,25 @@
import json
from typing import List, Optional

from langchain.base_language import BaseLanguageModel
from langchain.chat_models.anthropic import ChatAnthropic
from langchain_core.language_models import BaseLanguageModel

from codeinterpreterapi.prompts import determine_modifications_prompt


def get_file_modifications(
code: str,
llm: BaseLanguageModel,
retry: int = 2,
retry: int = 4,
) -> Optional[List[str]]:
if retry < 1:
return None

prompt = determine_modifications_prompt.format(code=code)

result = llm.predict(prompt, stop="```")
result = llm.invoke(prompt)

try:
result = json.loads(result)
result = json.loads(result.content)
except json.JSONDecodeError:
result = ""
if not result or not isinstance(result, dict) or "modifications" not in result:
Expand All @@ -31,17 +30,17 @@ def get_file_modifications(
async def aget_file_modifications(
code: str,
llm: BaseLanguageModel,
retry: int = 2,
retry: int = 4,
) -> Optional[List[str]]:
if retry < 1:
return None

prompt = determine_modifications_prompt.format(code=code)

result = await llm.apredict(prompt, stop="```")
result = await llm.ainvoke(prompt)

try:
result = json.loads(result)
result = json.loads(result.content)
except json.JSONDecodeError:
result = ""
if not result or not isinstance(result, dict) or "modifications" not in result:
Expand All @@ -50,7 +49,9 @@ async def aget_file_modifications(


async def test() -> None:
llm = ChatAnthropic(model="claude-2") # type: ignore
from langchain_openai import ChatOpenAI

llm = ChatOpenAI()

code = """
import matplotlib.pyplot as plt
Expand Down
11 changes: 6 additions & 5 deletions src/codeinterpreterapi/chains/rm_dl_link.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from langchain.base_language import BaseLanguageModel
from langchain.chat_models.openai import ChatOpenAI
from langchain.schema import AIMessage, OutputParserException
from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import AIMessage
from langchain_openai import ChatOpenAI

from codeinterpreterapi.prompts import remove_dl_link_prompt

Expand All @@ -12,7 +13,7 @@ def remove_download_link(
messages = remove_dl_link_prompt.format_prompt(
input_response=input_response
).to_messages()
message = llm.predict_messages(messages)
message = llm.invoke(messages)

if not isinstance(message, AIMessage):
raise OutputParserException("Expected an AIMessage")
Expand All @@ -28,7 +29,7 @@ async def aremove_download_link(
messages = remove_dl_link_prompt.format_prompt(
input_response=input_response
).to_messages()
message = await llm.apredict_messages(messages)
message = await llm.ainvoke(messages)

if not isinstance(message, AIMessage):
raise OutputParserException("Expected an AIMessage")
Expand Down
7 changes: 3 additions & 4 deletions src/codeinterpreterapi/chat_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
import json
from typing import List

from codeboxapi import CodeBox # type: ignore
from langchain.schema import BaseChatMessageHistory
from langchain.schema.messages import BaseMessage, messages_from_dict, messages_to_dict
from codeboxapi import CodeBox
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage, messages_from_dict, messages_to_dict


# TODO: This is probably not efficient, but it works for now.
class CodeBoxChatMessageHistory(BaseChatMessageHistory):
"""
Chat message history that stores history inside the codebox.
Expand Down
16 changes: 8 additions & 8 deletions src/codeinterpreterapi/config.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
from typing import Optional

from dotenv import load_dotenv
from langchain.pydantic_v1 import BaseSettings, SecretStr
from langchain.schema import SystemMessage
from langchain_core.messages import SystemMessage
from langchain_core.pydantic_v1 import BaseSettings, SecretStr

from codeinterpreterapi.prompts import code_interpreter_system_message

# .env file
load_dotenv(dotenv_path="./.env")


class CodeInterpreterAPISettings(BaseSettings):
"""
Expand All @@ -18,8 +14,8 @@ class CodeInterpreterAPISettings(BaseSettings):
DEBUG: bool = False

# Models
OPENAI_API_KEY: Optional[str] = None
AZURE_API_KEY: Optional[str] = None
OPENAI_API_KEY: Optional[SecretStr] = None
AZURE_OPENAI_API_KEY: Optional[SecretStr] = None
AZURE_API_BASE: Optional[str] = None
AZURE_API_VERSION: Optional[str] = None
AZURE_DEPLOYMENT_NAME: Optional[str] = None
Expand All @@ -46,5 +42,9 @@ class CodeInterpreterAPISettings(BaseSettings):
# deprecated
VERBOSE: bool = DEBUG

class Config:
env_file = "./.env"
extra = "ignore"


settings = CodeInterpreterAPISettings()
2 changes: 1 addition & 1 deletion src/codeinterpreterapi/prompts/modifications_check.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from langchain.prompts import PromptTemplate
from langchain_core.prompts import PromptTemplate

determine_modifications_prompt = PromptTemplate(
input_variables=["code"],
Expand Down
4 changes: 2 additions & 2 deletions src/codeinterpreterapi/prompts/remove_dl_link.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain.schema import AIMessage, HumanMessage, SystemMessage
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate

remove_dl_link_prompt = ChatPromptTemplate(
input_variables=["input_response"],
Expand Down
2 changes: 1 addition & 1 deletion src/codeinterpreterapi/prompts/system_message.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from langchain.schema import SystemMessage
from langchain_core.messages import SystemMessage

system_message = SystemMessage(
content="""
Expand Down
4 changes: 2 additions & 2 deletions src/codeinterpreterapi/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from typing import Any

from codeboxapi.schema import CodeBoxStatus
from langchain.pydantic_v1 import BaseModel
from langchain.schema import AIMessage, HumanMessage
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.pydantic_v1 import BaseModel


class File(BaseModel):
Expand Down
32 changes: 18 additions & 14 deletions src/codeinterpreterapi/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,19 @@
ConversationalChatAgent,
)
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import Callbacks
from langchain.chat_models import AzureChatOpenAI, ChatAnthropic, ChatOpenAI
from langchain.chat_models.base import BaseChatModel
from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_message_histories import (
ChatMessageHistory,
from langchain.memory.buffer import ConversationBufferMemory
from langchain_community.chat_message_histories.in_memory import ChatMessageHistory
from langchain_community.chat_message_histories.postgres import (
PostgresChatMessageHistory,
RedisChatMessageHistory,
)
from langchain.prompts.chat import MessagesPlaceholder
from langchain.schema import BaseChatMessageHistory
from langchain.tools import BaseTool, StructuredTool
from langchain_community.chat_message_histories.redis import RedisChatMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts.chat import MessagesPlaceholder
from langchain_core.tools import BaseTool, StructuredTool
from langchain_openai import AzureChatOpenAI, ChatOpenAI

from codeinterpreterapi.chains import (
aget_file_modifications,
Expand Down Expand Up @@ -114,7 +114,8 @@ def _tools(self, additional_tools: list[BaseTool]) -> list[BaseTool]:
"Variables are preserved between runs. "
+ (
(
f"You can use all default python packages specifically also these: {settings.CUSTOM_PACKAGES}"
"You can use all default python packages "
f"specifically also these: {settings.CUSTOM_PACKAGES}"
)
if settings.CUSTOM_PACKAGES
else ""
Expand All @@ -127,7 +128,7 @@ def _tools(self, additional_tools: list[BaseTool]) -> list[BaseTool]:

def _choose_llm(self) -> BaseChatModel:
if (
settings.AZURE_API_KEY
settings.AZURE_OPENAI_API_KEY
and settings.AZURE_API_BASE
and settings.AZURE_API_VERSION
and settings.AZURE_DEPLOYMENT_NAME
Expand All @@ -138,12 +139,13 @@ def _choose_llm(self) -> BaseChatModel:
base_url=settings.AZURE_API_BASE,
api_version=settings.AZURE_API_VERSION,
azure_deployment=settings.AZURE_DEPLOYMENT_NAME,
api_key=settings.AZURE_API_KEY,
api_key=settings.AZURE_OPENAI_API_KEY,
max_retries=settings.MAX_RETRY,
timeout=settings.REQUEST_TIMEOUT,
) # type: ignore
if settings.OPENAI_API_KEY:
self.log("Using Chat OpenAI")
from langchain_openai import ChatOpenAI

return ChatOpenAI(
model=settings.MODEL,
api_key=settings.OPENAI_API_KEY,
Expand All @@ -152,6 +154,8 @@ def _choose_llm(self) -> BaseChatModel:
max_retries=settings.MAX_RETRY,
) # type: ignore
if settings.ANTHROPIC_API_KEY:
from langchain_anthropic import ChatAnthropic # type: ignore

if "claude" not in settings.MODEL:
print("Please set the claude model in the settings.")
self.log("Using Chat Anthropic")
Expand All @@ -172,7 +176,7 @@ def _choose_agent(self) -> BaseSingleActionAgent:
MessagesPlaceholder(variable_name="chat_history")
],
)
if isinstance(self.llm, ChatOpenAI)
if isinstance(self.llm, ChatOpenAI) or isinstance(self.llm, AzureChatOpenAI)
else ConversationalChatAgent.from_llm_and_tools(
llm=self.llm,
tools=self.tools,
Expand Down

0 comments on commit ed61a2b

Please sign in to comment.