Source code for torchio.data.sampler.grid

from __future__ import annotations

from collections.abc import Generator

import numpy as np

from ...data.subject import Subject
from ...types import TypeSpatialShape
from ...types import TypeTripletInt
from ...utils import to_tuple
from .sampler import PatchSampler


[docs] class GridSampler(PatchSampler): r"""Extract patches across a whole volume. Grid samplers are useful to perform inference using all patches from a volume. It is often used with a :class:`~torchio.data.GridAggregator`. Args: subject: Instance of :class:`~torchio.data.Subject` from which patches will be extracted. patch_size: Tuple of integers :math:`(w, h, d)` to generate patches of size :math:`w \times h \times d`. If a single number :math:`n` is provided, :math:`w = h = d = n`. patch_overlap: Tuple of even integers :math:`(w_o, h_o, d_o)` specifying the overlap between patches for dense inference. If a single number :math:`n` is provided, :math:`w_o = h_o = d_o = n`. padding_mode: Same as :attr:`padding_mode` in :class:`~torchio.transforms.Pad`. If ``None``, the volume will not be padded before sampling and patches at the border will not be cropped by the aggregator. Otherwise, the volume will be padded with :math:`\left(\frac{w_o}{2}, \frac{h_o}{2}, \frac{d_o}{2} \right)` on each side before sampling. If the sampler is passed to a :class:`~torchio.data.GridAggregator`, it will crop the output to its original size. Example: >>> import torchio as tio >>> colin = tio.datasets.Colin27() >>> sampler = tio.GridSampler(colin, patch_size=88) >>> for i, patch in enumerate(sampler()): ... patch.t1.save(f'patch_{i}.nii.gz') ... >>> # To figure out the number of patches beforehand: >>> sampler = tio.GridSampler(colin, patch_size=88) >>> len(sampler) 8 .. note:: Adapted from NiftyNet. See `this NiftyNet tutorial <https://niftynet.readthedocs.io/en/dev/window_sizes.html>`_ for more information about patch based sampling. Note that :attr:`patch_overlap` is twice :attr:`border` in NiftyNet tutorial. """ def __init__( self, subject: Subject, patch_size: TypeSpatialShape, patch_overlap: TypeSpatialShape = (0, 0, 0), padding_mode: str | float | None = None, ): super().__init__(patch_size) self.patch_overlap = np.array(to_tuple(patch_overlap, length=3)) self.padding_mode = padding_mode self.subject = self._pad(subject) self.locations = self._compute_locations(self.subject) def __len__(self): return len(self.locations) def __getitem__(self, index): # Assume 3D location = self.locations[index] index_ini = location[:3] cropped_subject = self.crop(self.subject, index_ini, self.patch_size) return cropped_subject def __call__( self, subject: Subject | None = None, num_patches: int | None = None, ) -> Generator[Subject]: subject = self.subject if subject is None else subject return super().__call__(subject, num_patches=num_patches) def _pad(self, subject: Subject) -> Subject: if self.padding_mode is not None: from ...transforms import Pad border = self.patch_overlap // 2 padding = border.repeat(2) pad = Pad(padding, padding_mode=self.padding_mode) # type: ignore[arg-type] subject = pad(subject) # type: ignore[assignment] return subject def _compute_locations(self, subject: Subject): sizes = subject.spatial_shape, self.patch_size, self.patch_overlap self._parse_sizes(*sizes) # type: ignore[arg-type] return self._get_patches_locations(*sizes) # type: ignore[arg-type] def _generate_patches( # type: ignore[override] self, subject: Subject, ) -> Generator[Subject]: subject = self._pad(subject) sizes = subject.spatial_shape, self.patch_size, self.patch_overlap self._parse_sizes(*sizes) # type: ignore[arg-type] locations = self._get_patches_locations(*sizes) # type: ignore[arg-type] for location in locations: index_ini = location[:3] yield self.extract_patch(subject, index_ini) @staticmethod def _parse_sizes( image_size: TypeTripletInt, patch_size: TypeTripletInt, patch_overlap: TypeTripletInt, ) -> None: image_size_array = np.array(image_size) patch_size_array = np.array(patch_size) patch_overlap_array = np.array(patch_overlap) if np.any(patch_size_array > image_size_array): message = ( f'Patch size {tuple(patch_size_array)} cannot be' f' larger than image size {tuple(image_size_array)}' ) raise ValueError(message) if np.any(patch_overlap_array >= patch_size_array): message = ( f'Patch overlap {tuple(patch_overlap_array)} must be smaller' f' than patch size {tuple(patch_size_array)}' ) raise ValueError(message) if np.any(patch_overlap_array % 2): message = ( 'Patch overlap must be a tuple of even integers,' f' not {tuple(patch_overlap_array)}' ) raise ValueError(message) @staticmethod def _get_patches_locations( image_size: TypeTripletInt, patch_size: TypeTripletInt, patch_overlap: TypeTripletInt, ) -> np.ndarray: # Example with image_size 10, patch_size 5, overlap 2: # [0 1 2 3 4 5 6 7 8 9] # [0 0 0 0 0] # [1 1 1 1 1] # [2 2 2 2 2] # Locations: # [[0, 5], # [3, 8], # [5, 10]] indices = [] zipped = zip(image_size, patch_size, patch_overlap) for im_size_dim, patch_size_dim, patch_overlap_dim in zipped: end = im_size_dim + 1 - patch_size_dim step = patch_size_dim - patch_overlap_dim indices_dim = list(range(0, end, step)) if indices_dim[-1] != im_size_dim - patch_size_dim: indices_dim.append(im_size_dim - patch_size_dim) indices.append(indices_dim) indices_ini = np.array(np.meshgrid(*indices)).reshape(3, -1).T indices_ini = np.unique(indices_ini, axis=0) indices_fin = indices_ini + np.array(patch_size) locations = np.hstack((indices_ini, indices_fin)) return np.array(sorted(locations.tolist()))