Source code for torchio.transforms.augmentation.random_transform
"""
This is the docstring of random transform module
"""
from typing import Tuple
import torch
from ...typing import TypeRangeFloat
from .. import Transform
[docs]class RandomTransform(Transform):
"""Base class for stochastic augmentation transforms.
Args:
**kwargs: See :class:`~torchio.transforms.Transform` for additional
keyword arguments.
"""
def __init__(
self,
**kwargs
):
super().__init__(**kwargs)
def add_include_exclude(self, kwargs):
kwargs['include'] = self.include
kwargs['exclude'] = self.exclude
return kwargs
def parse_degrees(
self,
degrees: TypeRangeFloat,
) -> Tuple[float, float]:
return self._parse_range(degrees, 'degrees')
def parse_translation(
self,
translation: TypeRangeFloat,
) -> Tuple[float, float]:
return self._parse_range(translation, 'translation')
@staticmethod
def sample_uniform(a, b):
return torch.FloatTensor(1).uniform_(a, b)
@staticmethod
def _get_random_seed():
"""Generate a random seed.
Returns:
A random seed as an int.
"""
return torch.randint(0, 2**31, (1,)).item()
def sample_uniform_sextet(self, params):
results = []
for (a, b) in zip(params[::2], params[1::2]):
results.append(self.sample_uniform(a, b))
return torch.Tensor(results)