Source code for torchio.transforms.augmentation.intensity.random_motion

from collections import defaultdict
from collections.abc import Sequence
from typing import Union

import numpy as np
import SimpleITK as sitk
import torch

from ....data.io import nib_to_sitk
from ....data.subject import Subject
from ....types import TypeTripletFloat
from ...fourier import FourierTransform
from ...intensity_transform import IntensityTransform
from .. import RandomTransform


[docs] class RandomMotion(RandomTransform, IntensityTransform, FourierTransform): r"""Add random MRI motion artifact. Magnetic resonance images suffer from motion artifacts when the subject moves during image acquisition. This transform follows `Shaw et al., 2019 <http://proceedings.mlr.press/v102/shaw19a.html>`_ to simulate motion artifacts for data augmentation. Args: degrees: Tuple :math:`(a, b)` defining the rotation range in degrees of the simulated movements. The rotation angles around each axis are :math:`(\theta_1, \theta_2, \theta_3)`, where :math:`\theta_i \sim \mathcal{U}(a, b)`. If only one value :math:`d` is provided, :math:`\theta_i \sim \mathcal{U}(-d, d)`. Larger values generate more distorted images. translation: Tuple :math:`(a, b)` defining the translation in mm of the simulated movements. The translations along each axis are :math:`(t_1, t_2, t_3)`, where :math:`t_i \sim \mathcal{U}(a, b)`. If only one value :math:`t` is provided, :math:`t_i \sim \mathcal{U}(-t, t)`. Larger values generate more distorted images. num_transforms: Number of simulated movements. Larger values generate more distorted images. image_interpolation: See :ref:`Interpolation`. **kwargs: See :class:`~torchio.transforms.Transform` for additional keyword arguments. .. warning:: Large numbers of movements lead to longer execution times for 3D images. """ def __init__( self, degrees: Union[float, tuple[float, float]] = 10, translation: Union[float, tuple[float, float]] = 10, # in mm num_transforms: int = 2, image_interpolation: str = 'linear', **kwargs, ): super().__init__(**kwargs) self.degrees_range = self.parse_degrees(degrees) self.translation_range = self.parse_translation(translation) if num_transforms < 1 or not isinstance(num_transforms, int): message = ( 'Number of transforms must be a strictly positive natural' f'number, not {num_transforms}' ) raise ValueError(message) self.num_transforms = num_transforms self.image_interpolation = self.parse_interpolation( image_interpolation, ) def apply_transform(self, subject: Subject) -> Subject: images_dict = self.get_images_dict(subject) if not images_dict: return subject arguments: dict[str, dict] = defaultdict(dict) for name, image in images_dict.items(): params = self.get_params( self.degrees_range, self.translation_range, self.num_transforms, is_2d=image.is_2d(), ) times_params, degrees_params, translation_params = params arguments['times'][name] = times_params arguments['degrees'][name] = degrees_params arguments['translation'][name] = translation_params arguments['image_interpolation'][name] = self.image_interpolation transform = Motion(**self.add_base_args(arguments)) transformed = transform(subject) assert isinstance(transformed, Subject) return transformed def get_params( self, degrees_range: tuple[float, float], translation_range: tuple[float, float], num_transforms: int, perturbation: float = 0.3, is_2d: bool = False, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: # If perturbation is 0, time intervals between movements are constant degrees_params = self.get_params_array( degrees_range, num_transforms, ) translation_params = self.get_params_array( translation_range, num_transforms, ) if is_2d: # imagine sagittal (1, A, S) degrees_params[:, :-1] = 0 # rotate around Z axis only translation_params[:, 2] = 0 # translate in XY plane only step = 1 / (num_transforms + 1) times = torch.arange(0, 1, step)[1:] noise = torch.FloatTensor(num_transforms) noise.uniform_(-step * perturbation, step * perturbation) times += noise times_params = times.numpy() return times_params, degrees_params, translation_params @staticmethod def get_params_array(nums_range: tuple[float, float], num_transforms: int): tensor = torch.FloatTensor(num_transforms, 3).uniform_(*nums_range) return tensor.numpy()
class Motion(IntensityTransform, FourierTransform): r"""Add MRI motion artifact. Magnetic resonance images suffer from motion artifacts when the subject moves during image acquisition. This transform follows `Shaw et al., 2019 <http://proceedings.mlr.press/v102/shaw19a.html>`_ to simulate motion artifacts for data augmentation. Args: degrees: Sequence of rotations :math:`(\theta_1, \theta_2, \theta_3)`. translation: Sequence of translations :math:`(t_1, t_2, t_3)` in mm. times: Sequence of times from 0 to 1 at which the motions happen. image_interpolation: See :ref:`Interpolation`. **kwargs: See :class:`~torchio.transforms.Transform` for additional keyword arguments. """ def __init__( self, degrees: Union[TypeTripletFloat, dict[str, TypeTripletFloat]], translation: Union[TypeTripletFloat, dict[str, TypeTripletFloat]], times: Union[Sequence[float], dict[str, Sequence[float]]], image_interpolation: Union[Sequence[str], dict[str, Sequence[str]]], **kwargs, ): super().__init__(**kwargs) self.degrees = degrees self.translation = translation self.times = times self.image_interpolation = image_interpolation self.args_names = [ 'degrees', 'translation', 'times', 'image_interpolation', ] def apply_transform(self, subject: Subject) -> Subject: degrees = self.degrees translation = self.translation times = self.times image_interpolation = self.image_interpolation for image_name, image in self.get_images_dict(subject).items(): if self.arguments_are_dict(): assert isinstance(self.degrees, dict) assert isinstance(self.translation, dict) assert isinstance(self.times, dict) assert isinstance(self.image_interpolation, dict) degrees = self.degrees[image_name] translation = self.translation[image_name] times = self.times[image_name] image_interpolation = self.image_interpolation[image_name] result_arrays = [] for channel in image.data: sitk_image = nib_to_sitk( channel[np.newaxis], image.affine, force_3d=True, ) transforms = self.get_rigid_transforms( np.asarray(degrees), np.asarray(translation), sitk_image, ) assert isinstance(image_interpolation, str) transformed_channel = self.add_artifact( sitk_image, transforms, np.asarray(times), image_interpolation, ) result_arrays.append(transformed_channel) result = np.stack(result_arrays) image.set_data(torch.as_tensor(result)) return subject def get_rigid_transforms( self, degrees_params: np.ndarray, translation_params: np.ndarray, image: sitk.Image, ) -> list[sitk.Euler3DTransform]: center_ijk = np.array(image.GetSize()) / 2 center_lps = image.TransformContinuousIndexToPhysicalPoint(center_ijk) identity = np.eye(4) matrices = [identity] for degrees, translation in zip(degrees_params, translation_params): radians = np.radians(degrees).tolist() motion = sitk.Euler3DTransform() motion.SetCenter(center_lps) motion.SetRotation(*radians) motion.SetTranslation(translation.tolist()) motion_matrix = self.transform_to_matrix(motion) matrices.append(motion_matrix) transforms = [self.matrix_to_transform(m) for m in matrices] return transforms @staticmethod def transform_to_matrix(transform: sitk.Euler3DTransform) -> np.ndarray: matrix = np.eye(4) rotation = np.array(transform.GetMatrix()).reshape(3, 3) matrix[:3, :3] = rotation matrix[:3, 3] = transform.GetTranslation() return matrix @staticmethod def matrix_to_transform(matrix: np.ndarray) -> sitk.Euler3DTransform: transform = sitk.Euler3DTransform() rotation = matrix[:3, :3].flatten().tolist() transform.SetMatrix(rotation) transform.SetTranslation(matrix[:3, 3]) return transform def resample_images( self, image: sitk.Image, transforms: Sequence[sitk.Euler3DTransform], interpolation: str, ) -> list[sitk.Image]: floating = reference = image default_value = np.float64(sitk.GetArrayViewFromImage(image).min()) transforms = transforms[1:] # first is identity images = [image] # first is identity for transform in transforms: interpolator = self.get_sitk_interpolator(interpolation) resampler = sitk.ResampleImageFilter() resampler.SetInterpolator(interpolator) resampler.SetReferenceImage(reference) resampler.SetOutputPixelType(sitk.sitkFloat32) resampler.SetDefaultPixelValue(default_value) resampler.SetTransform(transform) resampled = resampler.Execute(floating) images.append(resampled) return images @staticmethod def sort_spectra(spectra: list[torch.Tensor], times: np.ndarray): """Use original spectrum to fill the center of k-space.""" num_spectra = len(spectra) if np.any(times > 0.5): index = np.where(times > 0.5)[0].min() else: index = num_spectra - 1 spectra[0], spectra[index] = spectra[index], spectra[0] def add_artifact( self, image: sitk.Image, transforms: Sequence[sitk.Euler3DTransform], times: np.ndarray, interpolation: str, ): images = self.resample_images(image, transforms, interpolation) spectra = [] for image in images: array = sitk.GetArrayFromImage(image).transpose() # sitk to np spectrum = self.fourier_transform(torch.from_numpy(array)) spectra.append(spectrum) self.sort_spectra(spectra, times) result_spectrum = torch.empty_like(spectra[0]) last_index = result_spectrum.shape[2] indices_array = (last_index * times).astype(int) indices: list[int] = indices_array.tolist() # type: ignore[assignment] indices.append(last_index) ini = 0 for spectrum, fin in zip(spectra, indices): result_spectrum[..., ini:fin] = spectrum[..., ini:fin] ini = fin result_image = self.inv_fourier_transform(result_spectrum).real.float() return result_image