from __future__ import annotations
import warnings
from typing import Dict
from typing import Sequence
from typing import Union
import numpy as np
import torch
from . import RandomTransform
from .. import Transform
from ...data.subject import Subject
TypeTransformsDict = Union[Dict[Transform, float], Sequence[Transform]]
[docs]
class Compose(Transform):
"""Compose several transforms together.
Args:
transforms: Sequence of instances of
:class:`~torchio.transforms.Transform`.
**kwargs: See :class:`~torchio.transforms.Transform` for additional
keyword arguments.
"""
def __init__(self, transforms: Sequence[Transform], **kwargs):
super().__init__(parse_input=False, **kwargs)
for transform in transforms:
if not callable(transform):
message = (
'One or more of the objects passed to the Compose'
f' transform are not callable: "{transform}"'
)
raise TypeError(message)
self.transforms = list(transforms)
def __len__(self):
return len(self.transforms)
def __getitem__(self, index) -> Transform:
return self.transforms[index]
def __repr__(self) -> str:
return f'{self.name}({self.transforms})'
def apply_transform(self, subject: Subject) -> Subject:
for transform in self.transforms:
subject = transform(subject) # type: ignore[assignment]
return subject
def is_invertible(self) -> bool:
return all(t.is_invertible() for t in self.transforms)
def inverse(self, warn: bool = True) -> Compose:
"""Return a composed transform with inverted order and transforms.
Args:
warn: Issue a warning if some transforms are not invertible.
"""
transforms = []
for transform in self.transforms:
if transform.is_invertible():
transforms.append(transform.inverse())
elif warn:
message = f'Skipping {transform.name} as it is not invertible'
warnings.warn(message, RuntimeWarning, stacklevel=2)
transforms.reverse()
result = Compose(transforms)
if not transforms and warn:
warnings.warn(
'No invertible transforms found',
RuntimeWarning,
stacklevel=2,
)
return result
[docs]
class OneOf(RandomTransform):
"""Apply only one of the given transforms.
Args:
transforms: Dictionary with instances of
:class:`~torchio.transforms.Transform` as keys and
probabilities as values. Probabilities are normalized so they sum
to one. If a sequence is given, the same probability will be
assigned to each transform.
**kwargs: See :class:`~torchio.transforms.Transform` for additional
keyword arguments.
Example:
>>> import torchio as tio
>>> colin = tio.datasets.Colin27()
>>> transforms_dict = {
... tio.RandomAffine(): 0.75,
... tio.RandomElasticDeformation(): 0.25,
... } # Using 3 and 1 as probabilities would have the same effect
>>> transform = tio.OneOf(transforms_dict)
>>> transformed = transform(colin)
"""
def __init__(self, transforms: TypeTransformsDict, **kwargs):
super().__init__(parse_input=False, **kwargs)
self.transforms_dict = self._get_transforms_dict(transforms)
def apply_transform(self, subject: Subject) -> Subject:
weights = torch.Tensor(list(self.transforms_dict.values()))
index = torch.multinomial(weights, 1)
transforms = list(self.transforms_dict.keys())
transform = transforms[index]
transformed = transform(subject)
return transformed # type: ignore[return-value]
def _get_transforms_dict(
self,
transforms: TypeTransformsDict,
) -> Dict[Transform, float]:
if isinstance(transforms, dict):
transforms_dict = dict(transforms)
self._normalize_probabilities(transforms_dict)
else:
try:
p = 1 / len(transforms)
except TypeError as e:
message = (
'Transforms argument must be a dictionary or a sequence,'
f' not {type(transforms)}'
)
raise ValueError(message) from e
transforms_dict = {transform: p for transform in transforms}
for transform in transforms_dict:
if not isinstance(transform, Transform):
message = (
'All keys in transform_dict must be instances of'
f'torchio.Transform, not "{type(transform)}"'
)
raise ValueError(message)
return transforms_dict
@staticmethod
def _normalize_probabilities(
transforms_dict: Dict[Transform, float],
) -> None:
probabilities = np.array(list(transforms_dict.values()), dtype=float)
if np.any(probabilities < 0):
message = (
f'Probabilities must be greater or equal to zero, not "{probabilities}"'
)
raise ValueError(message)
if np.all(probabilities == 0):
message = (
'At least one probability must be greater than zero,'
f' but they are "{probabilities}"'
)
raise ValueError(message)
for transform, probability in transforms_dict.items():
transforms_dict[transform] = probability / probabilities.sum()