Skip to content

Commit

Permalink
Fix result corruption when re-using calculators (#206)
Browse files Browse the repository at this point in the history
* Add test cases for re-using calculators

* Fix result corruption when re-using calculators

* Fix incorrect calculator being used in drift test

* Fix lint error
  • Loading branch information
michael-nml authored Feb 9, 2023
1 parent b9bb9e0 commit e15d76c
Show file tree
Hide file tree
Showing 10 changed files with 91 additions and 3 deletions.
1 change: 0 additions & 1 deletion nannyml/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,6 @@ def __init__(
)
self.timestamp_column_name = timestamp_column_name


self.result: Optional[Result] = None

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def _calculate(self, data: pd.DataFrame, *args, **kwargs) -> Result:
continuous_column_names=self.continuous_column_names,
)
else:
self.result = self.result.filter(period='reference') # type: ignore
self.result.data = pd.concat([self.result.data, res]).reset_index(drop=True)

return self.result
Expand Down
1 change: 1 addition & 0 deletions nannyml/performance_calculation/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def _calculate(self, data: pd.DataFrame, *args, **kwargs) -> Result:
problem_type=self.problem_type,
)
else:
self.result = self.result.filter(period='reference') # type: ignore
self.result.data = pd.concat([self.result.data, res]).reset_index(drop=True)
self.result.analysis_data = data.copy()

Expand Down
1 change: 1 addition & 0 deletions nannyml/performance_estimation/confidence_based/cbpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def _estimate(self, data: pd.DataFrame, *args, **kwargs) -> Result:
problem_type=self.problem_type,
)
else:
self.result = self.result.filter(period='reference') # type: ignore
self.result.data = pd.concat([self.result.data, res]).reset_index(drop=True)

return self.result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ def _estimate(self, data: pd.DataFrame, *args, **kwargs) -> ResultType:
hyperparameters=self.hyperparameters,
)
else:
self.result = self.result.filter(period='reference') # type: ignore
self.result.data = pd.concat([self.result.data, res]).reset_index(drop=True)

return self.result
Expand Down
10 changes: 10 additions & 0 deletions tests/drift/test_data_reconstruction_drift.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,3 +533,13 @@ def test_result_comparison_to_cbpe_plots_raise_no_exceptions(sample_drift_data):
_ = result.compare(result2).plot()
except Exception as exc:
pytest.fail(f"an unexpected exception occurred: {exc}")


def test_data_reconstruction_drift_calculator_returns_distinct_but_consistent_results_when_reused(sample_drift_data):
ref_data = sample_drift_data.loc[sample_drift_data['period'] == 'reference']
sut = DataReconstructionDriftCalculator(column_names=['f1', 'f2', 'f3', 'f4']).fit(ref_data)
result1 = sut.calculate(data=sample_drift_data)
result2 = sut.calculate(data=sample_drift_data)

assert result1 is not result2
pd.testing.assert_frame_equal(result1.to_df(), result2.to_df())
24 changes: 23 additions & 1 deletion tests/drift/test_drift.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,13 @@ def test_result_comparison_to_cbpe_plots_raise_no_exceptions(sample_drift_data):
ref_data = sample_drift_data.loc[sample_drift_data['period'] == 'reference']
ana_data = sample_drift_data.loc[sample_drift_data['period'] == 'analysis']

calc = DataReconstructionDriftCalculator(column_names=['f1', 'f2', 'f3', 'f4']).fit(ref_data)
calc = UnivariateDriftCalculator(
column_names=['f1', 'f2', 'f3', 'f4'],
continuous_methods=['kolmogorov_smirnov'],
categorical_methods=['chi2'],
timestamp_column_name='timestamp',
calculation_method='auto',
).fit(ref_data)
result = calc.calculate(ana_data)

