Inference¶
Here is an example that uses a grid sampler and aggregator to perform dense inference across a 3D image using patches:
>>> import torch
>>> import torch.nn as nn
>>> import torchio as tio
>>> patch_overlap = 4, 4, 4 # or just 4
>>> patch_size = 88, 88, 60
>>> subject = tio.datasets.Colin27()
>>> subject
Colin27(Keys: ('t1', 'head', 'brain'); images: 3)
>>> grid_sampler = tio.inference.GridSampler(
... subject,
... patch_size,
... patch_overlap,
... )
>>> patch_loader = tio.SubjectsLoader(grid_sampler, batch_size=4)
>>> aggregator = tio.inference.GridAggregator(grid_sampler)
>>> model = nn.Identity().eval()
>>> with torch.no_grad():
... for patches_batch in patch_loader:
... input_tensor = patches_batch['t1'][tio.DATA]
... locations = patches_batch[tio.LOCATION]
... logits = model(input_tensor)
... labels = logits.argmax(dim=tio.CHANNELS_DIMENSION, keepdim=True)
... outputs = labels
... aggregator.add_batch(outputs, locations)
>>> output_tensor = aggregator.get_output_tensor()
Grid sampler¶
GridSampler
¶
- class torchio.data.GridSampler(subject: Subject, patch_size: int | Tuple[int, int, int], patch_overlap: int | Tuple[int, int, int] = (0, 0, 0), padding_mode: str | float | None = None)[source]¶
Bases:
PatchSampler
Extract patches across a whole volume.
Grid samplers are useful to perform inference using all patches from a volume. It is often used with a
GridAggregator
.- Parameters:
subject – Instance of
Subject
from which patches will be extracted.patch_size – Tuple of integers \((w, h, d)\) to generate patches of size \(w \times h \times d\). If a single number \(n\) is provided, \(w = h = d = n\).
patch_overlap – Tuple of even integers \((w_o, h_o, d_o)\) specifying the overlap between patches for dense inference. If a single number \(n\) is provided, \(w_o = h_o = d_o = n\).
padding_mode – Same as
padding_mode
inPad
. IfNone
, the volume will not be padded before sampling and patches at the border will not be cropped by the aggregator. Otherwise, the volume will be padded with \(\left(\frac{w_o}{2}, \frac{h_o}{2}, \frac{d_o}{2} \right)\) on each side before sampling. If the sampler is passed to aGridAggregator
, it will crop the output to its original size.
Example
>>> import torchio as tio >>> colin = tio.datasets.Colin27() >>> sampler = tio.GridSampler(colin, patch_size=88) >>> for i, patch in enumerate(sampler()): ... patch.t1.save(f'patch_{i}.nii.gz') ... >>> # To figure out the number of patches beforehand: >>> sampler = tio.GridSampler(colin, patch_size=88) >>> len(sampler) 8
Note
Adapted from NiftyNet. See this NiftyNet tutorial for more information about patch based sampling. Note that
patch_overlap
is twiceborder
in NiftyNet tutorial.
Grid aggregator¶
GridAggregator
¶
- class torchio.data.GridAggregator(sampler: GridSampler, overlap_mode: str = 'crop')[source]¶
Bases:
object
Aggregate patches for dense inference.
This class is typically used to build a volume made of patches after inference of batches extracted by a
GridSampler
.- Parameters:
sampler – Instance of
GridSampler
used to extract the patches.overlap_mode – If
'crop'
, the overlapping predictions will be cropped. If'average'
, the predictions in the overlapping areas will be averaged with equal weights. If'hann'
, the predictions in the overlapping areas will be weighted with a Hann window function. See the grid aggregator tests for a raw visualization of the three modes.
Note
Adapted from NiftyNet. See this NiftyNet tutorial for more information about patch-based sampling.
- add_batch(batch_tensor: Tensor, locations: Tensor) None [source]¶
Add batch processed by a CNN to the output prediction volume.
- Parameters:
batch_tensor – 5D tensor, typically the output of a convolutional neural network, e.g.
batch['image'][torchio.DATA]
.locations – 2D tensor with shape \((B, 6)\) representing the patch indices in the original image. They are typically extracted using
batch[torchio.LOCATION]
.