Skip to content

Commit

Permalink
Merge pull request #257 from mwprestonjr/flexible_report
Browse files Browse the repository at this point in the history
extend reports
  • Loading branch information
TomDonoghue authored Jun 30, 2023
2 parents 6a1ba83 + 454fcbd commit b942280
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 7 deletions.
6 changes: 4 additions & 2 deletions fooof/core/reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
###################################################################################################

@check_dependency(plt, 'matplotlib')
def save_report_fm(fm, file_name, file_path=None, plt_log=False, add_settings=True):
def save_report_fm(fm, file_name, file_path=None, plt_log=False, add_settings=True, **plot_kwargs):
"""Generate and save out a PDF report for a power spectrum model fit.
Parameters
Expand All @@ -37,6 +37,8 @@ def save_report_fm(fm, file_name, file_path=None, plt_log=False, add_settings=Tr
Whether or not to plot the frequency axis in log space.
add_settings : bool, optional, default: True
Whether to add a print out of the model settings to the end of the report.
plot_kwargs : keyword arguments
Keyword arguments to pass into the plot method.
"""

# Define grid settings based on what is to be plotted
Expand All @@ -56,7 +58,7 @@ def save_report_fm(fm, file_name, file_path=None, plt_log=False, add_settings=Tr

# Second - data plot
ax1 = plt.subplot(grid[1])
fm.plot(plt_log=plt_log, ax=ax1)
fm.plot(plt_log=plt_log, ax=ax1, **plot_kwargs)

# Third - FOOOF settings
if add_settings:
Expand Down
11 changes: 7 additions & 4 deletions fooof/objs/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def add_results(self, fooof_result):
self._check_loaded_results(fooof_result._asdict())


def report(self, freqs=None, power_spectrum=None, freq_range=None, plt_log=False):
def report(self, freqs=None, power_spectrum=None, freq_range=None, plt_log=False, **plot_kwargs):
"""Run model fit, and display a report, which includes a plot, and printed results.
Parameters
Expand All @@ -392,14 +392,16 @@ def report(self, freqs=None, power_spectrum=None, freq_range=None, plt_log=False
If not provided, fits across the entire given range.
plt_log : bool, optional, default: False
Whether or not to plot the frequency axis in log space.
**plot_kwargs
Keyword arguments to pass into the plot method.
Notes
-----
Data is optional, if data has already been added to the object.
"""

self.fit(freqs, power_spectrum, freq_range)
self.plot(plt_log=plt_log)
self.plot(plt_log=plt_log, **plot_kwargs)
self.print_results(concise=False)


Expand Down Expand Up @@ -648,9 +650,10 @@ def plot(self, plot_peaks=None, plot_aperiodic=True, plt_log=False,


@copy_doc_func_to_method(save_report_fm)
def save_report(self, file_name, file_path=None, plt_log=False, add_settings=True):
def save_report(self, file_name, file_path=None, plt_log=False,
add_settings=True, **plot_kwargs):

save_report_fm(self, file_name, file_path, plt_log, add_settings)
save_report_fm(self, file_name, file_path, plt_log, add_settings, **plot_kwargs)


@copy_doc_func_to_method(save_fm)
Expand Down
24 changes: 24 additions & 0 deletions fooof/objs/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,30 @@ def print_results(self, concise=False):
print(gen_results_fg_str(self, concise))


def save_model_report(self, index, file_name, file_path=None, plt_log=False,
add_settings=True, **plot_kwargs):
""""Save out an individual model report for a specified model fit.
Parameters
----------
index : int
Index of the model fit to save out.
file_name : str
Name to give the saved out file.
file_path : str, optional
Path to directory to save to. If None, saves to current directory.
plt_log : bool, optional, default: False
Whether or not to plot the frequency axis in log space.
add_settings : bool, optional, default: True
Whether to add a print out of the model settings to the end of the report.
plot_kwargs : keyword arguments
Keyword arguments to pass into the plot method.
"""

self.get_fooof(ind=index, regenerate=True).save_report(\
file_name, file_path, plt_log, **plot_kwargs)


def to_df(self, peak_org):
"""Convert and extract the model results as a pandas object.
Expand Down
11 changes: 10 additions & 1 deletion fooof/tests/objs/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
They serve rather as 'smoke tests', for if anything fails completely.
"""

import os

import numpy as np
from numpy.testing import assert_equal

Expand All @@ -17,7 +19,7 @@

pd = safe_import('pandas')

from fooof.tests.settings import TEST_DATA_PATH
from fooof.tests.settings import TEST_DATA_PATH, TEST_REPORTS_PATH
from fooof.tests.tutils import default_group_params, plot_test

from fooof.objs.group import *
Expand Down Expand Up @@ -212,6 +214,13 @@ def test_fg_print(tfg):
tfg.print_results()
assert True

def test_save_model_report(tfg):

file_name = 'test_group_model_report'
tfg.save_model_report(0, file_name, TEST_REPORTS_PATH)

assert os.path.exists(os.path.join(TEST_REPORTS_PATH, file_name + '.pdf'))

def test_get_results(tfg):
"""Check get results method."""

Expand Down

0 comments on commit b942280

Please sign in to comment.