Skip to content

Commit

Permalink
Langchain track token usage (#409)
Browse files Browse the repository at this point in the history
Signed-off-by: SimFG <bang.fu@zilliz.com>
  • Loading branch information
SimFG authored Jun 6, 2023
1 parent f5f02f0 commit 09a7b59
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 23 deletions.
174 changes: 152 additions & 22 deletions gptcache/adapter/langchain_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, List, Any
from typing import Optional, List, Any, Mapping

from gptcache.adapter.adapter import adapt
from gptcache.manager.scalar_data.base import Answer, DataType
Expand All @@ -12,9 +12,20 @@
from pydantic import BaseModel
from langchain.llms.base import LLM
from langchain.chat_models.base import BaseChatModel
from langchain.schema import BaseMessage, LLMResult, AIMessage, ChatGeneration, ChatResult
from langchain.schema import (
BaseMessage,
LLMResult,
AIMessage,
ChatGeneration,
ChatResult,
)
from langchain.callbacks.manager import (
Callbacks,
CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
)


# pylint: disable=protected-access
class LangChainLLMs(LLM, BaseModel):
"""LangChain LLM Wrapper.
Expand All @@ -39,25 +50,77 @@ class LangChainLLMs(LLM, BaseModel):

llm: Any
session: Session = None
tmp_args: Any = None

@property
def _llm_type(self) -> str:
return "gptcache_llm"
return self.llm._llm_type

@property
def _identifying_params(self) -> Mapping[str, Any]:
return self.llm._identifying_params

def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs) -> str:
session = self.session if "session" not in kwargs else kwargs.pop("session")
def __str__(self) -> str:
return str(self.llm)

def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
_: Optional[CallbackManagerForLLMRun] = None,
) -> str:
session = (
self.session
if "session" not in self.tmp_args
else self.tmp_args.pop("session")
)
return adapt(
self.llm,
_cache_data_convert,
_update_cache_callback,
prompt=prompt,
stop=stop,
session=session,
**kwargs
**self.tmp_args,
)

def __call__(self, prompt: str, stop: Optional[List[str]] = None, **kwargs) -> str:
return self._call(prompt=prompt, stop=stop, **kwargs)
async def _acall(self, prompt: str, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None) -> str:
return super()._acall(prompt, stop=stop, run_manager=run_manager)

def generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs,
) -> LLMResult:
self.tmp_args = kwargs
return super().generate(prompts, stop=stop, callbacks=callbacks)

async def agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs,
) -> LLMResult:
self.tmp_args = kwargs
return await super().agenerate(prompts, stop=stop, callbacks=callbacks)

def __call__(
self,
prompt: str,
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs,
) -> str:
"""Check Cache and run the LLM on the given prompt and input."""
return (
self.generate([prompt], stop=stop, callbacks=callbacks, **kwargs)
.generations[0][0]
.text
)


# pylint: disable=protected-access
Expand Down Expand Up @@ -88,53 +151,120 @@ def _llm_type(self) -> str:

chat: Any
session: Session = None
tmp_args: Any = None

def _generate(self, messages: Any, stop: Optional[List[str]] = None, **kwargs):
session = self.session if "session" not in kwargs else kwargs.pop("session")
def _generate(
self,
messages: Any,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> ChatResult:
session = (
self.session
if "session" not in self.tmp_args
else self.tmp_args.pop("session")
)
return adapt(
self.chat._generate,
_cache_msg_data_convert,
_update_cache_msg_callback,
messages=messages,
stop=stop,
session=session,
**kwargs
run_manager=run_manager,
**self.tmp_args,
)

async def _agenerate(self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None, **kwargs) -> LLMResult:
session = self.session if "session" not in kwargs else kwargs.pop("session")
async def _agenerate(
self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> ChatResult:
session = (
self.session
if "session" not in self.tmp_args
else self.tmp_args.pop("session")
)
return adapt(
self.chat._agenerate,
_cache_msg_data_convert,
_update_cache_msg_callback,
messages=messages,
stop=stop,
session=session,
**kwargs
run_manager=run_manager,
**self.tmp_args,
)

def generate(
self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs,
) -> LLMResult:
self.tmp_args = kwargs
return super().generate(messages, stop=stop, callbacks=callbacks)

async def agenerate(
self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs,
) -> LLMResult:
self.tmp_args = kwargs
return await super().agenerate(messages, stop=stop, callbacks=callbacks)

@property
def _identifying_params(self):
return self.chat._identifying_params

def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
return self.chat._combine_llm_outputs(llm_outputs)

def get_num_tokens(self, text: str) -> int:
return self.chat.get_num_tokens(text)

def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
return self.chat.get_num_tokens_from_messages(messages)

def __call__(self, messages: Any, stop: Optional[List[str]] = None, **kwargs):
res = self._generate(messages=messages, stop=stop, **kwargs)
return res.generations[0].message
generation = self.generate([messages], stop=stop, **kwargs).generations[0][0]
if isinstance(generation, ChatGeneration):
return generation.message
else:
raise ValueError("Unexpected generation type")


def _cache_data_convert(cache_data):
return cache_data


def _update_cache_callback(llm_data, update_cache_func, *args, **kwargs): # pylint: disable=unused-argument
def _update_cache_callback(
llm_data, update_cache_func, *args, **kwargs
): # pylint: disable=unused-argument
update_cache_func(Answer(llm_data, DataType.STR))
return llm_data


def _cache_msg_data_convert(cache_data):
llm_res = ChatResult(generations=[ChatGeneration(text="",
generation_info=None,
message=AIMessage(content=cache_data, additional_kwargs={}))],
llm_output=None)
llm_res = ChatResult(
generations=[
ChatGeneration(
text="",
generation_info=None,
message=AIMessage(content=cache_data, additional_kwargs={}),
)
],
llm_output=None,
)
return llm_res


def _update_cache_msg_callback(llm_data, update_cache_func, *args, **kwargs): # pylint: disable=unused-argument
def _update_cache_msg_callback(
llm_data, update_cache_func, *args, **kwargs
): # pylint: disable=unused-argument
update_cache_func(llm_data.generations[0].text)
return llm_data
3 changes: 2 additions & 1 deletion gptcache/manager/vector_data/usearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import_usearch()

from usearch.index import Index # pylint: disable=C0413
from usearch.compiled import MetricKind # pylint: disable=C0413


class USearch(VectorBase):
Expand Down Expand Up @@ -48,7 +49,7 @@ def __init__(
self._top_k = top_k
self._index = Index(
ndim=self._dimension,
metric=metric,
metric=getattr(MetricKind, metric.lower().capitalize()),
dtype=dtype,
connectivity=connectivity,
expansion_add=expansion_add,
Expand Down
2 changes: 2 additions & 0 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,5 @@ typing_extensions<4.6.0
stability-sdk
grpcio==1.53.0
protobuf==3.20.0
milvus==2.2.8
pymilvus==2.2.8

0 comments on commit 09a7b59

Please sign in to comment.