Skip to content

Commit

Permalink
Abstract scalar/vector store and vector index
Browse files Browse the repository at this point in the history
Signed-off-by: SimFG <bang.fu@zilliz.com>
  • Loading branch information
SimFG committed Mar 29, 2023
1 parent f96c4ff commit 365cad1
Show file tree
Hide file tree
Showing 21 changed files with 520 additions and 87 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,5 @@ dmypy.json
**/data_map.txt
**/faiss.index
**/sqlite.db
**/example.py
**/example.db
11 changes: 6 additions & 5 deletions example/benchmark/benchmark_sf_towhee.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

from gpt_cache.view import openai
from gpt_cache.core import cache
from gpt_cache.cache.factory import get_data_manager
from gpt_cache.similarity_evaluation.faiss import faiss_evaluation
from gpt_cache.cache.factory import get_si_data_manager
from gpt_cache.similarity_evaluation.simple import pair_evaluation
from gpt_cache.embedding.towhee import Towhee


Expand All @@ -13,10 +13,11 @@ def run():
mock_data = json.load(mock_file)

towhee = Towhee()
data_manager = get_si_data_manager("sqlite", "faiss", dimension=towhee.dimension())
cache.init(embedding_func=towhee.to_embeddings,
data_manager=get_data_manager("sqlite_faiss", dimension=towhee.dimension()),
evaluation_func=faiss_evaluation,
similarity_threshold=50,
data_manager=data_manager,
evaluation_func=pair_evaluation,
similarity_threshold=0.5,
similarity_positive=False)

i = 0
Expand Down
4 changes: 2 additions & 2 deletions example/map/map_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def run():
]

# you should OPEN it if you FIRST run it
# for i in range(10):
# cache.data_manager.save(f"receiver the foo {i}", cache.embedding_func(f"foo{i}"))
for i in range(10):
cache.data_manager.save(f"receiver the foo {i}", cache.embedding_func(f"foo{i}"))
answer = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=mock_messages,
Expand Down
9 changes: 5 additions & 4 deletions example/sf_mock/sf_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from gpt_cache.view import openai
from gpt_cache.core import cache, Config
from gpt_cache.cache.factory import get_data_manager
from gpt_cache.similarity_evaluation.faiss import faiss_evaluation
from gpt_cache.cache.factory import get_si_data_manager
from gpt_cache.similarity_evaluation.simple import pair_evaluation
import numpy as np


Expand All @@ -15,9 +15,10 @@ def mock_embeddings(data, **kwargs):


def run():
data_manager = get_si_data_manager("sqlite", "faiss", dimension=d, max_size=8, clean_size=2, top_k=3)
cache.init(embedding_func=mock_embeddings,
data_manager=get_data_manager("sqlite_faiss", dimension=d, max_size=8, clean_size=2, top_k=3),
evaluation_func=faiss_evaluation,
data_manager=data_manager,
evaluation_func=pair_evaluation,
similarity_threshold=10000,
similarity_positive=False,
config=Config(),
Expand Down
11 changes: 6 additions & 5 deletions example/sf_towhee/sf_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,22 @@

from gpt_cache.view import openai
from gpt_cache.core import cache
from gpt_cache.cache.factory import get_data_manager
from gpt_cache.similarity_evaluation.faiss import faiss_evaluation
from gpt_cache.cache.factory import get_si_data_manager
from gpt_cache.similarity_evaluation.simple import pair_evaluation
from gpt_cache.embedding.towhee import Towhee


def run():
towhee = Towhee()
data_manager = get_si_data_manager("sqlite", "faiss", dimension=towhee.dimension(), max_size=2000)
cache.init(embedding_func=towhee.to_embeddings,
data_manager=get_data_manager("sqlite_faiss", dimension=towhee.dimension(), max_size=2000),
evaluation_func=faiss_evaluation,
data_manager=data_manager,
evaluation_func=pair_evaluation,
similarity_threshold=10000,
similarity_positive=False)

# you should OPEN it if you FIRST run it
# cache.data_manager.save("chatgpt is a good application", cache.embedding_func("what do you think about chatgpt"))
cache.data_manager.save("chatgpt is a good application", cache.embedding_func("what do you think about chatgpt"))

# distance 77
mock_messages = [
Expand Down
48 changes: 48 additions & 0 deletions example/sqlite_milvus_mock/sqlite_milvus_mock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from gpt_cache.view import openai
from gpt_cache.core import cache, Config
from gpt_cache.cache.factory import get_ss_data_manager
from gpt_cache.similarity_evaluation.simple import pair_evaluation
import numpy as np


d = 8


def mock_embeddings(data, **kwargs):
return np.random.random((1, d)).astype('float32')


def run():
# milvus
data_manager = get_ss_data_manager("sqlite", "milvus", dimension=d, max_size=8, clean_size=2)
# milvus cloud
# data_manager = get_ss_data_manager("sqlite", "milvus", dimension=d, max_size=8, clean_size=2,
# host="xxx.zillizcloud.com",
# port=19530,
# user="xxx", password="xxx", is_https=True,
# )
cache.init(embedding_func=mock_embeddings,
data_manager=data_manager,
evaluation_func=pair_evaluation,
similarity_threshold=10000,
similarity_positive=False,
config=Config(),
)

mock_messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "foo"}
]
# you should OPEN it if you FIRST run it
for i in range(10):
cache.data_manager.save(f"receiver the foo {i}", cache.embedding_func("foo"))

