from collections import defaultdict
from numbers import Number
from typing import Dict
from typing import Tuple
from typing import Union
import numpy as np
import torch
from .. import RandomTransform
from ... import FourierTransform
from ... import IntensityTransform
from ....data.subject import Subject
[docs]
class RandomSpike(RandomTransform, IntensityTransform, FourierTransform):
r"""Add random MRI spike artifacts.
Also known as `Herringbone artifact
<https://radiopaedia.org/articles/herringbone-artifact?lang=gb>`_,
crisscross artifact or corduroy artifact, it creates stripes in different
directions in image space due to spikes in k-space.
Args:
num_spikes: Number of spikes :math:`n` present in k-space.
If a tuple :math:`(a, b)` is provided, then
:math:`n \sim \mathcal{U}(a, b) \cap \mathbb{N}`.
If only one value :math:`d` is provided,
:math:`n \sim \mathcal{U}(0, d) \cap \mathbb{N}`.
Larger values generate more distorted images.
intensity: Ratio :math:`r` between the spike intensity and the maximum
of the spectrum.
If a tuple :math:`(a, b)` is provided, then
:math:`r \sim \mathcal{U}(a, b)`.
If only one value :math:`d` is provided,
:math:`r \sim \mathcal{U}(-d, d)`.
Larger values generate more distorted images.
**kwargs: See :class:`~torchio.transforms.Transform` for additional
keyword arguments.
.. note:: The execution time of this transform does not depend on the
number of spikes.
"""
def __init__(
self,
num_spikes: Union[int, Tuple[int, int]] = 1,
intensity: Union[float, Tuple[float, float]] = (1, 3),
**kwargs,
):
super().__init__(**kwargs)
self.intensity_range = self._parse_range(
intensity,
'intensity_range',
)
self.num_spikes_range: Tuple[int, int] = self._parse_range( # type: ignore[assignment] # noqa: B950
num_spikes,
'num_spikes',
min_constraint=0,
type_constraint=int,
)
def apply_transform(self, subject: Subject) -> Subject:
arguments: Dict[str, dict] = defaultdict(dict)
for image_name in self.get_images_dict(subject):
spikes_positions_param, intensity_param = self.get_params(
self.num_spikes_range,
self.intensity_range,
)
arguments['spikes_positions'][image_name] = spikes_positions_param
arguments['intensity'][image_name] = intensity_param
transform = Spike(**self.add_include_exclude(arguments))
transformed = transform(subject)
assert isinstance(transformed, Subject)
return transformed
def get_params(
self,
num_spikes_range: Tuple[int, int],
intensity_range: Tuple[float, float],
) -> Tuple[np.ndarray, float]:
ns_min, ns_max = num_spikes_range
num_spikes_param = int(torch.randint(ns_min, ns_max + 1, (1,)).item())
intensity_param = self.sample_uniform(*intensity_range)
spikes_positions = torch.rand(num_spikes_param, 3).numpy()
return spikes_positions, intensity_param
class Spike(IntensityTransform, FourierTransform):
r"""Add MRI spike artifacts.
Also known as `Herringbone artifact
<https://radiopaedia.org/articles/herringbone-artifact?lang=gb>`_,
crisscross artifact or corduroy artifact, it creates stripes in different
directions in image space due to spikes in k-space.
Args:
spikes_positions:
intensity: Ratio :math:`r` between the spike intensity and the maximum
of the spectrum.
**kwargs: See :class:`~torchio.transforms.Transform` for additional
keyword arguments.
.. note:: The execution time of this transform does not depend on the
number of spikes.
"""
def __init__(
self,
spikes_positions: Union[np.ndarray, Dict[str, np.ndarray]],
intensity: Union[float, Dict[str, float]],
**kwargs,
):
super().__init__(**kwargs)
self.spikes_positions = spikes_positions
self.intensity = intensity
self.args_names = ['spikes_positions', 'intensity']
self.invert_transform = False
def apply_transform(self, subject: Subject) -> Subject:
spikes_positions = self.spikes_positions
intensity = self.intensity
for image_name, image in self.get_images_dict(subject).items():
if self.arguments_are_dict():
spikes_positions = self.spikes_positions[image_name]
assert isinstance(self.intensity, dict)
intensity = self.intensity[image_name]
transformed_tensors = []
for channel in image.data:
assert isinstance(intensity, Number)
transformed_tensor = self.add_artifact(
channel,
np.asarray(spikes_positions),
intensity,
)
transformed_tensors.append(transformed_tensor)
image.set_data(torch.stack(transformed_tensors))
return subject
def add_artifact(
self,
tensor: torch.Tensor,
spikes_positions: np.ndarray,
intensity_factor: float,
):
if intensity_factor == 0 or len(spikes_positions) == 0:
return tensor
spectrum = self.fourier_transform(tensor)
shape = np.array(spectrum.shape)
mid_shape = shape // 2
indices = np.floor(spikes_positions * shape).astype(int)
for index in indices:
diff = index - mid_shape
i, j, k = mid_shape + diff
# As of torch 1.7, "max is not yet implemented for complex tensors"
artifact = spectrum.cpu().numpy().max() * intensity_factor
if self.invert_transform:
spectrum[i, j, k] -= artifact
else:
spectrum[i, j, k] += artifact
# If we wanted to add a pure cosine, we should add spikes to both
# sides of k-space. However, having only one is a better
# representation og the actual cause of the artifact in real
# scans. Therefore the next two lines have been removed.
# #i, j, k = mid_shape - diff
# #spectrum[i, j, k] = spectrum.max() * intensity_factor
result = self.inv_fourier_transform(spectrum).real.float()
return result