calc2 = CBPE(
Expand All @@ -586,3 +592,19 @@ def test_result_comparison_to_cbpe_plots_raise_no_exceptions(sample_drift_data):
_ = result.compare(result2).plot()
except Exception as exc:
pytest.fail(f"an unexpected exception occurred: {exc}")


def test_univariate_drift_calculator_returns_distinct_but_consistent_results_when_reused(sample_drift_data):
ref_data = sample_drift_data.loc[sample_drift_data['period'] == 'reference']
sut = UnivariateDriftCalculator(
column_names=['f1', 'f3'],
timestamp_column_name='timestamp',
continuous_methods=['kolmogorov_smirnov'],
categorical_methods=['chi2'],
)
sut.fit(ref_data)
result1 = sut.calculate(data=sample_drift_data)
result2 = sut.calculate(data=sample_drift_data)

assert result1 is not result2
pd.testing.assert_frame_equal(result1.to_df(), result2.to_df())
16 changes: 15 additions & 1 deletion tests/performance_calculation/test_performance_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def performance_calculator() -> PerformanceCalculator:
timestamp_column_name='timestamp',
y_pred_proba='y_pred_proba',
y_pred='y_pred',
y_true='y_true',
y_true='work_home_actual',
metrics=['roc_auc', 'f1'],
problem_type='classification_binary',
)
Expand Down Expand Up @@ -256,3 +256,17 @@ def test_binary_classification_result_plots_raise_no_exceptions(calc_args, plot_
_ = sut.plot(**plot_args)
except Exception as exc:
pytest.fail(f"an unexpected exception occurred: {exc}")


def test_calculator_returns_distinct_but_consistent_results_when_reused(data, performance_calculator):
reference, analysis, target = data

data = analysis.merge(target, on='identifier')
performance_calculator.fit(reference)
result1 = performance_calculator.calculate(data)
result2 = performance_calculator.calculate(data)

# Checks two distinct results are returned. Previously there was a bug causing the previous result instance to be
# modified on subsequent estimates.
assert result1 is not result2
pd.testing.assert_frame_equal(result1.to_df(), result2.to_df())
22 changes: 22 additions & 0 deletions tests/performance_estimation/CBPE/test_cbpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,3 +504,25 @@ def test_cbpe_for_multiclass_classification_chunked_by_period_should_include_var

assert (metric, 'sampling_error') in sut.columns
assert np.array_equal(np.round(sut.loc[:, (metric, 'sampling_error')], 4), np.round(sampling_error, 4))


def test_cbpe_returns_distinct_but_consistent_results_when_reused(binary_classification_data):
reference, analysis = binary_classification_data

sut = CBPE(
# timestamp_column_name='timestamp',
chunk_size=50_000,
y_true='work_home_actual',
y_pred='y_pred',
y_pred_proba='y_pred_proba',
metrics=['roc_auc'],
problem_type='classification_binary',
)
sut.fit(reference)
result1 = sut.estimate(analysis)
result2 = sut.estimate(analysis)

# Checks two distinct results are returned. Previously there was a bug causing the previous result instance to be
# modified on subsequent estimates.
assert result1 is not result2
pd.testing.assert_frame_equal(result1.to_df(), result2.to_df())
17 changes: 17 additions & 0 deletions tests/performance_estimation/DLE/test_dle.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,3 +318,20 @@ def test_binary_classification_result_plots_raise_no_exceptions(estimator_args,
_ = sut.plot(**plot_args)
except Exception as exc:
pytest.fail(f"an unexpected exception occurred: {exc}")


def test_dle_returns_distinct_but_consistent_results_when_reused(regression_data, direct_error_estimator):
reference, analysis = regression_data

# Get rid of negative values for log based metrics
reference = reference[~(reference['y_pred'] < 0)]
analysis = analysis[~(analysis['y_pred'] < 0)]

direct_error_estimator.fit(reference)
estimate1 = direct_error_estimator.estimate(analysis)
estimate2 = direct_error_estimator.estimate(analysis)

# Checks two distinct results are returned. Previously there was a bug causing the previous result instance to be
# modified on subsequent estimates.
assert estimate1 is not estimate2
pd.testing.assert_frame_equal(estimate1.to_df(), estimate2.to_df())

0 comments on commit e15d76c

Please sign in to comment.