# Source code for torchio.transforms.augmentation.intensity.random_blur

from collections import defaultdict
from typing import Dict
from typing import Tuple
from typing import Union

import numpy as np
import scipy.ndimage as ndi
import torch

from .. import RandomTransform
from ... import IntensityTransform
from ....data.subject import Subject
from ....typing import TypeData
from ....typing import TypeSextetFloat
from ....typing import TypeTripletFloat

[docs]
class RandomBlur(RandomTransform, IntensityTransform):
r"""Blur an image using a random-sized Gaussian filter.

Args:
std: Tuple :math:(a_1, b_1, a_2, b_2, a_3, b_3) representing the
ranges (in mm) of the standard deviations
:math:(\sigma_1, \sigma_2, \sigma_3) of the Gaussian kernels used
to blur the image along each axis, where
:math:\sigma_i \sim \mathcal{U}(a_i, b_i).
If two values :math:(a, b) are provided,
then :math:\sigma_i \sim \mathcal{U}(a, b).
If only one value :math:x is provided,
then :math:\sigma_i \sim \mathcal{U}(0, x).
If three values :math:(x_1, x_2, x_3) are provided,
then :math:\sigma_i \sim \mathcal{U}(0, x_i).
**kwargs: See :class:~torchio.transforms.Transform for additional
keyword arguments.
"""

def __init__(self, std: Union[float, Tuple[float, float]] = (0, 2), **kwargs):
super().__init__(**kwargs)
self.std_ranges = self.parse_params(std, None, 'std', min_constraint=0)

def apply_transform(self, subject: Subject) -> Subject:
arguments: Dict[str, dict] = defaultdict(dict)
for name in self.get_images_dict(subject):
std = self.get_params(self.std_ranges)
arguments['std'][name] = std
transformed = transform(subject)
assert isinstance(transformed, Subject)
return transformed

def get_params(self, std_ranges: TypeSextetFloat) -> TypeTripletFloat:
std = self.sample_uniform_sextet(std_ranges)
return std

class Blur(IntensityTransform):
r"""Blur an image using a Gaussian filter.

Args:
std: Tuple :math:(\sigma_1, \sigma_2, \sigma_3) representing the
the standard deviations (in mm) of the Gaussian kernels used to
blur the image along each axis.
**kwargs: See :class:~torchio.transforms.Transform for additional
keyword arguments.
"""

def __init__(
self,
std: Union[TypeTripletFloat, Dict[str, TypeTripletFloat]],
**kwargs,
):
super().__init__(**kwargs)
self.std = std
self.args_names = ['std']

def apply_transform(self, subject: Subject) -> Subject:
stds = self.std
for name, image in self.get_images_dict(subject).items():
if self.arguments_are_dict():
assert isinstance(self.std, dict)
stds = self.std[name]
repets = image.num_channels, 1
stds_channels: np.ndarray
stds_channels = np.tile(stds, repets)  # type: ignore[arg-type]
transformed_tensors = []
for std, channel in zip(stds_channels, image.data):
transformed_tensor = blur(
channel,
image.spacing,
std,
)
transformed_tensors.append(transformed_tensor)
image.set_data(torch.stack(transformed_tensors))
return subject

def blur(
data: TypeData,
spacing: TypeTripletFloat,
std_physical: TypeTripletFloat,
) -> torch.Tensor:
assert data.ndim == 3
# For example, if the standard deviation of the kernel is 2 mm and the
# image spacing is 0.5 mm/voxel, the kernel should be
# (2 mm / 0.5 mm/voxel) = 4 voxels wide
std_voxel = np.array(std_physical) / np.array(spacing)
blurred = ndi.gaussian_filter(data, std_voxel)
tensor = torch.as_tensor(blurred)
return tensor