Source code for torchio.transforms.preprocessing.spatial.pad

import warnings
from numbers import Number
from typing import Dict
from typing import Union

import nibabel as nib
import numpy as np
import torch

from ....data.image import LabelMap
from ....data.subject import Subject
from .bounds_transform import BoundsTransform
from .bounds_transform import TypeBounds


[docs] class Pad(BoundsTransform): r"""Pad an image. Args: padding: Tuple :math:`(w_{ini}, w_{fin}, h_{ini}, h_{fin}, d_{ini}, d_{fin})` defining the number of values padded to the edges of each axis. If the initial shape of the image is :math:`W \times H \times D`, the final shape will be :math:`(w_{ini} + W + w_{fin}) \times (h_{ini} + H + h_{fin}) \times (d_{ini} + D + d_{fin})`. If only three values :math:`(w, h, d)` are provided, then :math:`w_{ini} = w_{fin} = w`, :math:`h_{ini} = h_{fin} = h` and :math:`d_{ini} = d_{fin} = d`. If only one value :math:`n` is provided, then :math:`w_{ini} = w_{fin} = h_{ini} = h_{fin} = d_{ini} = d_{fin} = n`. padding_mode: See possible modes in `NumPy docs`_. If it is a number, the mode will be set to ``'constant'``. **kwargs: See :class:`~torchio.transforms.Transform` for additional keyword arguments. .. seealso:: If you want to pass the output shape instead, please use :class:`~torchio.transforms.CropOrPad` instead. .. _NumPy docs: https://numpy.org/doc/stable/reference/generated/numpy.pad.html """ # noqa: B950 PADDING_MODES = ( 'empty', 'edge', 'wrap', 'constant', 'linear_ramp', 'maximum', 'mean', 'median', 'minimum', 'reflect', 'symmetric', ) def __init__( self, padding: TypeBounds, padding_mode: Union[str, float] = 0, **kwargs, ): super().__init__(padding, **kwargs) self.padding = padding self.check_padding_mode(padding_mode) self.padding_mode = padding_mode self.args_names = ['padding', 'padding_mode'] @classmethod def check_padding_mode(cls, padding_mode): is_number = isinstance(padding_mode, Number) is_callable = callable(padding_mode) if not (padding_mode in cls.PADDING_MODES or is_number or is_callable): message = ( f'Padding mode "{padding_mode}" not valid. Valid options are' f' {list(cls.PADDING_MODES)}, a number or a function' ) raise KeyError(message) def apply_transform(self, subject: Subject) -> Subject: assert self.bounds_parameters is not None low = self.bounds_parameters[::2] for image in self.get_images(subject): if isinstance(image, LabelMap) and self.padding_mode == 'mean': message = ( 'Padding mode "mean" might create non-integer values in label maps' ) warnings.warn(message, RuntimeWarning, stacklevel=2) new_origin = nib.affines.apply_affine(image.affine, -np.array(low)) new_affine = image.affine.copy() new_affine[:3, 3] = new_origin kwargs: Dict[str, Union[str, float]] if isinstance(self.padding_mode, Number): kwargs = { 'mode': 'constant', 'constant_values': self.padding_mode, } else: kwargs = {'mode': self.padding_mode} pad_params = self.bounds_parameters paddings = (0, 0), pad_params[:2], pad_params[2:4], pad_params[4:] padded = np.pad(image.data, paddings, **kwargs) # type: ignore[call-overload] # noqa: B950 image.set_data(torch.as_tensor(padded)) image.affine = new_affine return subject def inverse(self): from .crop import Crop return Crop(self.padding)