answer = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=mock_messages,
)
print(answer)


if __name__ == '__main__':
run()
96 changes: 86 additions & 10 deletions gpt_cache/cache/data_manager.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import hashlib
import threading
import time
from abc import abstractmethod, ABCMeta
import pickle

import cachetools

from .scalar_data.sqllite3 import SQLite
from .scalar_data.scalar_store import ScalarStore
from .vector_data.faiss import Faiss
from .vector_data.vector_store import VectorStore
from .vector_data.vector_index import VectorIndex


class DataManager(metaclass=ABCMeta):
Expand Down Expand Up @@ -65,7 +66,7 @@ def close(self):

def sha_data(data):
m = hashlib.sha1()
m.update(data.tobytes())
m.update(data.astype('float32').tobytes())
return m.hexdigest()


Expand Down Expand Up @@ -93,21 +94,20 @@ def init(self, **kwargs):
self.f = Faiss(self.index_path, self.dimension, self.top_k)

def rebuild_index(self, all_data, top_k=1):
print("rebuild index")
bak = Faiss(self.index_path, self.dimension, top_k=top_k, skip_file=True)
bak.mult_add(all_data)
self.f = bak
self.clean_cache_thread = None

def save(self, data, embedding_data, **kwargs):
if self.cur_size >= self.max_size and self.clean_cache_thread is None:
self.s.clean_cache_func(self.clean_size)
self.s.clean_cache(self.clean_size)
all_data = self.s.select_all_embedding_data()
self.cur_size = len(all_data)
self.rebuild_index(all_data, kwargs.get("top_k", 1))
self.rebuild_index(all_data, self.top_k)
# TODO async
# self.clean_cache_thread = threading.Thread(target=self.rebuild_index,
# args=(all_data, kwargs.get("top_k", 1)),
# args=(all_data, self.top_k),
# daemon=True)
# self.clean_cache_thread.start()

Expand All @@ -119,7 +119,7 @@ def save(self, data, embedding_data, **kwargs):
def get_scalar_data(self, search_data, **kwargs):
distance, vector_data = search_data
key = sha_data(vector_data)
return self.s.select(key)
return self.s.select_data(key)

def search(self, embedding_data, **kwargs):
return self.f.search(embedding_data)
Expand All @@ -130,5 +130,81 @@ def close(self):


# SVDataManager scalar store and vector store
class SVDataManager(DataManager):
pass
class SSDataManager(DataManager):
s: ScalarStore
v: VectorStore

def __init__(self, max_size, clean_size, s, v):
self.max_size = max_size
self.cur_size = 0
self.clean_size = clean_size
self.s = s
self.v = v

def init(self, **kwargs):
self.s.init(**kwargs)
self.v.init(**kwargs)
self.cur_size = self.s.count()

def save(self, data, embedding_data, **kwargs):
if self.cur_size >= self.max_size:
ids = self.s.clean_cache(self.clean_size)
self.cur_size = self.s.count()
self.v.delete(ids)
key = sha_data(embedding_data)
self.s.insert(key, data, embedding_data)
self.v.add(key, embedding_data)
self.cur_size += 1

def get_scalar_data(self, search_data, **kwargs):
distance, vector_data = search_data
key = sha_data(vector_data)
return self.s.select_data(key)

def search(self, embedding_data, **kwargs):
return self.v.search(embedding_data)

def close(self):
self.s.close()
self.v.close()


# SIDataManager scalar store and vector index
class SIDataManager(DataManager):
s: ScalarStore
v: VectorIndex

def __init__(self, max_size, clean_size, s, v):
self.max_size = max_size
self.cur_size = 0
self.clean_size = clean_size
self.s = s
self.v = v

def init(self, **kwargs):
self.s.init(**kwargs)
self.v.init(**kwargs)
self.cur_size = self.s.count()

def save(self, data, embedding_data, **kwargs):
if self.cur_size >= self.max_size:
self.s.clean_cache(self.clean_size)
all_data = self.s.select_all_embedding_data()
self.cur_size = len(all_data)
self.v = self.v.rebuild_index(all_data)
key = sha_data(embedding_data)
self.s.insert(key, data, embedding_data)
self.v.add(key, embedding_data)
self.cur_size += 1

def get_scalar_data(self, search_data, **kwargs):
distance, vector_data = search_data
key = sha_data(vector_data)
return self.s.select_data(key)

def search(self, embedding_data, **kwargs):
return self.v.search(embedding_data)

def close(self):
self.s.close()
self.v.close()
97 changes: 81 additions & 16 deletions gpt_cache/cache/factory.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,90 @@
from gpt_cache.cache.data_manager import DataManager
from .data_manager import DataManager, SIDataManager, SSDataManager
from .scalar_data.sqllite3 import SQLite
from .vector_data.faiss import Faiss
from .vector_data.milvus import Milvus


