-
Notifications
You must be signed in to change notification settings - Fork 513
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
- Loading branch information
1 parent
ca09f95
commit 02949b9
Showing
3 changed files
with
139 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from typing import Any | ||
|
||
from gptcache.adapter.adapter import adapt | ||
from gptcache.utils import import_huggingface, import_torch | ||
from gptcache.manager.scalar_data.base import Answer, DataType | ||
|
||
import_torch() | ||
import_huggingface() | ||
|
||
from transformers import pipeline # pylint: disable=wrong-import-position | ||
|
||
|
||
class Dolly: | ||
"""Wrapper for Dolly (https://github.com/databrickslabs/dolly.git). | ||
Example using from_model: | ||
.. code-block:: python | ||
from gptcache import cache | ||
from gptcache.processor.pre import get_inputs | ||
cache.init(pre_embedding_func=get_inputs) | ||
from gptcache.adapter.dolly import Dolly | ||
dolly = Dolly.from_model( | ||
model="databricks/dolly-v2-12b", torch_dtype=torch.bfloat16, trust_remote_code=True, device=0 | ||
) | ||
Example passing pipeline in directly: | ||
.. code-block:: python | ||
import torch | ||
from transformers import pipeline | ||
from gptcache import cache | ||
from gptcache.processor.pre import get_inputs | ||
cache.init(pre_embedding_func=get_inputs) | ||
from gptcache.adapter.dolly import Dolly | ||
pipe = pipeline( | ||
model="databricks/dolly-v2-12b", torch_dtype=torch.bfloat16, trust_remote_code=True, device=0 | ||
) | ||
dolly = Dolly(pipe) | ||
""" | ||
|
||
def __init__(self, dolly_pipeline: Any): | ||
self._dolly_pipeline = dolly_pipeline | ||
|
||
@classmethod | ||
def from_model(cls, model: str, **kwargs): | ||
pipe = pipeline(model=model, **kwargs) | ||
return cls(pipe) | ||
|
||
def __call__(self, prompt: str, **kwargs): | ||
return adapt( | ||
self._dolly_pipeline, | ||
cache_data_convert, | ||
update_cache_callback, | ||
inputs=prompt, | ||
**kwargs | ||
) | ||
|
||
|
||
def cache_data_convert(cache_data): | ||
return [{"generated_text": cache_data, "gptcache": True}] | ||
|
||
|
||
def update_cache_callback(llm_data, update_cache_func, *args, **kwargs): # pylint: disable=unused-argument | ||
update_cache_func(Answer(llm_data[0]["generated_text"], DataType.STR)) | ||
return llm_data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import os | ||
import sys | ||
import unittest | ||
from unittest.mock import patch | ||
from tempfile import TemporaryDirectory | ||
|
||
from gptcache import Cache | ||
from gptcache.processor.pre import get_inputs | ||
from gptcache.manager.factory import manager_factory | ||
from gptcache.embedding import Onnx | ||
|
||
question = "test_dolly" | ||
expect_answer = "hello world" | ||
onnx = Onnx() | ||
|
||
class MockDolly: | ||
def __init__(self, *args, **kwargs): | ||
pass | ||
|
||
def __call__(self, inputs, **kwargs): | ||
return [{"generated_text": expect_answer}] | ||
|
||
|
||
class TestDolly(unittest.TestCase): | ||
def test_normal(self): | ||
with patch('gptcache.utils.import_torch'), \ | ||
patch('gptcache.utils.import_huggingface'), \ | ||
patch('transformers.pipeline') as mock_pipeline: | ||
|
||
with TemporaryDirectory(dir="./") as root: | ||
m = manager_factory('sqlite,faiss,local', data_dir=root, vector_params={"dimension": onnx.dimension}) | ||
llm_cache = Cache() | ||
llm_cache.init( | ||
pre_embedding_func=get_inputs, | ||
data_manager=m, | ||
embedding_func=onnx.to_embeddings | ||
) | ||
|
||
from gptcache.adapter.dolly import Dolly | ||
|
||
mock_pipeline.return_value = MockDolly() | ||
dolly = Dolly.from_model('dolly_model') | ||
answer = dolly(question, cache_obj=llm_cache) | ||
self.assertEqual(answer[0]["generated_text"], expect_answer) | ||
self.assertFalse(answer[0].get("gptcache", False)) | ||
answer = dolly(question, cache_obj=llm_cache) | ||
self.assertEqual(answer[0]["generated_text"], expect_answer) | ||
self.assertTrue(answer[0].get("gptcache", False)) | ||
|
||
with TemporaryDirectory(dir="./") as root: | ||
m = manager_factory('sqlite,faiss,local', data_dir=root, vector_params={"dimension": onnx.dimension}) | ||
llm_cache = Cache() | ||
llm_cache.init( | ||
pre_embedding_func=get_inputs, | ||
data_manager=m, | ||
embedding_func=onnx.to_embeddings | ||
) | ||
|
||
from gptcache.adapter.dolly import Dolly | ||
from transformers import pipeline | ||
dolly = Dolly(pipeline('dolly')) | ||
answer = dolly(question, cache_obj=llm_cache) | ||
self.assertEqual(answer[0]["generated_text"], expect_answer) | ||
self.assertFalse(answer[0].get("gptcache", False)) | ||
answer = dolly(question, cache_obj=llm_cache) | ||
self.assertEqual(answer[0]["generated_text"], expect_answer) | ||
self.assertTrue(answer[0].get("gptcache", False)) |