from collections import defaultdict
from typing import Union
import numpy as np
import scipy.ndimage as ndi
import torch
from ....data.subject import Subject
from ....types import TypeData
from ....types import TypeSextetFloat
from ....types import TypeTripletFloat
from ...intensity_transform import IntensityTransform
from .. import RandomTransform
[docs]
class RandomBlur(RandomTransform, IntensityTransform):
r"""Blur an image using a random-sized Gaussian filter.
Args:
std: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` representing the
ranges (in mm) of the standard deviations
:math:`(\sigma_1, \sigma_2, \sigma_3)` of the Gaussian kernels used
to blur the image along each axis, where
:math:`\sigma_i \sim \mathcal{U}(a_i, b_i)`.
If two values :math:`(a, b)` are provided,
then :math:`\sigma_i \sim \mathcal{U}(a, b)`.
If only one value :math:`x` is provided,
then :math:`\sigma_i \sim \mathcal{U}(0, x)`.
If three values :math:`(x_1, x_2, x_3)` are provided,
then :math:`\sigma_i \sim \mathcal{U}(0, x_i)`.
**kwargs: See :class:`~torchio.transforms.Transform` for additional
keyword arguments.
"""
def __init__(self, std: Union[float, tuple[float, float]] = (0, 2), **kwargs):
super().__init__(**kwargs)
self.std_ranges = self.parse_params(std, None, 'std', min_constraint=0)
def apply_transform(self, subject: Subject) -> Subject:
images_dict = self.get_images_dict(subject)
if not images_dict:
return subject
arguments: dict[str, dict] = defaultdict(dict)
for name in images_dict:
std = self.get_params(self.std_ranges) # type: ignore[arg-type]
arguments['std'][name] = std
transform = Blur(**self.add_base_args(arguments))
transformed = transform(subject)
assert isinstance(transformed, Subject)
return transformed
def get_params(self, std_ranges: TypeSextetFloat) -> TypeTripletFloat:
sx, sy, sz = self.sample_uniform_sextet(std_ranges)
return sx, sy, sz
class Blur(IntensityTransform):
r"""Blur an image using a Gaussian filter.
Args:
std: Tuple :math:`(\sigma_1, \sigma_2, \sigma_3)` representing the
the standard deviations (in mm) of the Gaussian kernels used to
blur the image along each axis.
**kwargs: See :class:`~torchio.transforms.Transform` for additional
keyword arguments.
"""
def __init__(
self,
std: Union[TypeTripletFloat, dict[str, TypeTripletFloat]],
**kwargs,
):
super().__init__(**kwargs)
self.std = std
self.args_names = ['std']
def apply_transform(self, subject: Subject) -> Subject:
stds = self.std
for name, image in self.get_images_dict(subject).items():
if self.arguments_are_dict():
assert isinstance(self.std, dict)
stds = self.std[name]
repets = image.num_channels, 1
stds_channels: np.ndarray
stds_channels = np.tile(stds, repets) # type: ignore[arg-type]
transformed_tensors = []
for std, channel in zip(stds_channels, image.data):
transformed_tensor = blur(
channel,
image.spacing,
std,
)
transformed_tensors.append(transformed_tensor)
image.set_data(torch.stack(transformed_tensors))
return subject
def blur(
data: TypeData,
spacing: TypeTripletFloat,
std_physical: TypeTripletFloat,
) -> torch.Tensor:
assert data.ndim == 3
# For example, if the standard deviation of the kernel is 2 mm and the
# image spacing is 0.5 mm/voxel, the kernel should be
# (2 mm / 0.5 mm/voxel) = 4 voxels wide
std_voxel = np.array(std_physical) / np.array(spacing)
blurred = ndi.gaussian_filter(data, std_voxel)
tensor = torch.as_tensor(blurred)
return tensor