Source code for torchio.transforms.augmentation.spatial.random_affine_elastic_deformation

from __future__ import annotations

from typing import Any

import numpy as np
import SimpleITK as sitk
import torch

from ....constants import INTENSITY
from ....constants import TYPE
from ....data.io import nib_to_sitk
from ....data.subject import Subject
from ...spatial_transform import SpatialTransform
from .. import RandomTransform
from .random_affine import Affine
from .random_elastic_deformation import ElasticDeformation


[docs] class RandomAffineElasticDeformation(RandomTransform, SpatialTransform): r"""Apply a RandomAffine and RandomElasticDeformation simultaneously. Optimization to use only a single SimpleITK resampling. For additional details on the transformations, see :class:`~torchio.transforms.RandomAffine` and :class:`~torchio.transforms.RandomElasticDeformation` Args: affine_first: Apply affine before elastic deformation. affine_kwargs: See :class:`~torchio.transforms.RandomAffine` for kwargs. elastic_kwargs: See :class:`~torchio.transforms.RandomElasticDeformation` for kwargs. **kwargs: See :class:`~torchio.transforms.Transform` for additional keyword arguments. Example: >>> import torchio as tio >>> image = tio.datasets.Colin27().t1 >>> affine_kwargs = {'scales': (0.9, 1.2), 'degrees': 15} >>> elastic_kwargs = {'max_displacement': (17, 12, 2)} >>> transform = tio.RandomAffineElasticDeformation( ... affine_kwargs, ... elastic_kwargs ... ) >>> transformed = transform(image) .. plot:: import torchio as tio subject = tio.datasets.Slicer('CTChest') ct = subject.CT_chest elastic_kwargs = {'max_displacement': (17, 12, 2)} transform = tio.RandomAffineElasticDeformation(elastic_kwargs=elastic_kwargs) ct_transformed = transform(ct) subject.add_image(ct_transformed, 'Transformed') subject.plot() """ def __init__( self, affine_first: bool = True, affine_kwargs: dict[str, Any] | None = None, elastic_kwargs: dict[str, Any] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) self.affine_first = affine_first # Avoid circular imports from .random_affine import RandomAffine from .random_elastic_deformation import RandomElasticDeformation self.affine_kwargs = affine_kwargs or {} self.random_affine = RandomAffine(**self.affine_kwargs) self.elastic_kwargs = elastic_kwargs or {} self.random_elastic = RandomElasticDeformation(**self.elastic_kwargs) def get_params(self): affine_params = self.random_affine.get_params( self.random_affine.scales, self.random_affine.degrees, self.random_affine.translation, self.random_affine.isotropic, ) elastic_params = self.random_elastic.get_params( self.random_elastic.num_control_points, self.random_elastic.max_displacement, self.random_elastic.num_locked_borders, ) return affine_params, elastic_params def apply_transform(self, subject: Subject): affine_params, elastic_params = self.get_params() scaling_params, rotation_params, translation_params = affine_params affine_params = { 'scales': scaling_params.tolist(), 'degrees': rotation_params.tolist(), 'translation': translation_params.tolist(), 'center': self.random_affine.center, 'default_pad_value': self.random_affine.default_pad_value, 'image_interpolation': self.random_affine.image_interpolation, 'label_interpolation': self.random_affine.label_interpolation, 'check_shape': self.random_affine.check_shape, } elastic_params = { 'control_points': elastic_params, 'max_displacement': self.random_elastic.max_displacement, 'image_interpolation': self.random_elastic.image_interpolation, 'label_interpolation': self.random_elastic.label_interpolation, } arguments = { 'affine_first': self.affine_first, 'affine_params': affine_params, 'elastic_params': elastic_params, } transform = AffineElasticDeformation(**self.add_base_args(arguments)) transformed = transform(subject) assert isinstance(transformed, Subject) return transformed
class AffineElasticDeformation(SpatialTransform): r"""Apply an Affine and ElasticDeformation simultaneously. Optimization to use only a single SimpleITK resampling. For additional details on the transformations, see :class:`~torchio.transforms.augmentation.Affine` and :class:`~torchio.transforms.augmentation.ElasticDeformation` Args: affine_first: Apply affine before elastic deformation. affine_kwargs: See :class:`~torchio.transforms.augmentation.RandomAffine` for kwargs. elastic_kwargs: See :class:`~torchio.transforms.augmentation.RandomElasticDeformation` for kwargs. **kwargs: See :class:`~torchio.transforms.Transform` for additional keyword arguments. """ def __init__( self, affine_first: bool, affine_params: dict[str, Any], elastic_params: dict[str, Any], **kwargs, ) -> None: super().__init__(**kwargs) self.affine_first = affine_first self.affine_params = affine_params self._affine = Affine( **self.affine_params, **kwargs, ) self.elastic_params = elastic_params self._elastic = ElasticDeformation( **self.elastic_params, **kwargs, ) self.args_names = ['affine_first', 'affine_params', 'elastic_params'] def apply_transform(self, subject: Subject) -> Subject: if self._affine.check_shape: subject.check_consistent_spatial_shape() default_value: float for image in self.get_images(subject): affine_transform = self._affine.get_affine_transform(image) transformed_tensors = [] for tensor in image.data: sitk_image = nib_to_sitk( tensor[np.newaxis], image.affine, force_3d=True, ) if image[TYPE] != INTENSITY: interpolation = self._affine.label_interpolation default_value = 0 else: interpolation = self._affine.image_interpolation default_value = self._affine.get_default_pad_value( tensor, sitk_image ) bspline_transform = self._elastic.get_bspline_transform(sitk_image) self._elastic.parse_free_form_transform( bspline_transform, self._elastic.max_displacement, ) # stack: LIFO if self.affine_first: combined_transforms = [affine_transform, bspline_transform] else: combined_transforms = [bspline_transform, affine_transform] composite_transform = sitk.CompositeTransform(combined_transforms) transformed_tensor = self.apply_composite_transform( sitk_image, composite_transform, interpolation, default_value, ) transformed_tensors.append(transformed_tensor) image.set_data(torch.stack(transformed_tensors)) return subject def apply_composite_transform( self, sitk_image: sitk.Image, transform: sitk.Transform, interpolation: str, default_value: float, ) -> torch.Tensor: floating = reference = sitk_image resampler = sitk.ResampleImageFilter() resampler.SetInterpolator(self.get_sitk_interpolator(interpolation)) resampler.SetReferenceImage(reference) resampler.SetDefaultPixelValue(float(default_value)) resampler.SetOutputPixelType(sitk.sitkFloat32) resampler.SetTransform(transform) resampled = resampler.Execute(floating) np_array = sitk.GetArrayFromImage(resampled) np_array = np_array.transpose() # ITK to NumPy tensor = torch.as_tensor(np_array) return tensor