Source code for torchio.transforms.preprocessing.label.remap_labels

from typing import Dict

from ...transform import TypeMaskingMethod
from .label_transform import LabelTransform


[docs] class RemapLabels(LabelTransform): r"""Modify labels in a label map. Masking can be used to split the label into two during the `inverse transformation <invertibility>`_. Args: remapping: Dictionary that specifies how labels should be remapped. The keys are the old labels, and the corresponding values replace them. masking_method: Defines a mask for where the label remapping is applied. It can be one of: - ``None``: the mask image is all ones, i.e. all values in the image are used. - A string: key to a :class:`torchio.LabelMap` in the subject which is used as a mask, OR an anatomical label: ``'Left'``, ``'Right'``, ``'Anterior'``, ``'Posterior'``, ``'Inferior'``, ``'Superior'`` which specifies a half of the mask volume to be ones. - A function: the mask image is computed as a function of the intensity image. The function must receive and return a 4D :class:`torch.Tensor`. **kwargs: See :class:`~torchio.transforms.Transform` for additional keyword arguments. .. plot:: import torchio as tio subject = tio.datasets.FPG() subject.remove_image('t1') background_labels = (0, 1, 2, 3, 4) csf_labels = (5, 12, 16, 47, 52, 53) white_matter_labels = ( 45, 46, 66, 67, 81, 82, 83, 84, 85, 86, 87, 89, 90, 91, 92, 93, 94, ) not_gray_matter_labels = ( background_labels + csf_labels + white_matter_labels ) gray_matter_labels = [ label for label in subject.GIF_COLORS if label not in not_gray_matter_labels ] labels_groups = ( background_labels, gray_matter_labels, white_matter_labels, csf_labels, ) remapping = {} for target, labels in enumerate(labels_groups): for label in labels: remapping[label] = target parcellation_to_tissues = tio.RemapLabels(remapping) tissues = parcellation_to_tissues(subject).seg subject.add_image(tissues, 'remapped') subject.plot() Example: >>> import torch >>> import torchio as tio >>> def get_image(*labels): ... tensor = torch.as_tensor(labels).reshape(1, 1, 1, -1) ... image = tio.LabelMap(tensor=tensor) ... return image ... >>> image = get_image(0, 1, 2, 3, 4) >>> remapping = {1: 2, 2: 1, 3: 1, 4: 7} >>> transform = tio.RemapLabels(remapping) >>> transform(image).data tensor([[[[0, 2, 1, 1, 7]]]]) .. warning:: The transform will not be correctly inverted if one of the values in ``remapping`` is also in the input image:: >>> tensor = torch.as_tensor([0, 1]).reshape(1, 1, 1, -1) >>> subject = tio.Subject(label=tio.LabelMap(tensor=tensor)) >>> mapping = {3: 1} # the value 1 is in the input image >>> transform = tio.RemapLabels(mapping) >>> transformed = transform(subject) >>> back = transformed.apply_inverse_transform() >>> original_label_set = set(subject.label.data.unique().tolist()) >>> back_label_set = set(back.label.data.unique().tolist()) >>> original_label_set {0, 1} >>> back_label_set {0, 3} Example: >>> import torchio as tio >>> # Target label map has the following labels: >>> # { >>> # 'left_ventricle': 1, 'right_ventricle': 2, >>> # 'left_caudate': 3, 'right_caudate': 4, >>> # 'left_putamen': 5, 'right_putamen': 6, >>> # 'left_thalamus': 7, 'right_thalamus': 8, >>> # } >>> transform = tio.RemapLabels({2:1, 4:3, 6:5, 8:7}) >>> # Merge right side labels with left side labels >>> transformed = transform(subject) >>> # Undesired behavior: The inverse transform will remap ALL left side labels to right side labels >>> # so the label map only has right side labels. >>> inverse_transformed = transformed.apply_inverse_transform() >>> # Here's the *right* way to do it with masking: >>> transform = tio.RemapLabels({2:1, 4:3, 6:5, 8:7}, masking_method="Right") >>> # Remap the labels on the right side only (no difference yet). >>> transformed = transform(subject) >>> # Apply the inverse on the right side only. The labels are correctly split into left/right. >>> inverse_transformed = transformed.apply_inverse_transform() """ # noqa: B950 def __init__( self, remapping: Dict[int, int], masking_method: TypeMaskingMethod = None, **kwargs, ): super().__init__(**kwargs) self.kwargs = kwargs self.remapping = remapping self.masking_method = masking_method self.args_names = ['remapping', 'masking_method'] def apply_transform(self, subject): for image in self.get_images(subject): original_label_set = set(image.data.unique().tolist()) source_label_set = set(self.remapping.keys()) # Do nothing if no keys in the mapping are found in the image if not source_label_set.intersection(original_label_set): continue new_data = image.data.clone() mask = self.get_mask_from_masking_method( self.masking_method, subject, new_data, ) for old_id, new_id in self.remapping.items(): new_data[mask & (image.data == old_id)] = new_id image.set_data(new_data) return subject def is_invertible(self): # Not always, as explained in the docstring return True def inverse(self): targets = self.remapping.values() unique_targets = set(targets) if len(unique_targets) < len(targets): message = ( 'Labels mapping cannot be inverted because original values' f' are not unique: {self.remapping}' ) raise RuntimeError(message) inverse_remapping = {v: k for k, v in self.remapping.items()} inverse_transform = RemapLabels( inverse_remapping, masking_method=self.masking_method, **self.kwargs, ) return inverse_transform