Skip to content

Commit

Permalink
Showing 6 changed files with 52 additions and 6 deletions.
7 changes: 4 additions & 3 deletions README.rst
Original file line number Diff line number Diff line change
@@ -359,10 +359,11 @@ To create a batch prediction job:
batch_prediction_job = model.batch_predict(
job_display_name='my-batch-prediction-job',
instances_format='csv'
instances_format='csv',
machine_type='n1-standard-4',
gcs_source=['gs://path/to/my/file.csv']
gcs_destination_prefix='gs://path/to/by/batch_prediction/results/'
gcs_source=['gs://path/to/my/file.csv'],
gcs_destination_prefix='gs://path/to/my/batch_prediction/results/',
service_account='my-sa@my-project.iam.gserviceaccount.com'
)
You can also create a batch prediction job asynchronously by including the `sync=False` argument:
7 changes: 4 additions & 3 deletions docs/README.rst
Original file line number Diff line number Diff line change
@@ -284,10 +284,11 @@ To create a batch prediction job:
batch_prediction_job = model.batch_predict(
job_display_name='my-batch-prediction-job',
instances_format='csv'
instances_format='csv',
machine_type='n1-standard-4',
gcs_source=['gs://path/to/my/file.csv']
gcs_destination_prefix='gs://path/to/by/batch_prediction/results/'
gcs_source=['gs://path/to/my/file.csv'],
gcs_destination_prefix='gs://path/to/my/batch_prediction/results/',
service_account='my-sa@my-project.iam.gserviceaccount.com'
)
You can also create a batch prediction job asynchronously by including the `sync=False` argument:
7 changes: 7 additions & 0 deletions google/cloud/aiplatform/jobs.py
Original file line number Diff line number Diff line change
@@ -403,6 +403,7 @@ def create(
"aiplatform.model_monitoring.AlertConfig"
] = None,
analysis_instance_schema_uri: Optional[str] = None,
service_account: Optional[str] = None,
) -> "BatchPredictionJob":
"""Create a batch prediction job.
@@ -586,6 +587,9 @@ def create(
and TFDV instance, this field can be used to override the schema.
For models trained with Vertex AI, this field must be set as all the
fields in predict instance formatted as string.
service_account (str):
Optional. Specifies the service account for workload run-as account.
Users submitting jobs must have act-as permission on this run-as account.
Returns:
(jobs.BatchPredictionJob):
Instantiated representation of the created batch prediction job.
@@ -745,6 +749,9 @@ def create(
)
gapic_batch_prediction_job.explanation_spec = explanation_spec

if service_account:
gapic_batch_prediction_job.service_account = service_account

empty_batch_prediction_job = cls._empty_constructor(
project=project,
location=location,
5 changes: 5 additions & 0 deletions google/cloud/aiplatform/models.py
Original file line number Diff line number Diff line change
@@ -3511,6 +3511,7 @@ def batch_predict(
sync: bool = True,
create_request_timeout: Optional[float] = None,
batch_size: Optional[int] = None,
service_account: Optional[str] = None,
) -> jobs.BatchPredictionJob:
"""Creates a batch prediction job using this Model and outputs
prediction results to the provided destination prefix in the specified
@@ -3673,6 +3674,9 @@ def batch_predict(
but too high value will result in a whole batch not fitting in a machine's memory,
and the whole operation will fail.
The default value is 64.
service_account (str):
Optional. Specifies the service account for workload run-as account.
Users submitting jobs must have act-as permission on this run-as account.
Returns:
job (jobs.BatchPredictionJob):
@@ -3705,6 +3709,7 @@ def batch_predict(
encryption_spec_key_name=encryption_spec_key_name,
sync=sync,
create_request_timeout=create_request_timeout,
service_account=service_account,
)

@classmethod
23 changes: 23 additions & 0 deletions tests/unit/aiplatform/test_jobs.py
Original file line number Diff line number Diff line change
@@ -76,6 +76,8 @@
_TEST_BQ_JOB_ID = "123459876"
_TEST_BQ_MAX_RESULTS = 100
_TEST_GCS_BUCKET_NAME = "my-bucket"
_TEST_SERVICE_ACCOUNT = "vinnys@my-project.iam.gserviceaccount.com"


_TEST_BQ_PATH = f"bq://{_TEST_BQ_PROJECT_ID}.{_TEST_BQ_DATASET_ID}"
_TEST_GCS_BUCKET_PATH = f"gs://{_TEST_GCS_BUCKET_NAME}"
@@ -719,6 +721,7 @@ def test_batch_predict_gcs_source_and_dest(
gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
sync=sync,
create_request_timeout=None,
service_account=_TEST_SERVICE_ACCOUNT,
)

batch_prediction_job.wait_for_resource_creation()
@@ -741,6 +744,7 @@ def test_batch_predict_gcs_source_and_dest(
),
predictions_format="jsonl",
),
service_account=_TEST_SERVICE_ACCOUNT,
)

