Skip to content

Commit

Permalink
🧑‍💻 refactor: Enhance Error Handling and Introduce Alternate Variable…
Browse files Browse the repository at this point in the history
…s for Configuration (#19)

* fix: correctly handle embedding errors

* chore: log query error

* refactor: add override environment variables to be used as default

* chore: rename OPENAI_BASEURL to RAG_OPENAI_BASEURL as standard
  • Loading branch information
danny-avila authored Apr 3, 2024
1 parent 3cec7fe commit 24d6b6d
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 10 deletions.
12 changes: 9 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ uvicorn main:app

The following environment variables are required to run the application:

- `OPENAI_API_KEY`: The API key for OpenAI API Embeddings (if using default settings).
- `RAG_OPENAI_API_KEY`: The API key for OpenAI API Embeddings (if using default settings).
- Note: `OPENAI_API_KEY` will work but `RAG_OPENAI_API_KEY` will override it in order to not conflict with LibreChat setting.
- `RAG_OPENAI_BASEURL`: (Optional) The base URL for your OpenAI API Embeddings
- `RAG_OPENAI_PROXY`: (Optional) Proxy for OpenAI API Embeddings
- `POSTGRES_DB`: (Optional) The name of the PostgreSQL database.
- `POSTGRES_USER`: (Optional) The username for connecting to the PostgreSQL database.
- `POSTGRES_PASSWORD`: (Optional) The password for connecting to the PostgreSQL database.
Expand All @@ -65,8 +68,11 @@ The following environment variables are required to run the application:
- huggingface: "sentence-transformers/all-MiniLM-L6-v2"
- huggingfacetei: "http://huggingfacetei:3000". Hugging Face TEI uses model defined on TEI service launch.
- ollama: "nomic-embed-text"
- `AZURE_OPENAI_API_KEY`: (Optional) The API key for Azure OpenAI service.
- `AZURE_OPENAI_ENDPOINT`: (Optional) The endpoint URL for Azure OpenAI service, including the resource. Example: `https://example-resource.azure.openai.com/`.
- `RAG_AZURE_OPENAI_API_KEY`: (Optional) The API key for Azure OpenAI service.
- Note: `AZURE_OPENAI_API_KEY` will work but `RAG_AZURE_OPENAI_API_KEY` will override it in order to not conflict with LibreChat setting.
- `RAG_AZURE_OPENAI_ENDPOINT`: (Optional) The endpoint URL for Azure OpenAI service, including the resource.
- Example: `https://example-resource.azure.openai.com`.
- Note: `AZURE_OPENAI_ENDPOINT` will work but `RAG_AZURE_OPENAI_ENDPOINT` will override it in order to not conflict with LibreChat setting.
- `HF_TOKEN`: (Optional) if needed for `huggingface` option.
- `OLLAMA_BASE_URL`: (Optional) defaults to `http://ollama:11434`.

Expand Down
23 changes: 18 additions & 5 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
load_dotenv(find_dotenv())


def get_env_variable(var_name: str, default_value: str = None) -> str:
def get_env_variable(var_name: str, default_value: str = None, required: bool = False) -> str:
value = os.getenv(var_name)
if value is None:
if default_value is None:
if default_value is None and required:
raise ValueError(f"Environment variable '{var_name}' not found.")
return default_value
return value
Expand Down Expand Up @@ -130,8 +130,13 @@ async def dispatch(self, request, call_next):
## Credentials

OPENAI_API_KEY = get_env_variable("OPENAI_API_KEY", "")
RAG_OPENAI_API_KEY = get_env_variable("RAG_OPENAI_API_KEY", OPENAI_API_KEY)
RAG_OPENAI_BASEURL = get_env_variable("RAG_OPENAI_BASEURL", None)
RAG_OPENAI_PROXY = get_env_variable("RAG_OPENAI_PROXY", None)
AZURE_OPENAI_API_KEY = get_env_variable("AZURE_OPENAI_API_KEY", "")
RAG_AZURE_OPENAI_API_KEY = get_env_variable("RAG_AZURE_OPENAI_API_KEY", AZURE_OPENAI_API_KEY)
AZURE_OPENAI_ENDPOINT = get_env_variable("AZURE_OPENAI_ENDPOINT", "")
RAG_AZURE_OPENAI_ENDPOINT = get_env_variable("RAG_AZURE_OPENAI_ENDPOINT", AZURE_OPENAI_ENDPOINT).rstrip("/")
HF_TOKEN = get_env_variable("HF_TOKEN", "")
OLLAMA_BASE_URL = get_env_variable("OLLAMA_BASE_URL", "http://ollama:11434")

Expand All @@ -140,10 +145,18 @@ async def dispatch(self, request, call_next):

def init_embeddings(provider, model):
if provider == "openai":
return OpenAIEmbeddings(model=model, api_key=OPENAI_API_KEY)
return OpenAIEmbeddings(
model=model,
api_key=RAG_OPENAI_API_KEY,
openai_api_base=RAG_OPENAI_BASEURL,
openai_proxy=RAG_OPENAI_PROXY
)
elif provider == "azure":
return AzureOpenAIEmbeddings(model=model,
api_key=AZURE_OPENAI_API_KEY) # AZURE_OPENAI_ENDPOINT is being grabbed from the environment
return AzureOpenAIEmbeddings(
model=model,
api_key=RAG_AZURE_OPENAI_API_KEY,
azure_endpoint=RAG_AZURE_OPENAI_ENDPOINT
)
elif provider == "huggingface":
return HuggingFaceEmbeddings(model_name=model, encode_kwargs={
'normalize_embeddings': True})
Expand Down
21 changes: 19 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ async def query_embeddings_by_file_id(body: QueryRequestBody, request: Request):

return authorized_documents
except Exception as e:
logger.error(e)
raise HTTPException(status_code=500, detail=str(e))


Expand Down Expand Up @@ -310,6 +311,8 @@ async def embed_local_file(document: StoreDocument, request: Request):

@app.post("/embed")
async def embed_file(request: Request, file_id: str = Form(...), file: UploadFile = File(...)):
response_status = True
response_message = "File processed successfully."
known_type = None
if not hasattr(request.state, 'user'):
user_id = "public"
Expand All @@ -335,11 +338,25 @@ async def embed_file(request: Request, file_id: str = Form(...), file: UploadFil
result = await store_data_in_vector_db(data, file_id, user_id)

if not result:
response_status = False
response_message = "Failed to process/store the file data."
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to process/store the file data.",
)
elif 'error' in result:
response_status = False
response_message = "Failed to process/store the file data."
if isinstance(result['error'], str):
response_message = result['error']
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An unspecified error occurred.",
)
except Exception as e:
response_status = False
response_message = f"Error during file processing: {str(e)}"
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Error during file processing: {str(e)}")
finally:
Expand All @@ -349,8 +366,8 @@ async def embed_file(request: Request, file_id: str = Form(...), file: UploadFil
logger.info(f"Failed to remove temporary file: {str(e)}")

return {
"status": True,
"message": "File processed successfully.",
"status": response_status,
"message": response_message,
"file_id": file_id,
"filename": file.filename,
"known_type": known_type,
Expand Down

0 comments on commit 24d6b6d

Please sign in to comment.