Skip to content

Commit

Permalink
bug fix for multiple cons
Browse files Browse the repository at this point in the history
  • Loading branch information
tsbinns committed Feb 2, 2023
1 parent a0d1e44 commit f3c9a7d
Showing 1 changed file with 19 additions and 19 deletions.
38 changes: 19 additions & 19 deletions mne_connectivity/spectral/epochs_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,12 +447,7 @@ def _compute_separate_gc_csd_and_connectivity(self):
this_con, _, = self._compute_connectivity(con_methods, new_indices)

for method_i in range(n_gc_methods):
self.separate_gc_con[method_i].append(this_con[method_i])

self.separate_gc_con = [
np.squeeze(np.array(this_con), 1) for this_con in
self.separate_gc_con
]
self.separate_gc_con[method_i].extend(this_con[method_i])

# finds the methods still needing to be computed
self.remaining_method_types = [
Expand Down Expand Up @@ -581,7 +576,7 @@ def _prepare_csd_computation(self, data, con_method_types, indices):
) = _prepare_connectivity(
epoch_block=self.epoch_blocks[0], times_in=self.times_in,
tmin=self.tmin, tmax=self.tmax, fmin=fmin, fmax=fmax,
sfreq=self.sfreq, indices=self.indices, mode=self.mode,
sfreq=self.sfreq, indices=indices, mode=self.mode,
fskip=self.fskip, n_bands=self.n_bands, cwt_freqs=self.cwt_freqs,
faverage=self.faverage
)
Expand All @@ -600,7 +595,7 @@ def _prepare_csd_computation(self, data, con_method_types, indices):

self._sort_con_indices(indices)

con_methods = self._instantiate_con_estimators(con_method_types)
con_methods = self._instantiate_con_estimators(con_method_types, indices)

self.csd_call_params = dict(
sig_idx=self.sig_idx, tmin_idx=tmin_idx, tmax_idx=tmax_idx,
Expand Down Expand Up @@ -685,7 +680,7 @@ def _get_unique_signals(self, indices):
"""Find the unique signals in a set of indices."""
return np.unique(sum(sum(indices, []), []))

def _instantiate_con_estimators(self, con_method_types):
def _instantiate_con_estimators(self, con_method_types, indices):
"""Create instances of the connectivity estimators and log the methods
being computed."""
con_methods = []
Expand All @@ -694,15 +689,15 @@ def _instantiate_con_estimators(self, con_method_types):
# if a GC method, provide n_lags argument
con_methods.append(
mtype(
self.use_n_signals, self.n_cons, self.n_freqs,
self.use_n_signals, len(indices[0]), self.n_freqs,
self.n_times_spectrum, self.gc_n_lags, self.n_jobs
)
)
else:
# if a coherence method, do not provide n_lags argument
con_methods.append(
mtype(
self.use_n_signals, self.n_cons, self.n_freqs,
self.use_n_signals, len(indices[0]), self.n_freqs,
self.n_times_spectrum, self.n_jobs
)
)
Expand Down Expand Up @@ -732,7 +727,7 @@ def _compute_connectivity(self, con_methods, indices):
)

self._check_correct_results_dimensions(
con_methods, method_con, method_topo
con_methods, method_con, method_topo, indices
)

if self.faverage:
Expand Down Expand Up @@ -919,12 +914,13 @@ def _get_coh_form_to_compute(self, con_methods):
self.compute_coh_form[form_name] = form_info
break # only one form is possible at any one instance

def _check_correct_results_dimensions(self, con_methods, con, topo):
def _check_correct_results_dimensions(self, con_methods, con, topo, indices):
"""Checks that the results of the connectivity computations have the
appropriate dimensions."""
n_cons = len(indices[0])
n_times = con_methods[0].n_times

assert (con.shape[0] == self.n_cons), (
assert (con.shape[0] == n_cons), (
'The first dimension of connectivity scores does not match the '
'number of connections. Please contact the mne-connectivity '
'developers.'
Expand All @@ -944,13 +940,13 @@ def _check_correct_results_dimensions(self, con_methods, con, topo):
)

if topo is not None:
assert (topo[0].shape[0] == self.n_cons and topo[1].shape[0]), (
assert (topo[0].shape[0] == n_cons and topo[1].shape[0]), (
'The first dimension of topographies does not match the number '
'of connections. Please contact the mne-connectivity '
'developers.'
)

for con_i in range(self.n_cons):
for con_i in range(n_cons):
assert (
topo[0][con_i].shape[1] == self.n_freqs and
topo[1][con_i].shape[1] == self.n_freqs
Expand All @@ -972,17 +968,18 @@ def _check_correct_results_dimensions(self, con_methods, con, topo):

def _compute_faverage(self, con, topo):
"""Computes the average connectivity across the frequency bands."""
con_shape = (self.n_cons, self.n_bands) + con.shape[2:]
n_cons = con.shape[0]
con_shape = (n_cons, self.n_bands) + con.shape[2:]
con_bands = np.empty(con_shape, dtype=con.dtype)
for band_idx in range(self.n_bands):
con_bands[:, band_idx] = np.mean(
con[:, self.freq_idx_bands[band_idx]], axis=1
)

if topo is not None:
topo_bands = np.empty((2, self.n_cons), dtype=topo.dtype)
topo_bands = np.empty((2, n_cons), dtype=topo.dtype)
for group_i in range(2):
for con_i in range(self.n_cons):
for con_i in range(n_cons):
band_topo = [
np.mean(topo[group_i][con_i][:, freq_idx_band], axis=1)
for freq_idx_band in self.freq_idx_bands
Expand Down Expand Up @@ -1020,6 +1017,9 @@ def _collate_connectivity_results(self):
# else you only have the remaining (non-GC with SVD/discontinuous
# frequency) results, already in the order in which they were called

self.con = np.array(self.con)
self.topo = np.array(self.topo, dtype=object)

def store_connectivity_results(self):
"""Stores multivariate connectivity results in mne-connectivity
objects."""
Expand Down

0 comments on commit f3c9a7d

Please sign in to comment.