Skip to content

Commit

Permalink
Remove channels_last kwarg from Image (#685)
Browse files Browse the repository at this point in the history
* Remove channels_last kwarg from Image

Fixes #684.

SimpleITK is not very good for fMRI and dMRI images.
After this commit, if 4 spatial dimensions are detected, we assume that
the read image is of this nature and we move convert the temporal
dimension into the channels dimension.

The `channels_last` kwarg should no longer be necessary.

It is important to take into account that images saved by TorchIO
always put the channels in the 5th NIfTI dimension, not the 4th.

* Show deprecation message
  • Loading branch information
fepegar authored Oct 10, 2021
1 parent 9b09c86 commit 141aa81
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 17 deletions.
3 changes: 2 additions & 1 deletion tests/data/test_subject.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def test_no_sample(self):
input_dict = {'image': tio.ScalarImage(f.name)}
subject = tio.Subject(input_dict)
with self.assertRaises(RuntimeError):
tio.RandomFlip()(subject)
with self.assertWarns(UserWarning):
tio.RandomFlip()(subject)

def test_history(self):
transformed = tio.RandomGamma()(self.sample_subject)
Expand Down
15 changes: 7 additions & 8 deletions torchio/data/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,6 @@ class Image(dict):
check_nans: If ``True``, issues a warning if NaNs are found
in the image. If ``False``, images will not be checked for the
presence of NaNs.
channels_last: If ``True``, the read tensor will be permuted so the
last dimension becomes the first. This is useful, e.g., when
NIfTI images have been saved with the channels dimension being the
fourth instead of the fifth.
reader: Callable object that takes a path and returns a 4D tensor and a
2D, :math:`4 \times 4` affine matrix. This can be used if your data
is saved in a custom format, such as ``.npy`` (see example below).
Expand Down Expand Up @@ -121,12 +117,10 @@ def __init__(
tensor: Optional[TypeData] = None,
affine: Optional[TypeData] = None,
check_nans: bool = False, # removed by ITK by default
channels_last: bool = False,
reader: Callable = read_image,
**kwargs: Dict[str, Any],
):
self.check_nans = check_nans
self.channels_last = channels_last
self.reader = reader

if type is None:
Expand All @@ -151,6 +145,13 @@ def __init__(
if key in kwargs:
message = f'Key "{key}" is reserved. Use a different one'
raise ValueError(message)
if 'channels_last' in kwargs:
message = (
'The "channels_last" keyword argument is deprecated after'
' https://github.com/fepegar/torchio/pull/685 and will be'
' removed in the future'
)
warnings.warn(message, DeprecationWarning)

super().__init__(**kwargs)
self.path = self._parse_path(path)
Expand Down Expand Up @@ -505,8 +506,6 @@ def read_and_check(self, path: TypePath) -> Tuple[torch.Tensor, np.ndarray]:
tensor = self._parse_tensor_shape(tensor)
tensor = self._parse_tensor(tensor)
affine = self._parse_affine(affine)
if self.channels_last:
tensor = tensor.permute(3, 0, 1, 2)
if self.check_nans and torch.isnan(tensor).any():
warnings.warn(f'NaNs found in file "{path}"', RuntimeWarning)
return tensor, affine
Expand Down
34 changes: 26 additions & 8 deletions torchio/data/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,21 @@
def read_image(path: TypePath) -> Tuple[torch.Tensor, np.ndarray]:
try:
result = _read_sitk(path)
except RuntimeError: # try with NiBabel
except RuntimeError as e: # try with NiBabel
message = (
f'Error loading image with SimpleITK:\n{e}\n\nTrying NiBabel...'
)
warnings.warn(message)
try:
result = _read_nibabel(path)
except nib.loadsave.ImageFileError:
except nib.loadsave.ImageFileError as e:
message = (
f'File "{path}" not understood.'
' Check supported formats by at'
' https://simpleitk.readthedocs.io/en/master/IO.html#images'
' and https://nipy.org/nibabel/api.html#file-formats'
)
raise RuntimeError(message)
raise RuntimeError(message) from e
return result


Expand Down Expand Up @@ -83,9 +87,13 @@ def read_shape(path: TypePath) -> Tuple[int, int, int, int]:
reader.ReadImageInformation()
num_channels = reader.GetNumberOfComponents()
spatial_shape = reader.GetSize()
if reader.GetDimension() == 2:
num_dimensions = reader.GetDimension()
if num_dimensions == 2:
spatial_shape = *spatial_shape, 1
return (num_channels,) + spatial_shape
elif num_dimensions == 4: # assume bad NIfTI
*spatial_shape, num_channels = spatial_shape
shape = (num_channels,) + tuple(spatial_shape)
return shape


def read_affine(path: TypePath) -> np.ndarray:
Expand Down Expand Up @@ -314,10 +322,15 @@ def sitk_to_nib(
input_spatial_dims = image.GetDimension()
if input_spatial_dims == 2:
data = data[..., np.newaxis]
elif input_spatial_dims == 4: # probably a bad NIfTI (1, sx, sy, sz, c)
# Try to fix it
num_components = data.shape[-1]
data = data[0]
data = data.transpose(3, 0, 1, 2)
input_spatial_dims = 3
if not keepdim:
data = ensure_4d(data, num_spatial_dims=input_spatial_dims)
assert data.shape[0] == num_components
assert data.shape[1: 1 + input_spatial_dims] == image.GetSize()
affine = get_ras_affine_from_sitk(image)
return data, affine

Expand All @@ -328,14 +341,19 @@ def get_ras_affine_from_sitk(
spacing = np.array(sitk_object.GetSpacing())
direction_lps = np.array(sitk_object.GetDirection())
origin_lps = np.array(sitk_object.GetOrigin())
if len(direction_lps) == 9:
direction_length = len(direction_lps)
if direction_length == 9:
rotation_lps = direction_lps.reshape(3, 3)
elif len(direction_lps) == 4: # ignore last dimension if 2D (1, W, H, 1)
elif direction_length == 4: # ignore last dimension if 2D (1, W, H, 1)
rotation_lps_2d = direction_lps.reshape(2, 2)
rotation_lps = np.eye(3)
rotation_lps[:2, :2] = rotation_lps_2d
spacing = np.append(spacing, 1)
origin_lps = np.append(origin_lps, 0)
elif direction_length == 16: # probably a bad NIfTI. Let's try to fix it
rotation_lps = direction_lps.reshape(4, 4)[:3, :3]
spacing = spacing[:-1]
origin_lps = origin_lps[:-1]
rotation_ras = np.dot(FLIPXY_33, rotation_lps)
rotation_ras_zoom = rotation_ras * spacing
translation_ras = np.dot(FLIPXY_33, origin_lps)
Expand Down

0 comments on commit 141aa81

Please sign in to comment.