# Source code for torchio.data.sampler.grid

```
from typing import Generator
from typing import Optional
from typing import Union
import numpy as np
from ...data.subject import Subject
from ...typing import TypeSpatialShape
from ...typing import TypeTripletInt
from ...utils import to_tuple
from .sampler import PatchSampler
[docs]class GridSampler(PatchSampler):
r"""Extract patches across a whole volume.
Grid samplers are useful to perform inference using all patches from a
volume. It is often used with a :class:`~torchio.data.GridAggregator`.
Args:
subject: Instance of :class:`~torchio.data.Subject`
from which patches will be extracted. This argument should only be
used before instantiating a :class:`~torchio.data.GridAggregator`,
or to precompute the number of patches that would be generated from
a subject.
patch_size: Tuple of integers :math:`(w, h, d)` to generate patches
of size :math:`w \times h \times d`.
If a single number :math:`n` is provided,
:math:`w = h = d = n`.
This argument is mandatory (it is a keyword argument for backward
compatibility).
patch_overlap: Tuple of even integers :math:`(w_o, h_o, d_o)`
specifying the overlap between patches for dense inference. If a
single number :math:`n` is provided, :math:`w_o = h_o = d_o = n`.
padding_mode: Same as :attr:`padding_mode` in
:class:`~torchio.transforms.Pad`. If ``None``, the volume will not
be padded before sampling and patches at the border will not be
cropped by the aggregator.
Otherwise, the volume will be padded with
:math:`\left(\frac{w_o}{2}, \frac{h_o}{2}, \frac{d_o}{2} \right)`
on each side before sampling. If the sampler is passed to a
:class:`~torchio.data.GridAggregator`, it will crop the output
to its original size.
Example:
>>> import torchio as tio
>>> sampler = tio.GridSampler(patch_size=88)
>>> colin = tio.datasets.Colin27()
>>> for i, patch in enumerate(sampler(colin)):
... patch.t1.save(f'patch_{i}.nii.gz')
...
>>> # To figure out the number of patches beforehand:
>>> sampler = tio.GridSampler(subject=colin, patch_size=88)
>>> len(sampler)
8
.. note:: Adapted from NiftyNet. See `this NiftyNet tutorial
<https://niftynet.readthedocs.io/en/dev/window_sizes.html>`_ for more
information about patch based sampling. Note that
:attr:`patch_overlap` is twice :attr:`border` in NiftyNet
tutorial.
"""
def __init__(
self,
subject: Optional[Subject] = None,
patch_size: TypeSpatialShape = None,
patch_overlap: TypeSpatialShape = (0, 0, 0),
padding_mode: Union[str, float, None] = None,
):
if patch_size is None:
raise ValueError('A value for patch_size must be given')
super().__init__(patch_size)
self.patch_overlap = np.array(to_tuple(patch_overlap, length=3))
self.padding_mode = padding_mode
if subject is not None and not isinstance(subject, Subject):
raise ValueError('The subject argument must be None or Subject')
assert subject is not None
self.subject = self._pad(subject)
self.locations = self._compute_locations(self.subject)
def __len__(self):
return len(self.locations)
def __getitem__(self, index):
# Assume 3D
location = self.locations[index]
index_ini = location[:3]
cropped_subject = self.crop(self.subject, index_ini, self.patch_size)
return cropped_subject
def _pad(self, subject: Subject) -> Subject:
if self.padding_mode is not None:
from ...transforms import Pad
border = self.patch_overlap // 2
padding = border.repeat(2)
pad = Pad(padding, padding_mode=self.padding_mode) # type: ignore[arg-type] # noqa: E501
subject = pad(subject) # type: ignore[assignment]
return subject
def _compute_locations(self, subject: Subject):
if subject is None:
return None
sizes = subject.spatial_shape, self.patch_size, self.patch_overlap
self._parse_sizes(*sizes) # type: ignore[arg-type]
return self._get_patches_locations(*sizes) # type: ignore[arg-type]
def _generate_patches( # type: ignore[override]
self,
subject: Subject,
) -> Generator[Subject, None, None]:
subject = self._pad(subject)
sizes = subject.spatial_shape, self.patch_size, self.patch_overlap
self._parse_sizes(*sizes) # type: ignore[arg-type]
locations = self._get_patches_locations(*sizes) # type: ignore[arg-type] # noqa: E501
for location in locations:
index_ini = location[:3]
yield self.extract_patch(subject, index_ini)
@staticmethod
def _parse_sizes(
image_size: TypeTripletInt,
patch_size: TypeTripletInt,
patch_overlap: TypeTripletInt,
) -> None:
image_size_array = np.array(image_size)
patch_size_array = np.array(patch_size)
patch_overlap_array = np.array(patch_overlap)
if np.any(patch_size_array > image_size_array):
message = (
f'Patch size {tuple(patch_size_array)} cannot be'
f' larger than image size {tuple(image_size_array)}'
)
raise ValueError(message)
if np.any(patch_overlap_array >= patch_size_array):
message = (
f'Patch overlap {tuple(patch_overlap_array)} must be smaller'
f' than patch size {tuple(patch_size_array)}'
)
raise ValueError(message)
if np.any(patch_overlap_array % 2):
message = (
'Patch overlap must be a tuple of even integers,'
f' not {tuple(patch_overlap_array)}'
)
raise ValueError(message)
@staticmethod
def _get_patches_locations(
image_size: TypeTripletInt,
patch_size: TypeTripletInt,
patch_overlap: TypeTripletInt,
) -> np.ndarray:
# Example with image_size 10, patch_size 5, overlap 2:
# [0 1 2 3 4 5 6 7 8 9]
# [0 0 0 0 0]
# [1 1 1 1 1]
# [2 2 2 2 2]
# Locations:
# [[0, 5],
# [3, 8],
# [5, 10]]
indices = []
zipped = zip(image_size, patch_size, patch_overlap)
for im_size_dim, patch_size_dim, patch_overlap_dim in zipped:
end = im_size_dim + 1 - patch_size_dim
step = patch_size_dim - patch_overlap_dim
indices_dim = list(range(0, end, step))
if indices_dim[-1] != im_size_dim - patch_size_dim:
indices_dim.append(im_size_dim - patch_size_dim)
indices.append(indices_dim)
indices_ini = np.array(np.meshgrid(*indices)).reshape(3, -1).T
indices_ini = np.unique(indices_ini, axis=0)
indices_fin = indices_ini + np.array(patch_size)
locations = np.hstack((indices_ini, indices_fin))
return np.array(sorted(locations.tolist()))
```