Skip to content

Commit

Permalink
fix: updated proto message formatting logic for batch predict model m…
Browse files Browse the repository at this point in the history
…onitoring

PiperOrigin-RevId: 499377219
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Jan 4, 2023
1 parent 65300c4 commit f87fef0
Show file tree
Hide file tree
Showing 5 changed files with 422 additions and 101 deletions.
100 changes: 56 additions & 44 deletions google/cloud/aiplatform/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@
io as gca_io_compat,
job_state as gca_job_state,
hyperparameter_tuning_job as gca_hyperparameter_tuning_job_compat,
machine_resources as gca_machine_resources_compat,
manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters_compat,
study as gca_study_compat,
model_deployment_monitoring_job as gca_model_deployment_monitoring_job_compat,
)
job_state_v1beta1 as gca_job_state_v1beta1,
model_monitoring_v1beta1 as gca_model_monitoring_v1beta1,
) # TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA

from google.cloud.aiplatform.constants import base as constants
from google.cloud.aiplatform import initializer
Expand All @@ -63,16 +63,23 @@

_LOGGER = base.Logger(__name__)

# TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA
_JOB_COMPLETE_STATES = (
gca_job_state.JobState.JOB_STATE_SUCCEEDED,
gca_job_state.JobState.JOB_STATE_FAILED,
gca_job_state.JobState.JOB_STATE_CANCELLED,
gca_job_state.JobState.JOB_STATE_PAUSED,
gca_job_state_v1beta1.JobState.JOB_STATE_SUCCEEDED,
gca_job_state_v1beta1.JobState.JOB_STATE_FAILED,
gca_job_state_v1beta1.JobState.JOB_STATE_CANCELLED,
gca_job_state_v1beta1.JobState.JOB_STATE_PAUSED,
)

_JOB_ERROR_STATES = (
gca_job_state.JobState.JOB_STATE_FAILED,
gca_job_state.JobState.JOB_STATE_CANCELLED,
gca_job_state_v1beta1.JobState.JOB_STATE_FAILED,
gca_job_state_v1beta1.JobState.JOB_STATE_CANCELLED,
)

