Skip to content

Commit

Permalink
Fix the wrong LangChainChat comment
Browse files Browse the repository at this point in the history
Signed-off-by: SimFG <bang.fu@zilliz.com>
  • Loading branch information
SimFG committed May 24, 2023
1 parent 74536ec commit 550c1a7
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 21 deletions.
32 changes: 32 additions & 0 deletions examples/integrate/langchain/langchain_qa_chain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import time

from langchain import OpenAI
from langchain.chains.question_answering import load_qa_chain
from langchain.schema import Document

from gptcache import cache
from gptcache.adapter.api import init_similar_cache
from gptcache.adapter.langchain_models import LangChainLLMs


def get_content_func(data, **_):
return data.get("prompt").split("Question:")[-1]


init_similar_cache(pre_func=get_content_func)
cache.set_openai_key()

mkt_qa = load_qa_chain(llm=LangChainLLMs(llm=OpenAI(temperature=0)), chain_type="stuff")

msg = "What is Traditional marketing?"


before = time.time()
answer = mkt_qa.run(question=msg, input_documents=[Document(page_content="marketing is hello world")])
print(answer)
print("Time Spent:", time.time() - before)

before = time.time()
answer = mkt_qa.run(question=msg, input_documents=[Document(page_content="marketing is hello world")])
print(answer)
print("Time Spent:", time.time() - before)
21 changes: 5 additions & 16 deletions examples/integrate/langchain/langchain_similaritycache_openai.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
import os
import time

import openai
import time
from langchain.llms import OpenAI
from langchain import PromptTemplate
from langchain.llms import OpenAI

from gptcache.adapter.langchain_models import LangChainLLMs
from gptcache.manager import get_data_manager, CacheBase, VectorBase
from gptcache import Cache
from gptcache.embedding import Onnx
from gptcache.adapter.api import init_similar_cache
from gptcache.adapter.langchain_models import LangChainLLMs
from gptcache.processor.pre import get_prompt
from gptcache.similarity_evaluation.distance import SearchDistanceEvaluation

openai.api_key = os.getenv("OPENAI_API_KEY")

Expand All @@ -25,16 +23,7 @@
question = "What NFL team won the Super Bowl in the year Justin Bieber was born?"

llm_cache = Cache()
onnx = Onnx()
cache_base = CacheBase('sqlite')
vector_base = VectorBase('faiss', dimension=onnx.dimension)
data_manager = get_data_manager(cache_base, vector_base, max_size=10, clean_size=2)
llm_cache.init(
pre_embedding_func=get_prompt,
embedding_func=onnx.to_embeddings,
data_manager=data_manager,
similarity_evaluation=SearchDistanceEvaluation(),
)
init_similar_cache(pre_func=get_prompt, cache_obj=llm_cache)


before = time.time()
Expand Down
6 changes: 3 additions & 3 deletions gptcache/adapter/api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# pylint: disable=wrong-import-position
from typing import Any, Optional, Callable, List
from typing import Any, Optional, Callable

import gptcache.processor.post
import gptcache.processor.pre
Expand All @@ -21,7 +21,7 @@
from gptcache.embedding.base import BaseEmbedding
from gptcache.manager import manager_factory
from gptcache.manager.data_manager import DataManager
from gptcache.processor.post import first
from gptcache.processor.post import temperature_softmax
from gptcache.processor.pre import get_prompt
from gptcache.similarity_evaluation import (
SearchDistanceEvaluation,
Expand Down Expand Up @@ -126,7 +126,7 @@ def init_similar_cache(
pre_func: Callable = get_prompt,
embedding: Optional[BaseEmbedding] = None,
data_manager: Optional[DataManager] = None,
post_func: Callable[[List[Any]], Any] = first,
post_func: Callable = temperature_softmax,
config: Config = Config(),
):
"""Provide a quick way to initialize cache for api service
Expand Down
4 changes: 2 additions & 2 deletions gptcache/adapter/langchain_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ class LangChainChat(BaseChatModel, BaseModel):
.. code-block:: python
from gptcache import cache
from gptcache.processor.pre import get_prompt
from gptcache.processor.pre import get_messages_last_content
# init gptcache
cache.init(pre_embedding_func=get_prompt)
cache.init(pre_embedding_func=get_messages_last_content)
cache.set_openai_key()
from langchain.chat_models import ChatOpenAI
from gptcache.adapter.langchain_models import LangChainChat
Expand Down
17 changes: 17 additions & 0 deletions gptcache/processor/pre.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,23 @@ def get_inputs(data: Dict[str, Any], **_: Dict[str, Any]):
return data.get("inputs")


def get_messages_last_content(data: Dict[str, Any], **_: Any) -> str:
""" get the last content of the llm request messages array
:param data: the user llm request data
:type data: Dict[str, Any]
Example:
.. code-block:: python
from gptcache.processor.pre import get_messages_last_content
content = get_messages_last_content({"messages": [{"content": "hello"}, {"content": "world"}]})
# "world"
"""
return data.get("messages")[-1].content


def get_openai_moderation_input(data: Dict[str, Any], **_: Dict[str, Any]) -> str:
"""get the input param of the openai moderation request params
Expand Down
5 changes: 5 additions & 0 deletions tests/unit_tests/processor/test_pre.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,8 @@ def test_get_prompt():
def test_get_openai_moderation_input():
content = get_openai_moderation_input({"input": ["hello", "world"]})
assert content == "['hello', 'world']"


def test_get_messages_last_content():
content = last_content({"messages": [{"content": "foo1"}, {"content": "foo2"}]})
assert content == "foo2"

0 comments on commit 550c1a7

Please sign in to comment.