Skip to content

Commit

Permalink
Merge pull request #302 from juaml/update/multi-mni-bold-warper
Browse files Browse the repository at this point in the history
[ENH]: Allow `BOLDWarper` to warp between MNI spaces
  • Loading branch information
synchon authored Feb 21, 2024
2 parents 03d5e91 + 37a68eb commit 9316bca
Show file tree
Hide file tree
Showing 3 changed files with 203 additions and 45 deletions.
1 change: 1 addition & 0 deletions docs/changes/newsfragments/302.enh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Allow :class:`.BOLDWarper` to warp BOLD data to other MNI spaces by `Synchon Mandal`_
177 changes: 133 additions & 44 deletions junifer/preprocess/bold_warper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,13 @@
Union,
)

import nibabel as nib
from templateflow import api as tflow

from ..api.decorators import register_preprocessor
from ..utils import logger, raise_error
from ..data import get_template, get_xfm
from ..pipeline import WorkDirManager
from ..utils import logger, raise_error, run_ext_cmd
from .ants.ants_apply_transforms_warper import _AntsApplyTransformsWarper
from .base import BasePreprocessor
from .fsl.apply_warper import _ApplyWarper
Expand All @@ -27,7 +32,13 @@ class BOLDWarper(BasePreprocessor):
Parameters
----------
reference : str
The data type to use as reference for warping.
The data type to use as reference for warping, can be either a data
type like "T1w" or a template space like "MNI152NLin2009cAsym".
Raises
------
ValueError
If ``reference`` is invalid.
"""

Expand All @@ -49,9 +60,15 @@ class BOLDWarper(BasePreprocessor):
def __init__(self, reference: str) -> None:
"""Initialize the class."""
self.ref = reference
super().__init__(
on="BOLD", required_data_types=["BOLD", self.ref, "Warp"]
)
# Initialize superclass based on reference
if self.ref == "T1w":
super().__init__(
on="BOLD", required_data_types=["BOLD", self.ref, "Warp"]
)
elif self.ref in tflow.templates():
super().__init__(on="BOLD", required_data_types=["BOLD"])
else:
raise_error(f"Unknown reference: {self.ref}")

def get_valid_inputs(self) -> List[str]:
"""Get valid data types for input.
Expand Down Expand Up @@ -97,7 +114,8 @@ def preprocess(
The BOLD input from the Junifer Data object.
extra_input : dict, optional
The other fields in the Junifer Data object. Must include the
``Warp`` and ``ref`` value's keys.
``Warp`` and ``ref`` value's keys if native space transformation is
needed.
Returns
-------
Expand All @@ -110,46 +128,117 @@ def preprocess(
Raises
------
ValueError
If ``extra_input`` is None.
If ``extra_input`` is None when transforming to native space
i.e., using "T1w" as reference.
RuntimeError
If warp / transformation file extension is not ".mat" or ".h5".
If warp / transformation file extension is not ".mat" or ".h5"
when transforming to native space or
if the BOLD data is in the correct space and does not require
warping.
"""
logger.debug("Warping BOLD using BOLDWarper")
# Check for extra inputs
if extra_input is None:
raise_error(
f"No extra input provided, requires `Warp` and `{self.ref}` "
"data types in particular."
)
# Check for warp file type to use correct tool
warp_file_ext = extra_input["Warp"]["path"].suffix
if warp_file_ext == ".mat":
logger.debug("Using FSL with BOLDWarper")
# Initialize ApplyWarper for computation
apply_warper = _ApplyWarper(reference=self.ref, on="BOLD")
# Replace original BOLD data with warped BOLD data
_, input = apply_warper.preprocess(
input=input,
extra_input=extra_input,
)
elif warp_file_ext == ".h5":
logger.debug("Using ANTs with BOLDWarper")
# Initialize AntsApplyTransformsWarper for computation
ants_apply_transforms_warper = _AntsApplyTransformsWarper(
reference=self.ref, on="BOLD"
)
# Replace original BOLD data with warped BOLD data
_, input = ants_apply_transforms_warper.preprocess(
input=input,
extra_input=extra_input,
)
logger.info(f"Warping BOLD to {self.ref} space using BOLDWarper")
# Transform to native space
if self.ref == "T1w":
# Check for extra inputs
if extra_input is None:
raise_error(
"No extra input provided, requires `Warp` and "
f"`{self.ref}` data types in particular."
)
# Check for warp file type to use correct tool
warp_file_ext = extra_input["Warp"]["path"].suffix
if warp_file_ext == ".mat":
logger.debug("Using FSL with BOLDWarper")
# Initialize ApplyWarper for computation
apply_warper = _ApplyWarper(reference=self.ref, on="BOLD")
# Replace original BOLD data with warped BOLD data
_, input = apply_warper.preprocess(
input=input,
extra_input=extra_input,
)
elif warp_file_ext == ".h5":
logger.debug("Using ANTs with BOLDWarper")
# Initialize AntsApplyTransformsWarper for computation
ants_apply_transforms_warper = _AntsApplyTransformsWarper(
reference=self.ref, on="BOLD"
)
# Replace original BOLD data with warped BOLD data
_, input = ants_apply_transforms_warper.preprocess(
input=input,
extra_input=extra_input,
)
else:
raise_error(
msg=(
"Unknown warp / transformation file extension: "
f"{warp_file_ext}"
),
klass=RuntimeError,
)
# Transform to template space
else:
raise_error(
msg=(
"Unknown warp / transformation file extension: "
f"{warp_file_ext}"
),
klass=RuntimeError,
)
# Check pre-requirements for space manipulation
if self.ref == input["space"]:
raise_error(
(
f"Skipped warping as the BOLD data is in {self.ref} "
"space which would mean that you can remove the "
"BOLDWarper from the preprocess step."
),
klass=RuntimeError,
)
else:
# Get xfm file
xfm_file_path = get_xfm(src=input["space"], dst=self.ref)
# Get template space image
template_space_img = get_template(
space=self.ref,
target_data=input,
extra_input=None,
)

# Create component-scoped tempdir
tempdir = WorkDirManager().get_tempdir(prefix="boldwarper")
# Create element-scoped tempdir so that warped BOLD is
# available later as nibabel stores file path reference for
# loading on computation
element_tempdir = WorkDirManager().get_element_tempdir(
prefix="boldwarper"
)

# Save template
template_space_img_path = tempdir / f"{self.ref}_T1w.nii.gz"
nib.save(template_space_img, template_space_img_path)

# Create a tempfile for warped output
warped_bold_path = (
element_tempdir
/ f"bold_warped_from_{input['space']}_to_{self.ref}.nii.gz"
)

logger.debug(
f"Using ANTs to warp BOLD "
f"from {input['space']} to {self.ref}"
)
# Set antsApplyTransforms command
apply_transforms_cmd = [
"antsApplyTransforms",
"-d 3",
"-e 3",
"-n LanczosWindowedSinc",
f"-i {input['path'].resolve()}",
f"-r {template_space_img_path.resolve()}",
f"-t {xfm_file_path.resolve()}",
f"-o {warped_bold_path.resolve()}",
]
# Call antsApplyTransforms
run_ext_cmd(
name="antsApplyTransforms", cmd=apply_transforms_cmd
)

# Modify target data
input["data"] = nib.load(warped_bold_path)
input["space"] = self.ref

return "BOLD", input
70 changes: 69 additions & 1 deletion junifer/preprocess/tests/test_bold_warper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import TYPE_CHECKING, List, Tuple

import pytest
from numpy.testing import assert_array_equal, assert_raises

from junifer.datagrabber import DataladHCP1200, DMCC13Benchmark
from junifer.datareader import DefaultDataReader
Expand Down Expand Up @@ -83,7 +84,7 @@ def test_BOLDWarper_get_output_type(input_: List[str]) -> None:
socket.gethostname() != "juseless",
reason="only for juseless",
)
def test_BOLDWarper_preprocess(
def test_BOLDWarper_preprocess_to_native(
datagrabber: "BaseDataGrabber", element: Tuple[str, ...]
) -> None:
"""Test BOLDWarper preprocess.
Expand All @@ -106,3 +107,70 @@ def test_BOLDWarper_preprocess(
)
assert data_type == "BOLD"
assert isinstance(data, dict)


@pytest.mark.parametrize(
"datagrabber, element, space",
[
[
DMCC13Benchmark(
types=["BOLD"],
sessions=["wave1bas"],
tasks=["Rest"],
phase_encodings=["AP"],
runs=["1"],
native_t1w=False,
),
("f9057kp", "wave1bas", "Rest", "AP", "1"),
"MNI152NLin2009aAsym",
],
[
DMCC13Benchmark(
types=["BOLD"],
sessions=["wave1bas"],
tasks=["Rest"],
phase_encodings=["AP"],
runs=["1"],
native_t1w=False,
),
("f9057kp", "wave1bas", "Rest", "AP", "1"),
"MNI152NLin6Asym",
],
],
)
@pytest.mark.skipif(
_check_ants() is False, reason="requires ANTs to be in PATH"
)
@pytest.mark.skipif(
socket.gethostname() != "juseless",
reason="only for juseless",
)
def test_BOLDWarper_preprocess_to_multi_mni(
datagrabber: "BaseDataGrabber", element: Tuple[str, ...], space: str
) -> None:
"""Test BOLDWarper preprocess.
Parameters
----------
datagrabber : DataGrabber-like object
The parametrized DataGrabber objects.
element : tuple of str
The parametrized elements.
space : str
The parametrized template space to transform to.
"""
with datagrabber as dg:
# Read data
element_data = DefaultDataReader().fit_transform(dg[element])
pre_xfm_data = element_data["BOLD"]["data"].get_fdata().copy()
# Preprocess data
data_type, data = BOLDWarper(reference=space).preprocess(
input=element_data["BOLD"],
extra_input=element_data,
)
assert data_type == "BOLD"
assert isinstance(data, dict)
assert data["space"] == space
with assert_raises(AssertionError):
assert_array_equal(pre_xfm_data, data["data"])

0 comments on commit 9316bca

Please sign in to comment.