Skip to content

Commit

Permalink
Merge pull request #196 from fooof-tools/df
Browse files Browse the repository at this point in the history
[ENH] - Add support for converting model results, including to DFs
  • Loading branch information
TomDonoghue authored Jun 29, 2023
2 parents 208fc09 + 7ba9834 commit 9fd6893
Show file tree
Hide file tree
Showing 10 changed files with 246 additions and 4 deletions.
1 change: 1 addition & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ There are also optional dependencies, which are not required for model fitting i

- `matplotlib <https://github.com/matplotlib/matplotlib>`_ is needed to visualize data and model fits
- `tqdm <https://github.com/tqdm/tqdm>`_ is needed to print progress bars when fitting many models
- `pandas <https://github.com/pandas-dev/pandas>`_ is needed to for exporting model fit results to dataframes
- `pytest <https://github.com/pytest-dev/pytest>`_ is needed to run the test suite locally

We recommend using the `Anaconda <https://www.anaconda.com/distribution/>`_ distribution to manage these requirements.
Expand Down
106 changes: 106 additions & 0 deletions fooof/data/conversions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""Conversion functions for organizing model results into alternate representations."""

import numpy as np

from fooof import Bands
from fooof.core.funcs import infer_ap_func
from fooof.core.info import get_ap_indices, get_peak_indices
from fooof.core.modutils import safe_import, check_dependency
from fooof.analysis.periodic import get_band_peak

pd = safe_import('pandas')

###################################################################################################
###################################################################################################

def model_to_dict(fit_results, peak_org):
"""Convert model fit results to a dictionary.
Parameters
----------
fit_results : FOOOFResults
Results of a model fit.
peak_org : int or Bands
How to organize peaks.
If int, extracts the first n peaks.
If Bands, extracts peaks based on band definitions.
Returns
-------
dict
Model results organized into a dictionary.
"""

fr_dict = {}

# aperiodic parameters
for label, param in zip(get_ap_indices(infer_ap_func(fit_results.aperiodic_params)),
fit_results.aperiodic_params):
fr_dict[label] = param

# periodic parameters
peaks = fit_results.peak_params

if isinstance(peak_org, int):

if len(peaks) < peak_org:
nans = [np.array([np.nan] * 3) for ind in range(peak_org-len(peaks))]
peaks = np.vstack((peaks, nans))

for ind, peak in enumerate(peaks[:peak_org, :]):
for pe_label, pe_param in zip(get_peak_indices(), peak):
fr_dict[pe_label.lower() + '_' + str(ind)] = pe_param

elif isinstance(peak_org, Bands):
for band, f_range in peak_org:
for label, param in zip(get_peak_indices(), get_band_peak(peaks, f_range)):
fr_dict[band + '_' + label.lower()] = param

# goodness-of-fit metrics
fr_dict['error'] = fit_results.error
fr_dict['r_squared'] = fit_results.r_squared

return fr_dict

@check_dependency(pd, 'pandas')
def model_to_dataframe(fit_results, peak_org):
"""Convert model fit results to a dataframe.
Parameters
----------
fit_results : FOOOFResults
Results of a model fit.
peak_org : int or Bands
How to organize peaks.
If int, extracts the first n peaks.
If Bands, extracts peaks based on band definitions.
Returns
-------
pd.Series
Model results organized into a dataframe.
"""

return pd.Series(model_to_dict(fit_results, peak_org))


@check_dependency(pd, 'pandas')
def group_to_dataframe(fit_results, peak_org):
"""Convert a group of model fit results into a dataframe.
Parameters
----------
fit_results : list of FOOOFResults
List of FOOOFResults objects.
peak_org : int or Bands
How to organize peaks.
If int, extracts the first n peaks.
If Bands, extracts peaks based on band definitions.
Returns
-------
pd.DataFrame
Model results organized into a dataframe.
"""

return pd.DataFrame([model_to_dataframe(f_res, peak_org) for f_res in fit_results])
20 changes: 20 additions & 0 deletions fooof/objs/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
from fooof.utils.data import trim_spectrum
from fooof.utils.params import compute_gauss_std
from fooof.data import FOOOFResults, FOOOFSettings, FOOOFMetaData
from fooof.data.conversions import model_to_dataframe
from fooof.sim.gen import gen_freqs, gen_aperiodic, gen_periodic, gen_model

###################################################################################################
Expand Down Expand Up @@ -716,6 +717,25 @@ def set_check_data_mode(self, check_data):
self._check_data = check_data


def to_df(self, peak_org):
"""Convert and extract the model results as a pandas object.
Parameters
----------
peak_org : int or Bands
How to organize peaks.
If int, extracts the first n peaks.
If Bands, extracts peaks based on band definitions.
Returns
-------
pd.Series
Model results organized into a pandas object.
"""

return model_to_dataframe(self.get_results(), peak_org)


def _check_width_limits(self):
"""Check and warn about peak width limits / frequency resolution interaction."""

Expand Down
20 changes: 20 additions & 0 deletions fooof/objs/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from fooof.core.strings import gen_results_fg_str
from fooof.core.io import save_fg, load_jsonlines
from fooof.core.modutils import copy_doc_func_to_method, safe_import
from fooof.data.conversions import group_to_dataframe

###################################################################################################
###################################################################################################
Expand Down Expand Up @@ -541,6 +542,25 @@ def print_results(self, concise=False):
print(gen_results_fg_str(self, concise))


def to_df(self, peak_org):
"""Convert and extract the model results as a pandas object.
Parameters
----------
peak_org : int or Bands
How to organize peaks.
If int, extracts the first n peaks.
If Bands, extracts peaks based on band definitions.
Returns
-------
pd.DataFrame
Model results organized into a pandas object.
"""

return group_to_dataframe(self.get_results(), peak_org)


def _fit(self, *args, **kwargs):
"""Create an alias to FOOOF.fit for FOOOFGroup object, for internal use."""

Expand Down
11 changes: 10 additions & 1 deletion fooof/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from fooof.core.modutils import safe_import

from fooof.tests.tutils import get_tfm, get_tfg, get_tbands
from fooof.tests.tutils import get_tfm, get_tfg, get_tbands, get_tresults
from fooof.tests.settings import (BASE_TEST_FILE_PATH, TEST_DATA_PATH,
TEST_REPORTS_PATH, TEST_PLOTS_PATH)

Expand Down Expand Up @@ -48,7 +48,16 @@ def tfg():
def tbands():
yield get_tbands()

@pytest.fixture(scope='session')
def tresults():
yield get_tresults()

@pytest.fixture(scope='session')
def skip_if_no_mpl():
if not safe_import('matplotlib'):
pytest.skip('Matplotlib not available: skipping test.')

@pytest.fixture(scope='session')
def skip_if_no_pandas():
if not safe_import('pandas'):
pytest.skip('Pandas not available: skipping test.')
53 changes: 53 additions & 0 deletions fooof/tests/data/test_conversions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""Tests for the fooof.data.conversions."""

