Source code for torchio.transforms.preprocessing.label.one_hot

import torch.nn.functional as F  # noqa: N812

from ....data.image import Image
from .label_transform import LabelTransform


[docs] class OneHot(LabelTransform): r"""Reencode label maps using one-hot encoding. Args: num_classes: See :func:`~torch.nn.functional.one_hot`. **kwargs: See :class:`~torchio.transforms.Transform` for additional keyword arguments. """ def __init__(self, num_classes: int = -1, **kwargs): super().__init__(**kwargs) self.num_classes = num_classes self.args_names = ['num_classes'] self.invert_transform = False def apply_transform(self, subject): for image in self.get_images(subject): if self.invert_transform: self.argmax(image) else: self.one_hot(image) return subject @staticmethod def argmax(image: Image) -> None: data = image.data.argmax(dim=0, keepdim=True) image.set_data(data) def one_hot(self, image: Image) -> None: if image.num_channels > 1: message = ( 'The number of input channels must be 1,' f' but it is {image.num_channels}' ) raise RuntimeError(message) data = image.data[0] num_classes = -1 if self.num_classes is None else self.num_classes one_hot = F.one_hot(data.long(), num_classes=num_classes) image.set_data(one_hot.permute(3, 0, 1, 2).type(data.type()))