def get_data_manager(name: str, **kwargs) -> DataManager:
if name == "map":
from gpt_cache.cache.data_manager import MapDataManager
def get_data_manager(data_manager_name: str, **kwargs) -> DataManager:
if data_manager_name == "map":
from .data_manager import MapDataManager

return MapDataManager(kwargs.get("data_path", "data_map.txt"),
kwargs.get("max_size", 100))
elif name == "sqlite_faiss":
from gpt_cache.cache.data_manager import SFDataManager

dimension = kwargs.get("dimension", 0)
if dimension <= 0:
raise ValueError(f"the sqlite_faiss data manager should set the 'dimension' parameter greater than zero, "
f"current: {dimension}")
top_k = kwargs.get("top_k", 1)
sqlite_path = kwargs.get("sqlite_path", "sqlite.db")
index_path = kwargs.get("index_path", "faiss.index")
elif data_manager_name == "scalar_vector":
scalar_store = kwargs.get("scalar_store", None)
vector_store = kwargs.get("vector_store", None)
max_size = kwargs.get("max_size", 1000)
clean_size = kwargs.get("clean_size", int(max_size * 0.2))
if scalar_store is None or vector_store is None:
raise ValueError(f"Missing scalar_store or vector_store parameter for scalar_vector")
return SSDataManager(max_size, clean_size, scalar_store, vector_store)
elif data_manager_name == "scalar_vector_index":
scalar_store = kwargs.get("scalar_store", None)
vector_index = kwargs.get("vector_index", None)
max_size = kwargs.get("max_size", 1000)
clean_size = kwargs.get("clean_size", int(max_size * 0.2))
if scalar_store is None or vector_index is None:
raise ValueError(f"Missing scalar_store or vector_index parameter for scalar_vector_index")
return SIDataManager(max_size, clean_size, scalar_store, vector_index)
# elif data_manager_name == "sqlite_faiss":
# from .data_manager import SFDataManager
#
# dimension = kwargs.get("dimension", 0)
# if dimension <= 0:
# raise ValueError(f"the sqlite_faiss data manager should set the 'dimension' parameter greater than zero, "
# f"current: {dimension}")
# top_k = kwargs.get("top_k", 1)
# sqlite_path = kwargs.get("sqlite_path", "sqlite.db")
# index_path = kwargs.get("index_path", "faiss.index")
# max_size = kwargs.get("max_size", 1000)
# clean_size = kwargs.get("clean_size", int(max_size * 0.2))
# clean_cache_strategy = kwargs.get("clean_cache_strategy", "least_accessed_data")
# return SFDataManager(sqlite_path, index_path, dimension, top_k, max_size, clean_size, clean_cache_strategy)
else:
raise ValueError(f"Unsupported data manager: {data_manager_name}")


def _get_scalar_store(scalar_store: str, **kwargs):
if scalar_store == "sqlite":
sqlite_path = kwargs.get("sqlite_path", "sqlite.db")
clean_cache_strategy = kwargs.get("clean_cache_strategy", "least_accessed_data")
return SFDataManager(sqlite_path, index_path, dimension, top_k, max_size, clean_size, clean_cache_strategy)
store = SQLite(sqlite_path, clean_cache_strategy)
else:
raise ValueError(f"Unsupported data manager: {name}")
raise ValueError(f"Unsupported scalar store: {scalar_store}")
return store


def _get_common_params(**kwargs):
max_size = kwargs.get("max_size", 1000)
clean_size = kwargs.get("clean_size", int(max_size * 0.2))
top_k = kwargs.get("top_k", 1)
dimension = kwargs.get("dimension", 0)
if dimension <= 0:
raise ValueError(f"the data manager should set the 'dimension' parameter greater than zero, "
f"current: {dimension}")
return max_size, clean_size, dimension, top_k


# scalar_store + vector_store
def get_ss_data_manager(scalar_store: str, vector_store: str, **kwargs):
max_size, clean_size, dimension, top_k = _get_common_params(**kwargs)
scalar = _get_scalar_store(scalar_store, **kwargs)
if vector_store == "milvus":
vector = Milvus(collection_name="gpt_cache", dim=dimension, top_k=top_k, **kwargs)
else:
raise ValueError(f"Unsupported vector store: {vector_store}")
return SSDataManager(max_size, clean_size, scalar, vector)


# scalar_store + vector_index
def get_si_data_manager(scalar_store: str, vector_index: str, **kwargs):
max_size, clean_size, dimension, top_k = _get_common_params(**kwargs)
store = _get_scalar_store(scalar_store, **kwargs)

if vector_index == "faiss":
index_path = kwargs.get("index_path", "faiss.index")
index = Faiss(index_path, dimension, top_k)
else:
raise ValueError(f"Unsupported vector index: {vector_index}")

return SIDataManager(max_size, clean_size, store, index)
Loading

0 comments on commit 365cad1

Please sign in to comment.