Skip to content

Commit

Permalink
Merge branch 'dev'
Browse files Browse the repository at this point in the history
Conflicts:
	codeinterpreterapi/session.py
	pyproject.toml
  • Loading branch information
shroominic committed Aug 2, 2023
2 parents c10cc2c + 977252a commit 4bb4006
Showing 21 changed files with 578 additions and 2,571 deletions.
9 changes: 6 additions & 3 deletions .env.example
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
OPENAI_API_KEY= # your openai api key (required)
CODEBOX_API_KEY= # your codebox api key (optional, required for production)
VERBOSE=False # set to True to enable verbose logging
# (required)
OPENAI_API_KEY=
# (optional, required for production)
# CODEBOX_API_KEY=
# (set True to enable logging)
VERBOSE=False
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -19,9 +19,12 @@ You can run everything local except the LLM using your own OpenAI API Key.
Get your OpenAI API Key [here](https://platform.openai.com/account/api-keys) and install the package.

```bash
pip install codeinterpreterapi
pip install "codeinterpreterapi[all]"
```

Everything for local experiments are installed with the `all` extra.
For deployments, you can use `pip install codeinterpreterapi` instead which does not install the additional dependencies.

## Usage

Make sure to set the `OPENAI_API_KEY` environment variable (or use a `.env` file)
2 changes: 1 addition & 1 deletion codeinterpreterapi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from codeinterpreterapi.session import CodeInterpreterSession
from codeinterpreterapi.schema import File
from codeinterpreterapi.schema import File
1 change: 1 addition & 0 deletions codeinterpreterapi/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .functions_agent import OpenAIFunctionsAgent
2 changes: 2 additions & 0 deletions codeinterpreterapi/agents/custom_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# TODO: override some methods of the ConversationalAgent class
# to improve the agent's performance
3 changes: 3 additions & 0 deletions codeinterpreterapi/chains/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .modifications_check import get_file_modifications
from .rm_dl_link import remove_download_link
from .extract_code import extract_python_code
45 changes: 45 additions & 0 deletions codeinterpreterapi/chains/extract_code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import json
from typing import List, Optional

from langchain.base_language import BaseLanguageModel
from langchain.chat_models.openai import ChatOpenAI
from langchain.chat_models.anthropic import ChatAnthropic
from langchain.schema import AIMessage, OutputParserException

# from codeinterpreterapi.prompts import extract_code_prompt


async def extract_python_code(
text: str,
llm: BaseLanguageModel,
retry: int = 2,
) -> Optional[str]:
pass


async def test():
llm = ChatAnthropic(model="claude-1.3") # type: ignore

code = \
"""
import matplotlib.pyplot as plt
x = list(range(1, 11))
y = [29, 39, 23, 32, 4, 43, 43, 23, 43, 77]
plt.plot(x, y, marker='o')
plt.xlabel('Index')
plt.ylabel('Value')
plt.title('Data Plot')
plt.show()
"""

print(await extract_python_code(code, llm))


if __name__ == "__main__":
import asyncio, dotenv
dotenv.load_dotenv()

asyncio.run(test())
63 changes: 32 additions & 31 deletions codeinterpreterapi/chains/modifications_check.py
Original file line number Diff line number Diff line change
@@ -3,9 +3,10 @@

from langchain.base_language import BaseLanguageModel
from langchain.chat_models.openai import ChatOpenAI
from langchain.chat_models.anthropic import ChatAnthropic
from langchain.schema import AIMessage, OutputParserException

from codeinterpreterapi.prompts import determine_modifications_function, determine_modifications_prompt
from codeinterpreterapi.prompts import determine_modifications_prompt


async def get_file_modifications(
@@ -15,44 +16,44 @@ async def get_file_modifications(
) -> Optional[List[str]]:
if retry < 1:
return None
messages = determine_modifications_prompt.format_prompt(code=code).to_messages()
message = await llm.apredict_messages(messages, functions=[determine_modifications_function])

if not isinstance(message, AIMessage):
raise OutputParserException("Expected an AIMessage")
prompt = determine_modifications_prompt.format(code=code)

function_call = message.additional_kwargs.get("function_call", None)
result = await llm.apredict(prompt, stop="```")

if function_call is None:

try:
result = json.loads(result)
except json.JSONDecodeError:
result = ""
if not result or not isinstance(result, dict) or "modifications" not in result:
return await get_file_modifications(code, llm, retry=retry - 1)
else:
function_call = json.loads(function_call["arguments"])
return function_call["modifications"]

return result["modifications"]


async def test():
llm = ChatOpenAI(model="gpt-3.5") # type: ignore

code = """
import matplotlib.pyplot as plt
x = list(range(1, 11))
y = [29, 39, 23, 32, 4, 43, 43, 23, 43, 77]
plt.plot(x, y, marker='o')
plt.xlabel('Index')
plt.ylabel('Value')
plt.title('Data Plot')
plt.show()
"""

llm = ChatAnthropic(model="claude-1.3") # type: ignore

code = \
"""
import matplotlib.pyplot as plt
x = list(range(1, 11))
y = [29, 39, 23, 32, 4, 43, 43, 23, 43, 77]
plt.plot(x, y, marker='o')
plt.xlabel('Index')
plt.ylabel('Value')
plt.title('Data Plot')
plt.show()
"""

print(await get_file_modifications(code, llm))


if __name__ == "__main__":
import asyncio
from dotenv import load_dotenv
load_dotenv()
import asyncio, dotenv
dotenv.load_dotenv()

asyncio.run(test())
File renamed without changes.
2 changes: 1 addition & 1 deletion codeinterpreterapi/prompts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .system_message import system_message as code_interpreter_system_message
from .modifications_check import determine_modifications_function, determine_modifications_prompt
from .modifications_check import determine_modifications_prompt
from .remove_dl_link import remove_dl_link_prompt
83 changes: 52 additions & 31 deletions codeinterpreterapi/prompts/modifications_check.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,56 @@
from langchain.prompts import PromptTemplate

from langchain.schema import SystemMessage
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate


determine_modifications_prompt = ChatPromptTemplate(
determine_modifications_prompt = PromptTemplate(
input_variables=["code"],
messages=[
SystemMessage(
content="The user will input some code and you will need to determine if the code makes any changes to the file system. \n"
"With changes it means creating new files or modifying exsisting ones.\n"
"Answer with a function call `determine_modifications` and list them inside.\n"
"If the code does not make any changes to the file system, still answer with the function call but return an empty list.\n",
),
HumanMessagePromptTemplate.from_template("{code}"),
],
template=
"The user will input some code and you need to determine if the code makes any changes to the file system. \n"
"With changes it means creating new files or modifying exsisting ones.\n"
"Format your answer as JSON inside a codeblock with a list of filenames that are modified by the code.\n"
"If the code does not make any changes to the file system, return an empty list.\n\n"
"Determine modifications:\n"
"```python\n"
"import matplotlib.pyplot as plt\n"
"import numpy as np\n\n"
"t = np.arange(0.0, 4.0*np.pi, 0.1)\n\n"
"s = np.sin(t)\n\n"
"fig, ax = plt.subplots()\n\n"
"ax.plot(t, s)\n\n"
"ax.set(xlabel=\"time (s)\", ylabel=\"sin(t)\",\n"
" title=\"Simple Sin Wave\")\n"
"ax.grid()\n\n"
"plt.savefig(\"sin_wave.png\")\n"
"```\n\n"
"Answer:\n"
"```json\n"
"{{\n"
" \"modifications\": [\"sin_wave.png\"]\n"
"}}\n"
"```\n\n"
"Determine modifications:\n"
"```python\n"
"import matplotlib.pyplot as plt\n"
"import numpy as np\n\n"
"x = np.linspace(0, 10, 100)\n"
"y = x**2\n\n"
"plt.figure(figsize=(8, 6))\n"
"plt.plot(x, y)\n"
"plt.title(\"Simple Quadratic Function\")\n"
"plt.xlabel(\"x\")\n"
"plt.ylabel(\"y = x^2\")\n"
"plt.grid(True)\n"
"plt.show()\n"
"```\n\n"
"Answer:\n"
"```json\n"
"{{\n"
" \"modifications\": []\n"
"}}\n"
"```\n\n"
"Determine modifications:\n"
"```python\n"
"{code}\n"
"```\n\n"
"Answer:\n"
"```json\n",
)


determine_modifications_function = {
"name": "determine_modifications",
"description": "Based on code of the user determine if the code makes any changes to the file system. \n"
"With changes it means creating new files or modifying exsisting ones.\n",
"parameters": {
"type": "object",
"properties": {
"modifications": {
"type": "array",
"items": {"type": "string"},
"description": "The filenames that are modified by the code.",
},
},
"required": ["modifications"],
},
}
2 changes: 1 addition & 1 deletion codeinterpreterapi/schema/file.py
Original file line number Diff line number Diff line change
@@ -68,7 +68,7 @@ def show_image(self):

# If the shell is ZMQInteractiveShell, it means we're in a Jupyter notebook or similar.
if shell == 'ZMQInteractiveShell':
from IPython.display import display
from IPython.display import display # type: ignore
display(img)
else:
# We're not in a Jupyter notebook.
5 changes: 5 additions & 0 deletions codeinterpreterapi/schema/response.py
Original file line number Diff line number Diff line change
@@ -17,6 +17,11 @@ class CodeInterpreterResponse(AIMessage):
# final_code: str = "" TODO: implement
# final_output: str = "" TODO: implement

def show(self):
print("AI: ", self.content)
for file in self.files:
file.show_image()

def __str__(self):
return self.content

Loading

0 comments on commit 4bb4006

Please sign in to comment.