# Source code for torchio.transforms.augmentation.intensity.random_spike

from collections import defaultdict
from numbers import Number
from typing import Dict
from typing import Tuple
from typing import Union

import numpy as np
import torch

from .. import RandomTransform
from ... import FourierTransform
from ... import IntensityTransform
from ....data.subject import Subject

[docs]class RandomSpike(RandomTransform, IntensityTransform, FourierTransform):
r"""Add random MRI spike artifacts.

Also known as Herringbone artifact
<https://radiopaedia.org/articles/herringbone-artifact?lang=gb>_,
crisscross artifact or corduroy artifact, it creates stripes in different
directions in image space due to spikes in k-space.

Args:
num_spikes: Number of spikes :math:n present in k-space.
If a tuple :math:(a, b) is provided, then
:math:n \sim \mathcal{U}(a, b) \cap \mathbb{N}.
If only one value :math:d is provided,
:math:n \sim \mathcal{U}(0, d) \cap \mathbb{N}.
Larger values generate more distorted images.
intensity: Ratio :math:r between the spike intensity and the maximum
of the spectrum.
If a tuple :math:(a, b) is provided, then
:math:r \sim \mathcal{U}(a, b).
If only one value :math:d is provided,
:math:r \sim \mathcal{U}(-d, d).
Larger values generate more distorted images.
**kwargs: See :class:~torchio.transforms.Transform for additional
keyword arguments.

.. note:: The execution time of this transform does not depend on the
number of spikes.
"""
def __init__(
self,
num_spikes: Union[int, Tuple[int, int]] = 1,
intensity: Union[float, Tuple[float, float]] = (1, 3),
**kwargs
):
super().__init__(**kwargs)
self.intensity_range = self._parse_range(
intensity, 'intensity_range',
)
self.num_spikes_range: Tuple[int, int] = self._parse_range(  # type: ignore[assignment]  # noqa: E501
num_spikes, 'num_spikes', min_constraint=0, type_constraint=int,
)

def apply_transform(self, subject: Subject) -> Subject:
arguments: Dict[str, dict] = defaultdict(dict)
for image_name in self.get_images_dict(subject):
spikes_positions_param, intensity_param = self.get_params(
self.num_spikes_range,
self.intensity_range,
)
arguments['spikes_positions'][image_name] = spikes_positions_param
arguments['intensity'][image_name] = intensity_param
transformed = transform(subject)
assert isinstance(transformed, Subject)
return transformed

def get_params(
self,
num_spikes_range: Tuple[int, int],
intensity_range: Tuple[float, float],
) -> Tuple[np.ndarray, float]:
ns_min, ns_max = num_spikes_range
num_spikes_param = int(torch.randint(ns_min, ns_max + 1, (1,)).item())
intensity_param = self.sample_uniform(*intensity_range)
spikes_positions = torch.rand(num_spikes_param, 3).numpy()
return spikes_positions, intensity_param

class Spike(IntensityTransform, FourierTransform):
r"""Add MRI spike artifacts.

Also known as Herringbone artifact
<https://radiopaedia.org/articles/herringbone-artifact?lang=gb>_,
crisscross artifact or corduroy artifact, it creates stripes in different
directions in image space due to spikes in k-space.

Args:
spikes_positions:
intensity: Ratio :math:r between the spike intensity and the maximum
of the spectrum.
**kwargs: See :class:~torchio.transforms.Transform for additional
keyword arguments.

.. note:: The execution time of this transform does not depend on the
number of spikes.
"""
def __init__(
self,
spikes_positions: Union[np.ndarray, Dict[str, np.ndarray]],
intensity: Union[float, Dict[str, float]],
**kwargs
):
super().__init__(**kwargs)
self.spikes_positions = spikes_positions
self.intensity = intensity
self.args_names = ['spikes_positions', 'intensity']
self.invert_transform = False

def apply_transform(self, subject: Subject) -> Subject:
spikes_positions = self.spikes_positions
intensity = self.intensity
for image_name, image in self.get_images_dict(subject).items():
if self.arguments_are_dict():
spikes_positions = self.spikes_positions[image_name]
assert isinstance(self.intensity, dict)
intensity = self.intensity[image_name]
transformed_tensors = []
for channel in image.data:
assert isinstance(intensity, Number)
channel,
np.asarray(spikes_positions),
intensity,
)
transformed_tensors.append(transformed_tensor)
image.set_data(torch.stack(transformed_tensors))
return subject

self,
tensor: torch.Tensor,
spikes_positions: np.ndarray,
intensity_factor: float,
):
if intensity_factor == 0 or len(spikes_positions) == 0:
return tensor
spectrum = self.fourier_transform(tensor)
shape = np.array(spectrum.shape)
mid_shape = shape // 2
indices = np.floor(spikes_positions * shape).astype(int)
for index in indices:
diff = index - mid_shape
i, j, k = mid_shape + diff
# As of torch 1.7, "max is not yet implemented for complex tensors"
artifact = spectrum.cpu().numpy().max() * intensity_factor
if self.invert_transform:
spectrum[i, j, k] -= artifact
else:
spectrum[i, j, k] += artifact
# If we wanted to add a pure cosine, we should add spikes to both
# sides of k-space. However, having only one is a better
# representation og the actual cause of the artifact in real
# scans. Therefore the next two lines have been removed.
# #i, j, k = mid_shape - diff
# #spectrum[i, j, k] = spectrum.max() * intensity_factor
result = self.inv_fourier_transform(spectrum).real.float()
return result