Skip to content

Commit

Permalink
Add example for sqlalchemy (#67)
Browse files Browse the repository at this point in the history
Signed-off-by: shiyu22 <shiyu.chen@zilliz.com>
  • Loading branch information
shiyu22 authored Apr 5, 2023
1 parent b6ee20e commit d2a939b
Show file tree
Hide file tree
Showing 12 changed files with 341 additions and 23 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ coverage.xml
local_settings.py
db.sqlite3
db.sqlite3-journal
*.db

# Flask stuff:
instance/
Expand Down
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,12 @@ This module is created to extract embeddings from requests for similarity search
- **Cache Storage**:
**Cache Storage** is where the response from LLMs, such as ChatGPT, is stored. Cached responses are retrieved to assist in evaluating similarity and are returned to the requester if there is a good semantic match. At present, GPTCache supports SQLite and offers a universally accessible interface for extension of this module.
- [x] Support [SQLite](https://sqlite.org/docs.html).
- [ ] Support [PostgreSQL](https://www.postgresql.org/).
- [ ] Support [MySQL](https://www.mysql.com/).
- [x] Support [PostgreSQL](https://www.postgresql.org/).
- [x] Support [MySQL](https://www.mysql.com/).
- [x] Support [MariaDB](https://mariadb.org/).
- [x] Support [SQL Server](https://www.microsoft.com/en-us/sql-server/).
- [x] Support [Oracle](https://www.oracle.com/).
- [ ] Support [MongoDB](https://www.mongodb.com/).
- [ ] Support [MariaDB](https://mariadb.org/).
- [ ] Support [SQL Server](https://www.microsoft.com/en-us/sql-server/).
- [ ] Support [Oracle](https://www.oracle.com/).
- [ ] Support [Redis](https://redis.io/).
- [ ] Support [Minio](https://min.io/).
- [ ] Support [Habse](https://hbase.apache.org//).
Expand Down
48 changes: 48 additions & 0 deletions examples/mariadb_faiss_mock/mariadb_faiss_mock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import os

from gptcache.adapter import openai
from gptcache.core import cache, Config
from gptcache.cache.factory import get_ss_data_manager
from gptcache.similarity_evaluation.simple import SearchDistanceEvaluation
import numpy as np


d = 8


def mock_embeddings(data, **kwargs):
return np.random.random((d, )).astype('float32')


def run():
faiss_file = "faiss.index"
has_data = os.path.isfile(faiss_file)
data_manager = get_ss_data_manager("mariadb", "faiss",
dimension=d, max_size=8, clean_size=2, top_k=3)
cache.init(embedding_func=mock_embeddings,
data_manager=data_manager,
similarity_evaluation=SearchDistanceEvaluation(),
config=Config(
similarity_threshold=0,
),
)

mock_messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "foo"}
]
if not has_data:
for i in range(10):
question = f"foo{i}"
answer = f"receiver the foo {i}"
cache.data_manager.save(question, answer, cache.embedding_func(question))

answer = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=mock_messages,
)
print(answer)


if __name__ == '__main__':
run()
48 changes: 48 additions & 0 deletions examples/mssql_faiss_mock/mssql_faiss_mock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import os

from gptcache.adapter import openai
from gptcache.core import cache, Config
from gptcache.cache.factory import get_ss_data_manager
from gptcache.similarity_evaluation.simple import SearchDistanceEvaluation
import numpy as np


d = 8


def mock_embeddings(data, **kwargs):
return np.random.random((d, )).astype('float32')


def run():
faiss_file = "faiss.index"
has_data = os.path.isfile(faiss_file)
data_manager = get_ss_data_manager("sqlserver", "faiss",
dimension=d, max_size=8, clean_size=2, top_k=3)
cache.init(embedding_func=mock_embeddings,
data_manager=data_manager,
similarity_evaluation=SearchDistanceEvaluation(),
config=Config(
similarity_threshold=0,
),
)

mock_messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "foo"}
]
if not has_data:
for i in range(10):
question = f"foo{i}"
answer = f"receiver the foo {i}"
cache.data_manager.save(question, answer, cache.embedding_func(question))

answer = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=mock_messages,
)
print(answer)


if __name__ == '__main__':
run()
48 changes: 48 additions & 0 deletions examples/mysql_faiss_mock/mysql_faiss_mock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import os

from gptcache.adapter import openai
from gptcache.core import cache, Config
from gptcache.cache.factory import get_ss_data_manager
from gptcache.similarity_evaluation.simple import SearchDistanceEvaluation
import numpy as np


d = 8


def mock_embeddings(data, **kwargs):
return np.random.random((d, )).astype('float32')


def run():
faiss_file = "faiss.index"
has_data = os.path.isfile(faiss_file)
data_manager = get_ss_data_manager("mysql", "faiss",
dimension=d, max_size=8, clean_size=2, top_k=3)
cache.init(embedding_func=mock_embeddings,
data_manager=data_manager,
similarity_evaluation=SearchDistanceEvaluation(),
config=Config(
similarity_threshold=0,
),
)

mock_messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "foo"}
]
if not has_data:
for i in range(10):
question = f"foo{i}"
answer = f"receiver the foo {i}"
cache.data_manager.save(question, answer, cache.embedding_func(question))

answer = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=mock_messages,
)
print(answer)


if __name__ == '__main__':
run()
54 changes: 54 additions & 0 deletions examples/oracle_faiss_mock/oracle_faiss_mock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from gptcache.utils import import_cxoracle
import_cxoracle()

import cx_Oracle
import os

from gptcache.adapter import openai
from gptcache.core import cache, Config
from gptcache.cache.factory import get_ss_data_manager
from gptcache.similarity_evaluation.simple import SearchDistanceEvaluation
import numpy as np


d = 8


def mock_embeddings(data, **kwargs):
return np.random.random((d, )).astype('float32')


def run():
cx_Oracle.init_oracle_client(lib_dir="/Users/root/Downloads/instantclient_19_8")

faiss_file = "faiss.index"
has_data = os.path.isfile(faiss_file)
data_manager = get_ss_data_manager("oracle", "faiss",
dimension=d, max_size=8, clean_size=2, top_k=3)
cache.init(embedding_func=mock_embeddings,
data_manager=data_manager,
similarity_evaluation=SearchDistanceEvaluation(),
config=Config(
similarity_threshold=0,
),
)

mock_messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "foo"}
]
if not has_data:
for i in range(10):
question = f"foo{i}"
answer = f"receiver the foo {i}"
cache.data_manager.save(question, answer, cache.embedding_func(question))

answer = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=mock_messages,
)
print(answer)


if __name__ == '__main__':
run()
48 changes: 48 additions & 0 deletions examples/postgresql_faiss_mock/postgresql_faiss_mock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import os

from gptcache.adapter import openai
from gptcache.core import cache, Config
from gptcache.cache.factory import get_ss_data_manager
from gptcache.similarity_evaluation.simple import SearchDistanceEvaluation
import numpy as np


d = 8


def mock_embeddings(data, **kwargs):
return np.random.random((d, )).astype('float32')


def run():
faiss_file = "faiss.index"
has_data = os.path.isfile(faiss_file)
data_manager = get_ss_data_manager("postgresql", "faiss",
dimension=d, max_size=8, clean_size=2, top_k=3)
cache.init(embedding_func=mock_embeddings,
data_manager=data_manager,
similarity_evaluation=SearchDistanceEvaluation(),
config=Config(
similarity_threshold=0,
),
)

mock_messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "foo"}
]
if not has_data:
for i in range(10):
question = f"foo{i}"
answer = f"receiver the foo {i}"
cache.data_manager.save(question, answer, cache.embedding_func(question))

answer = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=mock_messages,
)
print(answer)


if __name__ == '__main__':
run()
4 changes: 3 additions & 1 deletion gptcache/cache/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .scalar_data import SQLDataBase, SQL_URL
from .vector_data import Milvus, Faiss, Chromadb
from ..utils.error import NotFoundStoreError, ParamError
from ..utils import import_sql_client


def get_data_manager(data_manager_name: str, **kwargs) -> DataManager:
Expand All @@ -26,7 +27,8 @@ def get_data_manager(data_manager_name: str, **kwargs) -> DataManager:
def _get_scalar_store(scalar_store: str, **kwargs):
if scalar_store in ["sqlite", "postgresql", "mysql", "mariadb", "sqlserver", "oracle"]:
sql_url = kwargs.pop("sql_url", SQL_URL[scalar_store])
store = SQLDataBase(url=sql_url)
import_sql_client(scalar_store)
store = SQLDataBase(url=sql_url, db_type=scalar_store)
else:
raise NotFoundStoreError("scalar store", scalar_store)
return store
Expand Down
12 changes: 6 additions & 6 deletions gptcache/cache/scalar_data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ def SQLDataBase(**kwargs):


SQL_URL = {
'sqlite': 'sqlite:///./gpt_cache.db',
'postgresql': 'postgresql+psycopg2://user:password@hostname:port/database_name',
'mysql': 'mysql+pymysql://user:password@hostname:port/database_name',
'mariadb': 'mariadb+pymysql://user:password@hostname:port/database_name',
'sqlserver': 'mssql+pyodbc://user:password@database_name',
'oracle': 'oracle+zxjdbc://user:password@hostname:port/database_name',
'sqlite': 'sqlite:///./sqlite.db',
'postgresql': 'postgresql+psycopg2://postgres:123456@127.0.0.1:5432/postgres',
'mysql': 'mysql+pymysql://root:123456@127.0.0.1:3306/mysql',
'mariadb': 'mariadb+pymysql://root:123456@127.0.0.1:3307/mysql',
'sqlserver': 'mssql+pyodbc://sa:Strongpsw_123@127.0.0.1:1434/msdb?driver=ODBC+Driver+17+for+SQL+Server',
'oracle': 'oracle+cx_oracle://oracle:123456@127.0.0.1:1521/?service_name=helowin&encoding=UTF-8&nencoding=UTF-8',
}
1 change: 1 addition & 0 deletions gptcache/cache/scalar_data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np

TABLE_NAME = 'cache_table'
TABLE_NAME_SEQ = 'cache_table_sequence'


class ScalarStorage(metaclass=ABCMeta):
Expand Down
41 changes: 31 additions & 10 deletions gptcache/cache/scalar_data/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

import numpy as np
from datetime import datetime
from sqlalchemy import func, create_engine, Column
from sqlalchemy.types import String, DateTime, VARBINARY, Integer
from sqlalchemy import func, create_engine, Column, Sequence
from sqlalchemy.types import String, DateTime, LargeBinary, Integer
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base

from .base import ScalarStorage, TABLE_NAME
from .base import ScalarStorage, TABLE_NAME, TABLE_NAME_SEQ

Base = declarative_base()

Expand All @@ -19,25 +19,45 @@ class CacheTable(Base):
"""
__tablename__ = TABLE_NAME

uid = Column(Integer, primary_key=True, index=True, autoincrement=True)
id = Column(String, nullable=False)
data = Column(String, nullable=False)
reply = Column(String, nullable=False)
uid = Column(Integer, primary_key=True, autoincrement=True)
id = Column(String(500), nullable=False)
data = Column(String(1000), nullable=False)
reply = Column(String(1000), nullable=False)
create_on = Column(DateTime, default=datetime.now)
last_access = Column(DateTime, default=datetime.now)
embedding_data = Column(VARBINARY, nullable=True)
embedding_data = Column(LargeBinary, nullable=True)
state = Column(Integer, default=0)


class CacheTableSequence(Base):
"""
cache_table_sequence
"""
__tablename__ = TABLE_NAME_SEQ

uid = Column(Integer, Sequence('id_seq', start=1), primary_key=True, autoincrement=True)
id = Column(String(500), nullable=False)
data = Column(String(1000), nullable=False)
reply = Column(String(1000), nullable=False)
create_on = Column(DateTime, default=datetime.now)
last_access = Column(DateTime, default=datetime.now)
embedding_data = Column(LargeBinary, nullable=True)
state = Column(Integer, default=0)


class SQLDataBase(ScalarStorage):
"""
Using sqlalchemy to manage SQLite, PostgreSQL, MySQL, MariaDB, SQL Server and Oracle.
"""
def __init__(self, url: str = 'sqlite:///./gpt_cache.db'):
def __init__(self, url: str = 'sqlite:///./gpt_cache.db', db_type: str = 'sqlite'):
self._url = url
self._engine = None
self._session = None
self._model = CacheTable
self._db_type = db_type
if self._db_type == 'oracle':
self._model = CacheTableSequence
else:
self._model = CacheTable
self.init()

def init(self):
Expand All @@ -50,6 +70,7 @@ def create(self):
self._model.__table__.create(bind=self._engine, checkfirst=True)

def insert(self, key, data, reply, embedding_data: np.ndarray = None):
embedding_data = embedding_data.tobytes()
model_obj = self._model(id=key, data=data, reply=reply, embedding_data=embedding_data)
self._session.add(model_obj)
self._session.commit()
Expand Down
Loading

0 comments on commit d2a939b

Please sign in to comment.