Skip to content

Commit

Permalink
feat: Add fraction_leaf_nodes_to_search_override. Add support for p…
Browse files Browse the repository at this point in the history
…rivate endpoint in `find_neighbors`.

PiperOrigin-RevId: 592387105
  • Loading branch information
lingyinw authored and copybara-github committed Dec 20, 2023
1 parent 77ee692 commit cd31c13
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -225,15 +225,15 @@ def create(
cls,
display_name: str,
network: Optional[str] = None,
public_endpoint_enabled: Optional[bool] = False,
public_endpoint_enabled: bool = False,
description: Optional[str] = None,
labels: Optional[Dict[str, str]] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
sync: bool = True,
enable_private_service_connect: Optional[bool] = False,
enable_private_service_connect: bool = False,
project_allowlist: Optional[Sequence[str]] = None,
encryption_spec_key_name: Optional[str] = None,
) -> "MatchingEngineIndexEndpoint":
Expand Down Expand Up @@ -367,15 +367,15 @@ def _create(
cls,
display_name: str,
network: Optional[str] = None,
public_endpoint_enabled: Optional[bool] = False,
public_endpoint_enabled: bool = False,
description: Optional[str] = None,
labels: Optional[Dict[str, str]] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
sync: bool = True,
enable_private_service_connect: Optional[bool] = False,
enable_private_service_connect: bool = False,
project_allowlist: Optional[Sequence[str]] = None,
encryption_spec_key_name: Optional[str] = None,
) -> "MatchingEngineIndexEndpoint":
Expand Down Expand Up @@ -1139,21 +1139,23 @@ def find_neighbors(
deployed_index_id: str,
queries: List[List[float]],
num_neighbors: int = 10,
filter: Optional[List[Namespace]] = [],
filter: Optional[List[Namespace]] = None,
per_crowding_attribute_neighbor_count: Optional[int] = None,
approx_num_neighbors: Optional[int] = None,
fraction_leaf_nodes_to_search_override: Optional[float] = None,
return_full_datapoint: bool = False,
numeric_filter: Optional[List[NumericNamespace]] = [],
numeric_filter: Optional[List[NumericNamespace]] = None,
) -> List[List[MatchNeighbor]]:
"""Retrieves nearest neighbors for the given embedding queries on the specified deployed index which is deployed to public endpoint.
"""Retrieves nearest neighbors for the given embedding queries on the
specified deployed index which is deployed to either public or private
endpoint.
```
Example usage:
my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
index_endpoint_name='projects/123/locations/us-central1/index_endpoint/my_index_id'
index_endpoint_name='projects/123/locations/us-central1/index_endpoint/my_index_endpoint_id'
)
my_index_endpoint.find_neighbors(deployed_index_id="public_test1", queries= [[1, 1]],)
my_index_endpoint.find_neighbors(deployed_index_id="deployed_index_id", queries= [[1, 1]],)
```
Args:
deployed_index_id (str):
Expand Down Expand Up @@ -1203,8 +1205,15 @@ def find_neighbors(
"""

if not self._public_match_client:
raise ValueError(
"Please make sure index has been deployed to public endpoint,and follow the example usage to call this method."
# Private endpoint
return self.match(
deployed_index_id=deployed_index_id,
queries=queries,
num_neighbors=num_neighbors,
filter=filter,
per_crowding_attribute_num_neighbors=per_crowding_attribute_neighbor_count,
approx_num_neighbors=approx_num_neighbors,
fraction_leaf_nodes_to_search_override=fraction_leaf_nodes_to_search_override,
)

# Create the FindNeighbors request
Expand All @@ -1227,21 +1236,26 @@ def find_neighbors(
)
datapoint = gca_index_v1beta1.IndexDatapoint(feature_vector=query)
# Token restricts
for namespace in filter:
restrict = gca_index_v1beta1.IndexDatapoint.Restriction()
restrict.namespace = namespace.name
restrict.allow_list.extend(namespace.allow_tokens)
restrict.deny_list.extend(namespace.deny_tokens)
datapoint.restricts.append(restrict)
if filter:
for namespace in filter:
restrict = gca_index_v1beta1.IndexDatapoint.Restriction()
restrict.namespace = namespace.name
restrict.allow_list.extend(namespace.allow_tokens)
restrict.deny_list.extend(namespace.deny_tokens)
datapoint.restricts.append(restrict)
# Numeric restricts
for numeric_namespace in numeric_filter:
numeric_restrict = gca_index_v1beta1.IndexDatapoint.NumericRestriction()
numeric_restrict.namespace = numeric_namespace.name
numeric_restrict.op = numeric_namespace.op
numeric_restrict.value_int = numeric_namespace.value_int
numeric_restrict.value_float = numeric_namespace.value_float
numeric_restrict.value_double = numeric_namespace.value_double
datapoint.numeric_restricts.append(numeric_restrict)
if numeric_filter:
for numeric_namespace in numeric_filter:
numeric_restrict = (
gca_index_v1beta1.IndexDatapoint.NumericRestriction()
)
numeric_restrict.namespace = numeric_namespace.name
numeric_restrict.op = numeric_namespace.op
numeric_restrict.value_int = numeric_namespace.value_int
numeric_restrict.value_float = numeric_namespace.value_float
numeric_restrict.value_double = numeric_namespace.value_double
datapoint.numeric_restricts.append(numeric_restrict)

find_neighbors_query.datapoint = datapoint
find_neighbors_request.queries.append(find_neighbors_query)

Expand Down Expand Up @@ -1364,19 +1378,21 @@ def _batch_get_embeddings(
def match(
self,
deployed_index_id: str,
queries: List[List[float]],
queries: Optional[List[List[float]]] = None,
num_neighbors: int = 1,
filter: Optional[List[Namespace]] = [],
filter: Optional[List[Namespace]] = None,
per_crowding_attribute_num_neighbors: Optional[int] = None,
approx_num_neighbors: Optional[int] = None,
fraction_leaf_nodes_to_search_override: Optional[float] = None,
) -> List[List[MatchNeighbor]]:
"""Retrieves nearest neighbors for the given embedding queries on the specified deployed index.
"""Retrieves nearest neighbors for the given embedding queries on the
specified deployed index for private endpoint only.
Args:
deployed_index_id (str):
Required. The ID of the DeployedIndex to match the queries against.
queries (List[List[float]]):
Required. A list of queries. Each query is a list of floats, representing a single embedding.
Optional. A list of queries. Each query is a list of floats, representing a single embedding.
num_neighbors (int):
Required. The number of nearest neighbors to be retrieved from database for
each query.
Expand All @@ -1394,6 +1410,11 @@ def match(
approx_num_neighbors (int):
The number of neighbors to find via approximate search before exact reordering is performed.
If not set, the default value from scam config is used; if set, this value must be > 0.
fraction_leaf_nodes_to_search_override (float):
Optional. The fraction of the number of leaves to search, set at
query time allows user to tune search performance. This value
increase result in both search accuracy and latency increase.
The value should be between 0.0 and 1.0.
Returns:
List[List[MatchNeighbor]] - A list of nearest neighbors for each query.
Expand All @@ -1408,22 +1429,30 @@ def match(
match_service_pb2.BatchMatchRequest.BatchMatchRequestPerIndex()
)
batch_request_for_index.deployed_index_id = deployed_index_id
requests = []
for query in queries:
request = match_service_pb2.MatchRequest(
num_neighbors=num_neighbors,
deployed_index_id=deployed_index_id,
float_val=query,
per_crowding_attribute_num_neighbors=per_crowding_attribute_num_neighbors,
approx_num_neighbors=approx_num_neighbors,
)

# Preprocess restricts to be used for each request
restricts = []
if filter:
for namespace in filter:
restrict = match_service_pb2.Namespace()
restrict.name = namespace.name
restrict.allow_tokens.extend(namespace.allow_tokens)
restrict.deny_tokens.extend(namespace.deny_tokens)
request.restricts.append(restrict)
requests.append(request)
restricts.append(restrict)

requests = []
if queries:
for query in queries:
request = match_service_pb2.MatchRequest(
deployed_index_id=deployed_index_id,
float_val=query,
num_neighbors=num_neighbors,
restricts=restricts,
per_crowding_attribute_num_neighbors=per_crowding_attribute_num_neighbors,
approx_num_neighbors=approx_num_neighbors,
fraction_leaf_nodes_to_search_override=fraction_leaf_nodes_to_search_override,
)
requests.append(request)

batch_request_for_index.requests.extend(requests)
batch_request.requests.append(batch_request_for_index)
Expand Down
61 changes: 57 additions & 4 deletions tests/unit/aiplatform/test_matching_engine_index_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,7 +1042,9 @@ def test_index_endpoint_match_queries_backward_compatibility(
index_endpoint_match_queries_mock.assert_called_with(batch_request)

@pytest.mark.usefixtures("get_index_endpoint_mock")
def test_index_endpoint_match_queries(self, index_endpoint_match_queries_mock):
def test_private_index_endpoint_match_queries(
self, index_endpoint_match_queries_mock
):
aiplatform.init(project=_TEST_PROJECT)

my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
Expand All @@ -1051,11 +1053,12 @@ def test_index_endpoint_match_queries(self, index_endpoint_match_queries_mock):

my_index_endpoint.match(
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
queries=_TEST_QUERIES,
num_neighbors=_TEST_NUM_NEIGHBOURS,
filter=_TEST_FILTER,
queries=_TEST_QUERIES,
per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS,
fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE,
)

batch_request = match_service_pb2.BatchMatchRequest(
Expand All @@ -1066,7 +1069,7 @@ def test_index_endpoint_match_queries(self, index_endpoint_match_queries_mock):
match_service_pb2.MatchRequest(
num_neighbors=_TEST_NUM_NEIGHBOURS,
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
float_val=_TEST_QUERIES[0],
float_val=_TEST_QUERIES[i],
restricts=[
match_service_pb2.Namespace(
name="class",
Expand All @@ -1076,14 +1079,64 @@ def test_index_endpoint_match_queries(self, index_endpoint_match_queries_mock):
],
per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS,
fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE,
)
for i in range(len(_TEST_QUERIES))
],
)
]
)

index_endpoint_match_queries_mock.assert_called_with(batch_request)

@pytest.mark.usefixtures("get_index_endpoint_mock")
def test_private_index_endpoint_find_neighbor_queries(
self, index_endpoint_match_queries_mock
):
aiplatform.init(project=_TEST_PROJECT)

my_pubic_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
index_endpoint_name=_TEST_INDEX_ENDPOINT_ID
)

my_pubic_index_endpoint.find_neighbors(
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
queries=_TEST_QUERIES,
num_neighbors=_TEST_NUM_NEIGHBOURS,
filter=_TEST_FILTER,
per_crowding_attribute_neighbor_count=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS,
fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE,
return_full_datapoint=_TEST_RETURN_FULL_DATAPOINT,
)

batch_match_request = match_service_pb2.BatchMatchRequest(
requests=[
match_service_pb2.BatchMatchRequest.BatchMatchRequestPerIndex(
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
requests=[
match_service_pb2.MatchRequest(
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
num_neighbors=_TEST_NUM_NEIGHBOURS,
float_val=test_query,
restricts=[
match_service_pb2.Namespace(
name="class",
allow_tokens=["token_1"],
deny_tokens=["token_2"],
)
],
per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS,
fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE,
)
for test_query in _TEST_QUERIES
],
)
]
)
index_endpoint_match_queries_mock.assert_called_with(batch_match_request)

@pytest.mark.usefixtures("get_index_public_endpoint_mock")
def test_index_public_endpoint_match_queries(
self, index_public_endpoint_match_queries_mock
Expand Down Expand Up @@ -1277,7 +1330,7 @@ def test_index_endpoint_batch_get_embeddings(
index_endpoint_batch_get_embeddings_mock.assert_called_with(batch_request)

@pytest.mark.usefixtures("get_index_endpoint_mock")
def test_index_endpoint_find_neighbors_for_private(
def test_index_private_endpoint_read_index_datapoints(
self, index_endpoint_batch_get_embeddings_mock
):
aiplatform.init(project=_TEST_PROJECT)
Expand Down

0 comments on commit cd31c13

Please sign in to comment.