Skip to content

Commit

Permalink
feat: Add vector_similarity_threshold support within RagRetrievalConf…
Browse files Browse the repository at this point in the history
…ig in rag_store and rag_retrieval GA and preview versions

PiperOrigin-RevId: 700812116
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Nov 27, 2024
1 parent 47a5a6d commit 9402b3d
Show file tree
Hide file tree
Showing 10 changed files with 231 additions and 33 deletions.
8 changes: 8 additions & 0 deletions tests/unit/vertex_rag/test_rag_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,3 +508,11 @@
top_k=2,
filter=Filter(vector_distance_threshold=0.5),
)
TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG = RagRetrievalConfig(
top_k=2,
filter=Filter(vector_similarity_threshold=0.5),
)
TEST_RAG_RETRIEVAL_ERROR_CONFIG = RagRetrievalConfig(
top_k=2,
filter=Filter(vector_distance_threshold=0.5, vector_similarity_threshold=0.5),
)
9 changes: 9 additions & 0 deletions tests/unit/vertex_rag/test_rag_constants_preview.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,3 +581,12 @@
filter=Filter(vector_distance_threshold=0.5),
hybrid_search=HybridSearch(alpha=0.5),
)
TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG = RagRetrievalConfig(
top_k=2,
filter=Filter(vector_distance_threshold=0.5),
hybrid_search=HybridSearch(alpha=0.5),
)
TEST_RAG_RETRIEVAL_ERROR_CONFIG = RagRetrievalConfig(
top_k=2,
filter=Filter(vector_distance_threshold=0.5, vector_similarity_threshold=0.5),
)
31 changes: 31 additions & 0 deletions tests/unit/vertex_rag/test_rag_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,15 @@ def test_retrieval_query_rag_resources_success(self):
)
retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE)

@pytest.mark.usefixtures("retrieve_contexts_mock")
def test_retrieval_query_rag_resources_similarity_success(self):
response = rag.retrieval_query(
rag_resources=[tc.TEST_RAG_RESOURCE],
text=tc.TEST_QUERY_TEXT,
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG,
)
retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE)

@pytest.mark.usefixtures("rag_client_mock_exception")
def test_retrieval_query_failure(self):
with pytest.raises(RuntimeError) as e:
Expand Down Expand Up @@ -105,3 +114,25 @@ def test_retrieval_query_multiple_rag_resources(self):
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG,
)
e.match("Currently only support 1 RagResource")

def test_retrieval_query_similarity_multiple_rag_resources(self):
with pytest.raises(ValueError) as e:
rag.retrieval_query(
rag_resources=[tc.TEST_RAG_RESOURCE, tc.TEST_RAG_RESOURCE],
text=tc.TEST_QUERY_TEXT,
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG,
)
e.match("Currently only support 1 RagResource")

def test_retrieval_query_invalid_config_filter(self):
with pytest.raises(ValueError) as e:
rag.retrieval_query(
rag_resources=[tc.TEST_RAG_RESOURCE],
text=tc.TEST_QUERY_TEXT,
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_ERROR_CONFIG,
)
e.match(
"Only one of vector_distance_threshold or"
" vector_similarity_threshold can be specified at a time"
" in rag_retrieval_config."
)
36 changes: 36 additions & 0 deletions tests/unit/vertex_rag/test_rag_retrieval_preview.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,17 @@ def test_retrieval_query_rag_resources_config_success(self):
response, test_rag_constants_preview.TEST_RETRIEVAL_RESPONSE
)

@pytest.mark.usefixtures("retrieve_contexts_mock")
def test_retrieval_query_rag_resources_similarity_config_success(self):
response = rag.retrieval_query(
rag_resources=[test_rag_constants_preview.TEST_RAG_RESOURCE],
text=test_rag_constants_preview.TEST_QUERY_TEXT,
rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG,
)
retrieve_contexts_eq(
response, test_rag_constants_preview.TEST_RETRIEVAL_RESPONSE
)

@pytest.mark.usefixtures("retrieve_contexts_mock")
def test_retrieval_query_rag_resources_default_config_success(self):
response = rag.retrieval_query(
Expand Down Expand Up @@ -223,3 +234,28 @@ def test_retrieval_query_multiple_rag_resources_config(self):
rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG,
)
e.match("Currently only support 1 RagResource")

def test_retrieval_query_multiple_rag_resources_similarity_config(self):
with pytest.raises(ValueError) as e:
rag.retrieval_query(
rag_resources=[
test_rag_constants_preview.TEST_RAG_RESOURCE,
test_rag_constants_preview.TEST_RAG_RESOURCE,
],
text=test_rag_constants_preview.TEST_QUERY_TEXT,
rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG,
)
e.match("Currently only support 1 RagResource")

