Source code for torchio.data.sampler.grid

from collections.abc import Generator
from typing import Optional
from typing import Union

import numpy as np

from ...data.subject import Subject
from ...typing import TypeSpatialShape
from ...typing import TypeTripletInt
from ...utils import to_tuple
from .sampler import PatchSampler


[docs] class GridSampler(PatchSampler): r"""Extract patches across a whole volume. Grid samplers are useful to perform inference using all patches from a volume. It is often used with a :class:`~torchio.data.GridAggregator`. Args: subject: Instance of :class:`~torchio.data.Subject` from which patches will be extracted. patch_size: Tuple of integers :math:`(w, h, d)` to generate patches of size :math:`w \times h \times d`. If a single number :math:`n` is provided, :math:`w = h = d = n`. patch_overlap: Tuple of even integers :math:`(w_o, h_o, d_o)` specifying the overlap between patches for dense inference. If a single number :math:`n` is provided, :math:`w_o = h_o = d_o = n`. padding_mode: Same as :attr:`padding_mode` in :class:`~torchio.transforms.Pad`. If ``None``, the volume will not be padded before sampling and patches at the border will not be cropped by the aggregator. Otherwise, the volume will be padded with :math:`\left(\frac{w_o}{2}, \frac{h_o}{2}, \frac{d_o}{2} \right)` on each side before sampling. If the sampler is passed to a :class:`~torchio.data.GridAggregator`, it will crop the output to its original size. Example: >>> import torchio as tio >>> colin = tio.datasets.Colin27() >>> sampler = tio.GridSampler(colin, patch_size=88) >>> for i, patch in enumerate(sampler()): ... patch.t1.save(f'patch_{i}.nii.gz') ... >>> # To figure out the number of patches beforehand: >>> sampler = tio.GridSampler(colin, patch_size=88) >>> len(sampler) 8 .. note:: Adapted from NiftyNet. See `this NiftyNet tutorial <https://niftynet.readthedocs.io/en/dev/window_sizes.html>`_ for more information about patch based sampling. Note that :attr:`patch_overlap` is twice :attr:`border` in NiftyNet tutorial. """ def __init__( self, subject: Subject, patch_size: TypeSpatialShape, patch_overlap: TypeSpatialShape = (0, 0, 0), padding_mode: Union[str, float, None] = None, ): super().__init__(patch_size) self.patch_overlap = np.array(to_tuple(patch_overlap, length=3)) self.padding_mode = padding_mode self.subject = self._pad(subject) self.locations = self._compute_locations(self.subject) def __len__(self): return len(self.locations) def __getitem__(self, index): # Assume 3D location = self.locations[index] index_ini = location[:3] cropped_subject = self.crop(self.subject, index_ini, self.patch_size) return cropped_subject def __call__( self, subject: Optional[Subject] = None, num_patches: Optional[int] = None, ) -> Generator[Subject, None, None]: subject = self.subject if subject is None else subject return super().__call__(subject, num_patches=num_patches) def _pad(self, subject: Subject) -> Subject: if self.padding_mode is not None: from ...transforms import Pad border = self.patch_overlap // 2 padding = border.repeat(2) pad = Pad(padding, padding_mode=self.padding_mode) # type: ignore[arg-type] subject = pad(subject) # type: ignore[assignment] return subject def _compute_locations(self, subject: Subject): sizes = subject.spatial_shape, self.patch_size, self.patch_overlap self._parse_sizes(*sizes) # type: ignore[arg-type] return self._get_patches_locations(*sizes) # type: ignore[arg-type] def _generate_patches( # type: ignore[override] self, subject: Subject, ) -> Generator[Subject, None, None]: subject = self._pad(subject) sizes = subject.spatial_shape, self.patch_size, self.patch_overlap self._parse_sizes(*sizes) # type: ignore[arg-type] locations = self._get_patches_locations(*sizes) # type: ignore[arg-type] for location in locations: index_ini = location[:3] yield self.extract_patch(subject, index_ini) @staticmethod def _parse_sizes( image_size: TypeTripletInt, patch_size: TypeTripletInt, patch_overlap: TypeTripletInt, ) -> None: image_size_array = np.array(image_size) patch_size_array = np.array(patch_size) patch_overlap_array = np.array(patch_overlap) if np.any(patch_size_array > image_size_array): message = ( f'Patch size {tuple(patch_size_array)} cannot be' f' larger than image size {tuple(image_size_array)}' ) raise ValueError(message) if np.any(patch_overlap_array >= patch_size_array): message = ( f'Patch overlap {tuple(patch_overlap_array)} must be smaller' f' than patch size {tuple(patch_size_array)}' ) raise ValueError(message) if np.any(patch_overlap_array % 2): message = ( 'Patch overlap must be a tuple of even integers,' f' not {tuple(patch_overlap_array)}' ) raise ValueError(message) @staticmethod def _get_patches_locations( image_size: TypeTripletInt, patch_size: TypeTripletInt, patch_overlap: TypeTripletInt, ) -> np.ndarray: # Example with image_size 10, patch_size 5, overlap 2: # [0 1 2 3 4 5 6 7 8 9] # [0 0 0 0 0] # [1 1 1 1 1] # [2 2 2 2 2] # Locations: # [[0, 5], # [3, 8], # [5, 10]] indices = [] zipped = zip(image_size, patch_size, patch_overlap) for im_size_dim, patch_size_dim, patch_overlap_dim in zipped: end = im_size_dim + 1 - patch_size_dim step = patch_size_dim - patch_overlap_dim indices_dim = list(range(0, end, step)) if indices_dim[-1] != im_size_dim - patch_size_dim: indices_dim.append(im_size_dim - patch_size_dim) indices.append(indices_dim) indices_ini = np.array(np.meshgrid(*indices)).reshape(3, -1).T indices_ini = np.unique(indices_ini, axis=0) indices_fin = indices_ini + np.array(patch_size) locations = np.hstack((indices_ini, indices_fin)) return np.array(sorted(locations.tolist()))