Skip to content

Commit

Permalink
switched to array indices & added inline comments
Browse files Browse the repository at this point in the history
  • Loading branch information
tsbinns committed Nov 3, 2023
1 parent ff3b9e9 commit f5c1c3e
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 37 deletions.
8 changes: 3 additions & 5 deletions mne_connectivity/spectral/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def _prepare_connectivity(epoch_block, times_in, tmin, tmax,
'causality, as all-to-all connectivity is not supported')
else:
logger.info('using all indices for multivariate connectivity')
# indices expected to be a masked array, even if not ragged
indices_use = (np.arange(n_signals, dtype=int)[np.newaxis, :],
np.arange(n_signals, dtype=int)[np.newaxis, :])
indices_use = np.ma.masked_array(indices_use,
Expand All @@ -115,11 +116,8 @@ def _prepare_connectivity(epoch_block, times_in, tmin, tmax,
indices_use = np.tril_indices(n_signals, -1)
else:
if multivariate_con:
# mask indices
# pad ragged indices and mask the invalid entries
indices_use = _check_multivariate_indices(indices, n_signals)
indices_use = np.ma.concatenate([inds[np.newaxis] for inds in
indices_use])
np.ma.set_fill_value(indices_use, -1) # else 99999 after concat.
if any(this_method in _gc_methods for this_method in method):
for seed, target in zip(indices_use[0], indices_use[1]):
intersection = np.intersect1d(seed.compressed(),
Expand Down Expand Up @@ -1908,7 +1906,7 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None,

# make sure padded indices are stored in the connectivity object
if multivariate_con and indices is not None:
# create a copy
# create a copy so that `indices_use` can be modified
indices = (indices_use[0].copy(), indices_use[1].copy())

# get the window function, wavelets, etc for different modes
Expand Down
8 changes: 3 additions & 5 deletions mne_connectivity/spectral/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False,
'indices must be specified when computing Granger '
'causality, as all-to-all connectivity is not supported')
logger.info('using all indices for multivariate connectivity')
# indices expected to be a masked array, even if not ragged
indices_use = (np.arange(n_signals, dtype=int)[np.newaxis, :],
np.arange(n_signals, dtype=int)[np.newaxis, :])
indices_use = np.ma.masked_array(indices_use,
Expand All @@ -416,11 +417,8 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False,
indices_use = np.tril_indices(n_signals, k=-1)
else:
if multivariate_con:
# mask indices
# pad ragged indices and mask the invalid entries
indices_use = _check_multivariate_indices(indices, n_signals)
indices_use = np.ma.concatenate([inds[np.newaxis] for inds in
indices_use])
np.ma.set_fill_value(indices_use, -1) # else 99999 after concat.
if any(this_method in _gc_methods for this_method in method):
for seed, target in zip(indices_use[0], indices_use[1]):
intersection = np.intersect1d(seed.compressed(),
Expand All @@ -430,7 +428,7 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False,
'seed and target indices must not intersect when '
'computing Granger causality')
# make sure padded indices are stored in the connectivity object
# create a copy
# create a copy so that `indices_use` can be modified
indices = (indices_use[0].copy(), indices_use[1].copy())
else:
indices_use = check_indices(indices)
Expand Down
32 changes: 16 additions & 16 deletions mne_connectivity/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,34 +94,34 @@ def test_check_multivariate_indices():
seeds = [[0, 1], [0, 1]]
targets = [[2, 3], [3, 4]]
indices = _check_multivariate_indices((seeds, targets), n_signals)
assert all(np.ma.isMA(inds) for inds in indices)
assert all(inds.fill_value == mask_value for inds in indices)
assert np.all(np.array(indices) == (np.array([[0, 1], [0, 1]]),
np.array([[2, 3], [3, 4]])))
assert np.ma.isMA(indices)
assert indices.fill_value == mask_value
assert np.all(indices == np.array([[[0, 1], [0, 1]],
[[2, 3], [3, 4]]]))
# non-ragged indices with negative values
seeds = [[0, 1], [0, 1]]
targets = [[2, 3], [3, -1]]
indices = _check_multivariate_indices((seeds, targets), n_signals)
assert all(np.ma.isMA(inds) for inds in indices)
assert all(inds.fill_value == mask_value for inds in indices)
assert np.all(np.array(indices) == (np.array([[0, 1], [0, 1]]),
np.array([[2, 3], [3, 4]])))
assert np.ma.isMA(indices)
assert indices.fill_value == mask_value
assert np.all(indices == np.array([[[0, 1], [0, 1]],
[[2, 3], [3, 4]]]))
# ragged indices
seeds = [[0, 1], [0, 1]]
targets = [[2, 3, 4], [4]]
indices = _check_multivariate_indices((seeds, targets), n_signals)
assert all(np.ma.isMA(inds) for inds in indices)
assert all(inds.fill_value == mask_value for inds in indices)
assert np.all(np.array(indices) == (np.array([[0, 1, -1], [0, 1, -1]]),
np.array([[2, 3, 4], [4, -1, -1]])))
assert np.ma.isMA(indices)
assert indices.fill_value == mask_value
assert np.all(indices == np.array([[[0, 1, -1], [0, 1, -1]],
[[2, 3, 4], [4, -1, -1]]]))
# ragged indices with negative values
seeds = [[0, 1], [0, 1]]
targets = [[2, 3, 4], [-1]]
indices = _check_multivariate_indices((seeds, targets), n_signals)
assert all(np.ma.isMA(inds) for inds in indices)
assert all(inds.fill_value == mask_value for inds in indices)
assert np.all(np.array(indices) == (np.array([[0, 1, -1], [0, 1, -1]]),
np.array([[2, 3, 4], [4, -1, -1]])))
assert np.ma.isMA(indices)
assert indices.fill_value == mask_value
assert np.all(indices == np.array([[[0, 1, -1], [0, 1, -1]],
[[2, 3, 4], [4, -1, -1]]]))

# test error catching
with pytest.raises(ValueError,
Expand Down
18 changes: 7 additions & 11 deletions mne_connectivity/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def _check_multivariate_indices(indices, n_chans):
Returns
-------
indices : tuple of array of array of int, shape of (2, n_cons, max_n_chans)
The indices as a masked array.
indices : array of array of int, shape of (2, n_cons, max_n_chans)
The padded indices as a masked array.
Notes
-----
Expand Down Expand Up @@ -176,17 +176,13 @@ def _check_multivariate_indices(indices, n_chans):
indices[group_idx][con_idx][chan_idx] = chan % n_chans

# pad indices to avoid ragged arrays
padded_indices = (np.full((n_cons, max_n_chans), invalid, dtype=np.int32),
np.full((n_cons, max_n_chans), invalid, dtype=np.int32))
con_i = 0
for seed, target in zip(indices[0], indices[1]):
padded_indices[0][con_i, :len(seed)] = seed
padded_indices[1][con_i, :len(target)] = target
con_i += 1
padded_indices = np.full((2, n_cons, max_n_chans), invalid, dtype=np.int32)
for con_i, (seed, target) in enumerate(zip(indices[0], indices[1])):
padded_indices[0, con_i, :len(seed)] = seed
padded_indices[1, con_i, :len(target)] = target

# mask invalid indices
masked_indices = (np.ma.masked_values(padded_indices[0], invalid),
np.ma.masked_values(padded_indices[1], invalid))
masked_indices = np.ma.masked_values(padded_indices, invalid)

return masked_indices

Expand Down

0 comments on commit f5c1c3e

Please sign in to comment.