Skip to content

Commit

Permalink
switch to masked arrays for indices
Browse files Browse the repository at this point in the history
  • Loading branch information
tsbinns committed Oct 26, 2023
1 parent 6fe682f commit 921cf9d
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 133 deletions.
39 changes: 18 additions & 21 deletions examples/handling_ragged_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,28 +62,25 @@
# ragged_indices = (np.array([[0, 1 ], [0, 1, 2, 3]], dtype='object'),
# np.array([[2, 3, 4], [4 ]], dtype='object'))
#
# **N.B. Note that when forming ragged arrays in NumPy, dtype='object' must be
# specified.**
# **N.B. Note that since NumPy v1.19.0, dtype='object' must be specified when
# forming ragged arrays.**
#
# Just as for bivariate connectivity, the length of ``indices[0]`` and
# ``indices[1]`` is equal (i.e. the number of connections), however information
# about the multiple channel indices for each connection is stored in a nested
# array. Importantly, these indices are ragged, as the first connection will be
# computed between 2 seed and 3 target channels, and the second connection
# between 4 seed and 1 target channel. The connectivity functions will
# recognise the indices as being ragged, and pad them accordingly to make them
# easier to work with and compatible with the h5netcdf saving engine. The known
# value used to pad the arrays is ``-1``, an invalid channel index. The above
# indices would be padded to::
# recognise the indices as being ragged, and pad them to a 'full' array by
# adding placeholder values which are masked accordingly. This makes the
# indices easier to work with, and also compatible with the engine used to save
# connectivity objects. For example, the above indices would become::
#
# padded_indices = (np.array([[0, 1, -1, -1], [0, 1, 2, 3]]),
# np.array([[2, 3, 4, -1], [4, -1, -1, -1]]))
# padded_indices = (np.array([[0, 1, --, --], [0, 1, 2, 3]]),
# np.array([[2, 3, 4, --], [4, --, --, --]]))
#
# These indices are what is stored in the connectivity object, and is also the
# format of indices returned from the helper functions
# :func:`~mne_connectivity.check_multivariate_indices` and
# :func:`~mne_connectivity.seed_target_multivariate_indices`. It is also
# possible to pass the padded indices to the connectivity functions directly.
# where ``--`` are masked entries. These indices are what is stored in the
# returned connectivity objects.
#
# For the connectivity results themselves, the methods available in
# MNE-Connectivity combine information across the different channels into a
Expand Down Expand Up @@ -118,11 +115,11 @@
max_n_chans = max(
[len(inds) for inds in ([*ragged_indices[0], *ragged_indices[1]])])

# show that the padded indices entries are all -1
assert np.count_nonzero(padded_indices[0][0] == -1) == 2 # 2 padded channels
assert np.count_nonzero(padded_indices[1][0] == -1) == 1 # 1 padded channels
assert np.count_nonzero(padded_indices[0][1] == -1) == 0 # 0 padded channels
assert np.count_nonzero(padded_indices[1][1] == -1) == 3 # 3 padded channels
# show that the padded indices entries are masked
assert np.sum(padded_indices[0][0].mask) == 2 # 2 padded channels
assert np.sum(padded_indices[1][0].mask) == 1 # 1 padded channels
assert np.sum(padded_indices[0][1].mask) == 0 # 0 padded channels
assert np.sum(padded_indices[1][1].mask) == 3 # 3 padded channels

# patterns have shape [seeds/targets x cons x max channels x freqs (x times)]
assert patterns.shape == (2, n_cons, max_n_chans, n_freqs)
Expand All @@ -137,11 +134,11 @@
seed_patterns_con1 = patterns[0, 0, :len(ragged_indices[0][0])]
target_patterns_con1 = patterns[1, 0, :len(ragged_indices[1][0])]

# extract patterns for second connection using the padded indices (pad = -1)
# extract patterns for second connection using the padded, masked indices
seed_patterns_con2 = (
patterns[0, 1, :np.count_nonzero(padded_indices[0][1] != -1)])
patterns[0, 1, :padded_indices[0][1].count()])
target_patterns_con2 = (
patterns[1, 1, :np.count_nonzero(padded_indices[1][1] != -1)])
patterns[1, 1, :padded_indices[1][1].count()])

# show that shapes of patterns are correct
assert seed_patterns_con1.shape == (2, n_freqs) # channels (0, 1)
Expand Down
6 changes: 4 additions & 2 deletions mne_connectivity/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,11 @@ def _xarray_to_conn(array, cls_func):
event_id = dict(zip(event_id_keys, event_id_vals))
array.attrs['event_id'] = event_id

# convert indices numpy arrays to a tuple of arrays
# convert indices numpy arrays to a tuple of masked arrays
# (only multivariate connectivity indices saved as arrays)
if isinstance(array.attrs['indices'], np.ndarray):
array.attrs['indices'] = tuple(array.attrs['indices'])
array.attrs['indices'] = tuple(
np.ma.masked_values(array.attrs['indices'], -1))