# _block_until_complete wait times
Expand Down Expand Up @@ -583,6 +590,23 @@ def create(
(jobs.BatchPredictionJob):
Instantiated representation of the created batch prediction job.
"""
# TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA
if model_monitoring_objective_config:
from google.cloud.aiplatform.compat.types import (
batch_prediction_job_v1beta1 as gca_bp_job_compat,
io_v1beta1 as gca_io_compat,
explanation_v1beta1 as gca_explanation_v1beta1,
machine_resources_v1beta1 as gca_machine_resources_compat,
manual_batch_tuning_parameters_v1beta1 as gca_manual_batch_tuning_parameters_compat,
)
else:
from google.cloud.aiplatform.compat.types import (
batch_prediction_job as gca_bp_job_compat,
io as gca_io_compat,
explanation as gca_explanation_v1beta1,
machine_resources as gca_machine_resources_compat,
manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters_compat,
)
if not job_display_name:
job_display_name = cls._generate_display_name()

Expand Down Expand Up @@ -629,18 +653,7 @@ def create(
f"{predictions_format} is not an accepted prediction format "
f"type. Please choose from: {constants.BATCH_PREDICTION_OUTPUT_STORAGE_FORMATS}"
)
# TODO(b/242108750): remove temporary re-import statements once model monitoring for batch prediction is GA
if model_monitoring_objective_config:
from google.cloud.aiplatform.compat.types import (
io_v1beta1 as gca_io_compat,
batch_prediction_job_v1beta1 as gca_bp_job_compat,
model_monitoring_v1beta1 as gca_model_monitoring_compat,
)
else:
from google.cloud.aiplatform.compat.types import (
io as gca_io_compat,
batch_prediction_job as gca_bp_job_compat,
)

gapic_batch_prediction_job = gca_bp_job_compat.BatchPredictionJob()

# Required Fields
Expand Down Expand Up @@ -721,40 +734,44 @@ def create(
gapic_batch_prediction_job.generate_explanation = generate_explanation

if explanation_metadata or explanation_parameters:
gapic_batch_prediction_job.explanation_spec = (
gca_explanation_compat.ExplanationSpec(
metadata=explanation_metadata, parameters=explanation_parameters
)
explanation_spec = gca_explanation_compat.ExplanationSpec(
metadata=explanation_metadata, parameters=explanation_parameters
)
# TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA
if model_monitoring_objective_config:

# Model Monitoring
if model_monitoring_objective_config:
if model_monitoring_objective_config.drift_detection_config:
_LOGGER.info(
"Drift detection config is currently not supported for monitoring models associated with batch prediction jobs."
)
if model_monitoring_objective_config.explanation_config:
_LOGGER.info(
"XAI config is currently not supported for monitoring models associated with batch prediction jobs."
explanation_spec = gca_explanation_v1beta1.ExplanationSpec.deserialize(
gca_explanation_compat.ExplanationSpec.serialize(explanation_spec)
)
gapic_batch_prediction_job.model_monitoring_config = (
gca_model_monitoring_compat.ModelMonitoringConfig(
objective_configs=[
model_monitoring_objective_config.as_proto(config_for_bp=True)
],
alert_config=model_monitoring_alert_config.as_proto(
config_for_bp=True
),
analysis_instance_schema_uri=analysis_instance_schema_uri,
)
)
gapic_batch_prediction_job.explanation_spec = explanation_spec

empty_batch_prediction_job = cls._empty_constructor(
project=project,
location=location,
credentials=credentials,
)
if model_monitoring_objective_config:
empty_batch_prediction_job.api_client = (
empty_batch_prediction_job.api_client.select_version("v1beta1")
)

# TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA
if model_monitoring_objective_config:
model_monitoring_objective_config._config_for_bp = True
if model_monitoring_alert_config is not None:
model_monitoring_alert_config._config_for_bp = True
gapic_mm_config = gca_model_monitoring_v1beta1.ModelMonitoringConfig(
objective_configs=[model_monitoring_objective_config.as_proto()],
alert_config=model_monitoring_alert_config.as_proto()
if model_monitoring_alert_config is not None
else None,
analysis_instance_schema_uri=analysis_instance_schema_uri
if analysis_instance_schema_uri is not None
else None,
)
gapic_batch_prediction_job.model_monitoring_config = gapic_mm_config

# TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA
return cls._create(
empty_batch_prediction_job=empty_batch_prediction_job,
model_or_model_name=model_name,
Expand All @@ -763,11 +780,6 @@ def create(
sync=sync,
create_request_timeout=create_request_timeout,
)
# TODO(b/242108750): remove temporary re-import statements once model monitoring for batch prediction is GA
from google.cloud.aiplatform.compat.types import (
io as gca_io_compat,
batch_prediction_job as gca_bp_job_compat,
)

@classmethod
@base.optional_sync(return_input_arg="empty_batch_prediction_job")
Expand Down
17 changes: 8 additions & 9 deletions google/cloud/aiplatform/model_monitoring/alert.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
model_monitoring as gca_model_monitoring_v1,
)

# TODO(b/242108750): remove temporary re-import statements once model monitoring for batch prediction is GA
# TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA
from google.cloud.aiplatform_v1beta1.types import (
model_monitoring as gca_model_monitoring_v1beta1,
)
Expand All @@ -46,17 +46,16 @@ def __init__(
"""
self.enable_logging = enable_logging
self.user_emails = user_emails
self._config_for_bp = False

# TODO(b/242108750): remove temporary re-import statements once model monitoring for batch prediction is GA
def as_proto(self, config_for_bp: bool = False):
"""Returns EmailAlertConfig as a proto message.
# TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA
def as_proto(self) -> gca_model_monitoring.ModelMonitoringAlertConfig:
"""Converts EmailAlertConfig to a proto message.
Args:
config_for_bp (bool):
Optional. Set this parameter to True if the config object
is used for model monitoring on a batch prediction job.
Returns:
The GAPIC representation of the email alert config.
"""
if config_for_bp:
if self._config_for_bp:
gca_model_monitoring = gca_model_monitoring_v1beta1
else:
gca_model_monitoring = gca_model_monitoring_v1
Expand Down
67 changes: 43 additions & 24 deletions google/cloud/aiplatform/model_monitoring/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,16 @@
from typing import Optional, Dict, Union

from google.cloud.aiplatform_v1.types import (
io as gca_io_v1,
io as gca_io,
model_monitoring as gca_model_monitoring_v1,
)

# TODO(b/242108750): remove temporary re-import statements once model monitoring for batch prediction is GA
# TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA
from google.cloud.aiplatform_v1beta1.types import (
io as gca_io_v1beta1,
model_monitoring as gca_model_monitoring_v1beta1,
)

gca_model_monitoring = gca_model_monitoring_v1
gca_io = gca_io_v1

TF_RECORD = "tf-record"
CSV = "csv"
Expand Down Expand Up @@ -92,8 +90,14 @@ def __init__(
self.data_format = data_format
self.target_field = target_field

def as_proto(self):
"""Returns _SkewDetectionConfig as a proto message."""
def as_proto(
self,
) -> gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingPredictionSkewDetectionConfig:
"""Converts _SkewDetectionConfig to a proto message.
Returns:
The GAPIC representation of the skew detection config.
"""
skew_thresholds_mapping = {}
attribution_score_skew_thresholds_mapping = {}
default_skew_threshold = None
Expand Down Expand Up @@ -147,8 +151,14 @@ def __init__(
self.drift_thresholds = drift_thresholds
self.attribute_drift_thresholds = attribute_drift_thresholds

def as_proto(self):
"""Returns drift detection config as a proto message."""
def as_proto(
self,
) -> gca_model_monitoring.ModelMonitoringObjectiveConfig.PredictionDriftDetectionConfig:
"""Converts _DriftDetectionConfig to a proto message.
Returns:
The GAPIC representation of the drift detection config.
"""
drift_thresholds_mapping = {}
attribution_score_drift_thresholds_mapping = {}
if self.drift_thresholds is not None:
Expand Down Expand Up @@ -178,8 +188,14 @@ def __init__(self):
"""Base class for ExplanationConfig."""
self.enable_feature_attributes = False

def as_proto(self):
"""Returns _ExplanationConfig as a proto message."""
def as_proto(
self,
) -> gca_model_monitoring.ModelMonitoringObjectiveConfig.ExplanationConfig:
"""Converts _ExplanationConfig to a proto message.
Returns:
The GAPIC representation of the explanation config.
"""
return gca_model_monitoring.ModelMonitoringObjectiveConfig.ExplanationConfig(
enable_feature_attributes=self.enable_feature_attributes
)
Expand Down Expand Up @@ -208,22 +224,15 @@ def __init__(
self.skew_detection_config = skew_detection_config
self.drift_detection_config = drift_detection_config
self.explanation_config = explanation_config
# TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA
self._config_for_bp = False

# TODO(b/242108750): remove temporary re-import statements once model monitoring for batch prediction is GA
def as_proto(self, config_for_bp: bool = False):
"""Returns _SkewDetectionConfig as a proto message.
def as_proto(self) -> gca_model_monitoring.ModelMonitoringObjectiveConfig:
"""Converts _ObjectiveConfig to a proto message.
Args:
config_for_bp (bool):
Optional. Set this parameter to True if the config object
is used for model monitoring on a batch prediction job.
Returns:
The GAPIC representation of the objective config.
"""
if config_for_bp:
gca_io = gca_io_v1beta1
gca_model_monitoring = gca_model_monitoring_v1beta1
else:
gca_io = gca_io_v1
gca_model_monitoring = gca_model_monitoring_v1
training_dataset = None
if self.skew_detection_config is not None:
training_dataset = (
Expand Down Expand Up @@ -252,7 +261,8 @@ def as_proto(self, config_for_bp: bool = False):
else:
training_dataset.dataset = self.skew_detection_config.data_source

return gca_model_monitoring.ModelMonitoringObjectiveConfig(
# TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA
gapic_config = gca_model_monitoring.ModelMonitoringObjectiveConfig(
training_dataset=training_dataset,
training_prediction_skew_detection_config=self.skew_detection_config.as_proto()
if self.skew_detection_config is not None
Expand All @@ -264,6 +274,15 @@ def as_proto(self, config_for_bp: bool = False):
if self.explanation_config is not None
else None,
)
if self._config_for_bp:
return (
gca_model_monitoring_v1beta1.ModelMonitoringObjectiveConfig.deserialize(
gca_model_monitoring.ModelMonitoringObjectiveConfig.serialize(
gapic_config
)
)
)
return gapic_config


class SkewDetectionConfig(_SkewDetectionConfig):
Expand Down
Loading

0 comments on commit f87fef0

Please sign in to comment.