from copy import deepcopy

import numpy as np

from fooof.core.modutils import safe_import
pd = safe_import('pandas')

from fooof.data.conversions import *

###################################################################################################
###################################################################################################

def test_model_to_dict(tresults, tbands):

out = model_to_dict(tresults, peak_org=1)
assert isinstance(out, dict)
assert 'cf_0' in out
assert out['cf_0'] == tresults.peak_params[0, 0]
assert not 'cf_1' in out

out = model_to_dict(tresults, peak_org=2)
assert 'cf_0' in out
assert 'cf_1' in out
assert out['cf_1'] == tresults.peak_params[1, 0]

out = model_to_dict(tresults, peak_org=3)
assert 'cf_2' in out
assert np.isnan(out['cf_2'])

out = model_to_dict(tresults, peak_org=tbands)
assert 'alpha_cf' in out

def test_model_to_dataframe(tresults, tbands, skip_if_no_pandas):

for peak_org in [1, 2, 3]:
out = model_to_dataframe(tresults, peak_org=peak_org)
assert isinstance(out, pd.Series)

out = model_to_dataframe(tresults, peak_org=tbands)
assert isinstance(out, pd.Series)

def test_group_to_dataframe(tresults, tbands, skip_if_no_pandas):

fit_results = [deepcopy(tresults), deepcopy(tresults), deepcopy(tresults)]

