from typing import List
from typing import Tuple
from typing import Union
import numpy as np
import torch
from .. import RandomTransform
from ... import SpatialTransform
from ....data.subject import Subject
from ....utils import to_tuple
[docs]
class RandomFlip(RandomTransform, SpatialTransform):
"""Reverse the order of elements in an image along the given axes.
Args:
axes: Index or tuple of indices of the spatial dimensions along which
the image might be flipped. If they are integers, they must be in
``(0, 1, 2)``. Anatomical labels may also be used, such as
``'Left'``, ``'Right'``, ``'Anterior'``, ``'Posterior'``,
``'Inferior'``, ``'Superior'``, ``'Height'`` and ``'Width'``,
``'AP'`` (antero-posterior), ``'lr'`` (lateral), ``'w'`` (width) or
``'i'`` (inferior). Only the first letter of the string will be
used. If the image is 2D, ``'Height'`` and ``'Width'`` may be
used.
flip_probability: Probability that the image will be flipped. This is
computed on a per-axis basis.
**kwargs: See :class:`~torchio.transforms.Transform` for additional
keyword arguments.
Example:
>>> import torchio as tio
>>> fpg = tio.datasets.FPG()
>>> flip = tio.RandomFlip(axes=('LR',)) # flip along lateral axis only
.. tip:: It is handy to specify the axes as anatomical labels when the
image orientation is not known.
"""
def __init__(
self,
axes: Union[int, Tuple[int, ...]] = 0,
flip_probability: float = 0.5,
**kwargs,
):
super().__init__(**kwargs)
self.axes = _parse_axes(axes)
self.flip_probability = self.parse_probability(flip_probability)
def apply_transform(self, subject: Subject) -> Subject:
potential_axes = _ensure_axes_indices(subject, self.axes)
axes_to_flip_hot = self.get_params(self.flip_probability)
for i in range(3):
if i not in potential_axes:
axes_to_flip_hot[i] = False
(axes,) = np.where(axes_to_flip_hot)
axes = axes.tolist()
if not axes:
return subject
arguments = {'axes': axes}
transform = Flip(**self.add_include_exclude(arguments))
transformed = transform(subject)
assert isinstance(transformed, Subject)
return transformed
@staticmethod
def get_params(probability: float) -> List[bool]:
return (probability > torch.rand(3)).tolist()
class Flip(SpatialTransform):
"""Reverse the order of elements in an image along the given axes.
Args:
axes: Index or tuple of indices of the spatial dimensions along which
the image will be flipped. See
:class:`~torchio.transforms.augmentation.spatial.random_flip.RandomFlip`
for more information.
**kwargs: See :class:`~torchio.transforms.Transform` for additional
keyword arguments.
.. tip:: It is handy to specify the axes as anatomical labels when the
image orientation is not known.
"""
def __init__(self, axes, **kwargs):
super().__init__(**kwargs)
self.axes = _parse_axes(axes)
self.args_names = ('axes',)
def apply_transform(self, subject: Subject) -> Subject:
axes = _ensure_axes_indices(subject, self.axes)
for image in self.get_images(subject):
_flip_image(image, axes)
return subject
@staticmethod
def is_invertible():
return True
def inverse(self):
return self
def _parse_axes(axes: Union[int, Tuple[int, ...]]):
axes_tuple = to_tuple(axes)
for axis in axes_tuple:
is_int = isinstance(axis, int)
is_string = isinstance(axis, str)
valid_number = is_int and axis in (0, 1, 2)
if not is_string and not valid_number:
message = (
f'All axes must be 0, 1 or 2, but found "{axis}" with type {type(axis)}'
)
raise ValueError(message)
return axes_tuple
def _ensure_axes_indices(subject, axes):
if any(isinstance(n, str) for n in axes):
subject.check_consistent_orientation()
image = subject.get_first_image()
axes = sorted(3 + image.axis_name_to_index(n) for n in axes)
return axes
def _flip_image(image, axes):
spatial_axes = np.array(axes, int) + 1
data = image.numpy()
data = np.flip(data, axis=spatial_axes)
data = data.copy() # remove negative strides
data = torch.as_tensor(data)
image.set_data(data)