Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH]: Allow BOLDWarper to warp between MNI spaces #302

Merged
merged 10 commits into from
Feb 21, 2024
1 change: 1 addition & 0 deletions docs/changes/newsfragments/302.enh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Allow :class:`.BOLDWarper` to warp BOLD data to other MNI spaces by `Synchon Mandal`_
179 changes: 135 additions & 44 deletions junifer/preprocess/bold_warper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@
Union,
)

import nibabel as nib

from ..api.decorators import register_preprocessor
from ..utils import logger, raise_error
from ..data import get_template, get_xfm
from ..pipeline import WorkDirManager
from ..utils import logger, raise_error, run_ext_cmd
from .ants.ants_apply_transforms_warper import _AntsApplyTransformsWarper
from .base import BasePreprocessor
from .fsl.apply_warper import _ApplyWarper
Expand All @@ -27,7 +31,13 @@ class BOLDWarper(BasePreprocessor):
Parameters
----------
reference : str
The data type to use as reference for warping.
The data type to use as reference for warping, can be either a data
type like "T1w" or a template space like "MNI152NLin2009cAsym".

Raises
------
ValueError
If ``reference`` is invalid.

"""

Expand All @@ -49,9 +59,25 @@ class BOLDWarper(BasePreprocessor):
def __init__(self, reference: str) -> None:
"""Initialize the class."""
self.ref = reference
super().__init__(
on="BOLD", required_data_types=["BOLD", self.ref, "Warp"]
)
# Initialize superclass based on reference
if self.ref == "T1w":
super().__init__(
on="BOLD", required_data_types=["BOLD", self.ref, "Warp"]
)
elif self.ref in [
"MNI152Lin",
synchon marked this conversation as resolved.
Show resolved Hide resolved
"MNI152NLin6Sym",
"MNI152NLin6Asym",
"MNI152NLin2009aSym",
"MNI152NLin2009aAsym",
"MNI152NLin2009bSym",
"MNI152NLin2009bAsym",
"MNI152NLin2009cAsym",
"MNI152NLin2009cAsym",
]:
super().__init__(on="BOLD", required_data_types=["BOLD"])
else:
raise_error(f"Unknown reference: {self.ref}")

def get_valid_inputs(self) -> List[str]:
"""Get valid data types for input.
Expand Down Expand Up @@ -97,7 +123,8 @@ def preprocess(
The BOLD input from the Junifer Data object.
extra_input : dict, optional
The other fields in the Junifer Data object. Must include the
``Warp`` and ``ref`` value's keys.
``Warp`` and ``ref`` value's keys if native space transformation is
needed.

Returns
-------
Expand All @@ -110,46 +137,110 @@ def preprocess(
Raises
------
ValueError
If ``extra_input`` is None.
If ``extra_input`` is None when transforming to native space
i.e., using "T1w" as reference.
RuntimeError
If warp / transformation file extension is not ".mat" or ".h5".
If warp / transformation file extension is not ".mat" or ".h5"
when transforming to native space.

"""
logger.debug("Warping BOLD using BOLDWarper")
# Check for extra inputs
if extra_input is None:
raise_error(
f"No extra input provided, requires `Warp` and `{self.ref}` "
"data types in particular."
)
# Check for warp file type to use correct tool
warp_file_ext = extra_input["Warp"]["path"].suffix
if warp_file_ext == ".mat":
logger.debug("Using FSL with BOLDWarper")
# Initialize ApplyWarper for computation
apply_warper = _ApplyWarper(reference=self.ref, on="BOLD")
# Replace original BOLD data with warped BOLD data
_, input = apply_warper.preprocess(
input=input,
extra_input=extra_input,
)
elif warp_file_ext == ".h5":
logger.debug("Using ANTs with BOLDWarper")
# Initialize AntsApplyTransformsWarper for computation
ants_apply_transforms_warper = _AntsApplyTransformsWarper(
reference=self.ref, on="BOLD"
)
# Replace original BOLD data with warped BOLD data
_, input = ants_apply_transforms_warper.preprocess(
input=input,
extra_input=extra_input,
)
logger.info(f"Warping BOLD to {self.ref} space using BOLDWarper")
# Transform to native space
if self.ref == "T1w":
# Check for extra inputs
if extra_input is None:
raise_error(
"No extra input provided, requires `Warp` and "
f"`{self.ref}` data types in particular."
)
# Check for warp file type to use correct tool
warp_file_ext = extra_input["Warp"]["path"].suffix
if warp_file_ext == ".mat":
logger.debug("Using FSL with BOLDWarper")
# Initialize ApplyWarper for computation
apply_warper = _ApplyWarper(reference=self.ref, on="BOLD")
# Replace original BOLD data with warped BOLD data
_, input = apply_warper.preprocess(
input=input,
extra_input=extra_input,
)
elif warp_file_ext == ".h5":
logger.debug("Using ANTs with BOLDWarper")
# Initialize AntsApplyTransformsWarper for computation
ants_apply_transforms_warper = _AntsApplyTransformsWarper(
reference=self.ref, on="BOLD"
)
# Replace original BOLD data with warped BOLD data
_, input = ants_apply_transforms_warper.preprocess(
input=input,
extra_input=extra_input,
)
else:
raise_error(
msg=(
"Unknown warp / transformation file extension: "
f"{warp_file_ext}"
),
klass=RuntimeError,
)
# Transform to template space
else:
raise_error(
msg=(
"Unknown warp / transformation file extension: "
f"{warp_file_ext}"
),
klass=RuntimeError,
)
# Check pre-requirements for space manipulation
if self.ref == input["space"]:
logger.info(
synchon marked this conversation as resolved.
Show resolved Hide resolved
f"Skipped warping as the BOLD data is in {self.ref} space"
)
else:
# Get xfm file
xfm_file_path = get_xfm(src=input["space"], dst=self.ref)
# Get template space image
template_space_img = get_template(
space=self.ref,
target_data=input,
extra_input=None,
)

# Create component-scoped tempdir
tempdir = WorkDirManager().get_tempdir(prefix="boldwarper")
# Create element-scoped tempdir so that warped BOLD is
# available later as nibabel stores file path reference for
# loading on computation
element_tempdir = WorkDirManager().get_element_tempdir(
prefix="boldwarper"
)

# Save template
template_space_img_path = tempdir / f"{self.ref}_T1w.nii.gz"
nib.save(template_space_img, template_space_img_path)

# Create a tempfile for warped output
warped_bold_path = (
element_tempdir
/ f"bold_warped_from_{input['space']}_to_{self.ref}.nii.gz"
)

logger.debug(
f"Using ANTs to warp BOLD "
f"from {input['space']} to {self.ref}"
)
# Set antsApplyTransforms command
apply_transforms_cmd = [
"antsApplyTransforms",
"-d 3",
"-e 3",
"-n LanczosWindowedSinc",
f"-i {input['path'].resolve()}",
f"-r {template_space_img_path.resolve()}",
f"-t {xfm_file_path.resolve()}",
f"-o {warped_bold_path.resolve()}",
]
# Call antsApplyTransforms
run_ext_cmd(
name="antsApplyTransforms", cmd=apply_transforms_cmd
)

# Modify target data
input["data"] = nib.load(warped_bold_path)
input["space"] = self.ref

return "BOLD", input
74 changes: 73 additions & 1 deletion junifer/preprocess/tests/test_bold_warper.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_BOLDWarper_get_output_type(input_: List[str]) -> None:
socket.gethostname() != "juseless",
reason="only for juseless",
)
def test_BOLDWarper_preprocess(
def test_BOLDWarper_preprocess_to_native(
datagrabber: "BaseDataGrabber", element: Tuple[str, ...]
) -> None:
"""Test BOLDWarper preprocess.
Expand All @@ -106,3 +106,75 @@ def test_BOLDWarper_preprocess(
)
assert data_type == "BOLD"
assert isinstance(data, dict)


@pytest.mark.parametrize(
"datagrabber, element, space",
[
[
DMCC13Benchmark(
types=["BOLD"],
sessions=["wave1bas"],
tasks=["Rest"],
phase_encodings=["AP"],
runs=["1"],
native_t1w=False,
),
("f9057kp", "wave1bas", "Rest", "AP", "1"),
"MNI152NLin2009aAsym",
],
[
DMCC13Benchmark(
types=["BOLD"],
sessions=["wave1bas"],
tasks=["Rest"],
phase_encodings=["AP"],
runs=["1"],
native_t1w=False,
),
("f9057kp", "wave1bas", "Rest", "AP", "1"),
"MNI152NLin6Asym",
],
[
DMCC13Benchmark(
types=["BOLD"],
sessions=["wave1bas"],
tasks=["Rest"],
phase_encodings=["AP"],
runs=["1"],
native_t1w=False,
),
("f9057kp", "wave1bas", "Rest", "AP", "1"),
"MNI152NLin2009cAsym",
],
],
)
@pytest.mark.skipif(
_check_ants() is False, reason="requires ANTs to be in PATH"
)
def test_BOLDWarper_preprocess_to_multi_mni(
datagrabber: "BaseDataGrabber", element: Tuple[str, ...], space: str
) -> None:
"""Test BOLDWarper preprocess.

Parameters
----------
datagrabber : DataGrabber-like object
The parametrized DataGrabber objects.
element : tuple of str
The parametrized elements.
space : str
The parametrized template space to transform to.

"""
with datagrabber as dg:
# Read data
element_data = DefaultDataReader().fit_transform(dg[element])
synchon marked this conversation as resolved.
Show resolved Hide resolved
# Preprocess data
data_type, data = BOLDWarper(reference=space).preprocess(
input=element_data["BOLD"],
extra_input=element_data,
)
assert data_type == "BOLD"
assert isinstance(data, dict)
assert data["space"] == space
Loading