Source code for torchio.transforms.augmentation.random_transform
from __future__ import annotations
import torch
from ...types import TypeRangeFloat
from ...types import TypeSextetFloat
from ...types import TypeTripletFloat
from ..transform import Transform
[docs]
class RandomTransform(Transform):
"""Base class for stochastic augmentation transforms.
Args:
**kwargs: See :class:`~torchio.transforms.Transform` for additional
keyword arguments.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def parse_degrees(
self,
degrees: TypeRangeFloat,
) -> tuple[float, float]:
return self._parse_range(degrees, 'degrees')
def parse_translation(
self,
translation: TypeRangeFloat,
) -> tuple[float, float]:
return self._parse_range(translation, 'translation')
@staticmethod
def sample_uniform(a: float, b: float) -> float:
return torch.FloatTensor(1).uniform_(a, b).item()
@staticmethod
def _get_random_seed() -> int:
"""Generate a random seed.
Returns:
A random seed as an int.
"""
return int(torch.randint(0, 2**31, (1,)).item())
@staticmethod
def sample_uniform_sextet(params: TypeSextetFloat) -> TypeTripletFloat:
results = []
for a, b in zip(params[::2], params[1::2]):
results.append(RandomTransform.sample_uniform(a, b))
sx, sy, sz = results
return sx, sy, sz