Skip to content

Commit

Permalink
refactor(sanitizer): reuse utility function to test input types (#192)
Browse files Browse the repository at this point in the history
  • Loading branch information
nicrie authored Aug 25, 2024
1 parent 1f38f48 commit db03601
Showing 1 changed file with 8 additions and 11 deletions.
19 changes: 8 additions & 11 deletions xeofs/preprocessing/sanitizer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Optional, Dict
from typing_extensions import Self
from typing import Dict, Optional

import dask
import xarray as xr
from dask.base import compute
from typing_extensions import Self

from ..utils.data_types import Data, DataArray, Dims
from ..utils.sanity_checks import assert_single_dataarray
from .transformer import Transformer
from ..utils.data_types import Dims, DataArray, Data


class Sanitizer(Transformer):
Expand All @@ -30,10 +31,6 @@ def get_serialization_attrs(self) -> Dict:
is_valid_feature=self.is_valid_feature,
)

def _check_input_type(self, X) -> None:
if not isinstance(X, xr.DataArray):
raise ValueError("Input must be an xarray DataArray")

def _check_input_dims(self, X) -> None:
if set(X.dims) != set([self.sample_name, self.feature_name]):
raise ValueError(
Expand Down Expand Up @@ -68,7 +65,7 @@ def fit(
**kwargs,
) -> Self:
# Check if input is a DataArray
self._check_input_type(X)
assert_single_dataarray(X)

# Check if input has the correct dimensions
self._check_input_dims(X)
Expand All @@ -84,7 +81,7 @@ def fit(

def transform(self, X: DataArray) -> DataArray:
# Check if input is a DataArray
self._check_input_type(X)
assert_single_dataarray(X)

# Check if input has the correct dimensions
self._check_input_dims(X)
Expand All @@ -103,7 +100,7 @@ def transform(self, X: DataArray) -> DataArray:
X_valid_features,
X_valid_samples,
X_valid_features_per_sample,
) = dask.compute(
) = compute(
self.is_valid_feature,
X_valid_features,
X_valid_samples,
Expand Down

0 comments on commit db03601

Please sign in to comment.