Skip to content

Commit

Permalink
feat: Add update_mask to MatchingEngineIndex `upsert_datapoints()…
Browse files Browse the repository at this point in the history
…` to support dynamic metadata update.

PiperOrigin-RevId: 609006985
lingyinw authored and copybara-github committed Feb 21, 2024
1 parent 09d1946 commit 81f6a25
Showing 2 changed files with 31 additions and 0 deletions.
12 changes: 12 additions & 0 deletions google/cloud/aiplatform/matching_engine/matching_engine_index.py
Original file line number Diff line number Diff line change
@@ -692,12 +692,21 @@ def create_brute_force_index(
def upsert_datapoints(
self,
datapoints: Sequence[gca_matching_engine_index.IndexDatapoint],
update_mask: Optional[Sequence[str]] = None,
) -> "MatchingEngineIndex":
"""Upsert datapoints to this index.
Args:
datapoints (Sequence[gca_matching_engine_index.IndexDatapoint]):
Required. Datapoints to be upserted to this index.
update_mask (Sequence[str]):
Optional. Update mask is used to specify the fields to be
overwritten in the datapoints by the update. The fields
specified in the update_mask are relative to each IndexDatapoint
inside datapoints, not the full request.
Updatable fields:
Use `all_restricts` to update both `restricts` and
`numeric_restricts`.
Returns:
MatchingEngineIndex - Index resource object
@@ -716,6 +725,9 @@ def upsert_datapoints(
gca_index_service.UpsertDatapointsRequest(
index=self.resource_name,
datapoints=datapoints,
update_mask=(
field_mask_pb2.FieldMask(paths=update_mask) if update_mask else None
),
)
)

19 changes: 19 additions & 0 deletions tests/unit/aiplatform/test_matching_engine_index.py
Original file line number Diff line number Diff line change
@@ -148,6 +148,7 @@
)
_TEST_DATAPOINTS = (_TEST_DATAPOINT_1, _TEST_DATAPOINT_2, _TEST_DATAPOINT_3)
_TEST_TIMEOUT = 1800.0
_TEST_UPDATE_MASK = ["all_restricts"]


def uuid_mock():
@@ -706,6 +707,24 @@ def test_upsert_datapoints(self, upsert_datapoints_mock):

upsert_datapoints_mock.assert_called_once_with(upsert_datapoints_request)

@pytest.mark.usefixtures("get_index_mock")
def test_upsert_datapoints_dynamic_metadata_update(self, upsert_datapoints_mock):
aiplatform.init(project=_TEST_PROJECT)

my_index = aiplatform.MatchingEngineIndex(index_name=_TEST_INDEX_ID)
my_index.upsert_datapoints(
datapoints=_TEST_DATAPOINTS,
update_mask=_TEST_UPDATE_MASK,
)

upsert_datapoints_request = gca_index_service.UpsertDatapointsRequest(
index=_TEST_INDEX_NAME,
datapoints=_TEST_DATAPOINTS,
update_mask=field_mask_pb2.FieldMask(paths=_TEST_UPDATE_MASK),
)

upsert_datapoints_mock.assert_called_once_with(upsert_datapoints_request)

@pytest.mark.usefixtures("get_index_mock")
def test_remove_datapoints(self, remove_datapoints_mock):
aiplatform.init(project=_TEST_PROJECT)

0 comments on commit 81f6a25

Please sign in to comment.