Skip to content

Commit

Permalink
Fix the langchain chat pydantic bug (#538)
Browse files Browse the repository at this point in the history
Signed-off-by: SimFG <bang.fu@zilliz.com>
  • Loading branch information
SimFG authored Sep 14, 2023
1 parent d19bf74 commit 3339e52
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 9 deletions.
18 changes: 9 additions & 9 deletions gptcache/adapter/langchain_models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Optional, List, Any, Mapping

from gptcache.adapter.adapter import adapt, aadapt
from gptcache.core import cache
from gptcache.manager.scalar_data.base import Answer, DataType
from gptcache.session import Session
from gptcache.utils import import_pydantic, import_langchain
from gptcache.core import Cache,cache

import_pydantic()
import_langchain()
Expand Down Expand Up @@ -51,7 +51,6 @@ class LangChainLLMs(LLM, BaseModel):

llm: Any
session: Session = None
cache_obj: Cache = cache
tmp_args: Any = None

@property
Expand All @@ -76,13 +75,14 @@ def _call(
if "session" not in self.tmp_args
else self.tmp_args.pop("session")
)
cache_obj = self.tmp_args.pop("cache_obj", cache)
return adapt(
self.llm,
_cache_data_convert,
_update_cache_callback,
prompt=prompt,
stop=stop,
cache_obj=self.cache_obj,
cache_obj=cache_obj,
session=session,
**self.tmp_args,
)
Expand Down Expand Up @@ -153,9 +153,8 @@ def _llm_type(self) -> str:
return "gptcache_llm_chat"

chat: Any
session: Session = None
cache_obj: Cache = cache
tmp_args: Any = None
session: Optional[Session] = None
tmp_args: Optional[Any] = None

def _generate(
self,
Expand All @@ -168,13 +167,14 @@ def _generate(
if "session" not in self.tmp_args
else self.tmp_args.pop("session")
)
cache_obj = self.tmp_args.pop("cache_obj", cache)
return adapt(
self.chat._generate,
_cache_msg_data_convert,
_update_cache_msg_callback,
messages=messages,
stop=stop,
cache_obj=self.cache_obj,
cache_obj=cache_obj,
session=session,
run_manager=run_manager,
**self.tmp_args,
Expand All @@ -191,14 +191,14 @@ async def _agenerate(
if "session" not in self.tmp_args
else self.tmp_args.pop("session")
)

cache_obj = self.tmp_args.pop("cache_obj", cache)
return await aadapt(
self.chat._agenerate,
_cache_msg_data_convert,
_update_cache_msg_callback,
messages=messages,
stop=stop,
cache_obj=self.cache_obj,
cache_obj=cache_obj,
session=session,
run_manager=run_manager,
**self.tmp_args,
Expand Down
2 changes: 2 additions & 0 deletions gptcache/adapter/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ def cache_data_convert(cache_data):
saved_token = [input_token, output_token]
else:
saved_token = [0, 0]
if kwargs.get("stream", False):
return _construct_stream_resp_from_cache(cache_data, saved_token)
return _construct_resp_from_cache(cache_data, saved_token)

kwargs = cls.fill_base_args(**kwargs)
Expand Down
1 change: 1 addition & 0 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pytest-sugar==0.9.5
pytest-parallel
psycopg2-binary
transformers==4.29.2
anyio==3.6.2
torch
mock
pexpect
Expand Down
3 changes: 3 additions & 0 deletions tests/unit_tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from unittest.mock import patch, Mock

from gptcache.utils import import_httpx

import_httpx()
from gptcache.client import Client


Expand Down

0 comments on commit 3339e52

Please sign in to comment.