Skip to content

Commit

Permalink
fix: change default for create_request_timeout arg to None (#1175)
Browse files Browse the repository at this point in the history
Change default value for `create_request_timeout` from `False` to `None` and add test for when `create_request_timeout` isn't explicitly set.

Fixes b/229868042 🦕
  • Loading branch information
sararob authored Apr 20, 2022
1 parent 4c21993 commit 47791f7
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 1 deletion.
2 changes: 1 addition & 1 deletion google/cloud/aiplatform/training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4736,7 +4736,7 @@ def run(
model_labels: Optional[Dict[str, str]] = None,
disable_early_stopping: bool = False,
sync: bool = True,
create_request_timeout: Optional[float] = False,
create_request_timeout: Optional[float] = None,
) -> models.Model:
"""Runs the AutoML Image training job and returns a model.
Expand Down
154 changes: 154 additions & 0 deletions tests/unit/aiplatform/test_training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4737,6 +4737,160 @@ def test_run_call_pipeline_service_create_with_tabular_dataset_with_timeout(
timeout=180.0,
)

@pytest.mark.parametrize("sync", [True, False])
def test_run_call_pipeline_service_create_with_tabular_dataset_with_timeout_not_explicitly_set(
self,
mock_pipeline_service_create,
mock_pipeline_service_get,
mock_tabular_dataset,
mock_model_service_get,
sync,
):
aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_BUCKET_NAME,
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
)

job = training_jobs.CustomPythonPackageTrainingJob(
display_name=_TEST_DISPLAY_NAME,
labels=_TEST_LABELS,
python_package_gcs_uri=_TEST_OUTPUT_PYTHON_PACKAGE_PATH,
python_module_name=_TEST_PYTHON_MODULE_NAME,
container_uri=_TEST_TRAINING_CONTAINER_IMAGE,
model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE,
model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE,
model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE,
model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND,
model_serving_container_args=_TEST_MODEL_SERVING_CONTAINER_ARGS,
model_serving_container_environment_variables=_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES,
model_serving_container_ports=_TEST_MODEL_SERVING_CONTAINER_PORTS,
model_description=_TEST_MODEL_DESCRIPTION,
model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI,
model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI,
model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI,
)

model_from_job = job.run(
dataset=mock_tabular_dataset,
model_display_name=_TEST_MODEL_DISPLAY_NAME,
model_labels=_TEST_MODEL_LABELS,
base_output_dir=_TEST_BASE_OUTPUT_DIR,
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
args=_TEST_RUN_ARGS,
environment_variables=_TEST_ENVIRONMENT_VARIABLES,
machine_type=_TEST_MACHINE_TYPE,
accelerator_type=_TEST_ACCELERATOR_TYPE,
accelerator_count=_TEST_ACCELERATOR_COUNT,
training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT,
validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT,
test_fraction_split=_TEST_TEST_FRACTION_SPLIT,
sync=sync,
)

if not sync:
model_from_job.wait()

true_args = _TEST_RUN_ARGS
true_env = [
{"name": key, "value": value}
for key, value in _TEST_ENVIRONMENT_VARIABLES.items()
]

true_worker_pool_spec = {
"replica_count": _TEST_REPLICA_COUNT,
"machine_spec": {
"machine_type": _TEST_MACHINE_TYPE,
"accelerator_type": _TEST_ACCELERATOR_TYPE,
"accelerator_count": _TEST_ACCELERATOR_COUNT,
},
"disk_spec": {
"boot_disk_type": _TEST_BOOT_DISK_TYPE_DEFAULT,
"boot_disk_size_gb": _TEST_BOOT_DISK_SIZE_GB_DEFAULT,
},
"python_package_spec": {
"executor_image_uri": _TEST_TRAINING_CONTAINER_IMAGE,
"python_module": _TEST_PYTHON_MODULE_NAME,
"package_uris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH],
"args": true_args,
"env": true_env,
},
}

true_fraction_split = gca_training_pipeline.FractionSplit(
training_fraction=_TEST_TRAINING_FRACTION_SPLIT,
validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT,
test_fraction=_TEST_TEST_FRACTION_SPLIT,
)

env = [
gca_env_var.EnvVar(name=str(key), value=str(value))
for key, value in _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items()
]

ports = [
gca_model.Port(container_port=port)
for port in _TEST_MODEL_SERVING_CONTAINER_PORTS
]

true_container_spec = gca_model.ModelContainerSpec(
image_uri=_TEST_SERVING_CONTAINER_IMAGE,
predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE,
health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE,
command=_TEST_MODEL_SERVING_CONTAINER_COMMAND,
args=_TEST_MODEL_SERVING_CONTAINER_ARGS,
env=env,
ports=ports,
)

true_managed_model = gca_model.Model(
display_name=_TEST_MODEL_DISPLAY_NAME,
labels=_TEST_MODEL_LABELS,
description=_TEST_MODEL_DESCRIPTION,
container_spec=true_container_spec,
predict_schemata=gca_model.PredictSchemata(
instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI,
parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI,
prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI,
),
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
)

true_input_data_config = gca_training_pipeline.InputDataConfig(
fraction_split=true_fraction_split,
dataset_id=mock_tabular_dataset.name,
gcs_destination=gca_io.GcsDestination(
output_uri_prefix=_TEST_BASE_OUTPUT_DIR
),
)

true_training_pipeline = gca_training_pipeline.TrainingPipeline(
display_name=_TEST_DISPLAY_NAME,
labels=_TEST_LABELS,
training_task_definition=schema.training_job.definition.custom_task,
training_task_inputs=json_format.ParseDict(
{
"worker_pool_specs": [true_worker_pool_spec],
"base_output_directory": {
"output_uri_prefix": _TEST_BASE_OUTPUT_DIR
},
"service_account": _TEST_SERVICE_ACCOUNT,
"network": _TEST_NETWORK,
},
struct_pb2.Value(),
),
model_to_upload=true_managed_model,
input_data_config=true_input_data_config,
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
)

mock_pipeline_service_create.assert_called_once_with(
parent=initializer.global_config.common_location_path(),
training_pipeline=true_training_pipeline,
timeout=None,
)

@pytest.mark.parametrize("sync", [True, False])
def test_run_call_pipeline_service_create_with_tabular_dataset_without_model_display_name_nor_model_labels(
self,
Expand Down

0 comments on commit 47791f7

Please sign in to comment.