Source code for torchio.transforms.preprocessing.spatial.ensure_shape_multiple

from __future__ import annotations

from typing import Callable
from typing import Union

import numpy as np

from ... import SpatialTransform
from ....data.subject import Subject
from ....typing import TypeTripletInt
from ....utils import to_tuple
from .crop_or_pad import CropOrPad


[docs] class EnsureShapeMultiple(SpatialTransform): """Ensure that all values in the image shape are divisible by :math:`n`. Some convolutional neural network architectures need that the size of the input across all spatial dimensions is a power of :math:`2`. For example, the canonical 3D U-Net from `Çiçek et al. <https://link.springer.com/chapter/10.1007/978-3-319-46723-8_49>`_ includes three downsampling (pooling) and upsampling operations: .. image:: https://www.researchgate.net/profile/Olaf-Ronneberger/publication/304226155/figure/fig1/AS:375619658502144@1466566113191/The-3D-u-net-architecture-Blue-boxes-represent-feature-maps-The-number-of-channels-is.png :alt: 3D U-Net Pooling operations in PyTorch round down the output size: >>> import torch >>> x = torch.rand(3, 10, 20, 31) >>> x_down = torch.nn.functional.max_pool3d(x, 2) >>> x_down.shape torch.Size([3, 5, 10, 15]) If we upsample this tensor, the original shape is lost: >>> x_down_up = torch.nn.functional.interpolate(x_down, scale_factor=2) >>> x_down_up.shape torch.Size([3, 10, 20, 30]) >>> x.shape torch.Size([3, 10, 20, 31]) If we try to concatenate ``x_down`` and ``x_down_up`` (to create skip connections), we will get an error. It is therefore good practice to ensure that the size of our images is such that concatenations will be safe. .. note:: In these examples, it's assumed that all convolutions in the U-Net use padding so that the output size is the same as the input size. The image above shows :math:`3` downsampling operations, so the input size along all dimensions should be a multiple of :math:`2^3 = 8`. Example (assuming ``pip install unet`` has been run before): >>> import torchio as tio >>> import unet >>> net = unet.UNet3D(padding=1) >>> t1 = tio.datasets.Colin27().t1 >>> tensor_bad = t1.data.unsqueeze(0) >>> tensor_bad.shape torch.Size([1, 1, 181, 217, 181]) >>> net(tensor_bad).shape Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/home/fernando/miniconda3/envs/resseg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) File "/home/fernando/miniconda3/envs/resseg/lib/python3.7/site-packages/unet/unet.py", line 122, in forward x = self.decoder(skip_connections, encoding) File "/home/fernando/miniconda3/envs/resseg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) File "/home/fernando/miniconda3/envs/resseg/lib/python3.7/site-packages/unet/decoding.py", line 61, in forward x = decoding_block(skip_connection, x) File "/home/fernando/miniconda3/envs/resseg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) File "/home/fernando/miniconda3/envs/resseg/lib/python3.7/site-packages/unet/decoding.py", line 131, in forward x = torch.cat((skip_connection, x), dim=CHANNELS_DIMENSION) RuntimeError: Sizes of tensors must match except in dimension 1. Got 45 and 44 in dimension 2 (The offending index is 1) >>> num_poolings = 3 >>> fix_shape_unet = tio.EnsureShapeMultiple(2**num_poolings) >>> t1_fixed = fix_shape_unet(t1) >>> tensor_ok = t1_fixed.data.unsqueeze(0) >>> tensor_ok.shape torch.Size([1, 1, 184, 224, 184]) # as expected Args: target_multiple: Tuple :math:`(n_w, n_h, n_d)`, so that the size of the output along axis :math:`i` is a multiple of :math:`n_i`. If a single value :math:`n` is provided, then :math:`n_w = n_h = n_d = n`. method: Either ``'crop'`` or ``'pad'``. **kwargs: See :class:`~torchio.transforms.Transform` for additional keyword arguments. Example: >>> import torchio as tio >>> image = tio.datasets.Colin27().t1 >>> image.shape (1, 181, 217, 181) >>> transform = tio.EnsureShapeMultiple(8, method='pad') >>> transformed = transform(image) >>> transformed.shape (1, 184, 224, 184) >>> transform = tio.EnsureShapeMultiple(8, method='crop') >>> transformed = transform(image) >>> transformed.shape (1, 176, 216, 176) >>> image_2d = image.data[..., :1] >>> image_2d.shape torch.Size([1, 181, 217, 1]) >>> transformed = transform(image_2d) >>> transformed.shape torch.Size([1, 176, 216, 1]) """ # noqa: B950 def __init__( self, target_multiple: Union[int, TypeTripletInt], *, method: str = 'pad', **kwargs, ): super().__init__(**kwargs) self.target_multiple = np.array(to_tuple(target_multiple, 3)) if method not in ('crop', 'pad'): raise ValueError('Method must be "crop" or "pad"') self.method = method def apply_transform(self, subject: Subject) -> Subject: source_shape = np.array(subject.spatial_shape, np.uint16) function: Callable = np.floor if self.method == 'crop' else np.ceil # type: ignore[assignment] # noqa: B950 integer_ratio = function(source_shape / self.target_multiple) target_shape = integer_ratio * self.target_multiple target_shape = np.maximum(target_shape, 1) transform = CropOrPad(target_shape.astype(int)) subject = transform(subject) # type: ignore[assignment] return subject