Skip to content

Commit

Permalink
feat: Support a list of GCS URIs in CustomPythonPackageTrainingJob
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 503270789
  • Loading branch information
jaycee-li authored and copybara-github committed Jan 19, 2023
1 parent 4415c10 commit 05bb71f
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 33 deletions.
68 changes: 38 additions & 30 deletions google/cloud/aiplatform/training_jobs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

# Copyright 2022 Google LLC
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -5827,7 +5827,7 @@ def __init__(
self,
# TODO(b/223262536): Make display_name parameter fully optional in next major release
display_name: str,
python_package_gcs_uri: str,
python_package_gcs_uri: Union[str, List[str]],
python_module_name: str,
container_uri: str,
model_serving_container_image_uri: Optional[str] = None,
Expand Down Expand Up @@ -5891,53 +5891,56 @@ def __init__(
Args:
display_name (str):
Required. The user-defined name of this TrainingPipeline.
python_package_gcs_uri (str):
Required: GCS location of the training python package.
python_package_gcs_uri (Union[str, List[str]]):
Required. GCS location of the training python package.
Could be a string for single package or a list of string for
multiple packages.
python_module_name (str):
Required: The module name of the training python package.
Required. The module name of the training python package.
container_uri (str):
Required: Uri of the training container image in the GCR.
Required. Uri of the training container image in the GCR.
model_serving_container_image_uri (str):
If the training produces a managed Vertex AI Model, the URI of the
Model serving container suitable for serving the model produced by the
training script.
Optional. If the training produces a managed Vertex AI Model,
the URI of the model serving container suitable for serving the
model produced by the training script.
model_serving_container_predict_route (str):
If the training produces a managed Vertex AI Model, An HTTP path to
send prediction requests to the container, and which must be supported
by it. If not specified a default HTTP path will be used by Vertex AI.
Optional. If the training produces a managed Vertex AI Model,
an HTTP path to send prediction requests to the container,
and which must be supported by it. If not specified a default
HTTP path will be used by Vertex AI.
model_serving_container_health_route (str):
If the training produces a managed Vertex AI Model, an HTTP path to
send health check requests to the container, and which must be supported
by it. If not specified a standard HTTP path will be used by AI
Platform.
Optional. If the training produces a managed Vertex AI Model,
an HTTP path to send health check requests to the container,
and which must be supported by it. If not specified a standard
HTTP path will be used by AI Platform.
model_serving_container_command (Sequence[str]):
The command with which the container is run. Not executed within a
Optional. The command with which the container is run. Not executed within a
shell. The Docker image's ENTRYPOINT is used if this is not provided.
Variable references $(VAR_NAME) are expanded using the container's
environment. If a variable cannot be resolved, the reference in the
input string will be unchanged. The $(VAR_NAME) syntax can be escaped
with a double $$, ie: $$(VAR_NAME). Escaped references will never be
expanded, regardless of whether the variable exists or not.
model_serving_container_args (Sequence[str]):
The arguments to the command. The Docker image's CMD is used if this is
not provided. Variable references $(VAR_NAME) are expanded using the
Optional. The arguments to the command. The Docker image's CMD is used if this
is not provided. Variable references $(VAR_NAME) are expanded using the
container's environment. If a variable cannot be resolved, the reference
in the input string will be unchanged. The $(VAR_NAME) syntax can be
escaped with a double $$, ie: $$(VAR_NAME). Escaped references will
never be expanded, regardless of whether the variable exists or not.
model_serving_container_environment_variables (Dict[str, str]):
The environment variables that are to be present in the container.
Optional. The environment variables that are to be present in the container.
Should be a dictionary where keys are environment variable names
and values are environment variable values for those names.
model_serving_container_ports (Sequence[int]):
Declaration of ports that are exposed by the container. This field is
primarily informational, it gives Vertex AI information about the
network connections the container uses. Listing or not a port here has
no impact on whether the port is actually exposed, any port listening on
the default "0.0.0.0" address inside a container will be accessible from
the network.
Optional. Declaration of ports that are exposed by the container.
This field is primarily informational, it gives Vertex AI information
about the network connections the container uses. Listing or not
a port here has no impact on whether the port is actually exposed,
any port listening on the default "0.0.0.0" address inside a
container will be accessible from the network.
model_description (str):
The description of the Model.
Optional. The description of the Model.
model_instance_schema_uri (str):
Optional. Points to a YAML file stored on Google Cloud
Storage describing the format of a single instance, which
Expand Down Expand Up @@ -6036,7 +6039,7 @@ def __init__(
Overrides encryption_spec_key_name set in aiplatform.init.
staging_bucket (str):
Bucket used to stage source and training artifacts. Overrides
Optional. Bucket used to stage source and training artifacts. Overrides
staging_bucket set in aiplatform.init.
"""
if not display_name:
Expand Down Expand Up @@ -6066,7 +6069,12 @@ def __init__(
staging_bucket=staging_bucket,
)

self._package_gcs_uri = python_package_gcs_uri
if isinstance(python_package_gcs_uri, str):
self._package_gcs_uri = [python_package_gcs_uri]
elif isinstance(python_package_gcs_uri, list):
self._package_gcs_uri = python_package_gcs_uri
else:
raise ValueError("'python_package_gcs_uri' must be a string or list.")
self._python_module = python_module_name

def run(
Expand Down Expand Up @@ -6668,7 +6676,7 @@ def _run(
spec["python_package_spec"] = {
"executor_image_uri": self._container_uri,
"python_module": self._python_module,
"package_uris": [self._package_gcs_uri],
"package_uris": self._package_gcs_uri,
}

if args:
Expand Down
58 changes: 55 additions & 3 deletions tests/unit/aiplatform/test_training_jobs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

# Copyright 2022 Google LLC
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -161,6 +161,7 @@
_TEST_MODEL_DESCRIPTION = "test description"

_TEST_OUTPUT_PYTHON_PACKAGE_PATH = "gs://test-staging-bucket/trainer.tar.gz"
_TEST_PACKAGE_GCS_URIS = [_TEST_OUTPUT_PYTHON_PACKAGE_PATH] * 2
_TEST_PYTHON_MODULE_NAME = "aiplatform.task"

_TEST_MODEL_NAME = f"projects/{_TEST_PROJECT}/locations/us-central1/models/{_TEST_ID}"
Expand Down Expand Up @@ -4987,13 +4988,18 @@ def teardown_method(self):
@mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1)
@mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1)
@pytest.mark.parametrize("sync", [True, False])
@pytest.mark.parametrize(
"python_package_gcs_uri",
[_TEST_OUTPUT_PYTHON_PACKAGE_PATH, _TEST_PACKAGE_GCS_URIS],
)
def test_run_call_pipeline_service_create_with_tabular_dataset(
self,
mock_pipeline_service_create,
mock_pipeline_service_get,
mock_tabular_dataset,
mock_model_service_get,
sync,
python_package_gcs_uri,
):
aiplatform.init(
project=_TEST_PROJECT,
Expand All @@ -5004,7 +5010,7 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
job = training_jobs.CustomPythonPackageTrainingJob(
display_name=_TEST_DISPLAY_NAME,
labels=_TEST_LABELS,
python_package_gcs_uri=_TEST_OUTPUT_PYTHON_PACKAGE_PATH,
python_package_gcs_uri=python_package_gcs_uri,
python_module_name=_TEST_PYTHON_MODULE_NAME,
container_uri=_TEST_TRAINING_CONTAINER_IMAGE,
model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE,
Expand Down Expand Up @@ -5050,6 +5056,11 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
for key, value in _TEST_ENVIRONMENT_VARIABLES.items()
]

if isinstance(python_package_gcs_uri, str):
package_uris = [python_package_gcs_uri]
else:
package_uris = python_package_gcs_uri

true_worker_pool_spec = {
"replica_count": _TEST_REPLICA_COUNT,
"machine_spec": {
Expand All @@ -5064,7 +5075,7 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
"python_package_spec": {
"executor_image_uri": _TEST_TRAINING_CONTAINER_IMAGE,
"python_module": _TEST_PYTHON_MODULE_NAME,
"package_uris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH],
"package_uris": package_uris,
"args": true_args,
"env": true_env,
},
Expand Down Expand Up @@ -5164,6 +5175,47 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(

assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED

@mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1)
@mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1)
def test_custom_python_package_training_job_run_raises_with_wrong_package_uris(
self,
mock_pipeline_service_create,
mock_pipeline_service_get,
mock_tabular_dataset,
mock_model_service_get,
):
aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_BUCKET_NAME,
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
)

wrong_package_gcs_uri = {"package": _TEST_OUTPUT_PYTHON_PACKAGE_PATH}

with pytest.raises(ValueError) as e:
training_jobs.CustomPythonPackageTrainingJob(
display_name=_TEST_DISPLAY_NAME,
labels=_TEST_LABELS,
python_package_gcs_uri=wrong_package_gcs_uri,
python_module_name=_TEST_PYTHON_MODULE_NAME,
container_uri=_TEST_TRAINING_CONTAINER_IMAGE,
model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE,
model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE,
model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE,
model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND,
model_serving_container_args=_TEST_MODEL_SERVING_CONTAINER_ARGS,
model_serving_container_environment_variables=_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES,
model_serving_container_ports=_TEST_MODEL_SERVING_CONTAINER_PORTS,
model_description=_TEST_MODEL_DESCRIPTION,
model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI,
model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI,
model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI,
explanation_metadata=_TEST_EXPLANATION_METADATA,
explanation_parameters=_TEST_EXPLANATION_PARAMETERS,
)

assert e.match("'python_package_gcs_uri' must be a string or list.")

@mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1)
@mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1)
def test_custom_python_package_training_job_run_raises_with_impartial_explanation_spec(
Expand Down

0 comments on commit 05bb71f

Please sign in to comment.