From 09a7b592838ae0d32fa672363cc1d0d42b135ce9 Mon Sep 17 00:00:00 2001 From: SimFG Date: Tue, 6 Jun 2023 10:06:19 +0800 Subject: [PATCH] Langchain track token usage (#409) Signed-off-by: SimFG --- gptcache/adapter/langchain_models.py | 174 +++++++++++++++++++++--- gptcache/manager/vector_data/usearch.py | 3 +- tests/requirements.txt | 2 + 3 files changed, 156 insertions(+), 23 deletions(-) diff --git a/gptcache/adapter/langchain_models.py b/gptcache/adapter/langchain_models.py index 9d1c45ee..3bc47daa 100644 --- a/gptcache/adapter/langchain_models.py +++ b/gptcache/adapter/langchain_models.py @@ -1,4 +1,4 @@ -from typing import Optional, List, Any +from typing import Optional, List, Any, Mapping from gptcache.adapter.adapter import adapt from gptcache.manager.scalar_data.base import Answer, DataType @@ -12,9 +12,20 @@ from pydantic import BaseModel from langchain.llms.base import LLM from langchain.chat_models.base import BaseChatModel -from langchain.schema import BaseMessage, LLMResult, AIMessage, ChatGeneration, ChatResult +from langchain.schema import ( + BaseMessage, + LLMResult, + AIMessage, + ChatGeneration, + ChatResult, +) +from langchain.callbacks.manager import ( + Callbacks, + CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun, +) +# pylint: disable=protected-access class LangChainLLMs(LLM, BaseModel): """LangChain LLM Wrapper. @@ -39,13 +50,30 @@ class LangChainLLMs(LLM, BaseModel): llm: Any session: Session = None + tmp_args: Any = None @property def _llm_type(self) -> str: - return "gptcache_llm" + return self.llm._llm_type + + @property + def _identifying_params(self) -> Mapping[str, Any]: + return self.llm._identifying_params - def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs) -> str: - session = self.session if "session" not in kwargs else kwargs.pop("session") + def __str__(self) -> str: + return str(self.llm) + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + _: Optional[CallbackManagerForLLMRun] = None, + ) -> str: + session = ( + self.session + if "session" not in self.tmp_args + else self.tmp_args.pop("session") + ) return adapt( self.llm, _cache_data_convert, @@ -53,11 +81,46 @@ def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs) -> str: prompt=prompt, stop=stop, session=session, - **kwargs + **self.tmp_args, ) - def __call__(self, prompt: str, stop: Optional[List[str]] = None, **kwargs) -> str: - return self._call(prompt=prompt, stop=stop, **kwargs) + async def _acall(self, prompt: str, stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None) -> str: + return super()._acall(prompt, stop=stop, run_manager=run_manager) + + def generate( + self, + prompts: List[str], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + **kwargs, + ) -> LLMResult: + self.tmp_args = kwargs + return super().generate(prompts, stop=stop, callbacks=callbacks) + + async def agenerate( + self, + prompts: List[str], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + **kwargs, + ) -> LLMResult: + self.tmp_args = kwargs + return await super().agenerate(prompts, stop=stop, callbacks=callbacks) + + def __call__( + self, + prompt: str, + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + **kwargs, + ) -> str: + """Check Cache and run the LLM on the given prompt and input.""" + return ( + self.generate([prompt], stop=stop, callbacks=callbacks, **kwargs) + .generations[0][0] + .text + ) # pylint: disable=protected-access @@ -88,9 +151,19 @@ def _llm_type(self) -> str: chat: Any session: Session = None + tmp_args: Any = None - def _generate(self, messages: Any, stop: Optional[List[str]] = None, **kwargs): - session = self.session if "session" not in kwargs else kwargs.pop("session") + def _generate( + self, + messages: Any, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> ChatResult: + session = ( + self.session + if "session" not in self.tmp_args + else self.tmp_args.pop("session") + ) return adapt( self.chat._generate, _cache_msg_data_convert, @@ -98,11 +171,21 @@ def _generate(self, messages: Any, stop: Optional[List[str]] = None, **kwargs): messages=messages, stop=stop, session=session, - **kwargs + run_manager=run_manager, + **self.tmp_args, ) - async def _agenerate(self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None, **kwargs) -> LLMResult: - session = self.session if "session" not in kwargs else kwargs.pop("session") + async def _agenerate( + self, + messages: List[List[BaseMessage]], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> ChatResult: + session = ( + self.session + if "session" not in self.tmp_args + else self.tmp_args.pop("session") + ) return adapt( self.chat._agenerate, _cache_msg_data_convert, @@ -110,31 +193,78 @@ async def _agenerate(self, messages: List[List[BaseMessage]], stop: Optional[Lis messages=messages, stop=stop, session=session, - **kwargs + run_manager=run_manager, + **self.tmp_args, ) + def generate( + self, + messages: List[List[BaseMessage]], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + **kwargs, + ) -> LLMResult: + self.tmp_args = kwargs + return super().generate(messages, stop=stop, callbacks=callbacks) + + async def agenerate( + self, + messages: List[List[BaseMessage]], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + **kwargs, + ) -> LLMResult: + self.tmp_args = kwargs + return await super().agenerate(messages, stop=stop, callbacks=callbacks) + + @property + def _identifying_params(self): + return self.chat._identifying_params + + def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: + return self.chat._combine_llm_outputs(llm_outputs) + + def get_num_tokens(self, text: str) -> int: + return self.chat.get_num_tokens(text) + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + return self.chat.get_num_tokens_from_messages(messages) + def __call__(self, messages: Any, stop: Optional[List[str]] = None, **kwargs): - res = self._generate(messages=messages, stop=stop, **kwargs) - return res.generations[0].message + generation = self.generate([messages], stop=stop, **kwargs).generations[0][0] + if isinstance(generation, ChatGeneration): + return generation.message + else: + raise ValueError("Unexpected generation type") def _cache_data_convert(cache_data): return cache_data -def _update_cache_callback(llm_data, update_cache_func, *args, **kwargs): # pylint: disable=unused-argument +def _update_cache_callback( + llm_data, update_cache_func, *args, **kwargs +): # pylint: disable=unused-argument update_cache_func(Answer(llm_data, DataType.STR)) return llm_data def _cache_msg_data_convert(cache_data): - llm_res = ChatResult(generations=[ChatGeneration(text="", - generation_info=None, - message=AIMessage(content=cache_data, additional_kwargs={}))], - llm_output=None) + llm_res = ChatResult( + generations=[ + ChatGeneration( + text="", + generation_info=None, + message=AIMessage(content=cache_data, additional_kwargs={}), + ) + ], + llm_output=None, + ) return llm_res -def _update_cache_msg_callback(llm_data, update_cache_func, *args, **kwargs): # pylint: disable=unused-argument +def _update_cache_msg_callback( + llm_data, update_cache_func, *args, **kwargs +): # pylint: disable=unused-argument update_cache_func(llm_data.generations[0].text) return llm_data diff --git a/gptcache/manager/vector_data/usearch.py b/gptcache/manager/vector_data/usearch.py index f461c080..303d8089 100644 --- a/gptcache/manager/vector_data/usearch.py +++ b/gptcache/manager/vector_data/usearch.py @@ -9,6 +9,7 @@ import_usearch() from usearch.index import Index # pylint: disable=C0413 +from usearch.compiled import MetricKind # pylint: disable=C0413 class USearch(VectorBase): @@ -48,7 +49,7 @@ def __init__( self._top_k = top_k self._index = Index( ndim=self._dimension, - metric=metric, + metric=getattr(MetricKind, metric.lower().capitalize()), dtype=dtype, connectivity=connectivity, expansion_add=expansion_add, diff --git a/tests/requirements.txt b/tests/requirements.txt index f706d145..605b90f2 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -24,3 +24,5 @@ typing_extensions<4.6.0 stability-sdk grpcio==1.53.0 protobuf==3.20.0 +milvus==2.2.8 +pymilvus==2.2.8