import warnings
from numbers import Number
from typing import Dict
from typing import Union
import nibabel as nib
import numpy as np
import torch
from ....data.image import LabelMap
from ....data.subject import Subject
from .bounds_transform import BoundsTransform
from .bounds_transform import TypeBounds
[docs]
class Pad(BoundsTransform):
r"""Pad an image.
Args:
padding: Tuple
:math:`(w_{ini}, w_{fin}, h_{ini}, h_{fin}, d_{ini}, d_{fin})`
defining the number of values padded to the edges of each axis.
If the initial shape of the image is
:math:`W \times H \times D`, the final shape will be
:math:`(w_{ini} + W + w_{fin}) \times (h_{ini} + H + h_{fin})
\times (d_{ini} + D + d_{fin})`.
If only three values :math:`(w, h, d)` are provided, then
:math:`w_{ini} = w_{fin} = w`,
:math:`h_{ini} = h_{fin} = h` and
:math:`d_{ini} = d_{fin} = d`.
If only one value :math:`n` is provided, then
:math:`w_{ini} = w_{fin} = h_{ini} = h_{fin} =
d_{ini} = d_{fin} = n`.
padding_mode: See possible modes in `NumPy docs`_. If it is a number,
the mode will be set to ``'constant'``.
**kwargs: See :class:`~torchio.transforms.Transform` for additional
keyword arguments.
.. seealso:: If you want to pass the output shape instead, please use
:class:`~torchio.transforms.CropOrPad` instead.
.. _NumPy docs: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
""" # noqa: B950
PADDING_MODES = (
'empty',
'edge',
'wrap',
'constant',
'linear_ramp',
'maximum',
'mean',
'median',
'minimum',
'reflect',
'symmetric',
)
def __init__(
self,
padding: TypeBounds,
padding_mode: Union[str, float] = 0,
**kwargs,
):
super().__init__(padding, **kwargs)
self.padding = padding
self.check_padding_mode(padding_mode)
self.padding_mode = padding_mode
self.args_names = ['padding', 'padding_mode']
@classmethod
def check_padding_mode(cls, padding_mode):
is_number = isinstance(padding_mode, Number)
is_callable = callable(padding_mode)
if not (padding_mode in cls.PADDING_MODES or is_number or is_callable):
message = (
f'Padding mode "{padding_mode}" not valid. Valid options are'
f' {list(cls.PADDING_MODES)}, a number or a function'
)
raise KeyError(message)
def apply_transform(self, subject: Subject) -> Subject:
assert self.bounds_parameters is not None
low = self.bounds_parameters[::2]
for image in self.get_images(subject):
if isinstance(image, LabelMap) and self.padding_mode == 'mean':
message = (
'Padding mode "mean" might create non-integer values in label maps'
)
warnings.warn(message, RuntimeWarning, stacklevel=2)
new_origin = nib.affines.apply_affine(image.affine, -np.array(low))
new_affine = image.affine.copy()
new_affine[:3, 3] = new_origin
kwargs: Dict[str, Union[str, float]]
if isinstance(self.padding_mode, Number):
kwargs = {
'mode': 'constant',
'constant_values': self.padding_mode,
}
else:
kwargs = {'mode': self.padding_mode}
pad_params = self.bounds_parameters
paddings = (0, 0), pad_params[:2], pad_params[2:4], pad_params[4:]
padded = np.pad(image.data, paddings, **kwargs) # type: ignore[call-overload] # noqa: B950
image.set_data(torch.as_tensor(padded))
image.affine = new_affine
return subject
def inverse(self):
from .crop import Crop
return Crop(self.padding)