Skip to content

Commit

Permalink
feat: Add COMET and MetricX to the evaluation SDK
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 696878382
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Nov 15, 2024
1 parent c39334a commit 4135810
Show file tree
Hide file tree
Showing 7 changed files with 408 additions and 45 deletions.
163 changes: 144 additions & 19 deletions tests/unit/vertexai/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,16 @@
evaluation_steps=_EVALUATION_STEPS,
),
)
_TEST_COMET = pointwise_metric.Comet(
version="COMET_22_SRC_REF",
source_language="en",
target_language="zh",
)
_TEST_METRICX = pointwise_metric.MetricX(
version="METRICX_24_SRC",
source_language="en",
target_language="zh",
)
_TEST_METRICS = (
"exact_match",
"bleu",
Expand Down Expand Up @@ -139,6 +149,7 @@
"reference": ["test", "ref"],
"context": ["test", "context"],
"instruction": ["test", "instruction"],
"source": ["test", "source"],
}
)
_TEST_EVAL_DATASET_SINGLE = pd.DataFrame({"prompt": ["test_prompt", "text_prompt"]})
Expand Down Expand Up @@ -305,7 +316,7 @@
)
),
)
_MOCK_POINTEWISE_RESULT = (
_MOCK_POINTWISE_RESULT = (
gapic_evaluation_service_types.EvaluateInstancesResponse(
pointwise_metric_result=gapic_evaluation_service_types.PointwiseMetricResult(
score=5, explanation="explanation"
Expand Down Expand Up @@ -423,6 +434,29 @@
)
),
)
_EXPECTED_COLUMN_MAPPING = {
"context": "context",
"reference": "reference",
"response": "response",
"instruction": "instruction",
"prompt": "prompt",
"source": "source",
}
_MOCK_MODEL_BASED_TRANSLATION_RESULT = (
# The order of the responses is important.
gapic_evaluation_service_types.EvaluateInstancesResponse(
comet_result=gapic_evaluation_service_types.CometResult(score=0.1)
),
gapic_evaluation_service_types.EvaluateInstancesResponse(
metricx_result=gapic_evaluation_service_types.MetricxResult(score=5)
),
gapic_evaluation_service_types.EvaluateInstancesResponse(
comet_result=gapic_evaluation_service_types.CometResult(score=0.9)
),
gapic_evaluation_service_types.EvaluateInstancesResponse(
metricx_result=gapic_evaluation_service_types.MetricxResult(score=20)
),
)


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -465,16 +499,10 @@ def test_create_eval_task(self):
assert test_eval_task.dataset.equals(_TEST_EVAL_DATASET_ALL_INCLUDED)
assert test_eval_task.metrics == _TEST_METRICS
assert test_eval_task.experiment == _TEST_EXPERIMENT
assert test_eval_task._metric_column_mapping == {
"context": "context",
"reference": "reference",
"response": "response",
"instruction": "instruction",
"prompt": "prompt",
}
assert test_eval_task._metric_column_mapping == _EXPECTED_COLUMN_MAPPING

@pytest.mark.parametrize("api_transport", ["grpc", "rest"])
def test_compute_automatic_metrics(self, api_transport):
def test_compute_exact_match_metric(self, api_transport):
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
Expand Down Expand Up @@ -521,7 +549,7 @@ def test_compute_pointwise_metrics(self, api_transport):
test_eval_task = EvalTask(
dataset=_TEST_EVAL_DATASET_ALL_INCLUDED, metrics=test_metrics
)
mock_metric_results = _MOCK_POINTEWISE_RESULT
mock_metric_results = _MOCK_POINTWISE_RESULT
with mock.patch.object(
target=gapic_evaluation_services.EvaluationServiceClient,
attribute="evaluate_instances",
Expand All @@ -543,6 +571,7 @@ def test_compute_pointwise_metrics(self, api_transport):
"reference",
"test_pointwise_metric/score",
"test_pointwise_metric/explanation",
"source",
]
)
assert test_result.metrics_table["response"].equals(
Expand All @@ -567,7 +596,7 @@ def test_compute_pointwise_metrics_free_string(self):
metrics=[_TEST_POINTWISE_METRIC_FREE_STRING],
metric_column_mapping={"abc": "prompt"},
)
mock_metric_results = _MOCK_POINTEWISE_RESULT
mock_metric_results = _MOCK_POINTWISE_RESULT
with mock.patch.object(
target=gapic_evaluation_services.EvaluationServiceClient,
attribute="evaluate_instances",
Expand All @@ -589,6 +618,7 @@ def test_compute_pointwise_metrics_free_string(self):
"reference",
"test_pointwise_metric_str/score",
"test_pointwise_metric_str/explanation",
"source",
]
)
assert test_result.metrics_table["response"].equals(
Expand Down Expand Up @@ -695,6 +725,7 @@ def test_compute_pointwise_metrics_without_model_inference(self, api_transport):
"response",
"summarization_quality/score",
"summarization_quality/explanation",
"source",
]
)
assert list(
Expand All @@ -707,6 +738,48 @@ def test_compute_pointwise_metrics_without_model_inference(self, api_transport):
"explanation",
]

