from collections import defaultdict
from collections.abc import Sequence
from typing import Union
import numpy as np
import SimpleITK as sitk
import torch
from ....data.io import nib_to_sitk
from ....data.subject import Subject
from ....types import TypeTripletFloat
from ...fourier import FourierTransform
from ...intensity_transform import IntensityTransform
from .. import RandomTransform
[docs]
class RandomMotion(RandomTransform, IntensityTransform, FourierTransform):
r"""Add random MRI motion artifact.
Magnetic resonance images suffer from motion artifacts when the subject
moves during image acquisition. This transform follows
`Shaw et al., 2019 <http://proceedings.mlr.press/v102/shaw19a.html>`_ to
simulate motion artifacts for data augmentation.
Args:
degrees: Tuple :math:`(a, b)` defining the rotation range in degrees of
the simulated movements. The rotation angles around each axis are
:math:`(\theta_1, \theta_2, \theta_3)`,
where :math:`\theta_i \sim \mathcal{U}(a, b)`.
If only one value :math:`d` is provided,
:math:`\theta_i \sim \mathcal{U}(-d, d)`.
Larger values generate more distorted images.
translation: Tuple :math:`(a, b)` defining the translation in mm of
the simulated movements. The translations along each axis are
:math:`(t_1, t_2, t_3)`,
where :math:`t_i \sim \mathcal{U}(a, b)`.
If only one value :math:`t` is provided,
:math:`t_i \sim \mathcal{U}(-t, t)`.
Larger values generate more distorted images.
num_transforms: Number of simulated movements.
Larger values generate more distorted images.
image_interpolation: See :ref:`Interpolation`.
**kwargs: See :class:`~torchio.transforms.Transform` for additional
keyword arguments.
.. warning:: Large numbers of movements lead to longer execution times for
3D images.
"""
def __init__(
self,
degrees: Union[float, tuple[float, float]] = 10,
translation: Union[float, tuple[float, float]] = 10, # in mm
num_transforms: int = 2,
image_interpolation: str = 'linear',
**kwargs,
):
super().__init__(**kwargs)
self.degrees_range = self.parse_degrees(degrees)
self.translation_range = self.parse_translation(translation)
if num_transforms < 1 or not isinstance(num_transforms, int):
message = (
'Number of transforms must be a strictly positive natural'
f'number, not {num_transforms}'
)
raise ValueError(message)
self.num_transforms = num_transforms
self.image_interpolation = self.parse_interpolation(
image_interpolation,
)
def apply_transform(self, subject: Subject) -> Subject:
images_dict = self.get_images_dict(subject)
if not images_dict:
return subject
arguments: dict[str, dict] = defaultdict(dict)
for name, image in images_dict.items():
params = self.get_params(
self.degrees_range,
self.translation_range,
self.num_transforms,
is_2d=image.is_2d(),
)
times_params, degrees_params, translation_params = params
arguments['times'][name] = times_params
arguments['degrees'][name] = degrees_params
arguments['translation'][name] = translation_params
arguments['image_interpolation'][name] = self.image_interpolation
transform = Motion(**self.add_base_args(arguments))
transformed = transform(subject)
assert isinstance(transformed, Subject)
return transformed
def get_params(
self,
degrees_range: tuple[float, float],
translation_range: tuple[float, float],
num_transforms: int,
perturbation: float = 0.3,
is_2d: bool = False,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
# If perturbation is 0, time intervals between movements are constant
degrees_params = self.get_params_array(
degrees_range,
num_transforms,
)
translation_params = self.get_params_array(
translation_range,
num_transforms,
)
if is_2d: # imagine sagittal (1, A, S)
degrees_params[:, :-1] = 0 # rotate around Z axis only
translation_params[:, 2] = 0 # translate in XY plane only
step = 1 / (num_transforms + 1)
times = torch.arange(0, 1, step)[1:]
noise = torch.FloatTensor(num_transforms)
noise.uniform_(-step * perturbation, step * perturbation)
times += noise
times_params = times.numpy()
return times_params, degrees_params, translation_params
@staticmethod
def get_params_array(nums_range: tuple[float, float], num_transforms: int):
tensor = torch.FloatTensor(num_transforms, 3).uniform_(*nums_range)
return tensor.numpy()
class Motion(IntensityTransform, FourierTransform):
r"""Add MRI motion artifact.
Magnetic resonance images suffer from motion artifacts when the subject
moves during image acquisition. This transform follows
`Shaw et al., 2019 <http://proceedings.mlr.press/v102/shaw19a.html>`_ to
simulate motion artifacts for data augmentation.
Args:
degrees: Sequence of rotations :math:`(\theta_1, \theta_2, \theta_3)`.
translation: Sequence of translations :math:`(t_1, t_2, t_3)` in mm.
times: Sequence of times from 0 to 1 at which the motions happen.
image_interpolation: See :ref:`Interpolation`.
**kwargs: See :class:`~torchio.transforms.Transform` for additional
keyword arguments.
"""
def __init__(
self,
degrees: Union[TypeTripletFloat, dict[str, TypeTripletFloat]],
translation: Union[TypeTripletFloat, dict[str, TypeTripletFloat]],
times: Union[Sequence[float], dict[str, Sequence[float]]],
image_interpolation: Union[Sequence[str], dict[str, Sequence[str]]],
**kwargs,
):
super().__init__(**kwargs)
self.degrees = degrees
self.translation = translation
self.times = times
self.image_interpolation = image_interpolation
self.args_names = [
'degrees',
'translation',
'times',
'image_interpolation',
]
def apply_transform(self, subject: Subject) -> Subject:
degrees = self.degrees
translation = self.translation
times = self.times
image_interpolation = self.image_interpolation
for image_name, image in self.get_images_dict(subject).items():
if self.arguments_are_dict():
assert isinstance(self.degrees, dict)
assert isinstance(self.translation, dict)
assert isinstance(self.times, dict)
assert isinstance(self.image_interpolation, dict)
degrees = self.degrees[image_name]
translation = self.translation[image_name]
times = self.times[image_name]
image_interpolation = self.image_interpolation[image_name]
result_arrays = []
for channel in image.data:
sitk_image = nib_to_sitk(
channel[np.newaxis],
image.affine,
force_3d=True,
)
transforms = self.get_rigid_transforms(
np.asarray(degrees),
np.asarray(translation),
sitk_image,
)
assert isinstance(image_interpolation, str)
transformed_channel = self.add_artifact(
sitk_image,
transforms,
np.asarray(times),
image_interpolation,
)
result_arrays.append(transformed_channel)
result = np.stack(result_arrays)
image.set_data(torch.as_tensor(result))
return subject
def get_rigid_transforms(
self,
degrees_params: np.ndarray,
translation_params: np.ndarray,
image: sitk.Image,
) -> list[sitk.Euler3DTransform]:
center_ijk = np.array(image.GetSize()) / 2
center_lps = image.TransformContinuousIndexToPhysicalPoint(center_ijk)
identity = np.eye(4)
matrices = [identity]
for degrees, translation in zip(degrees_params, translation_params):
radians = np.radians(degrees).tolist()
motion = sitk.Euler3DTransform()
motion.SetCenter(center_lps)
motion.SetRotation(*radians)
motion.SetTranslation(translation.tolist())
motion_matrix = self.transform_to_matrix(motion)
matrices.append(motion_matrix)
transforms = [self.matrix_to_transform(m) for m in matrices]
return transforms
@staticmethod
def transform_to_matrix(transform: sitk.Euler3DTransform) -> np.ndarray:
matrix = np.eye(4)
rotation = np.array(transform.GetMatrix()).reshape(3, 3)
matrix[:3, :3] = rotation
matrix[:3, 3] = transform.GetTranslation()
return matrix
@staticmethod
def matrix_to_transform(matrix: np.ndarray) -> sitk.Euler3DTransform:
transform = sitk.Euler3DTransform()
rotation = matrix[:3, :3].flatten().tolist()
transform.SetMatrix(rotation)
transform.SetTranslation(matrix[:3, 3])
return transform
def resample_images(
self,
image: sitk.Image,
transforms: Sequence[sitk.Euler3DTransform],
interpolation: str,
) -> list[sitk.Image]:
floating = reference = image
default_value = np.float64(sitk.GetArrayViewFromImage(image).min())
transforms = transforms[1:] # first is identity
images = [image] # first is identity
for transform in transforms:
interpolator = self.get_sitk_interpolator(interpolation)
resampler = sitk.ResampleImageFilter()
resampler.SetInterpolator(interpolator)
resampler.SetReferenceImage(reference)
resampler.SetOutputPixelType(sitk.sitkFloat32)
resampler.SetDefaultPixelValue(default_value)
resampler.SetTransform(transform)
resampled = resampler.Execute(floating)
images.append(resampled)
return images
@staticmethod
def sort_spectra(spectra: list[torch.Tensor], times: np.ndarray):
"""Use original spectrum to fill the center of k-space."""
num_spectra = len(spectra)
if np.any(times > 0.5):
index = np.where(times > 0.5)[0].min()
else:
index = num_spectra - 1
spectra[0], spectra[index] = spectra[index], spectra[0]
def add_artifact(
self,
image: sitk.Image,
transforms: Sequence[sitk.Euler3DTransform],
times: np.ndarray,
interpolation: str,
):
images = self.resample_images(image, transforms, interpolation)
spectra = []
for image in images:
array = sitk.GetArrayFromImage(image).transpose() # sitk to np
spectrum = self.fourier_transform(torch.from_numpy(array))
spectra.append(spectrum)
self.sort_spectra(spectra, times)
result_spectrum = torch.empty_like(spectra[0])
last_index = result_spectrum.shape[2]
indices_array = (last_index * times).astype(int)
indices: list[int] = indices_array.tolist() # type: ignore[assignment]
indices.append(last_index)
ini = 0
for spectrum, fin in zip(spectra, indices):
result_spectrum[..., ini:fin] = spectrum[..., ini:fin]
ini = fin
result_image = self.inv_fourier_transform(result_spectrum).real.float()
return result_image