Skip to content

Commit

Permalink
LangChainLLMs cache_obj update (#508)
Browse files Browse the repository at this point in the history
* [refactor]: Explicity input for cache_obj for langchain_models allowing for passing in custom cache_obj once rather than passing again and again during inference

* [tests]: Updated tests for langchain_models
  • Loading branch information
keenborder786 authored Aug 7, 2023
1 parent 4a63925 commit fdf1630
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
6 changes: 6 additions & 0 deletions gptcache/adapter/langchain_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from gptcache.manager.scalar_data.base import Answer, DataType
from gptcache.session import Session
from gptcache.utils import import_pydantic, import_langchain
from gptcache.core import Cache,cache

import_pydantic()
import_langchain()
Expand Down Expand Up @@ -50,6 +51,7 @@ class LangChainLLMs(LLM, BaseModel):

llm: Any
session: Session = None
cache_obj: Cache = cache
tmp_args: Any = None

@property
Expand Down Expand Up @@ -80,6 +82,7 @@ def _call(
_update_cache_callback,
prompt=prompt,
stop=stop,
cache_obj=self.cache_obj,
session=session,
**self.tmp_args,
)
Expand Down Expand Up @@ -151,6 +154,7 @@ def _llm_type(self) -> str:

chat: Any
session: Session = None
cache_obj: Cache = cache
tmp_args: Any = None

def _generate(
Expand All @@ -170,6 +174,7 @@ def _generate(
_update_cache_msg_callback,
messages=messages,
stop=stop,
cache_obj=self.cache_obj,
session=session,
run_manager=run_manager,
**self.tmp_args,
Expand All @@ -193,6 +198,7 @@ async def _agenerate(
_update_cache_msg_callback,
messages=messages,
stop=stop,
cache_obj=self.cache_obj,
session=session,
run_manager=run_manager,
**self.tmp_args,
Expand Down
18 changes: 9 additions & 9 deletions tests/unit_tests/adapter/test_langchain_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_langchain_llms():

os.environ["OPENAI_API_KEY"] = "API"
langchain_openai = OpenAI(model_name="text-ada-001")
llm = LangChainLLMs(llm=langchain_openai)
llm = LangChainLLMs(llm=langchain_openai,cache_obj=llm_cache)
assert str(langchain_openai) == str(llm)

with patch("openai.Completion.create") as mock_create:
Expand All @@ -53,10 +53,10 @@ def test_langchain_llms():
}
}

answer = llm(prompt=question, cache_obj=llm_cache)
answer = llm(prompt=question)
assert expect_answer == answer

answer = llm(prompt=question, cache_obj=llm_cache)
answer = llm(prompt=question)
assert expect_answer == answer


Expand All @@ -77,7 +77,7 @@ def test_langchain_chats():

os.environ["OPENAI_API_KEY"] = "API"
langchain_openai = ChatOpenAI(temperature=0)
chat = LangChainChat(chat=langchain_openai)
chat = LangChainChat(chat=langchain_openai,cache_obj=llm_cache)

assert chat.get_num_tokens("hello") == langchain_openai.get_num_tokens("hello")
assert chat.get_num_tokens_from_messages(messages=[HumanMessage(content="test_langchain_chats")]) \
Expand All @@ -104,7 +104,7 @@ def test_langchain_chats():
}
}

answer = chat(messages=question, cache_obj=llm_cache)
answer = chat(messages=question)
assert answer == _cache_msg_data_convert(msg).generations[0].message

with patch("openai.ChatCompletion.acreate") as mock_create:
Expand All @@ -128,16 +128,16 @@ def test_langchain_chats():
}
}

answer = asyncio.run(chat.agenerate([question2], cache_obj=llm_cache))
answer = asyncio.run(chat.agenerate([question2]))
assert answer.generations[0][0].text == _cache_msg_data_convert(msg).generations[0].text

answer = chat(messages=question, cache_obj=llm_cache)
answer = chat(messages=question)
assert answer == _cache_msg_data_convert(msg).generations[0].message

answer = asyncio.run(chat.agenerate([question], cache_obj=llm_cache))
answer = asyncio.run(chat.agenerate([question]))
assert answer.generations[0][0].text == _cache_msg_data_convert(msg).generations[0].text

answer = asyncio.run(chat.agenerate([question2], cache_obj=llm_cache))
answer = asyncio.run(chat.agenerate([question2]))
assert answer.generations[0][0].text == _cache_msg_data_convert(msg).generations[0].text


Expand Down

0 comments on commit fdf1630

Please sign in to comment.