def test_retrieval_query_invalid_config_filter(self):
with pytest.raises(ValueError) as e:
rag.retrieval_query(
rag_resources=[test_rag_constants_preview.TEST_RAG_RESOURCE],
text=test_rag_constants_preview.TEST_QUERY_TEXT,
rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_ERROR_CONFIG,
)
e.match(
"Only one of vector_distance_threshold or"
" vector_similarity_threshold can be specified at a time"
" in rag_retrieval_config."
)
18 changes: 17 additions & 1 deletion tests/unit/vertex_rag/test_rag_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_retrieval_tool_invalid_name(self):
retrieval=rag.Retrieval(
source=rag.VertexRagStore(
rag_resources=[tc.TEST_RAG_RESOURCE_INVALID_NAME],
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG,
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG,
),
)
)
Expand All @@ -45,3 +45,19 @@ def test_retrieval_tool_multiple_rag_resources(self):
)
)
e.match("Currently only support 1 RagResource")

def test_retrieval_tool_invalid_config_filter(self):
with pytest.raises(ValueError) as e:
Tool.from_retrieval(
retrieval=rag.Retrieval(
source=rag.VertexRagStore(
rag_resources=[tc.TEST_RAG_RESOURCE],
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_ERROR_CONFIG,
)
)
)
e.match(
"Only one of vector_distance_threshold or"
" vector_similarity_threshold can be specified at a time"
" in rag_retrieval_config."
)
29 changes: 29 additions & 0 deletions tests/unit/vertex_rag/test_rag_store_preview.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,19 @@ def test_retrieval_tool_config_success(self):
)
)

def test_retrieval_tool_similarity_config_success(self):
with pytest.warns(DeprecationWarning):
Tool.from_retrieval(
retrieval=rag.Retrieval(
source=rag.VertexRagStore(
rag_corpora=[
test_rag_constants_preview.TEST_RAG_CORPUS_ID,
],
rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG,
),
)
)

def test_retrieval_tool_invalid_name(self):
with pytest.raises(ValueError) as e:
Tool.from_retrieval(
Expand Down Expand Up @@ -137,3 +150,19 @@ def test_retrieval_tool_multiple_rag_resources_config(self):
)
)
e.match("Currently only support 1 RagResource")

