Source code for torchio.transforms.augmentation.intensity.random_spike

from collections import defaultdict
from numbers import Number
from typing import Dict
from typing import Tuple
from typing import Union

import numpy as np
import torch

from .. import RandomTransform
from ... import FourierTransform
from ... import IntensityTransform
from ....data.subject import Subject


[docs] class RandomSpike(RandomTransform, IntensityTransform, FourierTransform): r"""Add random MRI spike artifacts. Also known as `Herringbone artifact <https://radiopaedia.org/articles/herringbone-artifact?lang=gb>`_, crisscross artifact or corduroy artifact, it creates stripes in different directions in image space due to spikes in k-space. Args: num_spikes: Number of spikes :math:`n` present in k-space. If a tuple :math:`(a, b)` is provided, then :math:`n \sim \mathcal{U}(a, b) \cap \mathbb{N}`. If only one value :math:`d` is provided, :math:`n \sim \mathcal{U}(0, d) \cap \mathbb{N}`. Larger values generate more distorted images. intensity: Ratio :math:`r` between the spike intensity and the maximum of the spectrum. If a tuple :math:`(a, b)` is provided, then :math:`r \sim \mathcal{U}(a, b)`. If only one value :math:`d` is provided, :math:`r \sim \mathcal{U}(-d, d)`. Larger values generate more distorted images. **kwargs: See :class:`~torchio.transforms.Transform` for additional keyword arguments. .. note:: The execution time of this transform does not depend on the number of spikes. """ def __init__( self, num_spikes: Union[int, Tuple[int, int]] = 1, intensity: Union[float, Tuple[float, float]] = (1, 3), **kwargs, ): super().__init__(**kwargs) self.intensity_range = self._parse_range( intensity, 'intensity_range', ) self.num_spikes_range: Tuple[int, int] = self._parse_range( # type: ignore[assignment] # noqa: B950 num_spikes, 'num_spikes', min_constraint=0, type_constraint=int, ) def apply_transform(self, subject: Subject) -> Subject: arguments: Dict[str, dict] = defaultdict(dict) for image_name in self.get_images_dict(subject): spikes_positions_param, intensity_param = self.get_params( self.num_spikes_range, self.intensity_range, ) arguments['spikes_positions'][image_name] = spikes_positions_param arguments['intensity'][image_name] = intensity_param transform = Spike(**self.add_include_exclude(arguments)) transformed = transform(subject) assert isinstance(transformed, Subject) return transformed def get_params( self, num_spikes_range: Tuple[int, int], intensity_range: Tuple[float, float], ) -> Tuple[np.ndarray, float]: ns_min, ns_max = num_spikes_range num_spikes_param = int(torch.randint(ns_min, ns_max + 1, (1,)).item()) intensity_param = self.sample_uniform(*intensity_range) spikes_positions = torch.rand(num_spikes_param, 3).numpy() return spikes_positions, intensity_param
class Spike(IntensityTransform, FourierTransform): r"""Add MRI spike artifacts. Also known as `Herringbone artifact <https://radiopaedia.org/articles/herringbone-artifact?lang=gb>`_, crisscross artifact or corduroy artifact, it creates stripes in different directions in image space due to spikes in k-space. Args: spikes_positions: intensity: Ratio :math:`r` between the spike intensity and the maximum of the spectrum. **kwargs: See :class:`~torchio.transforms.Transform` for additional keyword arguments. .. note:: The execution time of this transform does not depend on the number of spikes. """ def __init__( self, spikes_positions: Union[np.ndarray, Dict[str, np.ndarray]], intensity: Union[float, Dict[str, float]], **kwargs, ): super().__init__(**kwargs) self.spikes_positions = spikes_positions self.intensity = intensity self.args_names = ['spikes_positions', 'intensity'] self.invert_transform = False def apply_transform(self, subject: Subject) -> Subject: spikes_positions = self.spikes_positions intensity = self.intensity for image_name, image in self.get_images_dict(subject).items(): if self.arguments_are_dict(): spikes_positions = self.spikes_positions[image_name] assert isinstance(self.intensity, dict) intensity = self.intensity[image_name] transformed_tensors = [] for channel in image.data: assert isinstance(intensity, Number) transformed_tensor = self.add_artifact( channel, np.asarray(spikes_positions), intensity, ) transformed_tensors.append(transformed_tensor) image.set_data(torch.stack(transformed_tensors)) return subject def add_artifact( self, tensor: torch.Tensor, spikes_positions: np.ndarray, intensity_factor: float, ): if intensity_factor == 0 or len(spikes_positions) == 0: return tensor spectrum = self.fourier_transform(tensor) shape = np.array(spectrum.shape) mid_shape = shape // 2 indices = np.floor(spikes_positions * shape).astype(int) for index in indices: diff = index - mid_shape i, j, k = mid_shape + diff # As of torch 1.7, "max is not yet implemented for complex tensors" artifact = spectrum.cpu().numpy().max() * intensity_factor if self.invert_transform: spectrum[i, j, k] -= artifact else: spectrum[i, j, k] += artifact # If we wanted to add a pure cosine, we should add spikes to both # sides of k-space. However, having only one is a better # representation og the actual cause of the artifact in real # scans. Therefore the next two lines have been removed. # #i, j, k = mid_shape - diff # #spectrum[i, j, k] = spectrum.max() * intensity_factor result = self.inv_fourier_transform(spectrum).real.float() return result