Source code for torchio.transforms.preprocessing.spatial.crop_or_pad

import warnings
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union

import numpy as np

from ... import SpatialTransform
from ....data.subject import Subject
from ....utils import parse_spatial_shape
from ...transform import TypeSixBounds
from ...transform import TypeTripletInt
from .crop import Crop
from .pad import Pad


[docs] class CropOrPad(SpatialTransform): """Modify the field of view by cropping or padding to match a target shape. This transform modifies the affine matrix associated to the volume so that physical positions of the voxels are maintained. Args: target_shape: Tuple :math:`(W, H, D)`. If a single value :math:`N` is provided, then :math:`W = H = D = N`. If ``None``, the shape will be computed from the :attr:`mask_name` (and the :attr:`labels`, if :attr:`labels` is not ``None``). padding_mode: Same as :attr:`padding_mode` in :class:`~torchio.transforms.Pad`. mask_name: If ``None``, the centers of the input and output volumes will be the same. If a string is given, the output volume center will be the center of the bounding box of non-zero values in the image named :attr:`mask_name`. labels: If a label map is used to generate the mask, sequence of labels to consider. **kwargs: See :class:`~torchio.transforms.Transform` for additional keyword arguments. Example: >>> import torchio as tio >>> subject = tio.Subject( ... chest_ct=tio.ScalarImage('subject_a_ct.nii.gz'), ... heart_mask=tio.LabelMap('subject_a_heart_seg.nii.gz'), ... ) >>> subject.chest_ct.shape torch.Size([1, 512, 512, 289]) >>> transform = tio.CropOrPad( ... (120, 80, 180), ... mask_name='heart_mask', ... ) >>> transformed = transform(subject) >>> transformed.chest_ct.shape torch.Size([1, 120, 80, 180]) .. warning:: If :attr:`target_shape` is ``None``, subjects in the dataset will probably have different shapes. This is probably fine if you are using `patch-based training <https://torchio.readthedocs.io/patches/index.html>`_. If you are using full volumes for training and a batch size larger than one, an error will be raised by the :class:`~torch.utils.data.DataLoader` while trying to collate the batches. .. plot:: import torchio as tio t1 = tio.datasets.Colin27().t1 crop_pad = tio.CropOrPad((512, 512, 32)) t1_pad_crop = crop_pad(t1) subject = tio.Subject(t1=t1, crop_pad=t1_pad_crop) subject.plot() """ # noqa: B950 def __init__( self, target_shape: Union[int, TypeTripletInt, None] = None, padding_mode: Union[str, float] = 0, mask_name: Optional[str] = None, labels: Optional[Sequence[int]] = None, **kwargs, ): if target_shape is None and mask_name is None: message = 'If mask_name is None, a target shape must be passed' raise ValueError(message) super().__init__(**kwargs) if target_shape is None: self.target_shape = None else: self.target_shape = parse_spatial_shape(target_shape) self.padding_mode = padding_mode if mask_name is not None and not isinstance(mask_name, str): message = ( f'If mask_name is not None, it must be a string, not {type(mask_name)}' ) raise ValueError(message) if mask_name is None: if labels is not None: message = ( 'If mask_name is None, labels should be None,' f' but "{labels}" was passed' ) raise ValueError(message) self.compute_crop_or_pad = self._compute_center_crop_or_pad else: if not isinstance(mask_name, str): message = ( 'If mask_name is not None, it must be a string,' f' not {type(mask_name)}' ) raise ValueError(message) self.compute_crop_or_pad = self._compute_mask_center_crop_or_pad self.mask_name = mask_name self.labels = labels @staticmethod def _bbox_mask(mask_volume: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: """Return 6 coordinates of a 3D bounding box from a given mask. Taken from `this SO question <https://stackoverflow.com/questions/31400769/bounding-box-of-numpy-array>`_. Args: mask_volume: 3D NumPy array. """ # noqa: B950 i_any = np.any(mask_volume, axis=(1, 2)) j_any = np.any(mask_volume, axis=(0, 2)) k_any = np.any(mask_volume, axis=(0, 1)) i_min, i_max = np.where(i_any)[0][[0, -1]] j_min, j_max = np.where(j_any)[0][[0, -1]] k_min, k_max = np.where(k_any)[0][[0, -1]] bb_min = np.array([i_min, j_min, k_min]) bb_max = np.array([i_max, j_max, k_max]) + 1 return bb_min, bb_max
[docs] @staticmethod def _get_six_bounds_parameters( parameters: np.ndarray, ) -> TypeSixBounds: r"""Compute bounds parameters for ITK filters. Args: parameters: Tuple :math:`(w, h, d)` with the number of voxels to be cropped or padded. Returns: Tuple :math:`(w_{ini}, w_{fin}, h_{ini}, h_{fin}, d_{ini}, d_{fin})`, where :math:`n_{ini} = \left \lceil \frac{n}{2} \right \rceil` and :math:`n_{fin} = \left \lfloor \frac{n}{2} \right \rfloor`. Example: >>> p = np.array((4, 0, 7)) >>> CropOrPad._get_six_bounds_parameters(p) (2, 2, 0, 0, 4, 3) """ # noqa: B950 parameters = parameters / 2 result = [] for number in parameters: ini, fin = int(np.ceil(number)), int(np.floor(number)) result.extend([ini, fin]) i1, i2, j1, j2, k1, k2 = result return i1, i2, j1, j2, k1, k2
def _compute_cropping_padding_from_shapes( self, source_shape: TypeTripletInt, ) -> Tuple[Optional[TypeSixBounds], Optional[TypeSixBounds]]: diff_shape = np.array(self.target_shape) - source_shape cropping = -np.minimum(diff_shape, 0) if cropping.any(): cropping_params = self._get_six_bounds_parameters(cropping) else: cropping_params = None padding = np.maximum(diff_shape, 0) if padding.any(): padding_params = self._get_six_bounds_parameters(padding) else: padding_params = None return padding_params, cropping_params def _compute_center_crop_or_pad( self, subject: Subject, ) -> Tuple[Optional[TypeSixBounds], Optional[TypeSixBounds]]: source_shape = subject.spatial_shape parameters = self._compute_cropping_padding_from_shapes(source_shape) padding_params, cropping_params = parameters return padding_params, cropping_params def _compute_mask_center_crop_or_pad( self, subject: Subject, ) -> Tuple[Optional[TypeSixBounds], Optional[TypeSixBounds]]: if self.mask_name not in subject: message = ( f'Mask name "{self.mask_name}"' f' not found in subject keys "{tuple(subject.keys())}".' ' Using volume center instead' ) warnings.warn(message, RuntimeWarning, stacklevel=2) return self._compute_center_crop_or_pad(subject=subject) mask_data = self.get_mask_from_masking_method( self.mask_name, subject, subject[self.mask_name].data, self.labels, ).numpy() if not np.any(mask_data): message = ( f'All values found in the mask "{self.mask_name}"' ' are zero. Using volume center instead' ) warnings.warn(message, RuntimeWarning, stacklevel=2) return self._compute_center_crop_or_pad(subject=subject) # Let's assume that the center of first voxel is at coordinate 0.5 # (which is typically not the case) subject_shape = subject.spatial_shape bb_min, bb_max = self._bbox_mask(mask_data[0]) center_mask = np.mean((bb_min, bb_max), axis=0) padding = [] cropping = [] if self.target_shape is None: target_shape = bb_max - bb_min else: target_shape = self.target_shape for dim in range(3): target_dim = target_shape[dim] center_dim = center_mask[dim] subject_dim = subject_shape[dim] center_on_index = not (center_dim % 1) target_even = not (target_dim % 2) # Approximation when the center cannot be computed exactly # The output will be off by half a voxel, but this is just an # implementation detail if target_even ^ center_on_index: center_dim -= 0.5 begin = center_dim - target_dim / 2 if begin >= 0: crop_ini = begin pad_ini = 0 else: crop_ini = 0 pad_ini = -begin end = center_dim + target_dim / 2 if end <= subject_dim: crop_fin = subject_dim - end pad_fin = 0 else: crop_fin = 0 pad_fin = end - subject_dim padding.extend([pad_ini, pad_fin]) cropping.extend([crop_ini, crop_fin]) # Conversion for SimpleITK compatibility padding_array = np.asarray(padding, dtype=int) cropping_array = np.asarray(cropping, dtype=int) if padding_array.any(): padding_params = tuple(padding_array.tolist()) else: padding_params = None if cropping_array.any(): cropping_params = tuple(cropping_array.tolist()) else: cropping_params = None return padding_params, cropping_params # type: ignore[return-value] def apply_transform(self, subject: Subject) -> Subject: subject.check_consistent_space() padding_params, cropping_params = self.compute_crop_or_pad(subject) padding_kwargs = {'padding_mode': self.padding_mode} if padding_params is not None: pad = Pad(padding_params, **padding_kwargs) subject = pad(subject) # type: ignore[assignment] if cropping_params is not None: crop = Crop(cropping_params) subject = crop(subject) # type: ignore[assignment] return subject