Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove channels_last kwarg from Image #685

Merged
merged 2 commits into from
Oct 10, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Remove channels_last kwarg from Image
Fixes #684.

SimpleITK is not very good for fMRI and dMRI images.
After this commit, if 4 spatial dimensions are detected, we assume that
the read image is of this nature and we move convert the temporal
dimension into the channels dimension.

The `channels_last` kwarg should no longer be necessary.

It is important to take into account that images saved by TorchIO
always put the channels in the 5th NIfTI dimension, not the 4th.
  • Loading branch information
fepegar committed Oct 10, 2021
commit 3548cb3e0186f0ea05cae2ead17111e43e6e4abf
3 changes: 2 additions & 1 deletion tests/data/test_subject.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def test_no_sample(self):
input_dict = {'image': tio.ScalarImage(f.name)}
subject = tio.Subject(input_dict)
with self.assertRaises(RuntimeError):
tio.RandomFlip()(subject)
with self.assertWarns(UserWarning):
tio.RandomFlip()(subject)

def test_history(self):
transformed = tio.RandomGamma()(self.sample_subject)
Expand Down
8 changes: 0 additions & 8 deletions torchio/data/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,6 @@ class Image(dict):
check_nans: If ``True``, issues a warning if NaNs are found
in the image. If ``False``, images will not be checked for the
presence of NaNs.
channels_last: If ``True``, the read tensor will be permuted so the
last dimension becomes the first. This is useful, e.g., when
NIfTI images have been saved with the channels dimension being the
fourth instead of the fifth.
reader: Callable object that takes a path and returns a 4D tensor and a
2D, :math:`4 \times 4` affine matrix. This can be used if your data
is saved in a custom format, such as ``.npy`` (see example below).
Expand Down Expand Up @@ -121,12 +117,10 @@ def __init__(
tensor: Optional[TypeData] = None,
affine: Optional[TypeData] = None,
check_nans: bool = False, # removed by ITK by default
channels_last: bool = False,
reader: Callable = read_image,
**kwargs: Dict[str, Any],
):
self.check_nans = check_nans
self.channels_last = channels_last
self.reader = reader

if type is None:
Expand Down Expand Up @@ -505,8 +499,6 @@ def read_and_check(self, path: TypePath) -> Tuple[torch.Tensor, np.ndarray]:
tensor = self._parse_tensor_shape(tensor)
tensor = self._parse_tensor(tensor)
affine = self._parse_affine(affine)
if self.channels_last:
tensor = tensor.permute(3, 0, 1, 2)
if self.check_nans and torch.isnan(tensor).any():
warnings.warn(f'NaNs found in file "{path}"', RuntimeWarning)
return tensor, affine
Expand Down
34 changes: 26 additions & 8 deletions torchio/data/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,21 @@
def read_image(path: TypePath) -> Tuple[torch.Tensor, np.ndarray]:
try:
result = _read_sitk(path)
except RuntimeError: # try with NiBabel
except RuntimeError as e: # try with NiBabel
message = (
f'Error loading image with SimpleITK:\n{e}\n\nTrying NiBabel...'
)
warnings.warn(message)
try:
result = _read_nibabel(path)
except nib.loadsave.ImageFileError:
except nib.loadsave.ImageFileError as e:
message = (
f'File "{path}" not understood.'
' Check supported formats by at'
' https://simpleitk.readthedocs.io/en/master/IO.html#images'
' and https://nipy.org/nibabel/api.html#file-formats'
)
raise RuntimeError(message)
raise RuntimeError(message) from e
return result


Expand Down Expand Up @@ -83,9 +87,13 @@ def read_shape(path: TypePath) -> Tuple[int, int, int, int]:
reader.ReadImageInformation()
num_channels = reader.GetNumberOfComponents()
spatial_shape = reader.GetSize()
if reader.GetDimension() == 2:
num_dimensions = reader.GetDimension()
if num_dimensions == 2:
spatial_shape = *spatial_shape, 1
return (num_channels,) + spatial_shape
elif num_dimensions == 4: # assume bad NIfTI
*spatial_shape, num_channels = spatial_shape
shape = (num_channels,) + tuple(spatial_shape)
return shape


def read_affine(path: TypePath) -> np.ndarray:
Expand Down Expand Up @@ -314,10 +322,15 @@ def sitk_to_nib(
input_spatial_dims = image.GetDimension()
if input_spatial_dims == 2:
data = data[..., np.newaxis]
elif input_spatial_dims == 4: # probably a bad NIfTI (1, sx, sy, sz, c)
# Try to fix it
num_components = data.shape[-1]
data = data[0]
data = data.transpose(3, 0, 1, 2)
input_spatial_dims = 3
if not keepdim:
data = ensure_4d(data, num_spatial_dims=input_spatial_dims)
assert data.shape[0] == num_components
assert data.shape[1: 1 + input_spatial_dims] == image.GetSize()
affine = get_ras_affine_from_sitk(image)
return data, affine

Expand All @@ -328,14 +341,19 @@ def get_ras_affine_from_sitk(
spacing = np.array(sitk_object.GetSpacing())
direction_lps = np.array(sitk_object.GetDirection())
origin_lps = np.array(sitk_object.GetOrigin())
if len(direction_lps) == 9:
direction_length = len(direction_lps)
if direction_length == 9:
rotation_lps = direction_lps.reshape(3, 3)
elif len(direction_lps) == 4: # ignore last dimension if 2D (1, W, H, 1)
elif direction_length == 4: # ignore last dimension if 2D (1, W, H, 1)
rotation_lps_2d = direction_lps.reshape(2, 2)
rotation_lps = np.eye(3)
rotation_lps[:2, :2] = rotation_lps_2d
spacing = np.append(spacing, 1)
origin_lps = np.append(origin_lps, 0)
elif direction_length == 16: # probably a bad NIfTI. Let's try to fix it
rotation_lps = direction_lps.reshape(4, 4)[:3, :3]
spacing = spacing[:-1]
origin_lps = origin_lps[:-1]
rotation_ras = np.dot(FLIPXY_33, rotation_lps)
rotation_ras_zoom = rotation_ras * spacing
translation_ras = np.dot(FLIPXY_33, origin_lps)
Expand Down