Source code for torchio.transforms.preprocessing.spatial.resample

from pathlib import Path
from numbers import Number
from typing import Union, Tuple, Optional

import torch
import numpy as np
import SimpleITK as sitk

from ....data.io import sitk_to_nib
from ....data.subject import Subject
from ....typing import TypeTripletFloat
from ....data.image import Image, ScalarImage
from ... import SpatialTransform


TypeSpacing = Union[float, Tuple[float, float, float]]
TypeTarget = Tuple[
    Optional[Union[Image, str]],
    Optional[Tuple[float, float, float]],
]


[docs]class Resample(SpatialTransform): """Change voxel spacing by resampling. Args: target: Tuple :math:`(s_h, s_w, s_d)`. If only one value :math:`n` is specified, then :math:`s_h = s_w = s_d = n`. If a string or :class:`~pathlib.Path` is given, all images will be resampled using the image with that name as reference or found at the path. An instance of :class:`~torchio.Image` can also be passed. pre_affine_name: Name of the *image key* (not subject key) storing an affine matrix that will be applied to the image header before resampling. If ``None``, the image is resampled with an identity transform. See usage in the example below. image_interpolation: See :ref:`Interpolation`. scalars_only: Apply only to instances of :class:`~torchio.ScalarImage`. See :class:`~torchio.transforms.RandomAnisotropy`. **kwargs: See :class:`~torchio.transforms.Transform` for additional keyword arguments. Example: >>> import torch >>> import torchio as tio >>> transform = tio.Resample(1) # resample all images to 1mm iso >>> transform = tio.Resample((2, 2, 2)) # resample all images to 2mm iso >>> transform = tio.Resample('t1') # resample all images to 't1' image space >>> # Example: using a precomputed transform to MNI space >>> ref_path = tio.datasets.Colin27().t1.path # this image is in the MNI space, so we can use it as reference/target >>> affine_matrix = tio.io.read_matrix('transform_to_mni.txt') # from a NiftyReg registration. Would also work with e.g. .tfm from SimpleITK >>> image = tio.ScalarImage(tensor=torch.rand(1, 256, 256, 180), to_mni=affine_matrix) # 'to_mni' is an arbitrary name >>> transform = tio.Resample(colin.t1.path, pre_affine_name='to_mni') # nearest neighbor interpolation is used for label maps >>> transformed = transform(image) # "image" is now in the MNI space """ # noqa: E501 def __init__( self, target: Union[TypeSpacing, str, Path, Image, None] = 1, image_interpolation: str = 'linear', pre_affine_name: Optional[str] = None, scalars_only: bool = False, **kwargs ): super().__init__(**kwargs) self.target = target self.reference_image, self.target_spacing = self.parse_target(target) parsed_interpolation = self.parse_interpolation(image_interpolation) self.image_interpolation = parsed_interpolation self.pre_affine_name = pre_affine_name self.scalars_only = scalars_only self.args_names = ( 'target', 'image_interpolation', 'pre_affine_name', 'scalars_only', ) def parse_target(self, target: Union[TypeSpacing, str]) -> TypeTarget: """ If target is an existing path, return a torchio.ScalarImage If it does not exist, return the string If it is not a Path or string or an Image, return None """ if isinstance(target, (str, Path)): if Path(target).is_file(): path = target reference_image = ScalarImage(path) else: reference_image = target target_spacing = None elif isinstance(target, Image): reference_image = target target_spacing = reference_image.spacing else: reference_image = None target_spacing = self.parse_spacing(target) return reference_image, target_spacing @staticmethod def parse_spacing(spacing: TypeSpacing) -> Tuple[float, float, float]: if isinstance(spacing, tuple) and len(spacing) == 3: result = spacing elif isinstance(spacing, Number): result = 3 * (spacing,) else: message = ( 'Target must be a string, a positive number' f' or a tuple of positive numbers, not {type(spacing)}' ) raise ValueError(message) if np.any(np.array(spacing) <= 0): raise ValueError(f'Spacing must be positive, not "{spacing}"') return result @staticmethod def check_affine(affine_name: str, image: Image): if not isinstance(affine_name, str): message = ( 'Affine name argument must be a string,' f' not {type(affine_name)}' ) raise TypeError(message) if affine_name in image: matrix = image[affine_name] if not isinstance(matrix, (np.ndarray, torch.Tensor)): message = ( 'The affine matrix must be a NumPy array or PyTorch' f' tensor, not {type(matrix)}' ) raise TypeError(message) if matrix.shape != (4, 4): message = ( 'The affine matrix shape must be (4, 4),' f' not {matrix.shape}' ) raise ValueError(message) @staticmethod def check_affine_key_presence(affine_name: str, subject: Subject): for image in subject.get_images(intensity_only=False): if affine_name in image: return message = ( f'An affine name was given ("{affine_name}"), but it was not found' ' in any image in the subject' ) raise ValueError(message) def apply_transform(self, subject: Subject) -> Subject: use_pre_affine = self.pre_affine_name is not None if use_pre_affine: self.check_affine_key_presence(self.pre_affine_name, subject) for image in self.get_images(subject): # Do not resample the reference image if there is one if image is self.reference_image: continue # Choose interpolation if not isinstance(image, ScalarImage): if self.scalars_only: continue interpolation = 'nearest' else: interpolation = self.image_interpolation interpolator = self.get_sitk_interpolator(interpolation) # Apply given affine matrix if found in image if use_pre_affine and self.pre_affine_name in image: self.check_affine(self.pre_affine_name, image) matrix = image[self.pre_affine_name] if isinstance(matrix, torch.Tensor): matrix = matrix.numpy() image.affine = matrix @ image.affine floating_itk = image.as_sitk(force_3d=True) # Get reference image if isinstance(self.reference_image, str): try: reference_image = subject[self.reference_image] reference_image_sitk = reference_image.as_sitk() except KeyError as error: message = ( f'Image name "{self.reference_image}"' f' not found in subject. If "{self.reference_image}"' ' is a path, it does not exist or permission has been' ' denied' ) raise ValueError(message) from error elif isinstance(self.reference_image, Image): reference_image_sitk = self.reference_image.as_sitk( force_3d=True) elif self.reference_image is None: # target is a spacing reference_image_sitk = self.get_reference_image( floating_itk, self.target_spacing, ) num_dims_ref = reference_image_sitk.GetDimension() num_dims_flo = floating_itk.GetDimension() assert num_dims_ref == num_dims_flo # Resample resampler = sitk.ResampleImageFilter() resampler.SetInterpolator(interpolator) resampler.SetReferenceImage(reference_image_sitk) resampled = resampler.Execute(floating_itk) array, affine = sitk_to_nib(resampled) image.set_data(torch.as_tensor(array)) image.affine = affine return subject @staticmethod def get_reference_image( image: sitk.Image, spacing: TypeTripletFloat, ) -> sitk.Image: old_spacing = np.array(image.GetSpacing()) new_spacing = np.array(spacing) old_size = np.array(image.GetSize()) new_size = old_size * old_spacing / new_spacing new_size = np.ceil(new_size).astype(np.uint16) new_size[old_size == 1] = 1 # keep singleton dimensions new_origin_index = 0.5 * (new_spacing / old_spacing - 1) new_origin_lps = image.TransformContinuousIndexToPhysicalPoint( new_origin_index) reference = sitk.Image( new_size.tolist(), image.GetPixelID(), image.GetNumberOfComponentsPerPixel(), ) reference.SetDirection(image.GetDirection()) reference.SetSpacing(new_spacing.tolist()) reference.SetOrigin(new_origin_lps) return reference @staticmethod def get_sigma(downsampling_factor, spacing): """Compute optimal standard deviation for Gaussian kernel. From Cardoso et al., "Scale factor point spread function matching: beyond aliasing in image resampling", MICCAI 2015 """ k = downsampling_factor variance = (k ** 2 - 1 ** 2) * (2 * np.sqrt(2 * np.log(2))) ** (-2) sigma = spacing * np.sqrt(variance) return sigma