Skip to content

Commit

Permalink
Improve the uform embedding
Browse files Browse the repository at this point in the history
Signed-off-by: SimFG <bang.fu@zilliz.com>
  • Loading branch information
SimFG committed May 24, 2023
1 parent 0f4243d commit f91a8da
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 21 deletions.
37 changes: 24 additions & 13 deletions gptcache/embedding/uform.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import os
from typing import Union, Any

from gptcache.embedding.base import BaseEmbedding
from gptcache.utils import import_uform, import_pillow
from gptcache.utils.error import ParamError

import_pillow()
import_uform()

from uform import get_model # pylint: disable=C0413 # nopep8
from uform import TritonClient, get_model # pylint: disable=C0413 # nopep8
from PIL import Image # pylint: disable=C0413 # nopep8


Expand All @@ -15,6 +16,8 @@ class UForm(BaseEmbedding):
:param model: model name, defaults to 'unum-cloud/uform-vl-english'.
:type model: str
:param embedding_type: type of embedding, defaults to 'text'. options: text, image
:type embedding_type: str
Example:
.. code-block:: python
Expand All @@ -30,27 +33,35 @@ class UForm(BaseEmbedding):
embed = encoder.to_embeddings(test_sentence)
"""

def __init__(self, model: str = "unum-cloud/uform-vl-english"):
self.model = get_model(model)
self.__dimension = self.model.image_encoder.dim
def __init__(self, model: Union[str, TritonClient] = "unum-cloud/uform-vl-english", embedding_type: str = "text"):
if isinstance(model, str):
self.__model = get_model(model)
else:
self.__model = model
self.__embedding_type = embedding_type
if embedding_type == "text":
self.__dimension = self.__model.text_encoder.proj.out_features
elif embedding_type == "image":
self.__dimension = self.__model.img_encoder.proj.out_features
else:
raise ParamError(f"Unknown embedding type: {embedding_type}")

def to_embeddings(self, data: str, **_):
def to_embeddings(self, data: Any, **_):
"""Generate embedding given text input or a path to a file.
:param data: text in string, or a path to an image file.
:type data: str
:return: an embedding in shape of (dim,).
"""
if os.path.exists(data):
if self.__embedding_type == "image":
data = Image.open(data)
data = self.model.preprocess_image(data)
emb = self.model.encode_image(data)
return emb.detach().numpy().flatten()
data = self.__model.preprocess_image(data)
emb = self.__model.encode_image(data)
else:
data = self.model.preprocess_text(data)
emb = self.model.encode_text(data)
return emb.detach().numpy().flatten()
data = self.__model.preprocess_text(data)
emb = self.__model.encode_text(data)
return emb.detach().numpy().flatten()

@property
def dimension(self):
Expand Down
1 change: 1 addition & 0 deletions gptcache/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"import_cohere",
"import_fasttext",
"import_huggingface",
"import_uform",
"import_torch",
"import_huggingface_hub",
"import_onnxruntime",
Expand Down
2 changes: 0 additions & 2 deletions tests/unit_tests/embedding/test_data2vec.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from io import BytesIO

import pytest
import requests

from gptcache.adapter.api import _get_model
from gptcache.embedding import Data2VecAudio


@pytest.mark.tags("L2")
def test_data2vec_audio():
url = "https://github.com/towhee-io/examples/releases/download/data/ah_yes.wav"
req = requests.get(url)
Expand Down
3 changes: 0 additions & 3 deletions tests/unit_tests/embedding/test_rwkv.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import pytest

from gptcache.adapter.api import _get_model
from gptcache.embedding import Rwkv


@pytest.mark.tags("L2")
def test_rwkv():
t = Rwkv("sgugger/rwkv-430M-pile")
data = t.to_embeddings("foo")
Expand Down
3 changes: 0 additions & 3 deletions tests/unit_tests/embedding/test_sbert.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import pytest

from gptcache.adapter.api import _get_model
from gptcache.embedding import SBERT


@pytest.mark.tags("L2")
def test_sbert():
t = SBERT("all-MiniLM-L6-v2")
dimension = t.dimension
Expand Down
31 changes: 31 additions & 0 deletions tests/unit_tests/embedding/test_uform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from io import BytesIO

import requests

from gptcache.embedding.uform import UForm
from gptcache.utils import import_uform, import_pillow
from gptcache.utils.error import ParamError

import_uform()
import_pillow()


def test_uform():
encoder = UForm()
embed = encoder.to_embeddings("Hello, world.")
assert len(embed) == encoder.dimension

url = 'https://raw.githubusercontent.com/zilliztech/GPTCache/main/docs/GPTCache.png'
image_bytes = requests.get(url).content
image_file = BytesIO(image_bytes)

encoder = UForm(embedding_type="image")
embed = encoder.to_embeddings(image_file)
assert len(embed) == encoder.dimension

is_exception = False
try:
UForm(embedding_type="foo")
except ParamError:
is_exception = True
assert is_exception

0 comments on commit f91a8da

Please sign in to comment.