for peak_org in [1, 2, 3]:
out = group_to_dataframe(fit_results, peak_org=peak_org)
assert isinstance(out, pd.DataFrame)

out = group_to_dataframe(fit_results, peak_org=tbands)
assert isinstance(out, pd.DataFrame)
12 changes: 11 additions & 1 deletion fooof/tests/objs/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
from fooof.core.items import OBJ_DESC
from fooof.core.errors import FitError
from fooof.core.utils import group_three
from fooof.core.modutils import safe_import
from fooof.core.errors import DataError, NoDataError, InconsistentDataError
from fooof.sim import gen_freqs, gen_power_spectrum
from fooof.data import FOOOFSettings, FOOOFMetaData, FOOOFResults
from fooof.core.errors import DataError, NoDataError, InconsistentDataError

pd = safe_import('pandas')

from fooof.tests.settings import TEST_DATA_PATH
from fooof.tests.tutils import get_tfm, plot_test
Expand Down Expand Up @@ -425,3 +428,10 @@ def test_fooof_check_data():
# Model fitting should execute, but return a null model fit, given the NaNs, without failing
tfm.fit()
assert not tfm.has_model

def test_fooof_to_df(tfm, tbands, skip_if_no_pandas):

df1 = tfm.to_df(2)
assert isinstance(df1, pd.Series)
df2 = tfm.to_df(tbands)
assert isinstance(df2, pd.Series)
13 changes: 12 additions & 1 deletion fooof/tests/objs/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@
import numpy as np
from numpy.testing import assert_equal

from fooof.data import FOOOFResults
from fooof.core.items import OBJ_DESC
from fooof.core.modutils import safe_import
from fooof.core.errors import DataError, NoDataError, InconsistentDataError
from fooof.data import FOOOFResults
from fooof.sim import gen_group_power_spectra

pd = safe_import('pandas')

from fooof.tests.settings import TEST_DATA_PATH
from fooof.tests.tutils import default_group_params, plot_test

Expand Down Expand Up @@ -349,3 +353,10 @@ def test_fg_get_group(tfg):
# Check that the correct results are extracted
assert [tfg.group_results[ind] for ind in inds1] == nfg1.group_results
assert [tfg.group_results[ind] for ind in inds2] == nfg2.group_results

def test_fg_to_df(tfg, tbands, skip_if_no_pandas):

df1 = tfg.to_df(2)
assert isinstance(df1, pd.DataFrame)
df2 = tfg.to_df(tbands)
assert isinstance(df2, pd.DataFrame)
11 changes: 11 additions & 0 deletions fooof/tests/tutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

from functools import wraps

import numpy as np

from fooof.bands import Bands
from fooof.data import FOOOFResults
from fooof.objs import FOOOF, FOOOFGroup
from fooof.core.modutils import safe_import
from fooof.sim.params import param_sampler
Expand Down Expand Up @@ -43,6 +46,14 @@ def get_tbands():

return Bands({'theta' : (4, 8), 'alpha' : (8, 12), 'beta' : (13, 30)})

def get_tresults():
"""Get a FOOOFResults objet, for testing."""

return FOOOFResults(aperiodic_params=np.array([1.0, 1.00]),
peak_params=np.array([[10.0, 1.25, 2.0], [20.0, 1.0, 3.0]]),
r_squared=0.97, error=0.01,
gaussian_params=np.array([[10.0, 1.25, 1.0], [20.0, 1.0, 1.5]]))

def default_group_params():
"""Create default parameters for generating a test group of power spectra."""

Expand Down
3 changes: 2 additions & 1 deletion optional-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
matplotlib
tqdm
tqdm
pandas

0 comments on commit 9fd6893

Please sign in to comment.