@pytest.mark.parametrize("api_transport", ["grpc", "rest"])
def test_compute_model_based_translation_metrics_without_model_inference(
self, api_transport
):
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
api_transport=api_transport,
)
test_metrics = [_TEST_COMET, _TEST_METRICX]
test_eval_task = EvalTask(
dataset=_TEST_EVAL_DATASET_ALL_INCLUDED, metrics=test_metrics
)

mock_metric_results = _MOCK_MODEL_BASED_TRANSLATION_RESULT
with mock.patch.object(
target=gapic_evaluation_services.EvaluationServiceClient,
attribute="evaluate_instances",
side_effect=mock_metric_results,
):
test_result = test_eval_task.evaluate()

assert test_result.summary_metrics["row_count"] == 2
assert test_result.summary_metrics["comet/mean"] == 0.5
assert test_result.summary_metrics["metricx/mean"] == 12.5
assert test_result.summary_metrics["comet/std"] == pytest.approx(0.5, 0.6)
assert test_result.summary_metrics["metricx/std"] == pytest.approx(10, 11)
assert set(test_result.metrics_table.columns.values) == set(
[
"context",
"instruction",
"reference",
"prompt",
"response",
"source",
"comet/score",
"metricx/score",
]
)
assert list(test_result.metrics_table["comet/score"].values) == [0.1, 0.9]
assert list(test_result.metrics_table["metricx/score"].values) == [5, 20]

@pytest.mark.parametrize("api_transport", ["grpc", "rest"])
def test_compute_automatic_metrics_with_custom_metric_spec(self, api_transport):
aiplatform.init(
Expand Down Expand Up @@ -940,6 +1013,7 @@ def test_compute_pairwise_metrics_without_model_inference(self, api_transport):
"instruction",
"pairwise_summarization_quality/pairwise_choice",
"pairwise_summarization_quality/explanation",
"source",
]
)
assert list(
Expand Down Expand Up @@ -1281,7 +1355,7 @@ def test_evaluate_response_column_and_model_not_provided(self):
):
test_eval_task.evaluate()

