Source code for torchio.transforms.preprocessing.intensity.z_normalization

from typing import Optional

import torch

from ....data.subject import Subject
from .normalization_transform import NormalizationTransform
from .normalization_transform import TypeMaskingMethod


[docs] class ZNormalization(NormalizationTransform): """Subtract mean and divide by standard deviation. Args: masking_method: See :class:`~torchio.transforms.preprocessing.intensity.NormalizationTransform`. **kwargs: See :class:`~torchio.transforms.Transform` for additional keyword arguments. """ def __init__(self, masking_method: TypeMaskingMethod = None, **kwargs): super().__init__(masking_method=masking_method, **kwargs) self.args_names = ['masking_method'] def apply_normalization( self, subject: Subject, image_name: str, mask: torch.Tensor, ) -> None: image = subject[image_name] standardized = self.znorm( image.data, mask, ) if standardized is None: message = ( 'Standard deviation is 0 for masked values' f' in image "{image_name}" ({image.path})' ) raise RuntimeError(message) image.set_data(standardized) @staticmethod def znorm( tensor: torch.Tensor, mask: torch.Tensor, ) -> Optional[torch.Tensor]: tensor = tensor.clone().float() values = tensor[mask] mean, std = values.mean(), values.std() if std == 0: return None tensor -= mean tensor /= std return tensor