Source code for torchio.transforms.augmentation.spatial.random_elastic_deformation
import warnings
from numbers import Number
from typing import Sequence
from typing import Tuple
from typing import Union
import numpy as np
import SimpleITK as sitk
import torch
from .. import RandomTransform
from ... import SpatialTransform
from ....data.image import ScalarImage
from ....data.io import nib_to_sitk
from ....data.subject import Subject
from ....typing import TypeTripletFloat
from ....typing import TypeTripletInt
from ....utils import to_tuple
SPLINE_ORDER = 3
[docs]
class RandomElasticDeformation(RandomTransform, SpatialTransform):
r"""Apply dense random elastic deformation.
A random displacement is assigned to a coarse grid of control points around
and inside the image. The displacement at each voxel is interpolated from
the coarse grid using cubic B-splines.
The `'Deformable Registration' <https://www.sciencedirect.com/topics/computer-science/deformable-registration>`_
topic on ScienceDirect contains useful articles explaining interpolation of
displacement fields using cubic B-splines.
.. warning:: This transform is slow as it requires expensive computations.
If your images are large you might want to use
:class:`~torchio.transforms.RandomAffine` instead.
Args:
num_control_points: Number of control points along each dimension of
the coarse grid :math:`(n_x, n_y, n_z)`.
If a single value :math:`n` is passed,
then :math:`n_x = n_y = n_z = n`.
Smaller numbers generate smoother deformations.
The minimum number of control points is ``4`` as this transform
uses cubic B-splines to interpolate displacement.
max_displacement: Maximum displacement along each dimension at each
control point :math:`(D_x, D_y, D_z)`.
The displacement along dimension :math:`i` at each control point is
:math:`d_i \sim \mathcal{U}(0, D_i)`.
If a single value :math:`D` is passed,
then :math:`D_x = D_y = D_z = D`.
Note that the total maximum displacement would actually be
:math:`D_{max} = \sqrt{D_x^2 + D_y^2 + D_z^2}`.
locked_borders: If ``0``, all displacement vectors are kept.
If ``1``, displacement of control points at the
border of the coarse grid will be set to ``0``.
If ``2``, displacement of control points at the border of the image
(red dots in the image below) will also be set to ``0``.
image_interpolation: See :ref:`Interpolation`.
Note that this is the interpolation used to compute voxel
intensities when resampling using the dense displacement field.
The value of the dense displacement at each voxel is always
interpolated with cubic B-splines from the values at the control
points of the coarse grid.
label_interpolation: See :ref:`Interpolation`.
**kwargs: See :class:`~torchio.transforms.Transform` for additional
keyword arguments.
`This gist <https://gist.github.com/fepegar/b723d15de620cd2a3a4dbd71e491b59d>`_
can also be used to better understand the meaning of the parameters.
This is an example from the
`3D Slicer registration FAQ <https://www.slicer.org/wiki/Documentation/4.10/FAQ/Registration#What.27s_the_BSpline_Grid_Size.3F>`_.
.. image:: https://www.slicer.org/w/img_auth.php/6/6f/RegLib_BSplineGridModel.png
:alt: B-spline example from 3D Slicer documentation
To generate a similar grid of control points with TorchIO,
the transform can be instantiated as follows::
>>> from torchio import RandomElasticDeformation
>>> transform = RandomElasticDeformation(
... num_control_points=(7, 7, 7), # or just 7
... locked_borders=2,
... )
Note that control points outside the image bounds are not showed in the
example image (they would also be red as we set :attr:`locked_borders`
to ``2``).
.. warning:: Image folding may occur if the maximum displacement is larger
than half the coarse grid spacing. The grid spacing can be computed
using the image bounds in physical space [#]_ and the number of control
points::
>>> import numpy as np
>>> import torchio as tio
>>> image = tio.datasets.Slicer().MRHead.as_sitk()
>>> image.GetSize() # in voxels
(256, 256, 130)
>>> image.GetSpacing() # in mm
(1.0, 1.0, 1.2999954223632812)
>>> bounds = np.array(image.GetSize()) * np.array(image.GetSpacing())
>>> bounds # mm
array([256. , 256. , 168.99940491])
>>> num_control_points = np.array((7, 7, 6))
>>> grid_spacing = bounds / (num_control_points - 2)
>>> grid_spacing
array([51.2 , 51.2 , 42.24985123])
>>> potential_folding = grid_spacing / 2
>>> potential_folding # mm
array([25.6 , 25.6 , 21.12492561])
Using a :attr:`max_displacement` larger than the computed
:attr:`potential_folding` will raise a :class:`RuntimeWarning`.
.. [#] Technically, :math:`2 \epsilon` should be added to the
image bounds, where :math:`\epsilon = 2^{-3}` `according to ITK
source code <https://github.com/InsightSoftwareConsortium/ITK/blob/633f84548311600845d54ab2463d3412194690a8/Modules/Core/Transform/include/itkBSplineTransformInitializer.hxx#L116-L138>`_.
""" # noqa: B950
def __init__(
self,
num_control_points: Union[int, Tuple[int, int, int]] = 7,
max_displacement: Union[float, Tuple[float, float, float]] = 7.5,
locked_borders: int = 2,
image_interpolation: str = 'linear',
label_interpolation: str = 'nearest',
**kwargs,
):
super().__init__(**kwargs)
self._bspline_transformation = None
self.num_control_points = to_tuple(num_control_points, length=3)
_parse_num_control_points(self.num_control_points) # type: ignore[arg-type] # noqa: B950
self.max_displacement = to_tuple(max_displacement, length=3)
_parse_max_displacement(self.max_displacement) # type: ignore[arg-type] # noqa: B950
self.num_locked_borders = locked_borders
if locked_borders not in (0, 1, 2):
raise ValueError('locked_borders must be 0, 1, or 2')
if locked_borders == 2 and 4 in self.num_control_points:
message = (
'Setting locked_borders to 2 and using less than 5 control'
'points results in an identity transform. Lock fewer borders'
' or use more control points.'
)
raise ValueError(message)
self.image_interpolation = self.parse_interpolation(
image_interpolation,
)
self.label_interpolation = self.parse_interpolation(
label_interpolation,
)
@staticmethod
def get_params(
num_control_points: TypeTripletInt,
max_displacement: Tuple[float, float, float],
num_locked_borders: int,
) -> np.ndarray:
grid_shape = num_control_points
num_dimensions = 3
coarse_field = torch.rand(*grid_shape, num_dimensions) # [0, 1)
coarse_field -= 0.5 # [-0.5, 0.5)
coarse_field *= 2 # [-1, 1]
for dimension in range(3):
# [-max_displacement, max_displacement)
coarse_field[..., dimension] *= max_displacement[dimension]
# Set displacement to 0 at the borders
for i in range(num_locked_borders):
coarse_field[i, :] = 0
coarse_field[-1 - i, :] = 0
coarse_field[:, i] = 0
coarse_field[:, -1 - i] = 0
return coarse_field.numpy()
def apply_transform(self, subject: Subject) -> Subject:
subject.check_consistent_spatial_shape()
control_points = self.get_params(
self.num_control_points, # type: ignore[arg-type]
self.max_displacement, # type: ignore[arg-type]
self.num_locked_borders,
)
arguments = {
'control_points': control_points,
'max_displacement': self.max_displacement,
'image_interpolation': self.image_interpolation,
'label_interpolation': self.label_interpolation,
}
transform = ElasticDeformation(**self.add_include_exclude(arguments))
transformed = transform(subject)
assert isinstance(transformed, Subject)
return transformed
class ElasticDeformation(SpatialTransform):
r"""Apply dense elastic deformation.
Args:
control_points:
max_displacement:
image_interpolation: See :ref:`Interpolation`.
label_interpolation: See :ref:`Interpolation`.
**kwargs: See :class:`~torchio.transforms.Transform` for additional
keyword arguments.
"""
def __init__(
self,
control_points: np.ndarray,
max_displacement: TypeTripletFloat,
image_interpolation: str = 'linear',
label_interpolation: str = 'nearest',
**kwargs,
):
super().__init__(**kwargs)
self.control_points = control_points
self.max_displacement = max_displacement
self.image_interpolation = self.parse_interpolation(
image_interpolation,
)
self.label_interpolation = self.parse_interpolation(
label_interpolation,
)
self.invert_transform = False
self.args_names = [
'control_points',
'image_interpolation',
'label_interpolation',
'max_displacement',
]
def get_bspline_transform(
self,
image: sitk.Image,
) -> sitk.BSplineTransform:
control_points = self.control_points.copy()
if self.invert_transform:
control_points *= -1
is_2d = image.GetSize()[2] == 1
if is_2d:
control_points[..., -1] = 0 # no displacement in IS axis
num_control_points = control_points.shape[:-1]
mesh_shape = [n - SPLINE_ORDER for n in num_control_points]
bspline_transform = sitk.BSplineTransformInitializer(image, mesh_shape)
parameters = control_points.flatten(order='F').tolist()
bspline_transform.SetParameters(parameters)
return bspline_transform
@staticmethod
def parse_free_form_transform(
transform: sitk.BSplineTransform,
max_displacement: Sequence[TypeTripletInt],
) -> None:
"""Issue a warning is possible folding is detected."""
coefficient_images = transform.GetCoefficientImages()
grid_spacing = coefficient_images[0].GetSpacing()
conflicts = np.array(max_displacement) > np.array(grid_spacing) / 2
if np.any(conflicts):
(where,) = np.where(conflicts)
message = (
'The maximum displacement is larger than the coarse grid'
f' spacing for dimensions: {where.tolist()}, so folding may'
' occur. Choose fewer control points or a smaller'
' maximum displacement'
)
warnings.warn(message, RuntimeWarning, stacklevel=2)
def apply_transform(self, subject: Subject) -> Subject:
no_displacement = not any(self.max_displacement)
if no_displacement:
return subject
subject.check_consistent_spatial_shape()
for image in self.get_images(subject):
if not isinstance(image, ScalarImage):
interpolation = self.label_interpolation
else:
interpolation = self.image_interpolation
transformed = self.apply_bspline_transform(
image.data,
image.affine,
interpolation,
)
image.set_data(transformed)
return subject
def apply_bspline_transform(
self,
tensor: torch.Tensor,
affine: np.ndarray,
interpolation: str,
) -> torch.Tensor:
assert tensor.dim() == 4
results = []
for component in tensor:
image = nib_to_sitk(component[np.newaxis], affine, force_3d=True)
floating = reference = image
bspline_transform = self.get_bspline_transform(image)
self.parse_free_form_transform(
bspline_transform,
self.max_displacement, # type: ignore[arg-type]
)
interpolator = self.get_sitk_interpolator(interpolation)
resampler = sitk.ResampleImageFilter()
resampler.SetReferenceImage(reference)
resampler.SetTransform(bspline_transform)
resampler.SetInterpolator(interpolator)
resampler.SetDefaultPixelValue(component.min().item())
resampler.SetOutputPixelType(sitk.sitkFloat32)
resampled = resampler.Execute(floating)
result, _ = self.sitk_to_nib(resampled)
results.append(torch.as_tensor(result))
tensor = torch.cat(results)
return tensor
def _parse_num_control_points(
num_control_points: TypeTripletInt,
) -> None:
for axis, number in enumerate(num_control_points):
if not isinstance(number, int) or number < 4:
message = (
f'The number of control points for axis {axis} must be'
f' an integer greater than 3, not {number}'
)
raise ValueError(message)
def _parse_max_displacement(
max_displacement: Tuple[float, float, float],
) -> None:
for axis, number in enumerate(max_displacement):
if not isinstance(number, Number) or number < 0:
message = (
'The maximum displacement at each control point'
f' for axis {axis} must be'
f' a number greater or equal to 0, not {number}'
)
raise ValueError(message)