Skip to content

Commit

Permalink
made it possible to pass additional tools to CodeInterpreterSession
Browse files Browse the repository at this point in the history
  • Loading branch information
JakubGetMe authored and Valian committed Jul 18, 2023
1 parent 44422c1 commit 12a62f3
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 5 deletions.
17 changes: 12 additions & 5 deletions codeinterpreterapi/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Optional
from codeboxapi import CodeBox # type: ignore
from codeboxapi.schema import CodeBoxOutput # type: ignore
from langchain.tools import StructuredTool
from langchain.tools import StructuredTool, BaseTool
from langchain.chat_models import ChatOpenAI
from langchain.chat_models.base import BaseChatModel
from langchain.prompts.chat import MessagesPlaceholder
Expand All @@ -20,10 +20,16 @@


class CodeInterpreterSession:
def __init__(self, model=None, openai_api_key=settings.OPENAI_API_KEY, verbose=settings.VERBOSE) -> None:
def __init__(
self,
model=None,
openai_api_key=settings.OPENAI_API_KEY,
verbose=settings.VERBOSE,
tools: list[BaseTool] = None
) -> None:
self.codebox = CodeBox()
self.verbose = verbose
self.tools: list[StructuredTool] = self._tools()
self.tools: list[BaseTool] = self._tools(tools)
self.llm: BaseChatModel = self._llm(model, openai_api_key)
self.agent_executor: AgentExecutor = self._agent_executor()
self.input_files: list[File] = []
Expand All @@ -32,8 +38,9 @@ def __init__(self, model=None, openai_api_key=settings.OPENAI_API_KEY, verbose=s
async def astart(self) -> None:
await self.codebox.astart()

def _tools(self) -> list[StructuredTool]:
return [
def _tools(self, additional_tools: list[BaseTool] = None) -> list[BaseTool]:
additional_tools = additional_tools or []
return additional_tools + [
StructuredTool(
name="python",
description=
Expand Down
49 changes: 49 additions & 0 deletions examples/use_additional_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""
The exciting part about this example is
that the code interpreter has internet access
so it can download the bitcoin chart from yahoo finance
and plot it for you
"""
import csv
import io
from datetime import datetime
from typing import Any

from langchain.tools import tool, BaseTool

from codeinterpreterapi import CodeInterpreterSession


class ExampleKnowledgeBaseTool(BaseTool):
name = "salary_database"
description = "Use to get salary data of company employees"

def _run(self, *args, **kwargs):
raise NotImplementedError()

async def _arun(self, *args, **kwargs: Any) -> Any:
f = io.StringIO()
writer = csv.writer(f)
writer.writerow(['month', 'employee', 'salary'])
writer.writerow(['march 2022', 'Jan', '1200'])
writer.writerow(['march 2022', 'Ola', '1500'])
writer.writerow(['april 2022', 'Jan', '1800'])
writer.writerow(['april 2022', 'Ola', '2000'])
return f.getvalue()


async def main():
async with CodeInterpreterSession(tools=[ExampleKnowledgeBaseTool()]) as session:
response = await session.generate_response(
f"Plot chart of company employee salaries"
)

print("AI: ", response.content)
for file in response.files:
file.show_image()


if __name__ == "__main__":
import asyncio

asyncio.run(main())

0 comments on commit 12a62f3

Please sign in to comment.