from __future__ import annotations
from collections.abc import Generator
import numpy as np
from ...data.subject import Subject
from ...types import TypeSpatialShape
from ...types import TypeTripletInt
from ...utils import to_tuple
from .sampler import PatchSampler
[docs]
class GridSampler(PatchSampler):
r"""Extract patches across a whole volume.
Grid samplers are useful to perform inference using all patches from a
volume. It is often used with a :class:`~torchio.data.GridAggregator`.
Args:
subject: Instance of :class:`~torchio.data.Subject`
from which patches will be extracted.
patch_size: Tuple of integers :math:`(w, h, d)` to generate patches
of size :math:`w \times h \times d`.
If a single number :math:`n` is provided,
:math:`w = h = d = n`.
patch_overlap: Tuple of even integers :math:`(w_o, h_o, d_o)`
specifying the overlap between patches for dense inference. If a
single number :math:`n` is provided, :math:`w_o = h_o = d_o = n`.
padding_mode: Same as :attr:`padding_mode` in
:class:`~torchio.transforms.Pad`. If ``None``, the volume will not
be padded before sampling and patches at the border will not be
cropped by the aggregator.
Otherwise, the volume will be padded with
:math:`\left(\frac{w_o}{2}, \frac{h_o}{2}, \frac{d_o}{2} \right)`
on each side before sampling. If the sampler is passed to a
:class:`~torchio.data.GridAggregator`, it will crop the output
to its original size.
Example:
>>> import torchio as tio
>>> colin = tio.datasets.Colin27()
>>> sampler = tio.GridSampler(colin, patch_size=88)
>>> for i, patch in enumerate(sampler()):
... patch.t1.save(f'patch_{i}.nii.gz')
...
>>> # To figure out the number of patches beforehand:
>>> sampler = tio.GridSampler(colin, patch_size=88)
>>> len(sampler)
8
.. note:: Adapted from NiftyNet. See `this NiftyNet tutorial
<https://niftynet.readthedocs.io/en/dev/window_sizes.html>`_ for more
information about patch based sampling. Note that
:attr:`patch_overlap` is twice :attr:`border` in NiftyNet
tutorial.
"""
def __init__(
self,
subject: Subject,
patch_size: TypeSpatialShape,
patch_overlap: TypeSpatialShape = (0, 0, 0),
padding_mode: str | float | None = None,
):
super().__init__(patch_size)
self.patch_overlap = np.array(to_tuple(patch_overlap, length=3))
self.padding_mode = padding_mode
self.subject = self._pad(subject)
self.locations = self._compute_locations(self.subject)
def __len__(self):
return len(self.locations)
def __getitem__(self, index):
# Assume 3D
location = self.locations[index]
index_ini = location[:3]
cropped_subject = self.crop(self.subject, index_ini, self.patch_size)
return cropped_subject
def __call__(
self,
subject: Subject | None = None,
num_patches: int | None = None,
) -> Generator[Subject]:
subject = self.subject if subject is None else subject
return super().__call__(subject, num_patches=num_patches)
def _pad(self, subject: Subject) -> Subject:
if self.padding_mode is not None:
from ...transforms import Pad
border = self.patch_overlap // 2
padding = border.repeat(2)
pad = Pad(padding, padding_mode=self.padding_mode) # type: ignore[arg-type]
subject = pad(subject) # type: ignore[assignment]
return subject
def _compute_locations(self, subject: Subject):
sizes = subject.spatial_shape, self.patch_size, self.patch_overlap
self._parse_sizes(*sizes) # type: ignore[arg-type]
return self._get_patches_locations(*sizes) # type: ignore[arg-type]
def _generate_patches( # type: ignore[override]
self,
subject: Subject,
) -> Generator[Subject]:
subject = self._pad(subject)
sizes = subject.spatial_shape, self.patch_size, self.patch_overlap
self._parse_sizes(*sizes) # type: ignore[arg-type]
locations = self._get_patches_locations(*sizes) # type: ignore[arg-type]
for location in locations:
index_ini = location[:3]
yield self.extract_patch(subject, index_ini)
@staticmethod
def _parse_sizes(
image_size: TypeTripletInt,
patch_size: TypeTripletInt,
patch_overlap: TypeTripletInt,
) -> None:
image_size_array = np.array(image_size)
patch_size_array = np.array(patch_size)
patch_overlap_array = np.array(patch_overlap)
if np.any(patch_size_array > image_size_array):
message = (
f'Patch size {tuple(patch_size_array)} cannot be'
f' larger than image size {tuple(image_size_array)}'
)
raise ValueError(message)
if np.any(patch_overlap_array >= patch_size_array):
message = (
f'Patch overlap {tuple(patch_overlap_array)} must be smaller'
f' than patch size {tuple(patch_size_array)}'
)
raise ValueError(message)
if np.any(patch_overlap_array % 2):
message = (
'Patch overlap must be a tuple of even integers,'
f' not {tuple(patch_overlap_array)}'
)
raise ValueError(message)
@staticmethod
def _get_patches_locations(
image_size: TypeTripletInt,
patch_size: TypeTripletInt,
patch_overlap: TypeTripletInt,
) -> np.ndarray:
# Example with image_size 10, patch_size 5, overlap 2:
# [0 1 2 3 4 5 6 7 8 9]
# [0 0 0 0 0]
# [1 1 1 1 1]
# [2 2 2 2 2]
# Locations:
# [[0, 5],
# [3, 8],
# [5, 10]]
indices = []
zipped = zip(image_size, patch_size, patch_overlap)
for im_size_dim, patch_size_dim, patch_overlap_dim in zipped:
end = im_size_dim + 1 - patch_size_dim
step = patch_size_dim - patch_overlap_dim
indices_dim = list(range(0, end, step))
if indices_dim[-1] != im_size_dim - patch_size_dim:
indices_dim.append(im_size_dim - patch_size_dim)
indices.append(indices_dim)
indices_ini = np.array(np.meshgrid(*indices)).reshape(3, -1).T
indices_ini = np.unique(indices_ini, axis=0)
indices_fin = indices_ini + np.array(patch_size)
locations = np.hstack((indices_ini, indices_fin))
return np.array(sorted(locations.tolist()))