Skip to content

Commit

Permalink
Add dolly (#311)
Browse files Browse the repository at this point in the history
Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
  • Loading branch information
junjiejiangjjj authored Apr 28, 2023
1 parent ca09f95 commit 02949b9
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 0 deletions.
68 changes: 68 additions & 0 deletions gptcache/adapter/dolly.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import Any

from gptcache.adapter.adapter import adapt
from gptcache.utils import import_huggingface, import_torch
from gptcache.manager.scalar_data.base import Answer, DataType

import_torch()
import_huggingface()

from transformers import pipeline # pylint: disable=wrong-import-position


class Dolly:
"""Wrapper for Dolly (https://github.com/databrickslabs/dolly.git).
Example using from_model:
.. code-block:: python
from gptcache import cache
from gptcache.processor.pre import get_inputs
cache.init(pre_embedding_func=get_inputs)
from gptcache.adapter.dolly import Dolly
dolly = Dolly.from_model(
model="databricks/dolly-v2-12b", torch_dtype=torch.bfloat16, trust_remote_code=True, device=0
)
Example passing pipeline in directly:
.. code-block:: python
import torch
from transformers import pipeline
from gptcache import cache
from gptcache.processor.pre import get_inputs
cache.init(pre_embedding_func=get_inputs)
from gptcache.adapter.dolly import Dolly
pipe = pipeline(
model="databricks/dolly-v2-12b", torch_dtype=torch.bfloat16, trust_remote_code=True, device=0
)
dolly = Dolly(pipe)
"""

def __init__(self, dolly_pipeline: Any):
self._dolly_pipeline = dolly_pipeline

@classmethod
def from_model(cls, model: str, **kwargs):
pipe = pipeline(model=model, **kwargs)
return cls(pipe)

def __call__(self, prompt: str, **kwargs):
return adapt(
self._dolly_pipeline,
cache_data_convert,
update_cache_callback,
inputs=prompt,
**kwargs
)


def cache_data_convert(cache_data):
return [{"generated_text": cache_data, "gptcache": True}]


def update_cache_callback(llm_data, update_cache_func, *args, **kwargs): # pylint: disable=unused-argument
update_cache_func(Answer(llm_data[0]["generated_text"], DataType.STR))
return llm_data
4 changes: 4 additions & 0 deletions gptcache/processor/pre.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,7 @@ def get_image_question(data: Dict[str, Any], **_: Dict[str, Any]) -> str: # pra

def get_image(data: Dict[str, Any], **_: Dict[str, Any]) -> str: # pragma: no cover
return data.get("image")


def get_inputs(data: Dict[str, Any], **_: Dict[str, Any]):
return data.get("inputs")
67 changes: 67 additions & 0 deletions tests/unit_tests/adapter/test_dolly.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import os
import sys
import unittest
from unittest.mock import patch
from tempfile import TemporaryDirectory

from gptcache import Cache
from gptcache.processor.pre import get_inputs
from gptcache.manager.factory import manager_factory
from gptcache.embedding import Onnx

question = "test_dolly"
expect_answer = "hello world"
onnx = Onnx()

class MockDolly:
def __init__(self, *args, **kwargs):
pass

def __call__(self, inputs, **kwargs):
return [{"generated_text": expect_answer}]


class TestDolly(unittest.TestCase):
def test_normal(self):
with patch('gptcache.utils.import_torch'), \
patch('gptcache.utils.import_huggingface'), \
patch('transformers.pipeline') as mock_pipeline:

with TemporaryDirectory(dir="./") as root:
m = manager_factory('sqlite,faiss,local', data_dir=root, vector_params={"dimension": onnx.dimension})
llm_cache = Cache()
llm_cache.init(
pre_embedding_func=get_inputs,
data_manager=m,
embedding_func=onnx.to_embeddings
)

from gptcache.adapter.dolly import Dolly

mock_pipeline.return_value = MockDolly()
dolly = Dolly.from_model('dolly_model')
answer = dolly(question, cache_obj=llm_cache)
self.assertEqual(answer[0]["generated_text"], expect_answer)
self.assertFalse(answer[0].get("gptcache", False))
answer = dolly(question, cache_obj=llm_cache)
self.assertEqual(answer[0]["generated_text"], expect_answer)
self.assertTrue(answer[0].get("gptcache", False))

with TemporaryDirectory(dir="./") as root:
m = manager_factory('sqlite,faiss,local', data_dir=root, vector_params={"dimension": onnx.dimension})
llm_cache = Cache()
llm_cache.init(
pre_embedding_func=get_inputs,
data_manager=m,
embedding_func=onnx.to_embeddings
)

from gptcache.adapter.dolly import Dolly
from transformers import pipeline
dolly = Dolly(pipeline('dolly'))
answer = dolly(question, cache_obj=llm_cache)
self.assertEqual(answer[0]["generated_text"], expect_answer)
self.assertFalse(answer[0].get("gptcache", False))
answer = dolly(question, cache_obj=llm_cache)
self.assertEqual(answer[0]["generated_text"], expect_answer)
self.assertTrue(answer[0].get("gptcache", False))

0 comments on commit 02949b9

Please sign in to comment.