Source code for torchio.data.sampler.label

from typing import Dict
from typing import Optional

import numpy as np
import torch

from ...constants import LABEL
from ...constants import TYPE
from ...data.image import Image
from ...data.subject import Subject
from ...typing import TypeSpatialShape
from .weighted import WeightedSampler


[docs] class LabelSampler(WeightedSampler): r"""Extract random patches with labeled voxels at their center. This sampler yields patches whose center value is greater than 0 in the :attr:`label_name`. Args: patch_size: See :class:`~torchio.data.PatchSampler`. label_name: Name of the label image in the subject that will be used to generate the sampling probability map. If ``None``, the first image of type :attr:`torchio.LABEL` found in the subject subject will be used. label_probabilities: Dictionary containing the probability that each class will be sampled. Probabilities do not need to be normalized. For example, a value of ``{0: 0, 1: 2, 2: 1, 3: 1}`` will create a sampler whose patches centers will have 50% probability of being labeled as ``1``, 25% of being ``2`` and 25% of being ``3``. If ``None``, the label map is binarized and the value is set to ``{0: 0, 1: 1}``. If the input has multiple channels, a value of ``{0: 0, 1: 2, 2: 1, 3: 1}`` will create a sampler whose patches centers will have 50% probability of being taken from a non zero value of channel ``1``, 25% from channel ``2`` and 25% from channel ``3``. Example: >>> import torchio as tio >>> subject = tio.datasets.Colin27() >>> subject Colin27(Keys: ('t1', 'head', 'brain'); images: 3) >>> probabilities = {0: 0.5, 1: 0.5} >>> sampler = tio.data.LabelSampler( ... patch_size=64, ... label_name='brain', ... label_probabilities=probabilities, ... ) >>> generator = sampler(subject) >>> for patch in generator: ... print(patch.shape) If you want a specific number of patches from a volume, e.g. 10: >>> generator = sampler(subject, num_patches=10) >>> for patch in iterator: ... print(patch.shape) """ def __init__( self, patch_size: TypeSpatialShape, label_name: Optional[str] = None, label_probabilities: Optional[Dict[int, float]] = None, ): super().__init__(patch_size, probability_map=label_name) self.label_probabilities_dict = label_probabilities def get_probability_map_image(self, subject: Subject) -> Image: if self.probability_map_name is None: for image in subject.get_images(intensity_only=False): if image[TYPE] == LABEL: label_map = image break else: images = subject.get_images(intensity_only=False) message = ( f'No label maps found in subject {subject} with image' f' paths {[image.path for image in images]}' ) raise RuntimeError(message) elif self.probability_map_name in subject: label_map = subject[self.probability_map_name] else: message = ( f'Image "{self.probability_map_name}"' f' not found in subject subject: {subject}' ) raise KeyError(message) return label_map def get_probability_map(self, subject: Subject) -> torch.Tensor: label_map_tensor = self.get_probability_map_image(subject).data.float() if self.label_probabilities_dict is None: return label_map_tensor > 0 probability_map = self.get_probabilities_from_label_map( label_map_tensor, self.label_probabilities_dict, self.patch_size, ) return probability_map @staticmethod def get_probabilities_from_label_map( label_map: torch.Tensor, label_probabilities_dict: Dict[int, float], patch_size: np.ndarray, ) -> torch.Tensor: """Create probability map according to label map probabilities.""" patch_size = patch_size.astype(int) ini_i, ini_j, ini_k = patch_size // 2 spatial_shape = np.array(label_map.shape[1:]) if np.any(patch_size > spatial_shape): message = f'Patch size {patch_size}larger than label map {spatial_shape}' raise RuntimeError(message) crop_fin_i, crop_fin_j, crop_fin_k = crop_fin = (patch_size - 1) // 2 fin_i, fin_j, fin_k = spatial_shape - crop_fin # See https://github.com/fepegar/torchio/issues/458 label_map = label_map[:, ini_i:fin_i, ini_j:fin_j, ini_k:fin_k] multichannel = label_map.shape[0] > 1 probability_map = torch.zeros_like(label_map) label_probs = torch.Tensor(list(label_probabilities_dict.values())) normalized_probs = label_probs / label_probs.sum() iterable = zip(label_probabilities_dict, normalized_probs) for label, label_probability in iterable: if multichannel: mask = label_map[label] else: mask = label_map == label label_size = mask.sum() if not label_size: continue prob_voxels = label_probability / label_size if multichannel: probability_map[label] = prob_voxels * mask else: probability_map[mask] = prob_voxels if multichannel: probability_map = probability_map.sum(dim=0, keepdim=True) # See https://github.com/fepegar/torchio/issues/458 padding = ini_k, crop_fin_k, ini_j, crop_fin_j, ini_i, crop_fin_i probability_map = torch.nn.functional.pad( probability_map, padding, ) return probability_map