Source code for torchio.transforms.preprocessing.intensity.mask

import warnings
from typing import Optional
from typing import Sequence

import torch

from ... import IntensityTransform
from ....data.image import ScalarImage
from ....data.subject import Subject
from ....transforms.transform import TypeMaskingMethod


[docs] class Mask(IntensityTransform): """Set voxels outside of mask to a constant value. Args: masking_method: See :class:`~torchio.transforms.preprocessing.intensity.NormalizationTransform`. outside_value: Value to set for all voxels outside of the mask. labels: If a label map is used to generate the mask, sequence of labels to consider. If ``None``, all values larger than zero will be used for the mask. **kwargs: See :class:`~torchio.transforms.Transform` for additional keyword arguments. Raises: RuntimeWarning: If a 4D image is masked with a 3D mask, the mask will be expanded along the channels (first) dimension, and a warning will be raised. Example: >>> import torchio as tio >>> subject = tio.datasets.Colin27() >>> subject Colin27(Keys: ('t1', 'head', 'brain'); images: 3) >>> mask = tio.Mask(masking_method='brain') # Use "brain" image to mask >>> transformed = mask(subject) # Set voxels outside of the brain to 0 .. plot:: import torchio as tio subject = tio.datasets.Colin27() subject.remove_image('head') mask = tio.Mask('brain') masked = mask(subject) subject.add_image(masked.t1, 'Masked') subject.plot() """ # noqa: B950 def __init__( self, masking_method: TypeMaskingMethod, outside_value: float = 0, labels: Optional[Sequence[int]] = None, **kwargs, ): super().__init__(**kwargs) self.masking_method = masking_method self.masking_labels = labels self.outside_value = outside_value self.args_names = ['masking_method'] def apply_transform(self, subject: Subject) -> Subject: for image in self.get_images(subject): mask_data = self.get_mask_from_masking_method( self.masking_method, subject, image.data, self.masking_labels, ) assert isinstance(image, ScalarImage) self.apply_masking(image, mask_data) return subject def apply_masking( self, image: ScalarImage, mask_data: torch.Tensor, ) -> None: masked = mask(image.data, mask_data, self.outside_value) image.set_data(masked)
def mask( tensor: torch.Tensor, mask: torch.Tensor, outside_value: float, ) -> torch.Tensor: array = tensor.clone() num_channels_array = array.shape[0] num_channels_mask = mask.shape[0] if num_channels_array != num_channels_mask: assert num_channels_mask == 1 message = ( f'Expanding mask with shape {mask.shape}' f' to match shape {array.shape} of input image' ) warnings.warn(message, RuntimeWarning, stacklevel=2) mask = mask.expand(*array.shape) array[~mask] = outside_value return array