Source code for torchio.transforms.augmentation.random_transform
from __future__ import annotations
from typing import Tuple
import torch
from .. import Transform
from ...typing import TypeRangeFloat
[docs]
class RandomTransform(Transform):
"""Base class for stochastic augmentation transforms.
Args:
**kwargs: See :class:`~torchio.transforms.Transform` for additional
keyword arguments.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def add_include_exclude(self, kwargs):
kwargs['include'] = self.include
kwargs['exclude'] = self.exclude
return kwargs
def parse_degrees(
self,
degrees: TypeRangeFloat,
) -> Tuple[float, float]:
return self._parse_range(degrees, 'degrees')
def parse_translation(
self,
translation: TypeRangeFloat,
) -> Tuple[float, float]:
return self._parse_range(translation, 'translation')
@staticmethod
def sample_uniform(a: float, b: float) -> float:
return torch.FloatTensor(1).uniform_(a, b).item()
@staticmethod
def _get_random_seed() -> int:
"""Generate a random seed.
Returns:
A random seed as an int.
"""
return int(torch.randint(0, 2**31, (1,)).item())
def sample_uniform_sextet(self, params):
results = []
for a, b in zip(params[::2], params[1::2]):
results.append(self.sample_uniform(a, b))
return torch.Tensor(results)