Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH]: Allow BOLDWarper to warp between MNI spaces #302

Merged
merged 10 commits into from
Feb 21, 2024
1 change: 1 addition & 0 deletions docs/changes/newsfragments/302.enh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Allow :class:`.BOLDWarper` to warp BOLD data to other MNI spaces by `Synchon Mandal`_
177 changes: 133 additions & 44 deletions junifer/preprocess/bold_warper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,13 @@
Union,
)

import nibabel as nib
from templateflow import api as tflow

from ..api.decorators import register_preprocessor
from ..utils import logger, raise_error
from ..data import get_template, get_xfm
from ..pipeline import WorkDirManager
from ..utils import logger, raise_error, run_ext_cmd
from .ants.ants_apply_transforms_warper import _AntsApplyTransformsWarper
from .base import BasePreprocessor
from .fsl.apply_warper import _ApplyWarper
Expand All @@ -27,7 +32,13 @@ class BOLDWarper(BasePreprocessor):
Parameters
----------
reference : str
The data type to use as reference for warping.
The data type to use as reference for warping, can be either a data
type like "T1w" or a template space like "MNI152NLin2009cAsym".

Raises
------
ValueError
If ``reference`` is invalid.

"""

Expand All @@ -49,9 +60,15 @@ class BOLDWarper(BasePreprocessor):
def __init__(self, reference: str) -> None:
"""Initialize the class."""
self.ref = reference
super().__init__(
on="BOLD", required_data_types=["BOLD", self.ref, "Warp"]
)
# Initialize superclass based on reference
if self.ref == "T1w":
super().__init__(
on="BOLD", required_data_types=["BOLD", self.ref, "Warp"]
)
elif self.ref in tflow.templates():
super().__init__(on="BOLD", required_data_types=["BOLD"])
else:
raise_error(f"Unknown reference: {self.ref}")

def get_valid_inputs(self) -> List[str]:
"""Get valid data types for input.
Expand Down Expand Up @@ -97,7 +114,8 @@ def preprocess(
The BOLD input from the Junifer Data object.
extra_input : dict, optional
The other fields in the Junifer Data object. Must include the
``Warp`` and ``ref`` value's keys.
``Warp`` and ``ref`` value's keys if native space transformation is
needed.

Returns
-------
Expand All @@ -110,46 +128,117 @@ def preprocess(
Raises
------
ValueError
If ``extra_input`` is None.
If ``extra_input`` is None when transforming to native space
i.e., using "T1w" as reference.
RuntimeError
If warp / transformation file extension is not ".mat" or ".h5".
If warp / transformation file extension is not ".mat" or ".h5"
when transforming to native space or
if the BOLD data is in the correct space and does not require
warping.

"""
logger.debug("Warping BOLD using BOLDWarper")
# Check for extra inputs
if extra_input is None:
raise_error(
f"No extra input provided, requires `Warp` and `{self.ref}` "
"data types in particular."
)
# Check for warp file type to use correct tool
warp_file_ext = extra_input["Warp"]["path"].suffix
if warp_file_ext == ".mat":
logger.debug("Using FSL with BOLDWarper")
# Initialize ApplyWarper for computation
apply_warper = _ApplyWarper(reference=self.ref, on="BOLD")
# Replace original BOLD data with warped BOLD data
_, input = apply_warper.preprocess(
input=input,
extra_input=extra_input,
)
elif warp_file_ext == ".h5":
logger.debug("Using ANTs with BOLDWarper")
# Initialize AntsApplyTransformsWarper for computation
ants_apply_transforms_warper = _AntsApplyTransformsWarper(
reference=self.ref, on="BOLD"
)
# Replace original BOLD data with warped BOLD data
_, input = ants_apply_transforms_warper.preprocess(
input=input,
extra_input=extra_input,
)
logger.info(f"Warping BOLD to {self.ref} space using BOLDWarper")
# Transform to native space
if self.ref == "T1w":
# Check for extra inputs
if extra_input is None:
raise_error(
"No extra input provided, requires `Warp` and "
f"`{self.ref}` data types in particular."
)
# Check for warp file type to use correct tool
warp_file_ext = extra_input["Warp"]["path"].suffix
if warp_file_ext == ".mat":
logger.debug("Using FSL with BOLDWarper")
# Initialize ApplyWarper for computation
apply_warper = _ApplyWarper(reference=self.ref, on="BOLD")
# Replace original BOLD data with warped BOLD data
_, input = apply_warper.preprocess(
input=input,
extra_input=extra_input,
)
elif warp_file_ext == ".h5":
logger.debug("Using ANTs with BOLDWarper")
# Initialize AntsApplyTransformsWarper for computation
ants_apply_transforms_warper = _AntsApplyTransformsWarper(
reference=self.ref, on="BOLD"
)
# Replace original BOLD data with warped BOLD data
_, input = ants_apply_transforms_warper.preprocess(
input=input,
extra_input=extra_input,
)
else:
raise_error(
msg=(
"Unknown warp / transformation file extension: "
f"{warp_file_ext}"
),
klass=RuntimeError,
)
# Transform to template space
else:
raise_error(
msg=(
"Unknown warp / transformation file extension: "
f"{warp_file_ext}"
),
klass=RuntimeError,
)
# Check pre-requirements for space manipulation
if self.ref == input["space"]:
raise_error(
(
f"Skipped warping as the BOLD data is in {self.ref} "
"space which would mean that you can remove the "
"BOLDWarper from the preprocess step."
),
klass=RuntimeError,
)
else:
# Get xfm file
xfm_file_path = get_xfm(src=input["space"], dst=self.ref)
# Get template space image
template_space_img = get_template(
space=self.ref,
target_data=input,
extra_input=None,
)

# Create component-scoped tempdir
tempdir = WorkDirManager().get_tempdir(prefix="boldwarper")
# Create element-scoped tempdir so that warped BOLD is
# available later as nibabel stores file path reference for
# loading on computation
element_tempdir = WorkDirManager().get_element_tempdir(
prefix="boldwarper"
)

# Save template
template_space_img_path = tempdir / f"{self.ref}_T1w.nii.gz"
nib.save(template_space_img, template_space_img_path)

# Create a tempfile for warped output
warped_bold_path = (
element_tempdir
/ f"bold_warped_from_{input['space']}_to_{self.ref}.nii.gz"
)

logger.debug(
f"Using ANTs to warp BOLD "
f"from {input['space']} to {self.ref}"
)
# Set antsApplyTransforms command
apply_transforms_cmd = [
"antsApplyTransforms",
"-d 3",
"-e 3",
"-n LanczosWindowedSinc",
f"-i {input['path'].resolve()}",
f"-r {template_space_img_path.resolve()}",
f"-t {xfm_file_path.resolve()}",
f"-o {warped_bold_path.resolve()}",
]
# Call antsApplyTransforms
run_ext_cmd(
name="antsApplyTransforms", cmd=apply_transforms_cmd
)

# Modify target data
input["data"] = nib.load(warped_bold_path)
input["space"] = self.ref

return "BOLD", input
66 changes: 65 additions & 1 deletion junifer/preprocess/tests/test_bold_warper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import TYPE_CHECKING, List, Tuple

import pytest
from numpy.testing import assert_array_equal, assert_raises

from junifer.datagrabber import DataladHCP1200, DMCC13Benchmark
from junifer.datareader import DefaultDataReader
Expand Down Expand Up @@ -83,7 +84,7 @@ def test_BOLDWarper_get_output_type(input_: List[str]) -> None:
socket.gethostname() != "juseless",
reason="only for juseless",
)
def test_BOLDWarper_preprocess(
def test_BOLDWarper_preprocess_to_native(
datagrabber: "BaseDataGrabber", element: Tuple[str, ...]
) -> None:
"""Test BOLDWarper preprocess.
Expand All @@ -106,3 +107,66 @@ def test_BOLDWarper_preprocess(
)
assert data_type == "BOLD"
assert isinstance(data, dict)


@pytest.mark.parametrize(
"datagrabber, element, space",
[
[
DMCC13Benchmark(
types=["BOLD"],
sessions=["wave1bas"],
tasks=["Rest"],
phase_encodings=["AP"],
runs=["1"],
native_t1w=False,
),
("f9057kp", "wave1bas", "Rest", "AP", "1"),
"MNI152NLin2009aAsym",
],
[
DMCC13Benchmark(
types=["BOLD"],
sessions=["wave1bas"],
tasks=["Rest"],
phase_encodings=["AP"],
runs=["1"],
native_t1w=False,
),
("f9057kp", "wave1bas", "Rest", "AP", "1"),
"MNI152NLin6Asym",
],
],
)
@pytest.mark.skipif(
_check_ants() is False, reason="requires ANTs to be in PATH"
)
def test_BOLDWarper_preprocess_to_multi_mni(
datagrabber: "BaseDataGrabber", element: Tuple[str, ...], space: str
) -> None:
"""Test BOLDWarper preprocess.

Parameters
----------
datagrabber : DataGrabber-like object
The parametrized DataGrabber objects.
element : tuple of str
The parametrized elements.
space : str
The parametrized template space to transform to.

"""
with datagrabber as dg:
# Read data
element_data = DefaultDataReader().fit_transform(dg[element])
synchon marked this conversation as resolved.
Show resolved Hide resolved
pre_xfm_data = element_data["BOLD"]["data"].get_fdata().copy()
# Preprocess data
data_type, data = BOLDWarper(reference=space).preprocess(
input=element_data["BOLD"],
extra_input=element_data,
)
assert data_type == "BOLD"
assert isinstance(data, dict)
assert data["space"] == space
with assert_raises(AssertionError):
assert_array_equal(pre_xfm_data, data["data"])
Loading