Source code for torchio.data.sampler.weighted

from typing import Optional, Tuple, Generator

import torch
import numpy as np

from ...constants import MIN_FLOAT_32
from ...typing import TypeSpatialShape
from ..image import Image
from ..subject import Subject
from .sampler import RandomSampler


[docs]class WeightedSampler(RandomSampler): r"""Randomly extract patches from a volume given a probability map. The probability of sampling a patch centered on a specific voxel is the value of that voxel in the probability map. The probabilities need not be normalized. For example, voxels can have values 0, 1 and 5. Voxels with value 0 will never be at the center of a patch. Voxels with value 5 will have 5 times more chance of being at the center of a patch that voxels with a value of 1. Args: patch_size: See :class:`~torchio.data.PatchSampler`. probability_map: Name of the image in the input subject that will be used as a sampling probability map. Raises: RuntimeError: If the probability map is empty. Example: >>> import torchio as tio >>> subject = tio.Subject( ... t1=tio.ScalarImage('t1_mri.nii.gz'), ... sampling_map=tio.Image('sampling.nii.gz', type=tio.SAMPLING_MAP), ... ) >>> patch_size = 64 >>> sampler = tio.data.WeightedSampler(patch_size, 'sampling_map') >>> for patch in sampler(subject): ... print(patch[tio.LOCATION]) .. note:: The index of the center of a patch with even size :math:`s` is arbitrarily set to :math:`s/2`. This is an implementation detail that will typically not make any difference in practice. .. note:: Values of the probability map near the border will be set to 0 as the center of the patch cannot be at the border (unless the patch has size 1 or 2 along that axis). """ # noqa: E501 def __init__( self, patch_size: TypeSpatialShape, probability_map: str, ): super().__init__(patch_size) self.probability_map_name = probability_map self.cdf = None def _generate_patches( self, subject: Subject, num_patches: Optional[int] = None, ) -> Generator[Subject, None, None]: probability_map = self.get_probability_map(subject) probability_map = self.process_probability_map( probability_map, subject) cdf = self.get_cumulative_distribution_function(probability_map) patches_left = num_patches if num_patches is not None else True while patches_left: yield self.extract_patch(subject, probability_map, cdf) if num_patches is not None: patches_left -= 1 def get_probability_map_image(self, subject: Subject) -> Image: if self.probability_map_name in subject: return subject[self.probability_map_name] else: message = ( f'Image "{self.probability_map_name}"' f' not found in subject: {subject}' ) raise KeyError(message) def get_probability_map(self, subject: Subject) -> torch.Tensor: data = self.get_probability_map_image(subject).data if torch.any(data < 0): message = ( 'Negative values found' f' in probability map "{self.probability_map_name}"' ) raise ValueError(message) return data def process_probability_map( self, probability_map: torch.Tensor, subject: Subject, ) -> np.ndarray: # Using float32 can create cdf with maximum very far from 1, e.g. 0.92! data = probability_map[0].numpy().astype(np.float64) assert data.ndim == 3 self.clear_probability_borders(data, self.patch_size) total = data.sum() if total == 0: half_patch_size = tuple(n // 2 for n in self.patch_size) message = ( 'Empty probability map found:' f' {self.get_probability_map_image(subject).path}' '\nVoxels with positive probability might be near the image' ' border.\nIf you suspect that this is the case, try adding a' ' padding transform\nwith half the patch size:' f' torchio.Pad({half_patch_size})' ) raise RuntimeError(message) data /= total # normalize probabilities return data @staticmethod def clear_probability_borders( probability_map: np.ndarray, patch_size: TypeSpatialShape, ) -> None: # Set probability to 0 on voxels that wouldn't possibly be sampled # given the current patch size # We will arbitrarily define the center of an array with even length # using the // Python operator # For example, the center of an array (3, 4) will be on (1, 2) # # Patch center # . . . . . . . . # . . . . -> . . x . # . . . . . . . . # # # Prob. map After preprocessing # # x x x x x x x . . . . . . . # x x x x x x x . . x x x x . # x x x x x x x --> . . x x x x . # x x x x x x x --> . . x x x x . # x x x x x x x . . x x x x . # x x x x x x x . . . . . . . # # The dots represent removed probabilities, x mark possible locations crop_ini = patch_size // 2 crop_fin = (patch_size - 1) // 2 crop_i, crop_j, crop_k = crop_ini probability_map[:crop_i, :, :] = 0 probability_map[:, :crop_j, :] = 0 probability_map[:, :, :crop_k] = 0 # The call tolist() is very important. Using np.uint16 as negative # index will not work because e.g. -np.uint16(2) == 65534 crop_i, crop_j, crop_k = crop_fin.tolist() if crop_i: probability_map[-crop_i:, :, :] = 0 if crop_j: probability_map[:, -crop_j:, :] = 0 if crop_k: probability_map[:, :, -crop_k:] = 0 @staticmethod def get_cumulative_distribution_function( probability_map: np.ndarray, ) -> Tuple[np.ndarray, np.ndarray]: """Return the cumulative distribution function of a probability map.""" flat_map = probability_map.flatten() flat_map_normalized = flat_map / flat_map.sum() cdf = np.cumsum(flat_map_normalized) return cdf def extract_patch( self, subject: Subject, probability_map: np.ndarray, cdf: np.ndarray ) -> Subject: index_ini = self.get_random_index_ini(probability_map, cdf) cropped_subject = self.crop(subject, index_ini, self.patch_size) return cropped_subject def get_random_index_ini( self, probability_map: np.ndarray, cdf: np.ndarray ) -> np.ndarray: center = self.sample_probability_map(probability_map, cdf) assert np.all(center >= 0) # See self.clear_probability_borders index_ini = center - self.patch_size // 2 assert np.all(index_ini >= 0) return index_ini @classmethod def sample_probability_map( cls, probability_map: np.ndarray, cdf: np.ndarray ) -> np.ndarray: """Inverse transform sampling. Example: >>> probability_map = np.array( ... ((0,0,1,1,5,2,1,1,0), ... (2,2,2,2,2,2,2,2,2))) >>> probability_map array([[0, 0, 1, 1, 5, 2, 1, 1, 0], [2, 2, 2, 2, 2, 2, 2, 2, 2]]) >>> histogram = np.zeros_like(probability_map) >>> for _ in range(100000): ... histogram[WeightedSampler.sample_probability_map(probability_map, cdf)] += 1 # doctest:+SKIP ... >>> histogram # doctest:+SKIP array([[ 0, 0, 3479, 3478, 17121, 7023, 3355, 3378, 0], [ 6808, 6804, 6942, 6809, 6946, 6988, 7002, 6826, 7041]]) """ # noqa: E501 # Get first value larger than random number ensuring the random number # is not exactly 0 (see https://github.com/fepegar/torchio/issues/510) random_number = max(MIN_FLOAT_32, torch.rand(1).item()) * cdf[-1] random_location_index = np.searchsorted(cdf, random_number) center = np.unravel_index( random_location_index, probability_map.shape ) probability = probability_map[center] if probability <= 0: message = ( 'Error retrieving probability in weighted sampler.' ' Please report this issue at' ' https://github.com/fepegar/torchio/issues/new?labels=bug&template=bug_report.md' # noqa: E501 ) raise RuntimeError(message) return np.array(center)