# create the connectivity class
conn = cls_func(
Expand Down
53 changes: 23 additions & 30 deletions mne_connectivity/spectral/epochs_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,18 @@ def _check_indices(indices, method, n_signals):
logger.info('using all indices for multivariate connectivity')
indices_use = (np.arange(n_signals, dtype=int)[np.newaxis, :],
np.arange(n_signals, dtype=int)[np.newaxis, :])
indices_use = np.ma.masked_array(indices_use,
mask=False, fill_value=-1)
else:
indices_use = check_multivariate_indices(indices) # pad with -1
indices_use = check_multivariate_indices(indices) # mask indices
indices_use = np.ma.concatenate([inds[np.newaxis] for inds in
indices_use])
np.ma.set_fill_value(indices_use, -1) # else 99999 after concat.
if any(this_method in _gc_methods for this_method in method):
for seed, target in zip(indices[0], indices[1]):
intersection = np.intersect1d(seed, target)
if np.any(intersection != -1): # ignore padded entries
for seed, target in zip(indices_use[0], indices_use[1]):
intersection = np.intersect1d(seed.compressed(),
target.compressed())
if intersection.size > 0:
raise ValueError(
'seed and target indices must not intersect when '
'computing Granger causality')
Expand All @@ -59,16 +65,10 @@ def _check_rank_input(rank, data, indices):
else:
data_arr = data

# XXX: Unpadding of arrays after already padding them is perhaps not so
# efficient. However, we need to remove the padded values to
# ensure only the correct channels are indexed, and having two
# versions of indices is a bit messy currently. A candidate for
# refactoring to simplify code.

for group_i in range(2): # seeds and targets
for con_i, con_idcs in enumerate(indices[group_i]):
con_idcs = con_idcs[con_idcs != -1] # -1 is padded value
s = np.linalg.svd(data_arr[:, con_idcs], compute_uv=False)
s = np.linalg.svd(data_arr[:, con_idcs.compressed()],
compute_uv=False)
rank[group_i][con_i] = np.min(
[np.count_nonzero(epoch >= epoch[0] * sv_tol)
for epoch in s])
Expand Down Expand Up @@ -197,8 +197,8 @@ def compute_con(self, indices, ranks, n_epochs=1):
indices[0], indices[1], ranks[0], ranks[1]):
self._log_connection_number(con_i)

seed_idcs = seed_idcs[seed_idcs != -1]
target_idcs = target_idcs[target_idcs != -1]
seed_idcs = seed_idcs.compressed()
target_idcs = target_idcs.compressed()
con_idcs = [*seed_idcs, *target_idcs]

C = csd[np.ix_(times, freqs, con_idcs, con_idcs)]
Expand Down Expand Up @@ -432,8 +432,8 @@ def compute_con(self, indices, ranks, n_epochs=1):
indices[0], indices[1], ranks[0], ranks[1]):
self._log_connection_number(con_i)

seed_idcs = seed_idcs[seed_idcs != -1]
target_idcs = target_idcs[target_idcs != -1]
seed_idcs = seed_idcs.compressed()
target_idcs = target_idcs.compressed()
con_idcs = [*seed_idcs, *target_idcs]