def test_retrieval_tool_invalid_config_filter(self):
with pytest.raises(ValueError) as e:
Tool.from_retrieval(
retrieval=rag.Retrieval(
source=rag.VertexRagStore(
rag_resources=[test_rag_constants_preview.TEST_RAG_RESOURCE],
rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_ERROR_CONFIG,
)
)
)
e.match(
"Only one of vector_distance_threshold or"
" vector_similarity_threshold can be specified at a time"
" in rag_retrieval_config."
)
33 changes: 28 additions & 5 deletions vertexai/preview/rag/rag_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,12 @@ def retrieval_query(
else:
# If rag_retrieval_config is specified, check for missing parameters.
api_retrival_config = aiplatform_v1beta1.RagRetrievalConfig()
api_retrival_config.top_k = (
rag_retrieval_config.top_k
if rag_retrieval_config.top_k
else similarity_top_k
)
# Set top_k to config value if specified
if rag_retrieval_config.top_k:
api_retrival_config.top_k = rag_retrieval_config.top_k
else:
api_retrival_config.top_k = similarity_top_k
# Set alpha to config value if specified
if (
rag_retrieval_config.hybrid_search
and rag_retrieval_config.hybrid_search.alpha
Expand All @@ -204,6 +205,19 @@ def retrieval_query(
)
else:
api_retrival_config.hybrid_search.alpha = vector_search_alpha
# Check if both vector_distance_threshold and vector_similarity_threshold
# are specified.
if (
rag_retrieval_config.filter
and rag_retrieval_config.filter.vector_distance_threshold
and rag_retrieval_config.filter.vector_similarity_threshold
):
raise ValueError(
"Only one of vector_distance_threshold or"
" vector_similarity_threshold can be specified at a time"
" in rag_retrieval_config."
)
# Set vector_distance_threshold to config value if specified
if (
rag_retrieval_config.filter
and rag_retrieval_config.filter.vector_distance_threshold
Expand All @@ -215,6 +229,15 @@ def retrieval_query(
api_retrival_config.filter.vector_distance_threshold = (
vector_distance_threshold
)
# Set vector_similarity_threshold to config value if specified
if (
rag_retrieval_config.filter
and rag_retrieval_config.filter.vector_similarity_threshold
):
api_retrival_config.filter.vector_similarity_threshold = (
rag_retrieval_config.filter.vector_similarity_threshold
)

query = aiplatform_v1beta1.RagQuery(
text=text,
rag_retrieval_config=api_retrival_config,
Expand Down
40 changes: 33 additions & 7 deletions vertexai/preview/rag/rag_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,16 +167,42 @@ def __init__(
else:
# If rag_retrieval_config is specified, check for missing parameters.
api_retrival_config = aiplatform_v1beta1.RagRetrievalConfig()
if not rag_retrieval_config.top_k:
# Set top_k to config value if specified
if rag_retrieval_config.top_k:
api_retrival_config.top_k = rag_retrieval_config.top_k
else:
api_retrival_config.top_k = similarity_top_k
# Check if both vector_distance_threshold and vector_similarity_threshold
# are specified.
if (
not rag_retrieval_config.filter
or not rag_retrieval_config.filter.vector_distance_threshold
rag_retrieval_config.filter
and rag_retrieval_config.filter.vector_distance_threshold
and rag_retrieval_config.filter.vector_similarity_threshold
):
api_retrival_config.filter = (
aiplatform_v1beta1.RagRetrievalConfig.Filter(
vector_distance_threshold=vector_distance_threshold
),
raise ValueError(
"Only one of vector_distance_threshold or"
" vector_similarity_threshold can be specified at a time"
" in rag_retrieval_config."
)
# Set vector_distance_threshold to config value if specified
if (
rag_retrieval_config.filter
and rag_retrieval_config.filter.vector_distance_threshold
):
api_retrival_config.filter.vector_distance_threshold = (
rag_retrieval_config.filter.vector_distance_threshold
)
else:
api_retrival_config.filter.vector_distance_threshold = (
vector_distance_threshold
)
# Set vector_similarity_threshold to config value if specified
if (
rag_retrieval_config.filter
and rag_retrieval_config.filter.vector_similarity_threshold
):
api_retrival_config.filter.vector_similarity_threshold = (
rag_retrieval_config.filter.vector_similarity_threshold
)

if rag_resources:
Expand Down
27 changes: 19 additions & 8 deletions vertexai/rag/rag_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,16 +108,27 @@ def retrieval_query(
api_retrival_config = aiplatform_v1.RagRetrievalConfig()
else:
# If rag_retrieval_config is specified, check for missing parameters.
api_retrival_config = aiplatform_v1.RagRetrievalConfig(
top_k=rag_retrieval_config.top_k,
)
api_retrival_config = aiplatform_v1.RagRetrievalConfig()
api_retrival_config.top_k = rag_retrieval_config.top_k
# Set vector_distance_threshold to config value if specified
if rag_retrieval_config.filter:
api_retrival_config.filter = aiplatform_v1.RagRetrievalConfig.Filter(
vector_distance_threshold=rag_retrieval_config.filter.vector_distance_threshold
# Check if both vector_distance_threshold and vector_similarity_threshold
# are specified.
if (
rag_retrieval_config.filter
and rag_retrieval_config.filter.vector_distance_threshold
and rag_retrieval_config.filter.vector_similarity_threshold
):
raise ValueError(
"Only one of vector_distance_threshold or"
" vector_similarity_threshold can be specified at a time"
" in rag_retrieval_config."
)
api_retrival_config.filter.vector_distance_threshold = (
rag_retrieval_config.filter.vector_distance_threshold
)
else:
api_retrival_config.filter = aiplatform_v1.RagRetrievalConfig.Filter(
vector_distance_threshold=None
api_retrival_config.filter.vector_similarity_threshold = (
rag_retrieval_config.filter.vector_similarity_threshold
)

query = aiplatform_v1.RagQuery(
Expand Down
33 changes: 21 additions & 12 deletions vertexai/rag/rag_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,20 +103,29 @@ def __init__(
)

# If rag_retrieval_config is not specified, set it to default values.
if not rag_retrieval_config:
api_retrival_config = aiplatform_v1.RagRetrievalConfig()
else:
# If rag_retrieval_config is specified, check for missing parameters.
api_retrival_config = aiplatform_v1.RagRetrievalConfig(
top_k=rag_retrieval_config.top_k,
)
api_retrival_config = aiplatform_v1.RagRetrievalConfig()
# If rag_retrieval_config is specified, populate the default config.
if rag_retrieval_config:
api_retrival_config.top_k = rag_retrieval_config.top_k
# Set vector_distance_threshold to config value if specified
if rag_retrieval_config.filter:
api_retrival_config.filter = aiplatform_v1.RagRetrievalConfig.Filter(
vector_distance_threshold=rag_retrieval_config.filter.vector_distance_threshold
# Check if both vector_distance_threshold and
# vector_similarity_threshold are specified.
if (
rag_retrieval_config.filter
and rag_retrieval_config.filter.vector_distance_threshold
and rag_retrieval_config.filter.vector_similarity_threshold
):
raise ValueError(
"Only one of vector_distance_threshold or"
" vector_similarity_threshold can be specified at a time"
" in rag_retrieval_config."
)
api_retrival_config.filter.vector_distance_threshold = (
rag_retrieval_config.filter.vector_distance_threshold
)
else:
api_retrival_config.filter = aiplatform_v1.RagRetrievalConfig.Filter(
vector_distance_threshold=None
api_retrival_config.filter.vector_similarity_threshold = (
rag_retrieval_config.filter.vector_similarity_threshold
)

if rag_resources:
Expand Down

0 comments on commit 9402b3d

Please sign in to comment.