create_batch_prediction_job_mock.assert_called_once_with(
@@ -766,6 +770,7 @@ def test_batch_predict_gcs_source_and_dest_with_timeout(
gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
sync=sync,
create_request_timeout=180.0,
service_account=_TEST_SERVICE_ACCOUNT,
)

batch_prediction_job.wait_for_resource_creation()
@@ -788,6 +793,7 @@ def test_batch_predict_gcs_source_and_dest_with_timeout(
),
predictions_format="jsonl",
),
service_account=_TEST_SERVICE_ACCOUNT,
)

create_batch_prediction_job_mock.assert_called_once_with(
@@ -812,6 +818,7 @@ def test_batch_predict_gcs_source_and_dest_with_timeout_not_explicitly_set(
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
sync=sync,
service_account=_TEST_SERVICE_ACCOUNT,
)

batch_prediction_job.wait_for_resource_creation()
@@ -834,6 +841,7 @@ def test_batch_predict_gcs_source_and_dest_with_timeout_not_explicitly_set(
),
predictions_format="jsonl",
),
service_account=_TEST_SERVICE_ACCOUNT,
)

create_batch_prediction_job_mock.assert_called_once_with(
@@ -855,6 +863,7 @@ def test_batch_predict_job_done_create(self, create_batch_prediction_job_mock):
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
sync=False,
service_account=_TEST_SERVICE_ACCOUNT,
)

batch_prediction_job.wait_for_resource_creation()
@@ -881,6 +890,7 @@ def test_batch_predict_gcs_source_bq_dest(
bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX,
sync=sync,
create_request_timeout=None,
service_account=_TEST_SERVICE_ACCOUNT,
)

batch_prediction_job.wait_for_resource_creation()
@@ -908,6 +918,7 @@ def test_batch_predict_gcs_source_bq_dest(
),
predictions_format="bigquery",
),
service_account=_TEST_SERVICE_ACCOUNT,
)

create_batch_prediction_job_mock.assert_called_once_with(
@@ -946,6 +957,7 @@ def test_batch_predict_with_all_args(
sync=sync,
create_request_timeout=None,
batch_size=_TEST_BATCH_SIZE,
service_account=_TEST_SERVICE_ACCOUNT,
)

batch_prediction_job.wait_for_resource_creation()
@@ -986,6 +998,7 @@ def test_batch_predict_with_all_args(
parameters=_TEST_EXPLANATION_PARAMETERS,
),
labels=_TEST_LABEL,
service_account=_TEST_SERVICE_ACCOUNT,
)

create_batch_prediction_job_with_explanations_mock.assert_called_once_with(
@@ -1047,6 +1060,7 @@ def test_batch_predict_with_all_args_and_model_monitoring(
model_monitoring_objective_config=mm_obj_cfg,
model_monitoring_alert_config=mm_alert_cfg,
analysis_instance_schema_uri="",
service_account=_TEST_SERVICE_ACCOUNT,
)

batch_prediction_job.wait_for_resource_creation()
@@ -1086,6 +1100,7 @@ def test_batch_predict_with_all_args_and_model_monitoring(
generate_explanation=True,
model_monitoring_config=_TEST_MODEL_MONITORING_CFG,
labels=_TEST_LABEL,
service_account=_TEST_SERVICE_ACCOUNT,
)
create_batch_prediction_job_v1beta1_mock.assert_called_once_with(
parent=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}",
@@ -1103,6 +1118,7 @@ def test_batch_predict_create_fails(self):
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX,
sync=False,
service_account=_TEST_SERVICE_ACCOUNT,
)

with pytest.raises(RuntimeError) as e:
@@ -1143,6 +1159,7 @@ def test_batch_predict_no_source(self, create_batch_prediction_job_mock):
model_name=_TEST_MODEL_NAME,
job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME,
bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX,
service_account=_TEST_SERVICE_ACCOUNT,
)

