from typing import Generator
from typing import Optional
import numpy as np
import torch
from ...constants import MIN_FLOAT_32
from ...typing import TypeSpatialShape
from ..image import Image
from ..subject import Subject
from .sampler import RandomSampler
[docs]
class WeightedSampler(RandomSampler):
r"""Randomly extract patches from a volume given a probability map.
The probability of sampling a patch centered on a specific voxel is the
value of that voxel in the probability map. The probabilities need not be
normalized. For example, voxels can have values 0, 1 and 5. Voxels with
value 0 will never be at the center of a patch. Voxels with value 5 will
have 5 times more chance of being at the center of a patch that voxels
with a value of 1.
Args:
patch_size: See :class:`~torchio.data.PatchSampler`.
probability_map: Name of the image in the input subject that will be
used as a sampling probability map.
Raises:
RuntimeError: If the probability map is empty.
Example:
>>> import torchio as tio
>>> subject = tio.Subject(
... t1=tio.ScalarImage('t1_mri.nii.gz'),
... sampling_map=tio.Image('sampling.nii.gz', type=tio.SAMPLING_MAP),
... )
>>> patch_size = 64
>>> sampler = tio.data.WeightedSampler(patch_size, 'sampling_map')
>>> for patch in sampler(subject):
... print(patch[tio.LOCATION])
.. note:: The index of the center of a patch with even size :math:`s` is
arbitrarily set to :math:`s/2`. This is an implementation detail that
will typically not make any difference in practice.
.. note:: Values of the probability map near the border will be set to 0 as
the center of the patch cannot be at the border (unless the patch has
size 1 or 2 along that axis).
""" # noqa: B950
def __init__(
self,
patch_size: TypeSpatialShape,
probability_map: Optional[str],
):
super().__init__(patch_size)
self.probability_map_name = probability_map
self.cdf = None
def _generate_patches(
self,
subject: Subject,
num_patches: Optional[int] = None,
) -> Generator[Subject, None, None]:
probability_map = self.get_probability_map(subject)
probability_map_array = self.process_probability_map(
probability_map,
subject,
)
cdf = self.get_cumulative_distribution_function(probability_map_array)
patches_left = num_patches if num_patches is not None else True
while patches_left:
yield self.extract_patch(subject, probability_map_array, cdf)
if num_patches is not None:
patches_left -= 1
def get_probability_map_image(self, subject: Subject) -> Image:
assert self.probability_map_name is not None
if self.probability_map_name in subject:
return subject[self.probability_map_name]
else:
message = (
f'Image "{self.probability_map_name}" not found in subject: {subject}'
)
raise KeyError(message)
def get_probability_map(self, subject: Subject) -> torch.Tensor:
data = self.get_probability_map_image(subject).data
if torch.any(data < 0):
message = (
'Negative values found'
f' in probability map "{self.probability_map_name}"'
)
raise ValueError(message)
return data
def process_probability_map(
self,
probability_map: torch.Tensor,
subject: Subject,
) -> np.ndarray:
# Using float32 can create cdf with maximum very far from 1, e.g. 0.92!
data = probability_map[0].numpy().astype(np.float64)
assert data.ndim == 3
self.clear_probability_borders(data, self.patch_size)
total = data.sum()
if total == 0:
half_patch_size = tuple(n // 2 for n in self.patch_size)
message = (
'Empty probability map found:'
f' {self.get_probability_map_image(subject).path}'
'\nVoxels with positive probability might be near the image'
' border.\nIf you suspect that this is the case, try adding a'
' padding transform\nwith half the patch size:'
f' torchio.Pad({half_patch_size})'
)
raise RuntimeError(message)
data /= total # normalize probabilities
return data
@staticmethod
def clear_probability_borders(
probability_map: np.ndarray,
patch_size: np.ndarray,
) -> None:
# Set probability to 0 on voxels that wouldn't possibly be sampled
# given the current patch size
# We will arbitrarily define the center of an array with even length
# using the // Python operator
# For example, the center of an array (3, 4) will be on (1, 2)
#
# Patch center
# . . . . . . . .
# . . . . -> . . x .
# . . . . . . . .
#
#
# Prob. map After preprocessing
#
# x x x x x x x . . . . . . .
# x x x x x x x . . x x x x .
# x x x x x x x --> . . x x x x .
# x x x x x x x --> . . x x x x .
# x x x x x x x . . x x x x .
# x x x x x x x . . . . . . .
#
# The dots represent removed probabilities, x mark possible locations
crop_ini = patch_size // 2
crop_fin = (patch_size - 1) // 2
crop_i, crop_j, crop_k = crop_ini
probability_map[:crop_i, :, :] = 0
probability_map[:, :crop_j, :] = 0
probability_map[:, :, :crop_k] = 0
# The call tolist() is very important. Using np.uint16 as negative
# index will not work because e.g. -np.uint16(2) == 65534
crop_i, crop_j, crop_k = crop_fin.tolist()
if crop_i:
probability_map[-crop_i:, :, :] = 0
if crop_j:
probability_map[:, -crop_j:, :] = 0
if crop_k:
probability_map[:, :, -crop_k:] = 0
@staticmethod
def get_cumulative_distribution_function(
probability_map: np.ndarray,
) -> np.ndarray:
"""Return the cumulative distribution function of a probability map."""
flat_map = probability_map.flatten()
flat_map_normalized = flat_map / flat_map.sum()
cdf = np.cumsum(flat_map_normalized)
return cdf
def extract_patch( # type: ignore[override]
self,
subject: Subject,
probability_map: np.ndarray,
cdf: np.ndarray,
) -> Subject:
i, j, k = self.get_random_index_ini(probability_map, cdf)
index_ini = i, j, k
si, sj, sk = self.patch_size
patch_size = si, sj, sk
cropped_subject = self.crop(
subject,
index_ini,
patch_size,
)
return cropped_subject
def get_random_index_ini(
self,
probability_map: np.ndarray,
cdf: np.ndarray,
) -> np.ndarray:
center = self.sample_probability_map(probability_map, cdf)
assert np.all(center >= 0)
# See self.clear_probability_borders
index_ini = center - self.patch_size // 2
assert np.all(index_ini >= 0)
return index_ini
@classmethod
def sample_probability_map(
cls,
probability_map: np.ndarray,
cdf: np.ndarray,
) -> np.ndarray:
"""Inverse transform sampling.
Example:
>>> probability_map = np.array(
... ((0,0,1,1,5,2,1,1,0),
... (2,2,2,2,2,2,2,2,2)))
>>> probability_map
array([[0, 0, 1, 1, 5, 2, 1, 1, 0],
[2, 2, 2, 2, 2, 2, 2, 2, 2]])
>>> histogram = np.zeros_like(probability_map)
>>> for _ in range(100000):
... histogram[WeightedSampler.sample_probability_map(probability_map, cdf)] += 1 # doctest:+SKIP
...
>>> histogram # doctest:+SKIP
array([[ 0, 0, 3479, 3478, 17121, 7023, 3355, 3378, 0],
[ 6808, 6804, 6942, 6809, 6946, 6988, 7002, 6826, 7041]])
""" # noqa: B950
# Get first value larger than random number ensuring the random number
# is not exactly 0 (see https://github.com/fepegar/torchio/issues/510)
random_number = max(MIN_FLOAT_32, torch.rand(1).item()) * cdf[-1]
random_location_index = np.searchsorted(cdf, random_number)
center = np.unravel_index(
random_location_index,
probability_map.shape,
)
probability = probability_map[center]
if probability <= 0:
message = (
'Error retrieving probability in weighted sampler.'
' Please report this issue at'
' https://github.com/fepegar/torchio/issues/new?labels=bug&template=bug_report.md' # noqa: B950
)
raise RuntimeError(message)
return np.array(center)