[docs]classOneHot(LabelTransform):r"""Reencode label maps using one-hot encoding. Args: num_classes: See :func:`~torch.nn.functional.one_hot`. **kwargs: See :class:`~torchio.transforms.Transform` for additional keyword arguments. """def__init__(self,num_classes:int=-1,**kwargs):super().__init__(**kwargs)self.num_classes=num_classesself.args_names=['num_classes']self.invert_transform=Falsedefapply_transform(self,subject):forimageinself.get_images(subject):ifself.invert_transform:self.argmax(image)else:self.one_hot(image)returnsubject@staticmethoddefargmax(image:Image)->None:data=image.data.argmax(dim=0,keepdim=True)image.set_data(data)defone_hot(self,image:Image)->None:ifimage.num_channels>1:message=('The number of input channels must be 1,'f' but it is {image.num_channels}')raiseRuntimeError(message)data=image.data[0]num_classes=-1ifself.num_classesisNoneelseself.num_classesone_hot=F.one_hot(data.long(),num_classes=num_classes)image.set_data(one_hot.permute(3,0,1,2).type(data.type()))