From fdf1630688f25e6da4e3a543000231e0cf2a7a06 Mon Sep 17 00:00:00 2001 From: Mohammad Mohtashim <45242107+keenborder786@users.noreply.github.com> Date: Mon, 7 Aug 2023 12:08:11 +0500 Subject: [PATCH] LangChainLLMs cache_obj update (#508) * [refactor]: Explicity input for cache_obj for langchain_models allowing for passing in custom cache_obj once rather than passing again and again during inference * [tests]: Updated tests for langchain_models --- gptcache/adapter/langchain_models.py | 6 ++++++ .../adapter/test_langchain_models.py | 18 +++++++++--------- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/gptcache/adapter/langchain_models.py b/gptcache/adapter/langchain_models.py index 945d9eb2..4113c543 100644 --- a/gptcache/adapter/langchain_models.py +++ b/gptcache/adapter/langchain_models.py @@ -4,6 +4,7 @@ from gptcache.manager.scalar_data.base import Answer, DataType from gptcache.session import Session from gptcache.utils import import_pydantic, import_langchain +from gptcache.core import Cache,cache import_pydantic() import_langchain() @@ -50,6 +51,7 @@ class LangChainLLMs(LLM, BaseModel): llm: Any session: Session = None + cache_obj: Cache = cache tmp_args: Any = None @property @@ -80,6 +82,7 @@ def _call( _update_cache_callback, prompt=prompt, stop=stop, + cache_obj=self.cache_obj, session=session, **self.tmp_args, ) @@ -151,6 +154,7 @@ def _llm_type(self) -> str: chat: Any session: Session = None + cache_obj: Cache = cache tmp_args: Any = None def _generate( @@ -170,6 +174,7 @@ def _generate( _update_cache_msg_callback, messages=messages, stop=stop, + cache_obj=self.cache_obj, session=session, run_manager=run_manager, **self.tmp_args, @@ -193,6 +198,7 @@ async def _agenerate( _update_cache_msg_callback, messages=messages, stop=stop, + cache_obj=self.cache_obj, session=session, run_manager=run_manager, **self.tmp_args, diff --git a/tests/unit_tests/adapter/test_langchain_models.py b/tests/unit_tests/adapter/test_langchain_models.py index 8cb40f5d..f7e8e58e 100644 --- a/tests/unit_tests/adapter/test_langchain_models.py +++ b/tests/unit_tests/adapter/test_langchain_models.py @@ -30,7 +30,7 @@ def test_langchain_llms(): os.environ["OPENAI_API_KEY"] = "API" langchain_openai = OpenAI(model_name="text-ada-001") - llm = LangChainLLMs(llm=langchain_openai) + llm = LangChainLLMs(llm=langchain_openai,cache_obj=llm_cache) assert str(langchain_openai) == str(llm) with patch("openai.Completion.create") as mock_create: @@ -53,10 +53,10 @@ def test_langchain_llms(): } } - answer = llm(prompt=question, cache_obj=llm_cache) + answer = llm(prompt=question) assert expect_answer == answer - answer = llm(prompt=question, cache_obj=llm_cache) + answer = llm(prompt=question) assert expect_answer == answer @@ -77,7 +77,7 @@ def test_langchain_chats(): os.environ["OPENAI_API_KEY"] = "API" langchain_openai = ChatOpenAI(temperature=0) - chat = LangChainChat(chat=langchain_openai) + chat = LangChainChat(chat=langchain_openai,cache_obj=llm_cache) assert chat.get_num_tokens("hello") == langchain_openai.get_num_tokens("hello") assert chat.get_num_tokens_from_messages(messages=[HumanMessage(content="test_langchain_chats")]) \ @@ -104,7 +104,7 @@ def test_langchain_chats(): } } - answer = chat(messages=question, cache_obj=llm_cache) + answer = chat(messages=question) assert answer == _cache_msg_data_convert(msg).generations[0].message with patch("openai.ChatCompletion.acreate") as mock_create: @@ -128,16 +128,16 @@ def test_langchain_chats(): } } - answer = asyncio.run(chat.agenerate([question2], cache_obj=llm_cache)) + answer = asyncio.run(chat.agenerate([question2])) assert answer.generations[0][0].text == _cache_msg_data_convert(msg).generations[0].text - answer = chat(messages=question, cache_obj=llm_cache) + answer = chat(messages=question) assert answer == _cache_msg_data_convert(msg).generations[0].message - answer = asyncio.run(chat.agenerate([question], cache_obj=llm_cache)) + answer = asyncio.run(chat.agenerate([question])) assert answer.generations[0][0].text == _cache_msg_data_convert(msg).generations[0].text - answer = asyncio.run(chat.agenerate([question2], cache_obj=llm_cache)) + answer = asyncio.run(chat.agenerate([question2])) assert answer.generations[0][0].text == _cache_msg_data_convert(msg).generations[0].text