from collections import defaultdict
from typing import Dict
from typing import List
from typing import Sequence
from typing import Tuple
from typing import Union
import numpy as np
import SimpleITK as sitk
import torch
from .. import RandomTransform
from ... import FourierTransform
from ... import IntensityTransform
from ....data.io import nib_to_sitk
from ....data.subject import Subject
from ....typing import TypeTripletFloat
[docs]
class RandomMotion(RandomTransform, IntensityTransform, FourierTransform):
r"""Add random MRI motion artifact.
Magnetic resonance images suffer from motion artifacts when the subject
moves during image acquisition. This transform follows
`Shaw et al., 2019 <http://proceedings.mlr.press/v102/shaw19a.html>`_ to
simulate motion artifacts for data augmentation.
Args:
degrees: Tuple :math:`(a, b)` defining the rotation range in degrees of
the simulated movements. The rotation angles around each axis are
:math:`(\theta_1, \theta_2, \theta_3)`,
where :math:`\theta_i \sim \mathcal{U}(a, b)`.
If only one value :math:`d` is provided,
:math:`\theta_i \sim \mathcal{U}(-d, d)`.
Larger values generate more distorted images.
translation: Tuple :math:`(a, b)` defining the translation in mm of
the simulated movements. The translations along each axis are
:math:`(t_1, t_2, t_3)`,
where :math:`t_i \sim \mathcal{U}(a, b)`.
If only one value :math:`t` is provided,
:math:`t_i \sim \mathcal{U}(-t, t)`.
Larger values generate more distorted images.
num_transforms: Number of simulated movements.
Larger values generate more distorted images.
image_interpolation: See :ref:`Interpolation`.
**kwargs: See :class:`~torchio.transforms.Transform` for additional
keyword arguments.
.. warning:: Large numbers of movements lead to longer execution times for
3D images.
"""
def __init__(
self,
degrees: Union[float, Tuple[float, float]] = 10,
translation: Union[float, Tuple[float, float]] = 10, # in mm
num_transforms: int = 2,
image_interpolation: str = 'linear',
**kwargs,
):
super().__init__(**kwargs)
self.degrees_range = self.parse_degrees(degrees)
self.translation_range = self.parse_translation(translation)
if num_transforms < 1 or not isinstance(num_transforms, int):
message = (
'Number of transforms must be a strictly positive natural'
f'number, not {num_transforms}'
)
raise ValueError(message)
self.num_transforms = num_transforms
self.image_interpolation = self.parse_interpolation(
image_interpolation,
)
def apply_transform(self, subject: Subject) -> Subject:
arguments: Dict[str, dict] = defaultdict(dict)
for name, image in self.get_images_dict(subject).items():
params = self.get_params(
self.degrees_range,
self.translation_range,
self.num_transforms,
is_2d=image.is_2d(),
)
times_params, degrees_params, translation_params = params
arguments['times'][name] = times_params
arguments['degrees'][name] = degrees_params
arguments['translation'][name] = translation_params
arguments['image_interpolation'][name] = self.image_interpolation
transform = Motion(**self.add_include_exclude(arguments))
transformed = transform(subject)
assert isinstance(transformed, Subject)
return transformed
def get_params(
self,
degrees_range: Tuple[float, float],
translation_range: Tuple[float, float],
num_transforms: int,
perturbation: float = 0.3,
is_2d: bool = False,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
# If perturbation is 0, time intervals between movements are constant
degrees_params = self.get_params_array(
degrees_range,
num_transforms,
)
translation_params = self.get_params_array(
translation_range,
num_transforms,
)
if is_2d: # imagine sagittal (1, A, S)
degrees_params[:, :-1] = 0 # rotate around Z axis only
translation_params[:, 2] = 0 # translate in XY plane only
step = 1 / (num_transforms + 1)
times = torch.arange(0, 1, step)[1:]
noise = torch.FloatTensor(num_transforms)
noise.uniform_(-step * perturbation, step * perturbation)
times += noise
times_params = times.numpy()
return times_params, degrees_params, translation_params
@staticmethod
def get_params_array(nums_range: Tuple[float, float], num_transforms: int):
tensor = torch.FloatTensor(num_transforms, 3).uniform_(*nums_range)
return tensor.numpy()
class Motion(IntensityTransform, FourierTransform):
r"""Add MRI motion artifact.
Magnetic resonance images suffer from motion artifacts when the subject
moves during image acquisition. This transform follows
`Shaw et al., 2019 <http://proceedings.mlr.press/v102/shaw19a.html>`_ to
simulate motion artifacts for data augmentation.
Args:
degrees: Sequence of rotations :math:`(\theta_1, \theta_2, \theta_3)`.
translation: Sequence of translations :math:`(t_1, t_2, t_3)` in mm.
times: Sequence of times from 0 to 1 at which the motions happen.
image_interpolation: See :ref:`Interpolation`.
**kwargs: See :class:`~torchio.transforms.Transform` for additional
keyword arguments.
"""
def __init__(
self,
degrees: Union[TypeTripletFloat, Dict[str, TypeTripletFloat]],
translation: Union[TypeTripletFloat, Dict[str, TypeTripletFloat]],
times: Union[Sequence[float], Dict[str, Sequence[float]]],
image_interpolation: Union[
Sequence[str], Dict[str, Sequence[str]]
], # noqa: B950
**kwargs,
):
super().__init__(**kwargs)
self.degrees = degrees
self.translation = translation
self.times = times
self.image_interpolation = image_interpolation
self.args_names = [
'degrees',
'translation',
'times',
'image_interpolation',
]
def apply_transform(self, subject: Subject) -> Subject:
degrees = self.degrees
translation = self.translation
times = self.times
image_interpolation = self.image_interpolation
for image_name, image in self.get_images_dict(subject).items():
if self.arguments_are_dict():
assert isinstance(self.degrees, dict)
assert isinstance(self.translation, dict)
assert isinstance(self.times, dict)
assert isinstance(self.image_interpolation, dict)
degrees = self.degrees[image_name]
translation = self.translation[image_name]
times = self.times[image_name]
image_interpolation = self.image_interpolation[image_name]
result_arrays = []
for channel in image.data:
sitk_image = nib_to_sitk(
channel[np.newaxis],
image.affine,
force_3d=True,
)
transforms = self.get_rigid_transforms(
np.asarray(degrees),
np.asarray(translation),
sitk_image,
)
assert isinstance(image_interpolation, str)
transformed_channel = self.add_artifact(
sitk_image,
transforms,
np.asarray(times),
image_interpolation,
)
result_arrays.append(transformed_channel)
result = np.stack(result_arrays)
image.set_data(torch.as_tensor(result))
return subject
def get_rigid_transforms(
self,
degrees_params: np.ndarray,
translation_params: np.ndarray,
image: sitk.Image,
) -> List[sitk.Euler3DTransform]:
center_ijk = np.array(image.GetSize()) / 2
center_lps = image.TransformContinuousIndexToPhysicalPoint(center_ijk)
identity = np.eye(4)
matrices = [identity]
for degrees, translation in zip(degrees_params, translation_params):
radians = np.radians(degrees).tolist()
motion = sitk.Euler3DTransform()
motion.SetCenter(center_lps)
motion.SetRotation(*radians)
motion.SetTranslation(translation.tolist())
motion_matrix = self.transform_to_matrix(motion)
matrices.append(motion_matrix)
transforms = [self.matrix_to_transform(m) for m in matrices]
return transforms
@staticmethod
def transform_to_matrix(transform: sitk.Euler3DTransform) -> np.ndarray:
matrix = np.eye(4)
rotation = np.array(transform.GetMatrix()).reshape(3, 3)
matrix[:3, :3] = rotation
matrix[:3, 3] = transform.GetTranslation()
return matrix
@staticmethod
def matrix_to_transform(matrix: np.ndarray) -> sitk.Euler3DTransform:
transform = sitk.Euler3DTransform()
rotation = matrix[:3, :3].flatten().tolist()
transform.SetMatrix(rotation)
transform.SetTranslation(matrix[:3, 3])
return transform
def resample_images(
self,
image: sitk.Image,
transforms: Sequence[sitk.Euler3DTransform],
interpolation: str,
) -> List[sitk.Image]:
floating = reference = image
default_value = np.float64(sitk.GetArrayViewFromImage(image).min())
transforms = transforms[1:] # first is identity
images = [image] # first is identity
for transform in transforms:
interpolator = self.get_sitk_interpolator(interpolation)
resampler = sitk.ResampleImageFilter()
resampler.SetInterpolator(interpolator)
resampler.SetReferenceImage(reference)
resampler.SetOutputPixelType(sitk.sitkFloat32)
resampler.SetDefaultPixelValue(default_value)
resampler.SetTransform(transform)
resampled = resampler.Execute(floating)
images.append(resampled)
return images
@staticmethod
def sort_spectra(spectra: List[torch.Tensor], times: np.ndarray):
"""Use original spectrum to fill the center of k-space."""
num_spectra = len(spectra)
if np.any(times > 0.5):
index = np.where(times > 0.5)[0].min()
else:
index = num_spectra - 1
spectra[0], spectra[index] = spectra[index], spectra[0]
def add_artifact(
self,
image: sitk.Image,
transforms: Sequence[sitk.Euler3DTransform],
times: np.ndarray,
interpolation: str,
):
images = self.resample_images(image, transforms, interpolation)
spectra = []
for image in images:
array = sitk.GetArrayFromImage(image).transpose() # sitk to np
spectrum = self.fourier_transform(torch.from_numpy(array))
spectra.append(spectrum)
self.sort_spectra(spectra, times)
result_spectrum = torch.empty_like(spectra[0])
last_index = result_spectrum.shape[2]
indices = (last_index * times).astype(int).tolist()
indices.append(last_index)
ini = 0
for spectrum, fin in zip(spectra, indices):
result_spectrum[..., ini:fin] = spectrum[..., ini:fin]
ini = fin
result_image = self.inv_fourier_transform(result_spectrum).real.float()
return result_image