Inference

Here’s an example that uses a grid sampler and aggregator to perform dense inference across a 3D image using small patches:

>>> import torch
>>> import torch.nn as nn
>>> import torchio as tio
>>> patch_overlap = 4, 4, 4  # or just 4
>>> patch_size = 88, 88, 60
>>> subject = tio.datasets.Colin27()
>>> subject
Colin27(Keys: ('t1', 'head', 'brain'); images: 3)
>>> grid_sampler = tio.inference.GridSampler(
...     subject,
...     patch_size,
...     patch_overlap,
... )
>>> patch_loader = torch.utils.data.DataLoader(grid_sampler, batch_size=4)
>>> aggregator = tio.inference.GridAggregator(grid_sampler)
>>> model = nn.Identity().eval()
>>> with torch.no_grad():
...     for patches_batch in patch_loader:
...         input_tensor = patches_batch['t1'][tio.DATA]
...         locations = patches_batch[tio.LOCATION]
...         logits = model(input_tensor)
...         labels = logits.argmax(dim=tio.CHANNELS_DIMENSION, keepdim=True)
...         outputs = labels
...         aggregator.add_batch(outputs, locations)
>>> output_tensor = aggregator.get_output_tensor()

Grid sampler

GridSampler

class torchio.data.GridSampler(subject: torchio.data.subject.Subject, patch_size: Union[int, Tuple[int, int, int]], patch_overlap: Union[int, Tuple[int, int, int]] = (0, 0, 0), padding_mode: Optional[Union[str, float]] = None)[source]

Bases: Generic[torch.utils.data.dataset.T_co]

Extract patches across a whole volume.

Grid samplers are useful to perform inference using all patches from a volume. It is often used with a GridAggregator.

Parameters
  • subject – Instance of Subject from which patches will be extracted.

  • patch_size – Tuple of integers \((w, h, d)\) to generate patches of size \(w \times h \times d\). If a single number \(n\) is provided, \(w = h = d = n\).

  • patch_overlap – Tuple of even integers \((w_o, h_o, d_o)\) specifying the overlap between patches for dense inference. If a single number \(n\) is provided, \(w_o = h_o = d_o = n\).

  • padding_mode – Same as padding_mode in Pad. If None, the volume will not be padded before sampling and patches at the border will not be cropped by the aggregator. Otherwise, the volume will be padded with \(\left(\frac{w_o}{2}, \frac{h_o}{2}, \frac{d_o}{2} \right)\) on each side before sampling. If the sampler is passed to a GridAggregator, it will crop the output to its original size.

Note

Adapted from NiftyNet. See this NiftyNet tutorial for more information about patch based sampling. Note that patch_overlap is twice border in NiftyNet tutorial.

Grid aggregator

GridAggregator

class torchio.data.GridAggregator(sampler: torchio.data.inference.grid_sampler.GridSampler, overlap_mode: str = 'crop')[source]

Aggregate patches for dense inference.

This class is typically used to build a volume made of patches after inference of batches extracted by a GridSampler.

Parameters
  • sampler – Instance of GridSampler used to extract the patches.

  • overlap_mode – If 'crop', the overlapping predictions will be cropped. If 'average', the predictions in the overlapping areas will be averaged with equal weights. See the grid aggregator tests for a raw visualization of both modes.

Note

Adapted from NiftyNet. See this NiftyNet tutorial for more information about patch-based sampling.

add_batch(batch_tensor: torch.Tensor, locations: torch.Tensor)None[source]

Add batch processed by a CNN to the output prediction volume.

Parameters
  • batch_tensor – 5D tensor, typically the output of a convolutional neural network, e.g. batch['image'][torchio.DATA].

  • locations – 2D tensor with shape \((B, 6)\) representing the patch indices in the original image. They are typically extracted using batch[torchio.LOCATION].

get_output_tensor()torch.Tensor[source]

Get the aggregated volume after dense inference.