Skip to content

Commit

Permalink
Use temperature to control possibility of skip_cache (#306)
Browse files Browse the repository at this point in the history
Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
  • Loading branch information
jaelgu authored Apr 28, 2023
1 parent 568e46f commit f5ba15f
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 36 deletions.
50 changes: 50 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,56 @@ for question in questions:
print(f'Answer: {response_text(response)}\n')
```

#### OpenAI API + GPTCache, use temperature

> You can always pass a parameter of temperature while requesting the API service or model.
>
> The range of `temperature` is [0, 2], default value is 0.0.
>
> A higher temperature means a higher possibility of skipping cache search and requesting large model directly.
> When temperature is 2, it will skip cache and send request to large model directly for sure. When temperature is 0, it will search cache before requesting large model service.
>
> The default `post_process_messages_func` is `temperature_softmax`. In this case, refer to [API reference](https://gptcache.readthedocs.io/en/latest/references/processor.html#module-gptcache.processor.post) to learn about how `temperature` affects output.
```python
import time

from gptcache import cache, Config
from gptcache.manager import manager_factory
from gptcache.embedding import Onnx
from gptcache.processor.post import temperature_softmax
from gptcache.similarity_evaluation.distance import SearchDistanceEvaluation
from gptcache.adapter import openai

cache.set_openai_key()

onnx = Onnx()
data_manager = manager_factory("sqlite,faiss", vector_params={"dimension": onnx.dimension})

cache.init(
embedding_func=onnx.to_embeddings,
data_manager=data_manager,
similarity_evaluation=SearchDistanceEvaluation(),
post_process_messages_func=temperature_softmax
)
# cache.config = Config(similarity_threshold=0.2)

question = "what's github"

for _ in range(3):
start = time.time()
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
temperature = 1.0, # Change temperature here
messages=[{
"role": "user",
"content": question
}],
)
print("Time elapsed:", round(time.time() - start, 3))
print("Answer:", response["choices"][0]["message"]["content"])
```

</details>

To use GPTCache exclusively, only the following lines of code are required, and there is no need to modify any existing code.
Expand Down
56 changes: 21 additions & 35 deletions examples/processor/temperature_example.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,43 @@
import os
import time

from gptcache import cache, Config
from gptcache.manager import manager_factory
from gptcache.embedding import Onnx
from gptcache.processor.pre import get_prompt
from gptcache.processor.post import temperature_softmax
from gptcache.similarity_evaluation.distance import SearchDistanceEvaluation
from gptcache.adapter.api import put, get
from gptcache.adapter import openai

cache.set_openai_key()

# Init cache with vector store
if os.path.exists("faiss.index"):
os.remove("faiss.index")
if os.path.exists("sqlite.db"):
os.remove("sqlite.db")
# if os.path.exists("faiss.index"):
# os.remove("faiss.index")
# if os.path.exists("sqlite.db"):
# os.remove("sqlite.db")

onnx = Onnx()
data_manager = manager_factory("sqlite,faiss", vector_params={"dimension": onnx.dimension})


cache.init(
pre_embedding_func=get_prompt,
embedding_func=onnx.to_embeddings,
data_manager=data_manager,
similarity_evaluation=SearchDistanceEvaluation(),
post_process_messages_func=temperature_softmax
)
# cache.config = Config(similarity_threshold=0.2)

# Input some prepared data to mock a cache with data stored
my_data = [
{"Q": "What is the most popular vector database?", "A": "Milvus!"},
{"Q": "What are most popular vector databases?", "A": "Milvus, Milvus, still Milvus ..."},
{"Q": "What is vector database?", "A": "Vector database is xxxx."},
{"Q": "Is Milvus an open-source vector database?", "A": "Yes, Milvus is open source."},
{"Q": "What is Zilliz cloud?", "A": "Zilliz cloud provides vector database on cloud."},
{"Q": "What is Milvus?", "A": "Milvus is an open-source vector database."},
{"Q": "Can you recommend a vector database?", "A": "Sure, Milvus is a good choice for vector database."},
{"Q": "Is Zilliz Cloud free?", "A": "No, Zilliz Cloud charges for instance."},
{"Q": "How many credits can I get for Zilliz Cloud?", "A": "A new user of Zilliz Cloud will get 350 credits."},
{"Q": "Do you like GPTCache?", "A": "Yea! GPTCache is great!"},
]

for qa in my_data:
put(prompt=qa["Q"], data=qa["A"], skip_cache=True)


# use cache without temperature (temperature=0.0)
for _ in range(5):
answer = get(prompt="popular vector database")
print(answer)

# use cache with temperature (eg. temperature=2.0)
for _ in range(5):
answer = get(prompt="popular vector database", temperature=2.0)
print(answer)
question = 'what is github'

for _ in range(3):
start = time.time()
response = openai.ChatCompletion.create(
model='gpt-3.5-turbo',
temperature = 1.0, # Change temperature here
messages=[{
'role': 'user',
'content': question
}],
)
print(round(time.time() - start, 3))
print(response["choices"][0]["message"]["content"])
10 changes: 9 additions & 1 deletion gptcache/adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,15 @@ def adapt(llm_handler, cache_data_convert, update_cache_callback, *args, **kwarg
context = kwargs.pop("cache_context", {})
embedding_data = None
# you want to retry to send the request to chatgpt when the cache is negative
cache_skip = kwargs.pop("cache_skip", False)
if 0 < temperature < 2:
cache_skip_options = [True, False]
prob_cache_skip = [0, 1]
cache_skip = kwargs.pop("cache_skip", temperature_softmax(
messages=cache_skip_options, scores = prob_cache_skip, temperature=temperature))
elif temperature >= 2:
cache_skip = kwargs.pop("cache_skip", True)
else: # temperature <= 0
cache_skip = kwargs.pop("cache_skip", False)
cache_factor = kwargs.pop("cache_factor", 1.0)
pre_embedding_res = chat_cache.pre_embedding_func(
kwargs,
Expand Down
1 change: 1 addition & 0 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ pytest-sugar==0.9.5
pytest-parallel
torch
mock
pexpect
6 changes: 6 additions & 0 deletions tests/unit_tests/adapter/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,12 @@ def test_cache_temperature():
for _ in range(5):
put(prompt=prompt, data=answer, skip_cache=True)

answers = get(prompt=prompt, temperature=2.0)
assert answers is None

answers = get(prompt=prompt, temperature=1.5)
assert answers in [None, [answer] * 5]

answers = get(prompt=prompt, temperature=0.0, top_k=3)
assert len(answers) == 3

Expand Down

0 comments on commit f5ba15f

Please sign in to comment.