Skip to content

Commit

Permalink
feat: Add vector search alpha to rag retrieval for hybrid search ranking
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 670703417
  • Loading branch information
speedstorm1 authored and copybara-github committed Sep 3, 2024
1 parent 37627de commit 6624ebe
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
1 change: 1 addition & 0 deletions tests/unit/vertex_rag/test_rag_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def test_retrieval_query_rag_resources_success(self):
text=tc.TEST_QUERY_TEXT,
similarity_top_k=2,
vector_distance_threshold=0.5,
vector_search_alpha=0.5,
)
retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE)

Expand Down
14 changes: 13 additions & 1 deletion vertexai/preview/rag/rag_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def retrieval_query(
rag_corpora: Optional[List[str]] = None,
similarity_top_k: Optional[int] = 10,
vector_distance_threshold: Optional[float] = 0.3,
vector_search_alpha: Optional[float] = 0.5,
) -> RetrieveContextsResponse:
"""Retrieve top k relevant docs/chunks.
Expand All @@ -54,6 +55,7 @@ def retrieval_query(
)],
similarity_top_k=2,
vector_distance_threshold=0.5,
vector_search_alpha=0.5,
)
```
Expand All @@ -67,6 +69,10 @@ def retrieval_query(
similarity_top_k: The number of contexts to retrieve.
vector_distance_threshold: Optional. Only return contexts with vector
distance smaller than the threshold.
vector_search_alpha: Optional. Controls the weight between dense and
sparse vector search results. The range is [0, 1], where 0 means
sparse vector search only and 1 means dense vector search only.
The default value is 0.5.
Returns:
RetrieveContextsResonse.
Expand Down Expand Up @@ -111,7 +117,13 @@ def retrieval_query(
)

vertex_rag_store.vector_distance_threshold = vector_distance_threshold
query = RagQuery(text=text, similarity_top_k=similarity_top_k)
query = RagQuery(
text=text,
similarity_top_k=similarity_top_k,
ranking=RagQuery.Ranking(
alpha=vector_search_alpha,
),
)
request = RetrieveContextsRequest(
vertex_rag_store=vertex_rag_store,
parent=parent,
Expand Down

0 comments on commit 6624ebe

Please sign in to comment.