def test_evaluate_baseline_response_column_and_baseline_model_not_provided(
def test_evaluate_baseline_model_response_column_not_provided(
self,
):
test_eval_dataset = _TEST_EVAL_DATASET_SINGLE.copy(deep=True)
Expand All @@ -1302,6 +1376,63 @@ def test_evaluate_baseline_response_column_and_baseline_model_not_provided(
):
test_eval_task.evaluate()

def test_evaluate_response_column_not_provided(
self,
):
test_eval_dataset = _TEST_EVAL_DATASET_SINGLE
test_eval_task = EvalTask(
dataset=test_eval_dataset,
metrics=["exact_match"],
)
with pytest.raises(
KeyError,
match=re.escape(
(
"Required column `response` not found in the evaluation "
"dataset. The columns in the evaluation dataset are ['prompt']"
)
),
):
test_eval_task.evaluate()

def test_evaluate_reference_column_not_provided(
self,
):
test_eval_dataset = pd.DataFrame({"response": ["test", "text"]})
test_eval_task = EvalTask(
dataset=test_eval_dataset,
metrics=["exact_match"],
)
with pytest.raises(
KeyError,
match=re.escape(
(
"Required column `reference` not found in the evaluation "
"dataset. The columns in the evaluation dataset are ['response']"
)
),
):
test_eval_task.evaluate()

def test_evaluate_reference_or_source_column_not_provided(
self,
):
test_eval_dataset = pd.DataFrame({"response": ["test", "text"]})
test_eval_task = EvalTask(
dataset=test_eval_dataset,
metrics=[_TEST_COMET, _TEST_METRICX],
)
with pytest.raises(
KeyError,
match=re.escape(
(
"Required column `source` not found in the evaluation "
"dataset. The columns in the evaluation dataset are ['response']"
)
),
):
test_eval_task.evaluate()

def test_evaluate_invalid_prompt_template_variables(self):
test_eval_task = EvalTask(
dataset=_TEST_EVAL_DATASET_SINGLE,
Expand Down Expand Up @@ -1530,13 +1661,7 @@ def test_initialize_metric_column_mapping(self):
metric_column_mapping=metric_column_mapping,
dataset=_TEST_EVAL_DATASET_ALL_INCLUDED,
)
assert converted_metric_column_mapping == {
"prompt": "prompt",
"response": "response",
"reference": "reference",
"context": "context",
"instruction": "instruction",
}
assert converted_metric_column_mapping == _EXPECTED_COLUMN_MAPPING


class TestPromptTemplate:
Expand Down
93 changes: 71 additions & 22 deletions vertexai/evaluation/_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,33 +124,73 @@ def _validate_metric_column_map(
)


def _validate_dataset_for_automatic_metrics(
def _validate_dataset(
evaluation_run_config: evaluation_base.EvaluationRunConfig,
):
"""Validates the required columns exist in the dataset for automatic metrics."""
) -> None:
"""Validates the required columns exists in the dataset."""
_validate_response_column_required(evaluation_run_config)
_validate_reference_column_required(evaluation_run_config)
_validate_reference_or_source_column_required(evaluation_run_config)


def _validate_response_column_required(
evaluation_run_config: evaluation_base.EvaluationRunConfig,
) -> None:
"""Validates the response column exists in the dataset."""
for metric in evaluation_run_config.metrics:
if metric in constants.Metric.AUTOMATIC_METRIC_LIST or isinstance(
metric, metrics_base._TranslationMetric # pylint: disable=protected-access
):
_validate_column_provided(
evaluation_run_config,
constants.Dataset.MODEL_RESPONSE_COLUMN,
)


def _validate_reference_column_required(
evaluation_run_config: evaluation_base.EvaluationRunConfig,
) -> None:
"""Validates the reference column exists in the dataset."""
if set(evaluation_run_config.metrics).intersection(
set(constants.Metric.AUTOMATIC_METRIC_LIST)
):
if (
constants.Dataset.REFERENCE_COLUMN
not in evaluation_run_config.metric_column_mapping
):
evaluation_run_config.metric_column_mapping[
constants.Dataset.REFERENCE_COLUMN
] = constants.Dataset.REFERENCE_COLUMN
evaluation_run_config.validate_dataset_column(
constants.Dataset.REFERENCE_COLUMN
_validate_column_provided(
evaluation_run_config,
constants.Dataset.REFERENCE_COLUMN,
)
if (
constants.Dataset.MODEL_RESPONSE_COLUMN
not in evaluation_run_config.metric_column_mapping


def _validate_column_provided(
evaluation_run_config: evaluation_base.EvaluationRunConfig,
column_name: str,
) -> None:
"""Validates the required column exist in the dataset."""
if column_name not in evaluation_run_config.metric_column_mapping:
evaluation_run_config.metric_column_mapping[column_name] = column_name
evaluation_run_config.validate_dataset_column(column_name)


def _validate_reference_or_source_column_required(
evaluation_run_config: evaluation_base.EvaluationRunConfig,
) -> None:
"""Validates one of reference or source columns exist in the dataset."""
for metric in evaluation_run_config.metrics:
if isinstance(
metric, metrics_base._TranslationMetric # pylint: disable=protected-access
):
evaluation_run_config.metric_column_mapping[
constants.Dataset.MODEL_RESPONSE_COLUMN
] = constants.Dataset.MODEL_RESPONSE_COLUMN
evaluation_run_config.validate_dataset_column(
constants.Dataset.MODEL_RESPONSE_COLUMN
)
# Validate the reference column.
# This is optional if source column is provided.
try:
_validate_column_provided(
evaluation_run_config,
constants.Dataset.REFERENCE_COLUMN,
)
except KeyError:
# Reference column is optional. Checking for source column.
_validate_column_provided(
evaluation_run_config,
constants.Dataset.SOURCE_COLUMN,
)


def _compute_custom_metrics(
Expand Down Expand Up @@ -639,6 +679,15 @@ def _parse_metric_results_to_dataframe(
metrics_table,
constants.MetricResult.SCORE_KEY,
)
elif isinstance(
metric, metrics_base._TranslationMetric # pylint: disable=protected-access
):
_set_metric_table(
str(metric),
metric_results,
metrics_table,
constants.MetricResult.SCORE_KEY,
)
else:
_LOGGER.warning(
f"Metric name: {str(metric)} is not supported when parsing"
Expand Down Expand Up @@ -889,7 +938,7 @@ def evaluate(
evaluation_run_config=evaluation_run_config,
response_column_name=constants.Dataset.MODEL_RESPONSE_COLUMN,
)
_validate_dataset_for_automatic_metrics(evaluation_run_config)
_validate_dataset(evaluation_run_config)

pairwise_metric_exists = any(
isinstance(metric, pairwise_metric.PairwiseMetric)
Expand Down
Loading

0 comments on commit 4135810

Please sign in to comment.