Source code for torchio.transforms.augmentation.intensity.random_ghosting

from __future__ import annotations

from collections import defaultdict
from collections.abc import Iterable

import numpy as np
import torch

from ....data.subject import Subject
from ...fourier import FourierTransform
from ...intensity_transform import IntensityTransform
from .. import RandomTransform


[docs] class RandomGhosting(RandomTransform, IntensityTransform): r"""Add random MRI ghosting artifact. Discrete "ghost" artifacts may occur along the phase-encode direction whenever the position or signal intensity of imaged structures within the field-of-view vary or move in a regular (periodic) fashion. Pulsatile flow of blood or CSF, cardiac motion, and respiratory motion are the most important patient-related causes of ghost artifacts in clinical MR imaging (from `mriquestions.com`_). .. _mriquestions.com: https://mriquestions.com/why-discrete-ghosts.html Args: num_ghosts: Number of 'ghosts' :math:`n` in the image. If :attr:`num_ghosts` is a tuple :math:`(a, b)`, then :math:`n \sim \mathcal{U}(a, b) \cap \mathbb{N}`. If only one value :math:`d` is provided, :math:`n \sim \mathcal{U}(0, d) \cap \mathbb{N}`. axes: Axis along which the ghosts will be created. If :attr:`axes` is a tuple, the axis will be randomly chosen from the passed values. Anatomical labels may also be used (see :class:`~torchio.transforms.augmentation.RandomFlip`). intensity: Positive number representing the artifact strength :math:`s` with respect to the maximum of the :math:`k`-space. If ``0``, the ghosts will not be visible. If a tuple :math:`(a, b)` is provided then :math:`s \sim \mathcal{U}(a, b)`. If only one value :math:`d` is provided, :math:`s \sim \mathcal{U}(0, d)`. restore: Number between ``0`` and ``1`` indicating how much of the :math:`k`-space center should be restored after removing the planes that generate the artifact. If ``None``, only the central slice will be restored. If a tuple :math:`(a, b)` is provided then :math:`r \sim \mathcal{U}(a, b)`. If only one value :math:`d` is provided, :math:`r \sim \mathcal{U}(0, d)`. **kwargs: See :class:`~torchio.transforms.Transform` for additional keyword arguments. .. note:: The execution time of this transform does not depend on the number of ghosts. """ def __init__( self, num_ghosts: int | tuple[int, int] = (4, 10), axes: int | tuple[int, ...] = (0, 1, 2), intensity: float | tuple[float, float] = (0.5, 1), restore: float | None = None, **kwargs, ): super().__init__(**kwargs) if not isinstance(axes, tuple): try: axes = tuple(axes) # type: ignore[arg-type] except TypeError: axes = (axes,) # type: ignore[assignment] assert isinstance(axes, Iterable) for axis in axes: if not isinstance(axis, str) and axis not in (0, 1, 2): raise ValueError(f'Axes must be in (0, 1, 2), not "{axes}"') self.axes = axes self.num_ghosts_range = self._parse_range( num_ghosts, 'num_ghosts', min_constraint=0, type_constraint=int, ) self.intensity_range = self._parse_range( intensity, 'intensity_range', min_constraint=0, ) if restore is None: self.restore = None else: self.restore = self._parse_range( restore, 'restore', min_constraint=0, max_constraint=1, ) def apply_transform(self, subject: Subject) -> Subject: images_dict = self.get_images_dict(subject) if not images_dict: return subject if any(isinstance(axis, str) for axis in self.axes): subject.check_consistent_orientation() arguments: dict[str, dict] = defaultdict(dict) for name, image in images_dict.items(): is_2d = image.is_2d() axes = [a for a in self.axes if a != 2] if is_2d else self.axes min_ghosts, max_ghosts = self.num_ghosts_range params = self.get_params( (int(min_ghosts), int(max_ghosts)), axes, # type: ignore[arg-type] self.intensity_range, self.restore, ) num_ghosts_param, axis_param, intensity_param, restore_param = params arguments['num_ghosts'][name] = num_ghosts_param arguments['axis'][name] = axis_param arguments['intensity'][name] = intensity_param arguments['restore'][name] = restore_param transform = Ghosting(**self.add_base_args(arguments)) transformed = transform(subject) assert isinstance(transformed, Subject) return transformed def get_params( self, num_ghosts_range: tuple[int, int], axes: tuple[int, ...], intensity_range: tuple[float, float], restore_range: tuple[float, float] | None, ) -> tuple[int, int, float, float | None]: ng_min, ng_max = num_ghosts_range num_ghosts = int(torch.randint(ng_min, ng_max + 1, (1,)).item()) axis = axes[torch.randint(0, len(axes), (1,))] intensity = self.sample_uniform(*intensity_range) if restore_range is None: restore = None else: restore = self.sample_uniform(*restore_range) return num_ghosts, axis, intensity, restore
class Ghosting(IntensityTransform, FourierTransform): r"""Add MRI ghosting artifact. Discrete "ghost" artifacts may occur along the phase-encode direction whenever the position or signal intensity of imaged structures within the field-of-view vary or move in a regular (periodic) fashion. Pulsatile flow of blood or CSF, cardiac motion, and respiratory motion are the most important patient-related causes of ghost artifacts in clinical MR imaging (from `mriquestions.com`_). .. _mriquestions.com: http://mriquestions.com/why-discrete-ghosts.html Args: num_ghosts: Number of 'ghosts' :math:`n` in the image. axes: Axis along which the ghosts will be created. intensity: Positive number representing the artifact strength :math:`s` with respect to the maximum of the :math:`k`-space. If ``0``, the ghosts will not be visible. restore: Number between ``0`` and ``1`` indicating how much of the :math:`k`-space center should be restored after removing the planes that generate the artifact. If ``None``, only the central slice will be restored. **kwargs: See :class:`~torchio.transforms.Transform` for additional keyword arguments. .. note:: The execution time of this transform does not depend on the number of ghosts. """ def __init__( self, num_ghosts: int | dict[str, int], axis: int | dict[str, int], intensity: float | dict[str, float], restore: float | None | dict[str, float | None], **kwargs, ): super().__init__(**kwargs) self.axis = axis self.num_ghosts = num_ghosts self.intensity = intensity self.restore = restore self.args_names = ['num_ghosts', 'axis', 'intensity', 'restore'] def apply_transform(self, subject: Subject) -> Subject: axis: int | dict[str, int] num_ghosts: int | dict[str, int] intensity: float | dict[str, float] restore: float | None | dict[str, float | None] for name, image in self.get_images_dict(subject).items(): if self.arguments_are_dict(): assert isinstance(self.axis, dict) assert isinstance(self.num_ghosts, dict) assert isinstance(self.intensity, dict) assert isinstance(self.restore, dict) axis = self.axis[name] num_ghosts = self.num_ghosts[name] intensity = self.intensity[name] restore = self.restore[name] else: axis = self.axis num_ghosts = self.num_ghosts intensity = self.intensity restore = self.restore transformed_tensors = [] for tensor in image.data: assert isinstance(num_ghosts, int) assert isinstance(axis, int) assert isinstance(intensity, (int, float)) if restore is not None: assert isinstance(restore, float) transformed_tensor = self.add_artifact( tensor, num_ghosts, axis, intensity, restore, ) transformed_tensors.append(transformed_tensor) image.set_data(torch.stack(transformed_tensors)) return subject def add_artifact( self, tensor: torch.Tensor, num_ghosts: int, axis: int, intensity: float, restore_center: float | None, ): if not num_ghosts or not intensity: return tensor spectrum = self.fourier_transform(tensor) # Variable "planes" is the part of the spectrum that will be modified # Variable "restore" is the part of the spectrum that will be restored planes = self._get_planes_to_modify(spectrum, axis, num_ghosts) tensor_restore, slices = self._get_slices_to_restore( spectrum, axis, restore_center ) tensor_restore = tensor_restore.clone() # Multiply by 0 if intensity is 1 planes *= 1 - intensity # Restore the center of k-space to avoid extreme artifacts spectrum[slices] = tensor_restore tensor_ghosts = self.inv_fourier_transform(spectrum) return tensor_ghosts.real.float() @staticmethod def _get_planes_to_modify( spectrum: torch.Tensor, axis: int, num_ghosts: int, ) -> torch.Tensor: slices = [slice(None)] * spectrum.ndim slices[axis] = slice(None, None, num_ghosts) slices_tuple = tuple(slices) return spectrum[slices_tuple] @staticmethod def _get_slices_to_restore( spectrum: torch.Tensor, axis: int, restore_center: float | None, ) -> tuple[torch.Tensor, tuple[slice, ...]]: dim_shape = spectrum.shape[axis] mid_idx = dim_shape // 2 slices = [slice(None)] * spectrum.ndim if restore_center is None: slice_ = slice(mid_idx, mid_idx + 1) else: size_restore = int(np.round(restore_center * dim_shape)) slice_ = slice(mid_idx - size_restore // 2, mid_idx + size_restore // 2) slices[axis] = slice_ slices_tuple = tuple(slices) restore_tensor = spectrum[slices_tuple] return restore_tensor, slices_tuple