From b79f76858cd02c1ca7c2b6207f49803479e2ebe8 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Tue, 19 Mar 2024 18:03:05 -0400 Subject: [PATCH] feat: add huggingface embedding options --- README.md | Bin 6396 -> 7058 bytes config.py | 34 ++++++++++++++++++++++++++++++++-- requirements.txt | 1 + 3 files changed, 33 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index f76140b41b9c03250d463ea9f0d7dda0b4058961..e81b8dec3139b77880877204e57d0c133ca85de2 100644 GIT binary patch delta 613 zcmaKq%}T>i5QUFhL6B~`5=vHrMyVo#?xghxlPXnQxM{0R+i262CZ*!0?_gg<@FCpz zNG|*)4JcaV!es8bXU?57H=pH~{nwA;hO3gAs>;_9vBEl2O((iiLzk-SLN_24&g!$M zur>`8DpgN`MpQbAH6$OX1@3^vIo{J)T~^3hR~^pp;W^rbGoe0Y-yIR&2GTlhQEl{FE=xrz_J60c}38awzwzQ&U{b$n0qS^B@dq3^`zT+TMw+5fM z%jFyCj@V|+ZMxb8PrzGr)z9rN9>qfSDaGcvi-bvr+9sO|j+jb}yu~!TB4Fp|Jdezk zr?PgKoXybHW5$0#(b4x9PIV098r$dCuJQV2Y@Pj3v**HUHDsN~`8S%|YQU;6wCME>x diff --git a/config.py b/config.py index 1e5ff507..189a381f 100644 --- a/config.py +++ b/config.py @@ -3,6 +3,7 @@ import os import logging from dotenv import find_dotenv, load_dotenv +from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_openai import OpenAIEmbeddings from store_factory import get_vector_store @@ -36,6 +37,8 @@ def get_env_variable(var_name: str, default_value: str = None) -> str: CONNECTION_STRING = f"postgresql+psycopg2://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{DB_HOST}:{DB_PORT}/{POSTGRES_DB}" DSN = f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{DB_HOST}:{DB_PORT}/{POSTGRES_DB}" +## Logging + logger = logging.getLogger() debug_mode = get_env_variable("DEBUG_RAG_API", "False").lower() == "true" @@ -49,8 +52,35 @@ def get_env_variable(var_name: str, default_value: str = None) -> str: handler.setFormatter(formatter) logger.addHandler(handler) -OPENAI_API_KEY = get_env_variable("OPENAI_API_KEY") -embeddings = OpenAIEmbeddings() +## Credentials + +OPENAI_API_KEY = get_env_variable("OPENAI_API_KEY", "") +HF_TOKEN = get_env_variable("HF_TOKEN", "") + +## Embeddings + +def init_embeddings(provider, model): + if provider == "openai": + return OpenAIEmbeddings(model=model) + elif provider == "huggingface": + return HuggingFaceEmbeddings(model_name=model, encode_kwargs={'normalize_embeddings': True}) + else: + raise ValueError(f"Unsupported embeddings provider: {provider}") + +EMBEDDINGS_PROVIDER = get_env_variable("EMBEDDINGS_PROVIDER", "openai").lower() + +if EMBEDDINGS_PROVIDER == "openai": + EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "text-embedding-3-small") +elif EMBEDDINGS_PROVIDER == "huggingface": + EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "sangmini/msmarco-cotmae-MiniLM-L12_en-ko-ja") +else: + raise ValueError(f"Unsupported embeddings provider: {EMBEDDINGS_PROVIDER}") + +embeddings = init_embeddings(EMBEDDINGS_PROVIDER, EMBEDDINGS_MODEL) + +logger.info(f"Initialized embeddings of type: {type(embeddings)}") + +## Vector store vector_store = get_vector_store( connection_string=CONNECTION_STRING, diff --git a/requirements.txt b/requirements.txt index a0c1e154..a21c6da0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,3 +19,4 @@ pypandoc==1.13 python-jose==3.3.0 asyncpg==0.29.0 python-multipart==0.0.9 +sentence_transformers==2.5.1