C = csd[np.ix_(times, freqs, con_idcs, con_idcs)]
Expand Down Expand Up @@ -1009,7 +1009,8 @@ def spectral_connectivity_epochs_multivariate(

# make sure padded indices are stored in the connectivity object
if indices is not None:
indices = tuple(np.array(indices_use)) # create a copy
# create a copy
indices = (indices_use[0].copy(), indices_use[1].copy())

# get the window function, wavelets, etc for different modes
(spectral_params, mt_adaptive, n_times_spectrum,
Expand All @@ -1020,20 +1021,12 @@ def spectral_connectivity_epochs_multivariate(
cwt_freqs=cwt_freqs, freqs=freqs, freq_mask=freq_mask)

# unique signals for which we actually need to compute CSD
sig_idx = np.unique(np.concatenate(np.concatenate(
indices_use)))
sig_idx = sig_idx[sig_idx != -1]
sig_idx = np.unique(indices_use.compressed())
remapping = {ch_i: sig_i for sig_i, ch_i in enumerate(sig_idx)}
remapping[-1] = -1
remapped_inds = (indices_use[0].copy(), indices_use[1].copy())
con_i = 0
for seed, target in zip(indices_use[0], indices_use[1]):
remapped_inds[0][con_i] = np.array([
remapping[idx] for idx in seed])
remapped_inds[1][con_i] = np.array([
remapping[idx] for idx in target])
con_i += 1
remapped_sig = [remapping[idx] for idx in sig_idx]
remapped_inds = indices_use.copy()
for idx in sig_idx:
remapped_inds[indices_use == idx] = remapping[idx]
remapped_sig = np.unique(remapped_inds.compressed())
n_signals_use = len(sig_idx)

# map indices to unique indices
Expand Down
6 changes: 5 additions & 1 deletion mne_connectivity/spectral/tests/test_spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -1392,7 +1392,11 @@ def test_multivar_spectral_connectivity_indices_roundtrip_io(
assert isinstance(con.indices, tuple) and isinstance(
read_con.indices, tuple
)
# check indices are masked
assert all([np.ma.isMA(inds) for inds in con.indices] and
[np.ma.isMA(inds) for inds in read_con.indices])
# check indices have same values
assert np.all(np.array(con.indices) == np.array(read_con.indices))
assert np.all([con_inds == read_inds for con_inds, read_inds in
zip(con.indices, read_con.indices)])
else:
assert con.indices is None and read_con.indices is None
42 changes: 23 additions & 19 deletions mne_connectivity/spectral/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,46 +408,50 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False,
'indices must be specified when computing Granger '
'causality, as all-to-all connectivity is not supported')
logger.info('using all indices for multivariate connectivity')
indices_use = (np.array([np.arange(n_signals, dtype=np.int32)]),
np.array([np.arange(n_signals, dtype=np.int32)]))
indices_use = (np.arange(n_signals, dtype=int)[np.newaxis, :],
np.arange(n_signals, dtype=int)[np.newaxis, :])
indices_use = np.ma.masked_array(indices_use,
mask=False, fill_value=-1)
else:
logger.info('only using indices for lower-triangular matrix')
indices_use = np.tril_indices(n_signals, k=-1)
else:
if multivariate_con:
indices_use = check_multivariate_indices(indices) # pad with -1
indices_use = check_multivariate_indices(indices) # mask indices
indices_use = np.ma.concatenate([inds[np.newaxis] for inds in
indices_use])
np.ma.set_fill_value(indices_use, -1) # else 99999 after concat.
if any(this_method in _gc_methods for this_method in method):
for seed, target in zip(indices[0], indices[1]):
intersection = np.intersect1d(seed, target)
if np.any(intersection != -1): # ignore padded entries
for seed, target in zip(indices_use[0], indices_use[1]):
intersection = np.intersect1d(seed.compressed(),
target.compressed())
if intersection.size > 0:
raise ValueError(
'seed and target indices must not intersect when '
'computing Granger causality')
# make sure padded indices are stored in the connectivity object
indices = tuple(np.array(indices_use)) # create a copy
# create a copy
indices = (indices_use[0].copy(), indices_use[1].copy())
else:
indices_use = check_indices(indices)
# create copies of indices_use for independent manipulation
source_idx = np.array(indices_use[0])
target_idx = np.array(indices_use[1])
n_cons = len(source_idx)
n_cons = len(indices_use[0])

# unique signals for which we actually need to compute the CSD of
if multivariate_con:
signals_use = np.unique(np.concatenate(np.concatenate(indices_use)))
signals_use = signals_use[signals_use != -1]
signals_use = np.unique(indices_use.compressed())
remapping = {ch_i: sig_i for sig_i, ch_i in enumerate(signals_use)}
remapping[-1] = -1
remapped_inds = indices_use.copy()
# multivariate functions expect seed/target remapping
con_i = 0
for seed, target in zip(indices_use[0], indices_use[1]):
source_idx[con_i] = np.array([remapping[idx] for idx in seed])
target_idx[con_i] = np.array([remapping[idx] for idx in target])
con_i += 1
for idx in signals_use:
remapped_inds[indices_use == idx] = remapping[idx]
source_idx = remapped_inds[0]
target_idx = remapped_inds[1]
max_n_channels = len(indices_use[0][0])
else:
# no indices remapping required for bivariate functions
signals_use = np.unique(np.r_[indices_use[0], indices_use[1]])
source_idx = indices_use[0].copy()
target_idx = indices_use[1].copy()
max_n_channels = len(indices_use[0])

# check rank input and compute data ranks if necessary
Expand Down
16 changes: 12 additions & 4 deletions mne_connectivity/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,22 @@ def test_seed_target_indices():
seeds = [[0, 1]]
targets = [[2, 3], [3, 4]]
indices = seed_target_multivariate_indices(seeds, targets)
assert np.all(np.array(indices) == (np.array([[0, 1], [0, 1]]),
np.array([[2, 3], [3, 4]])))
match_indices = (np.array([[0, 1], [0, 1]], dtype=object),
np.array([[2, 3], [3, 4]], dtype=object))
for type_i in range(2):
for con_i in range(len(indices[0])):
assert np.all(indices[type_i][con_i] ==
match_indices[type_i][con_i])
# ragged indices
seeds = [[0, 1]]
targets = [[2, 3, 4], [4]]
indices = seed_target_multivariate_indices(seeds, targets)
assert np.all(np.array(indices) == (np.array([[0, 1, -1], [0, 1, -1]]),
np.array([[2, 3, 4], [4, -1, -1]])))
match_indices = (np.array([[0, 1], [0, 1]], dtype=object),
np.array([[2, 3, 4], [4]], dtype=object))
for type_i in range(2):
for con_i in range(len(indices[0])):
assert np.all(indices[type_i][con_i] ==
match_indices[type_i][con_i])
# test error catching
# non-array-like seeds/targets
with pytest.raises(TypeError,
Expand Down
Loading

0 comments on commit 921cf9d

Please sign in to comment.