Skip to content

Commit

Permalink
feat: add encryption_spec_key_name to MatchingEngineIndex `create…
Browse files Browse the repository at this point in the history
…_tree_ah_index` and

`create_brute_force_index`

PiperOrigin-RevId: 580908192
  • Loading branch information
lingyinw authored and copybara-github committed Nov 9, 2023
1 parent 36d4086 commit 1a9e36f
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 1 deletion.
45 changes: 45 additions & 0 deletions google/cloud/aiplatform/matching_engine/matching_engine_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from google.cloud.aiplatform.compat.types import (
matching_engine_deployed_index_ref as gca_matching_engine_deployed_index_ref,
matching_engine_index as gca_matching_engine_index,
encryption_spec as gca_encryption_spec,
)
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform.matching_engine import (
Expand Down Expand Up @@ -109,6 +110,7 @@ def _create(
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
sync: bool = True,
index_update_method: Optional[str] = None,
encryption_spec_key_name: Optional[str] = None,
) -> "MatchingEngineIndex":
"""Creates a MatchingEngineIndex resource.
Expand Down Expand Up @@ -162,6 +164,18 @@ def _create(
Optional. The update method to use with this index. Choose
stream_update or batch_update. If not set, batch update will be
used by default.
encryption_spec_key_name (str):
Optional. The Cloud KMS resource identifier of the customer
managed encryption key used to protect the index. Has the
form:
``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
The key needs to be in the same region as where the compute
resource is created.
If set, this index and all sub-resources of this index will be
secured by this key.
The key needs to be in the same region as where the index is
created.
Returns:
MatchingEngineIndex - Index resource object
Expand All @@ -181,6 +195,9 @@ def _create(
"contentsDeltaUri": contents_delta_uri,
},
index_update_method=index_update_method_enum,
encryption_spec=gca_encryption_spec.EncryptionSpec(
kms_key_name=encryption_spec_key_name
),
)

if labels:
Expand Down Expand Up @@ -394,6 +411,7 @@ def create_tree_ah_index(
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
sync: bool = True,
index_update_method: Optional[str] = None,
encryption_spec_key_name: Optional[str] = None,
) -> "MatchingEngineIndex":
"""Creates a MatchingEngineIndex resource that uses the tree-AH algorithm.
Expand Down Expand Up @@ -472,6 +490,18 @@ def create_tree_ah_index(
Optional. The update method to use with this index. Choose
STREAM_UPDATE or BATCH_UPDATE. If not set, batch update will be
used by default.
encryption_spec_key_name (str):
Optional. The Cloud KMS resource identifier of the customer
managed encryption key used to protect the index. Has the
form:
``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
The key needs to be in the same region as where the compute
resource is created.
If set, this index and all sub-resources of this index will be
secured by this key.
The key needs to be in the same region as where the index is
created.
Returns:
MatchingEngineIndex - Index resource object
Expand Down Expand Up @@ -502,6 +532,7 @@ def create_tree_ah_index(
request_metadata=request_metadata,
sync=sync,
index_update_method=index_update_method,
encryption_spec_key_name=encryption_spec_key_name,
)

@classmethod
Expand All @@ -521,6 +552,7 @@ def create_brute_force_index(
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
sync: bool = True,
index_update_method: Optional[str] = None,
encryption_spec_key_name: Optional[str] = None,
) -> "MatchingEngineIndex":
"""Creates a MatchingEngineIndex resource that uses the brute force algorithm.
Expand Down Expand Up @@ -588,6 +620,18 @@ def create_brute_force_index(
Optional. The update method to use with this index. Choose
stream_update or batch_update. If not set, batch update will be
used by default.
encryption_spec_key_name (str):
Optional. The Cloud KMS resource identifier of the customer
managed encryption key used to protect the index. Has the
form:
``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
The key needs to be in the same region as where the compute
resource is created.
If set, this index and all sub-resources of this index will be
secured by this key.
The key needs to be in the same region as where the index is
created.
Returns:
MatchingEngineIndex - Index resource object
Expand All @@ -614,6 +658,7 @@ def create_brute_force_index(
request_metadata=request_metadata,
sync=sync,
index_update_method=index_update_method,
encryption_spec_key_name=encryption_spec_key_name,
)


Expand Down
16 changes: 15 additions & 1 deletion tests/unit/aiplatform/test_matching_engine_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@
index_service_client,
)

from google.cloud.aiplatform.compat.types import index as gca_index
from google.cloud.aiplatform.compat.types import (
index as gca_index,
encryption_spec as gca_encryption_spec,
)
import constants as test_constants

# project
Expand Down Expand Up @@ -104,6 +107,9 @@
_TEST_INDEX_INVALID_UPDATE_METHOD: None,
}

# Encryption spec
_TEST_ENCRYPTION_SPEC_KEY_NAME = "TEST_ENCRYPTION_SPEC"


def uuid_mock():
return uuid.UUID(int=1)
Expand Down Expand Up @@ -309,6 +315,7 @@ def test_create_tree_ah_index(self, create_index_mock, sync, index_update_method
labels=_TEST_LABELS,
sync=sync,
index_update_method=index_update_method,
encryption_spec_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME,
)

if not sync:
Expand Down Expand Up @@ -337,6 +344,9 @@ def test_create_tree_ah_index(self, create_index_mock, sync, index_update_method
index_update_method=_TEST_INDEX_UPDATE_METHOD_EXPECTED_RESULT_MAP[
index_update_method
],
encryption_spec=gca_encryption_spec.EncryptionSpec(
kms_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME
),
)

create_index_mock.assert_called_once_with(
Expand Down Expand Up @@ -370,6 +380,7 @@ def test_create_brute_force_index(
labels=_TEST_LABELS,
sync=sync,
index_update_method=index_update_method,
encryption_spec_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME,
)

if not sync:
Expand All @@ -393,6 +404,9 @@ def test_create_brute_force_index(
index_update_method=_TEST_INDEX_UPDATE_METHOD_EXPECTED_RESULT_MAP[
index_update_method
],
encryption_spec=gca_encryption_spec.EncryptionSpec(
kms_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME
),
)

create_index_mock.assert_called_once_with(
Expand Down

0 comments on commit 1a9e36f

Please sign in to comment.