Skip to content

Commit

Permalink
✨ add booth sync/async versions
Browse files Browse the repository at this point in the history
  • Loading branch information
shroominic committed Aug 8, 2023
1 parent 856cd85 commit 31b239b
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 12 deletions.
6 changes: 4 additions & 2 deletions codeinterpreterapi/chains/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from .extract_code import extract_python_code
from .modifications_check import get_file_modifications
from .rm_dl_link import remove_download_link
from .modifications_check import aget_file_modifications, get_file_modifications
from .rm_dl_link import aremove_download_link, remove_download_link

__all__ = [
"extract_python_code",
"get_file_modifications",
"aget_file_modifications",
"remove_download_link",
"aremove_download_link",
]
29 changes: 25 additions & 4 deletions codeinterpreterapi/chains/modifications_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,28 @@
from codeinterpreterapi.prompts import determine_modifications_prompt


async def get_file_modifications(
def get_file_modifications(
code: str,
llm: BaseLanguageModel,
retry: int = 2,
) -> Optional[List[str]]:
if retry < 1:
return None

prompt = determine_modifications_prompt.format(code=code)

result = llm.predict(prompt, stop="```")

try:
result = json.loads(result)
except json.JSONDecodeError:
result = ""
if not result or not isinstance(result, dict) or "modifications" not in result:
return get_file_modifications(code, llm, retry=retry - 1)
return result["modifications"]


async def aget_file_modifications(
code: str,
llm: BaseLanguageModel,
retry: int = 2,
Expand All @@ -24,12 +45,12 @@ async def get_file_modifications(
except json.JSONDecodeError:
result = ""
if not result or not isinstance(result, dict) or "modifications" not in result:
return await get_file_modifications(code, llm, retry=retry - 1)
return await aget_file_modifications(code, llm, retry=retry - 1)
return result["modifications"]


async def test():
llm = ChatAnthropic(model="claude-1.3") # type: ignore
llm = ChatAnthropic(model="claude-2") # type: ignore

code = """
import matplotlib.pyplot as plt
Expand All @@ -45,7 +66,7 @@ async def test():
plt.show()
"""

print(await get_file_modifications(code, llm))
print(get_file_modifications(code, llm))


if __name__ == "__main__":
Expand Down
25 changes: 19 additions & 6 deletions codeinterpreterapi/chains/rm_dl_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,22 @@
from codeinterpreterapi.prompts import remove_dl_link_prompt


async def remove_download_link(
def remove_download_link(
input_response: str,
llm: BaseLanguageModel,
) -> str:
messages = remove_dl_link_prompt.format_prompt(
input_response=input_response
).to_messages()
message = llm.predict_messages(messages)

if not isinstance(message, AIMessage):
raise OutputParserException("Expected an AIMessage")

return message.content


async def aremove_download_link(
input_response: str,
llm: BaseLanguageModel,
) -> str:
Expand All @@ -20,21 +35,19 @@ async def remove_download_link(
return message.content


async def test():
def test():
llm = ChatOpenAI(model="gpt-3.5-turbo-0613") # type: ignore

example = (
"I have created the plot to your dataset.\n\n"
"Link to the file [here](sandbox:/plot.png)."
)
print(await remove_download_link(example, llm))
print(remove_download_link(example, llm))


if __name__ == "__main__":
import asyncio

from dotenv import load_dotenv

load_dotenv()

asyncio.run(test())
test()

0 comments on commit 31b239b

Please sign in to comment.