Source code for torchio.transforms.augmentation.spatial.random_flip

from typing import Union

import numpy as np
import torch

from ....data.subject import Subject
from ....utils import to_tuple
from ...spatial_transform import SpatialTransform
from .. import RandomTransform


[docs] class RandomFlip(RandomTransform, SpatialTransform): """Reverse the order of elements in an image along the given axes. Args: axes: Index or tuple of indices of the spatial dimensions along which the image might be flipped. If they are integers, they must be in ``(0, 1, 2)``. Anatomical labels may also be used, such as ``'Left'``, ``'Right'``, ``'Anterior'``, ``'Posterior'``, ``'Inferior'``, ``'Superior'``, ``'Height'`` and ``'Width'``, ``'AP'`` (antero-posterior), ``'lr'`` (lateral), ``'w'`` (width) or ``'i'`` (inferior). Only the first letter of the string will be used. If the image is 2D, ``'Height'`` and ``'Width'`` may be used. flip_probability: Probability that the image will be flipped. This is computed on a per-axis basis. **kwargs: See :class:`~torchio.transforms.Transform` for additional keyword arguments. Example: >>> import torchio as tio >>> fpg = tio.datasets.FPG() >>> flip = tio.RandomFlip(axes=('LR',)) # flip along lateral axis only .. tip:: It is handy to specify the axes as anatomical labels when the image orientation is not known. """ def __init__( self, axes: Union[int, tuple[int, ...]] = 0, flip_probability: float = 0.5, **kwargs, ): super().__init__(**kwargs) self.axes = _parse_axes(axes) self.flip_probability = self.parse_probability(flip_probability) def apply_transform(self, subject: Subject) -> Subject: potential_axes = _ensure_axes_indices(subject, self.axes) axes_to_flip_hot = self.get_params(self.flip_probability) for i in range(3): if i not in potential_axes: axes_to_flip_hot[i] = False (axes,) = np.where(axes_to_flip_hot) axes_list = axes.tolist() if not axes_list: return subject arguments = {'axes': axes_list} transform = Flip(**self.add_base_args(arguments)) transformed = transform(subject) assert isinstance(transformed, Subject) return transformed @staticmethod def get_params(probability: float) -> list[bool]: return (probability > torch.rand(3)).tolist()
class Flip(SpatialTransform): """Reverse the order of elements in an image along the given axes. Args: axes: Index or tuple of indices of the spatial dimensions along which the image will be flipped. See :class:`~torchio.transforms.augmentation.spatial.random_flip.RandomFlip` for more information. **kwargs: See :class:`~torchio.transforms.Transform` for additional keyword arguments. .. tip:: It is handy to specify the axes as anatomical labels when the image orientation is not known. """ def __init__(self, axes, **kwargs): super().__init__(**kwargs) self.axes = _parse_axes(axes) self.args_names = ('axes',) def apply_transform(self, subject: Subject) -> Subject: axes = _ensure_axes_indices(subject, self.axes) for image in self.get_images(subject): _flip_image(image, axes) return subject @staticmethod def is_invertible(): return True def inverse(self): return self def _parse_axes(axes: Union[int, tuple[int, ...]]): axes_tuple = to_tuple(axes) for axis in axes_tuple: is_int = isinstance(axis, int) is_string = isinstance(axis, str) valid_number = is_int and axis in (0, 1, 2) if not is_string and not valid_number: message = ( f'All axes must be 0, 1 or 2, but found "{axis}" with type {type(axis)}' ) raise ValueError(message) return axes_tuple def _ensure_axes_indices(subject, axes): if any(isinstance(n, str) for n in axes): subject.check_consistent_orientation() image = subject.get_first_image() axes = sorted(3 + image.axis_name_to_index(n) for n in axes) return axes def _flip_image(image, axes): spatial_axes = np.array(axes, int) + 1 data = image.numpy() data = np.flip(data, axis=spatial_axes) data = data.copy() # remove negative strides data = torch.as_tensor(data) image.set_data(data)