Source code for torchio.data.sampler.sampler

from typing import Optional, Generator

import numpy as np

from ...typing import TypePatchSize, TypeTripletInt
from ...data.subject import Subject
from ...utils import to_tuple


[docs]class PatchSampler: r"""Base class for TorchIO samplers. Args: patch_size: Tuple of integers :math:`(w, h, d)` to generate patches of size :math:`w \times h \times d`. If a single number :math:`n` is provided, :math:`w = h = d = n`. .. warning:: This is an abstract class that should only be instantiated using child classes such as :class:`~torchio.data.UniformSampler` and :class:`~torchio.data.WeightedSampler`. """ def __init__(self, patch_size: TypePatchSize): patch_size_array = np.array(to_tuple(patch_size, length=3)) for n in patch_size_array: if n < 1 or not isinstance(n, (int, np.integer)): message = ( 'Patch dimensions must be positive integers,' f' not {patch_size_array}' ) raise ValueError(message) self.patch_size = patch_size_array.astype(np.uint16) def extract_patch( self, subject: Subject, index_ini: TypeTripletInt, ) -> Subject: cropped_subject = self.crop(subject, index_ini, self.patch_size) cropped_subject['index_ini'] = np.array(index_ini).astype(int) return cropped_subject def crop( self, subject: Subject, index_ini: TypeTripletInt, patch_size: TypeTripletInt, ) -> Subject: transform = self.get_crop_transform(subject, index_ini, patch_size) return transform(subject) @staticmethod def get_crop_transform( subject, index_ini, patch_size: TypePatchSize, ): from ...transforms.preprocessing.spatial.crop import Crop shape = np.array(subject.spatial_shape, dtype=np.uint16) index_ini = np.array(index_ini, dtype=np.uint16) patch_size = np.array(patch_size, dtype=np.uint16) index_fin = index_ini + patch_size crop_ini = index_ini.tolist() crop_fin = (shape - index_fin).tolist() start = () cropping = sum(zip(crop_ini, crop_fin), start) return Crop(cropping)
class RandomSampler(PatchSampler): r"""Base class for random samplers. Args: patch_size: Tuple of integers :math:`(w, h, d)` to generate patches of size :math:`w \times h \times d`. If a single number :math:`n` is provided, :math:`w = h = d = n`. """ def __call__( self, subject: Subject, num_patches: Optional[int] = None, ) -> Generator[Subject, None, None]: raise NotImplementedError def get_probability_map(self, subject: Subject): raise NotImplementedError