assert e.match(regexp=r"source")
@@ -1159,6 +1176,7 @@ def test_batch_predict_two_sources(self, create_batch_prediction_job_mock):
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
bigquery_source=_TEST_BATCH_PREDICTION_BQ_PREFIX,
bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX,
service_account=_TEST_SERVICE_ACCOUNT,
)

assert e.match(regexp=r"source")
@@ -1173,6 +1191,7 @@ def test_batch_predict_no_destination(self):
model_name=_TEST_MODEL_NAME,
job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME,
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
service_account=_TEST_SERVICE_ACCOUNT,
)

assert e.match(regexp=r"destination")
@@ -1189,6 +1208,7 @@ def test_batch_predict_wrong_instance_format(self):
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
instances_format="wrong",
bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX,
service_account=_TEST_SERVICE_ACCOUNT,
)

assert e.match(regexp=r"accepted instances format")
@@ -1205,6 +1225,7 @@ def test_batch_predict_wrong_prediction_format(self):
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
predictions_format="wrong",
bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX,
service_account=_TEST_SERVICE_ACCOUNT,
)

assert e.match(regexp=r"accepted prediction format")
@@ -1222,6 +1243,7 @@ def test_batch_predict_job_with_versioned_model(
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
sync=True,
service_account=_TEST_SERVICE_ACCOUNT,
)
assert (
create_batch_prediction_job_mock.call_args_list[0][1][
@@ -1237,6 +1259,7 @@ def test_batch_predict_job_with_versioned_model(
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
sync=True,
service_account=_TEST_SERVICE_ACCOUNT,
)
assert (
create_batch_prediction_job_mock.call_args_list[0][1][
9 changes: 9 additions & 0 deletions tests/unit/aiplatform/test_models.py
Original file line number Diff line number Diff line change
@@ -1644,6 +1644,7 @@ def test_init_aiplatform_with_encryption_key_name_and_batch_predict_gcs_source_a
gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
sync=sync,
create_request_timeout=None,
service_account=_TEST_SERVICE_ACCOUNT,
)

if not sync:
@@ -1669,6 +1670,7 @@ def test_init_aiplatform_with_encryption_key_name_and_batch_predict_gcs_source_a
predictions_format="jsonl",
),
encryption_spec=_TEST_ENCRYPTION_SPEC,
service_account=_TEST_SERVICE_ACCOUNT,
)
)

@@ -1693,6 +1695,7 @@ def test_batch_predict_gcs_source_and_dest(
gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
sync=sync,
create_request_timeout=None,
service_account=_TEST_SERVICE_ACCOUNT,
)

if not sync:
@@ -1711,6 +1714,7 @@ def test_batch_predict_with_version(self, sync, create_batch_prediction_job_mock
gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
sync=sync,
create_request_timeout=None,
service_account=_TEST_SERVICE_ACCOUNT,
)

if not sync:
@@ -1733,6 +1737,7 @@ def test_batch_predict_with_version(self, sync, create_batch_prediction_job_mock
),
predictions_format="jsonl",
),
service_account=_TEST_SERVICE_ACCOUNT,
)
)

@@ -1757,6 +1762,7 @@ def test_batch_predict_gcs_source_bq_dest(
bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX,
sync=sync,
create_request_timeout=None,
service_account=_TEST_SERVICE_ACCOUNT,
)

if not sync:
@@ -1781,6 +1787,7 @@ def test_batch_predict_gcs_source_bq_dest(
),
predictions_format="bigquery",
),
service_account=_TEST_SERVICE_ACCOUNT,
)
)

@@ -1817,6 +1824,7 @@ def test_batch_predict_with_all_args(self, create_batch_prediction_job_mock, syn
sync=sync,
create_request_timeout=None,
batch_size=_TEST_BATCH_SIZE,
service_account=_TEST_SERVICE_ACCOUNT,
)

if not sync:
@@ -1857,6 +1865,7 @@ def test_batch_predict_with_all_args(self, create_batch_prediction_job_mock, syn
),
labels=_TEST_LABEL,
encryption_spec=_TEST_ENCRYPTION_SPEC,
service_account=_TEST_SERVICE_ACCOUNT,
)

create_batch_prediction_job_mock.assert_called_once_with(

0 comments on commit deba06b

Please sign in to comment.