Source code for torchio.data.sampler.uniform
import torch
from ...data.subject import Subject
from ...typing import TypePatchSize
from .sampler import RandomSampler
from typing import Generator
import numpy as np
[docs]class UniformSampler(RandomSampler):
"""Randomly extract patches from a volume with uniform probability.
Args:
patch_size: See :class:`~torchio.data.PatchSampler`.
"""
def __init__(self, patch_size: TypePatchSize):
super().__init__(patch_size)
def get_probability_map(self, subject: Subject) -> torch.Tensor:
return torch.ones(1, *subject.spatial_shape)
def __call__(
self,
subject: Subject,
num_patches: int = None,
) -> Generator[Subject, None, None]:
subject.check_consistent_spatial_shape()
if np.any(self.patch_size > subject.spatial_shape):
message = (
f'Patch size {tuple(self.patch_size)} cannot be'
f' larger than image size {tuple(subject.spatial_shape)}'
)
raise RuntimeError(message)
valid_range = subject.spatial_shape - self.patch_size
patches_left = num_patches if num_patches is not None else True
while patches_left:
index_ini = [
torch.randint(x + 1, (1,)).item()
for x in valid_range
]
index_ini_array = np.asarray(index_ini)
yield self.extract_patch(subject, index_ini_array)
if num_patches is not None:
patches_left -= 1