[docs]classMask(IntensityTransform):"""Set voxels outside of mask to a constant value. Args: masking_method: See :class:`~torchio.transforms.preprocessing.intensity.NormalizationTransform`. outside_value: Value to set for all voxels outside of the mask. labels: If a label map is used to generate the mask, sequence of labels to consider. If ``None``, all values larger than zero will be used for the mask. **kwargs: See :class:`~torchio.transforms.Transform` for additional keyword arguments. Raises: RuntimeWarning: If a 4D image is masked with a 3D mask, the mask will be expanded along the channels (first) dimension, and a warning will be raised. Example: >>> import torchio as tio >>> subject = tio.datasets.Colin27() >>> subject Colin27(Keys: ('t1', 'head', 'brain'); images: 3) >>> mask = tio.Mask(masking_method='brain') # Use "brain" image to mask >>> transformed = mask(subject) # Set voxels outside of the brain to 0 .. plot:: import torchio as tio subject = tio.datasets.Colin27() subject.remove_image('head') mask = tio.Mask('brain') masked = mask(subject) subject.add_image(masked.t1, 'Masked') subject.plot() """# noqa: B950def__init__(self,masking_method:TypeMaskingMethod,outside_value:float=0,labels:Optional[Sequence[int]]=None,**kwargs,):super().__init__(**kwargs)self.masking_method=masking_methodself.masking_labels=labelsself.outside_value=outside_valueself.args_names=['masking_method']defapply_transform(self,subject:Subject)->Subject:forimageinself.get_images(subject):mask_data=self.get_mask_from_masking_method(self.masking_method,subject,image.data,self.masking_labels,)assertisinstance(image,ScalarImage)self.apply_masking(image,mask_data)returnsubjectdefapply_masking(self,image:ScalarImage,mask_data:torch.Tensor,)->None:masked=mask(image.data,mask_data,self.outside_value)image.set_data(masked)
defmask(tensor:torch.Tensor,mask:torch.Tensor,outside_value:float,)->torch.Tensor:array=tensor.clone()num_channels_array=array.shape[0]num_channels_mask=mask.shape[0]ifnum_channels_array!=num_channels_mask:assertnum_channels_mask==1message=(f'Expanding mask with shape {mask.shape}'f' to match shape {array.shape} of input image')warnings.warn(message,RuntimeWarning,stacklevel=2)mask=mask.expand(*array.shape)array[~mask]=outside_valuereturnarray