[docs]classPatchSampler:r"""Base class for TorchIO samplers. Args: patch_size: Tuple of integers :math:`(w, h, d)` to generate patches of size :math:`w \times h \times d`. If a single number :math:`n` is provided, :math:`w = h = d = n`. .. warning:: This is an abstract class that should only be instantiated using child classes such as :class:`~torchio.data.UniformSampler` and :class:`~torchio.data.WeightedSampler`. """def__init__(self,patch_size:TypeSpatialShape):patch_size_array=np.array(to_tuple(patch_size,length=3))forninpatch_size_array:ifn<1ornotisinstance(n,(int,np.integer)):message=('Patch dimensions must be positive integers,'f' not {patch_size_array}')raiseValueError(message)self.patch_size=patch_size_array.astype(np.uint16)defextract_patch(self,subject:Subject,index_ini:TypeTripletInt,)->Subject:cropped_subject=self.crop(subject,index_ini,self.patch_size)# type: ignore[arg-type] # noqa: B950returncropped_subjectdefcrop(self,subject:Subject,index_ini:TypeTripletInt,patch_size:TypeTripletInt,)->Subject:transform=self._get_crop_transform(subject,index_ini,patch_size)cropped_subject=transform(subject)index_ini_array=np.asarray(index_ini)patch_size_array=np.asarray(patch_size)index_fin=index_ini_array+patch_size_arraylocation=index_ini_array.tolist()+index_fin.tolist()cropped_subject[LOCATION]=torch.as_tensor(location)cropped_subject.update_attributes()returncropped_subject@staticmethoddef_get_crop_transform(subject,index_ini:TypeTripletInt,patch_size:TypeSpatialShape,):from...transforms.preprocessing.spatial.cropimportCropshape=np.array(subject.spatial_shape,dtype=np.uint16)index_ini_array=np.array(index_ini,dtype=np.uint16)patch_size_array=np.array(patch_size,dtype=np.uint16)assertlen(index_ini_array)==3assertlen(patch_size_array)==3index_fin=index_ini_array+patch_size_arraycrop_ini=index_ini_array.tolist()crop_fin=(shape-index_fin).tolist()start=()cropping=sum(zip(crop_ini,crop_fin),start)returnCrop(cropping)# type: ignore[arg-type]def__call__(self,subject:Subject,num_patches:Optional[int]=None,)->Generator[Subject,None,None]:subject.check_consistent_space()ifnp.any(self.patch_size>subject.spatial_shape):message=(f'Patch size {tuple(self.patch_size)} cannot be'f' larger than image size {tuple(subject.spatial_shape)}')raiseRuntimeError(message)kwargs={}ifnum_patchesisNoneelse{'num_patches':num_patches}returnself._generate_patches(subject,**kwargs)def_generate_patches(self,subject:Subject,num_patches:Optional[int]=None,)->Generator[Subject,None,None]:raiseNotImplementedError
classRandomSampler(PatchSampler):r"""Base class for random samplers. Args: patch_size: Tuple of integers :math:`(w, h, d)` to generate patches of size :math:`w \times h \times d`. If a single number :math:`n` is provided, :math:`w = h = d = n`. """defget_probability_map(self,subject:Subject):raiseNotImplementedError