Skip to content

Commit

Permalink
feat: add ENABLE_LOCAL_WEB_FETCH to protect against SSRF attacks
Browse files Browse the repository at this point in the history
  • Loading branch information
cheahjs committed Apr 29, 2024
1 parent e8abaa8 commit 1c4e63f
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 1 deletion.
39 changes: 38 additions & 1 deletion backend/apps/rag/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@
)
from langchain.text_splitter import RecursiveCharacterTextSplitter

import validators
import urllib.parse
import socket


from pydantic import BaseModel
from typing import Optional
import mimetypes
Expand Down Expand Up @@ -84,6 +89,7 @@
CHUNK_SIZE,
CHUNK_OVERLAP,
RAG_TEMPLATE,
ENABLE_LOCAL_WEB_FETCH,
)

from constants import ERROR_MESSAGES
Expand Down Expand Up @@ -454,7 +460,7 @@ def query_collection_handler(
def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
try:
loader = WebBaseLoader(form_data.url)
loader = get_web_loader(form_data.url)
data = loader.load()

collection_name = form_data.collection_name
Expand All @@ -475,6 +481,37 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
)


def get_web_loader(url: str):
# Check if the URL is valid
if isinstance(validators.url(url), validators.ValidationError):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
if not ENABLE_LOCAL_WEB_FETCH:
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
parsed_url = urllib.parse.urlparse(url)
# Get IPv4 and IPv6 addresses
ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
# Check if any of the resolved addresses are private
# This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
for ip in ipv4_addresses:
if validators.ipv4(ip, private=True):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
for ip in ipv6_addresses:
if validators.ipv6(ip, private=True):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
return WebBaseLoader(url)


def resolve_hostname(hostname):
# Get address information
addr_info = socket.getaddrinfo(hostname, None)

# Extract IP addresses from address information
ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]

return ipv4_addresses, ipv6_addresses


def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:

text_splitter = RecursiveCharacterTextSplitter(
Expand Down
2 changes: 2 additions & 0 deletions backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,8 @@ def create_config_file(file_path):
RAG_OPENAI_API_BASE_URL = os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL)
RAG_OPENAI_API_KEY = os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY)

ENABLE_LOCAL_WEB_FETCH = os.getenv("ENABLE_LOCAL_WEB_FETCH", "False").lower() == "true"

####################################
# Transcribe
####################################
Expand Down
4 changes: 4 additions & 0 deletions backend/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,7 @@ def __str__(self) -> str:
EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding."

DB_NOT_SQLITE = "This feature is only available when running with SQLite databases."

INVALID_URL = (
"Oops! The URL you provided is invalid. Please double-check and try again."
)
1 change: 1 addition & 0 deletions backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ pandas
openpyxl
pyxlsb
xlrd
validators

opencv-python-headless
rapidocr-onnxruntime
Expand Down

0 comments on commit 1c4e63f

Please sign in to comment.