Skip to content

Commit

Permalink
feat: add POP analysis (#215)
Browse files Browse the repository at this point in the history
  • Loading branch information
nicrie authored Sep 3, 2024
1 parent dba4478 commit 5a0cc95
Show file tree
Hide file tree
Showing 13 changed files with 790 additions and 1 deletion.
44 changes: 44 additions & 0 deletions docs/api_reference/_autosummary/xeofs.single.POP.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
POP
===

.. currentmodule:: xeofs.single

.. autoclass:: POP
:members:
:inherited-members:


.. automethod:: __init__


.. rubric:: Methods

.. autosummary::

~POP.__init__
~POP.components
~POP.components_amplitude
~POP.components_phase
~POP.compute
~POP.damping_times
~POP.deserialize
~POP.eigenvalues
~POP.fit
~POP.fit_transform
~POP.get_params
~POP.get_serialization_attrs
~POP.inverse_transform
~POP.load
~POP.periods
~POP.save
~POP.scores
~POP.scores_amplitude
~POP.scores_phase
~POP.serialize
~POP.transform






1 change: 1 addition & 0 deletions docs/api_reference/single_set_analysis.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Methods that investigate relationships or patterns between variables within a si
~xeofs.single.ComplexEOF
~xeofs.single.HilbertEOF
~xeofs.single.ExtendedEOF
~xeofs.single.POP
~xeofs.single.OPA
~xeofs.single.GWPCA
~xeofs.single.SparsePCA
Expand Down
Binary file modified docs/auto_examples/auto_examples_jupyter.zip
Binary file not shown.
Binary file modified docs/auto_examples/auto_examples_python.zip
Binary file not shown.
219 changes: 219 additions & 0 deletions tests/models/single/test_pop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
import numpy as np
import pytest
import xarray as xr

from xeofs.single import POP


def test_init():
"""Tests the initialization of the POP class"""
pop = POP(n_modes=5, standardize=True, use_coslat=True)

# Assert preprocessor has been initialized
assert hasattr(pop, "_params")
assert hasattr(pop, "preprocessor")
assert hasattr(pop, "whitener")


def test_fit(mock_data_array):
pop = POP()
pop.fit(mock_data_array, "time")


def test_eigenvalues(mock_data_array):
pop = POP()
pop.fit(mock_data_array, "time")

eigenvalues = pop.eigenvalues()
assert isinstance(eigenvalues, xr.DataArray)


def test_damping_times(mock_data_array):
pop = POP()
pop.fit(mock_data_array, "time")

times = pop.damping_times()
assert isinstance(times, xr.DataArray)


def test_periods(mock_data_array):
pop = POP()
pop.fit(mock_data_array, "time")

periods = pop.periods()
assert isinstance(periods, xr.DataArray)


def test_components(mock_data_array):
"""Tests the components method of the POP class"""
sample_dim = ("time",)
pop = POP()
pop.fit(mock_data_array, sample_dim)

# Test components method
components = pop.components()
feature_dims = tuple(set(mock_data_array.dims) - set(sample_dim))
assert isinstance(components, xr.DataArray), "Components is not a DataArray"
assert set(components.dims) == set(
("mode",) + feature_dims
), "Components does not have the right feature dimensions"


def test_scores(mock_data_array):
"""Tests the scores method of the POP class"""
sample_dim = ("time",)
pop = POP()
pop.fit(mock_data_array, sample_dim)

# Test scores method
scores = pop.scores()
assert isinstance(scores, xr.DataArray), "Scores is not a DataArray"
assert set(scores.dims) == set(
(sample_dim + ("mode",))
), "Scores does not have the right dimensions"


def test_transform(mock_data_array):
"""Test projecting new unseen data onto the POPs"""
dim = ("time",)
data = mock_data_array.isel({dim[0]: slice(1, None)})
new_data = mock_data_array.isel({dim[0]: slice(0, 1)})

# Create a xarray DataArray with random data
model = POP(n_modes=2, solver="full")
model.fit(data, dim)
scores = model.scores()

# Project data onto the components
projections = model.transform(data)

# Check that the projection has the right dimensions
assert projections.dims == scores.dims, "Projection has wrong dimensions" # type: ignore

# Check that the projection has the right data type
assert isinstance(projections, xr.DataArray), "Projection is not a DataArray"

# Check that the projection has the right name
assert projections.name == "scores", "Projection has wrong name: {}".format(
projections.name
)

# Check that the projection's data is the same as the scores
np.testing.assert_allclose(
scores.sel(mode=slice(1, 3)), projections.sel(mode=slice(1, 3)), rtol=1e-3
)

# Project unseen data onto the components
new_projections = model.transform(new_data)

# Check that the projection has the right dimensions
assert new_projections.dims == scores.dims, "Projection has wrong dimensions" # type: ignore

# Check that the projection has the right data type
assert isinstance(new_projections, xr.DataArray), "Projection is not a DataArray"

# Check that the projection has the right name
assert new_projections.name == "scores", "Projection has wrong name: {}".format(
new_projections.name
)

# Ensure that the new projections are not NaNs
assert np.all(new_projections.notnull().values), "New projections contain NaNs"


def test_inverse_transform(mock_data_array):
"""Test inverse_transform method in POP class."""

dim = ("time",)
# instantiate the POP class with necessary parameters
pop = POP(n_modes=20, standardize=True)

# fit the POP model
pop.fit(mock_data_array, dim=dim)
scores = pop.scores()

# Test with single mode
scores_selection = scores.sel(mode=1)
X_rec_1 = pop.inverse_transform(scores_selection)
assert isinstance(X_rec_1, xr.DataArray)

# Test with single mode as list
scores_selection = scores.sel(mode=[1])
X_rec_1_list = pop.inverse_transform(scores_selection)
assert isinstance(X_rec_1_list, xr.DataArray)

# Single mode and list should be equal
xr.testing.assert_allclose(X_rec_1, X_rec_1_list)

# Test with all modes
X_rec = pop.inverse_transform(scores)
assert isinstance(X_rec, xr.DataArray)

# Check that the reconstructed data has the same dimensions as the original data
assert set(X_rec.dims) == set(mock_data_array.dims)


@pytest.mark.parametrize("engine", ["zarr"])
def test_save_load(mock_data_array, tmp_path, engine):
"""Test save/load methods in POP class, ensuring that we can
roundtrip the model and get the same results when transforming
data."""
# NOTE: netcdf4 does not support complex data types, so we use only zarr here
dim = "time"
original = POP()
original.fit(mock_data_array, dim)

# Save the POP model
original.save(tmp_path / "pop", engine=engine)

# Check that the POP model has been saved
assert (tmp_path / "pop").exists()

# Recreate the model from saved file
loaded = POP.load(tmp_path / "pop", engine=engine)

# Check that the params and DataContainer objects match
assert original.get_params() == loaded.get_params()
assert all([key in loaded.data for key in original.data])
for key in original.data:
if original.data._allow_compute[key]:
assert loaded.data[key].equals(original.data[key])
else:
# but ensure that input data is not saved by default
assert loaded.data[key].size <= 1
assert loaded.data[key].attrs["placeholder"] is True

# Test that the recreated model can be used to transform new data
assert np.allclose(
original.transform(mock_data_array), loaded.transform(mock_data_array)
)

# The loaded model should also be able to inverse_transform new data
assert np.allclose(
original.inverse_transform(original.scores()),
loaded.inverse_transform(loaded.scores()),
)


def test_serialize_deserialize_dataarray(mock_data_array):
"""Test roundtrip serialization when the model is fit on a DataArray."""
dim = "time"
model = POP()
model.fit(mock_data_array, dim)
dt = model.serialize()
rebuilt_model = POP.deserialize(dt)
assert np.allclose(
model.transform(mock_data_array), rebuilt_model.transform(mock_data_array)
)


def test_serialize_deserialize_dataset(mock_dataset):
"""Test roundtrip serialization when the model is fit on a Dataset."""
dim = "time"
model = POP()
model.fit(mock_dataset, dim)
dt = model.serialize()
rebuilt_model = POP.deserialize(dt)
assert np.allclose(
model.transform(mock_dataset), rebuilt_model.transform(mock_dataset)
)
1 change: 1 addition & 0 deletions xeofs/cross/base_model_cross_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ def fit(

if self.get_params()["compute"]:
self.data.compute()
self._post_compute()

return self

Expand Down
3 changes: 3 additions & 0 deletions xeofs/linalg/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .utils import total_variance

__all__ = ["total_variance"]
14 changes: 13 additions & 1 deletion xeofs/linalg/_numpy/_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
solver: str = "auto",
random_state: np.random.Generator | int | None = None,
solver_kwargs: dict = {},
is_complex: bool | str = "auto",
):
sanity_check_n_modes(n_modes)
self.is_based_on_variance = True if isinstance(n_modes, float) else False
Expand All @@ -83,6 +84,7 @@ def __init__(
self.solver = solver
self.random_state = random_state
self.solver_kwargs = solver_kwargs
self.is_complex = is_complex

def _get_n_modes_precompute(self, rank: int) -> int:
if self.is_based_on_variance:
Expand Down Expand Up @@ -122,7 +124,17 @@ def fit_transform(self, X):
# Source: https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html

use_dask = True if isinstance(X, DaskArray) else False
use_complex = True if np.iscomplexobj(X) else False

match self.is_complex:
case bool():
use_complex = self.is_complex
case "auto":
use_complex = True if np.iscomplexobj(X) else False
case _:
raise ValueError(
f"Unrecognized value for is_complex '{self.is_complex}'. "
"Valid options are True, False, and 'auto'."
)

is_small_data = max(X.shape) < 500

Expand Down
3 changes: 3 additions & 0 deletions xeofs/linalg/svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class SVD:
def __init__(
self,
n_modes: int | float | str,
is_complex: bool | str = "auto",
init_rank_reduction: float = 0.3,
flip_signs: bool = True,
solver: str = "auto",
Expand All @@ -20,6 +21,7 @@ def __init__(
feature_name: str = "feature",
):
self.n_modes = n_modes
self.is_complex = is_complex
self.init_rank_reduction = init_rank_reduction
self.flip_signs = flip_signs
self.solver = solver
Expand Down Expand Up @@ -54,6 +56,7 @@ def fit_transform(self, X: DataArray) -> tuple[DataArray, DataArray, DataArray]:
flip_signs=self.flip_signs,
solver=self.solver,
random_state=self.random_state,
is_complex=self.is_complex,
**self.solver_kwargs,
)
U, s, V = xr.apply_ufunc(
Expand Down
6 changes: 6 additions & 0 deletions xeofs/linalg/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from ..utils.data_types import DataArray


def total_variance(X: DataArray, dim: str) -> DataArray:
"""Compute the total variance of the centered data."""
return (X * X.conj()).sum() / (X[dim].size - 1)
2 changes: 2 additions & 0 deletions xeofs/single/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
from .eof_rotator import ComplexEOFRotator, EOFRotator, HilbertEOFRotator
from .gwpca import GWPCA
from .opa import OPA
from .pop import POP
from .sparse_pca import SparsePCA

__all__ = [
"EOF",
"ExtendedEOF",
"SparsePCA",
"POP",
"OPA",
"GWPCA",
"ComplexEOF",
Expand Down
1 change: 1 addition & 0 deletions xeofs/single/base_model_single_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def fit(

if self._params["compute"]:
self.data.compute()
self._post_compute()

return self

Expand Down
Loading

0 comments on commit 5a0cc95

Please sign in to comment.