from __future__ import annotations
from typing import Callable
import numpy as np
from ....data.subject import Subject
from ....types import TypeTripletInt
from ....utils import to_tuple
from ...spatial_transform import SpatialTransform
from .crop_or_pad import CropOrPad
[docs]
class EnsureShapeMultiple(SpatialTransform):
"""Ensure that all values in the image shape are divisible by :math:`n`.
Some convolutional neural network architectures need that the size of the
input across all spatial dimensions is a power of :math:`2`.
For example, the canonical 3D U-Net from
`Çiçek et al. <https://link.springer.com/chapter/10.1007/978-3-319-46723-8_49>`_
includes three downsampling (pooling) and upsampling operations:
.. image:: https://www.researchgate.net/profile/Olaf-Ronneberger/publication/304226155/figure/fig1/AS:375619658502144@1466566113191/The-3D-u-net-architecture-Blue-boxes-represent-feature-maps-The-number-of-channels-is.png
:alt: 3D U-Net
Pooling operations in PyTorch round down the output size:
>>> import torch
>>> x = torch.rand(3, 10, 20, 31)
>>> x_down = torch.nn.functional.max_pool3d(x, 2)
>>> x_down.shape
torch.Size([3, 5, 10, 15])
If we upsample this tensor, the original shape is lost:
>>> x_down_up = torch.nn.functional.interpolate(x_down, scale_factor=2)
>>> x_down_up.shape
torch.Size([3, 10, 20, 30])
>>> x.shape
torch.Size([3, 10, 20, 31])
If we try to concatenate ``x_down`` and ``x_down_up`` (to create skip
connections), we will get an error. It is therefore good practice to ensure
that the size of our images is such that concatenations will be safe.
.. note:: In these examples, it's assumed that all convolutions in the
U-Net use padding so that the output size is the same as the input
size.
The image above shows :math:`3` downsampling operations, so the input size
along all dimensions should be a multiple of :math:`2^3 = 8`.
Example (assuming ``pip install unet`` has been run before):
>>> import torchio as tio
>>> import unet
>>> net = unet.UNet3D(padding=1)
>>> t1 = tio.datasets.Colin27().t1
>>> tensor_bad = t1.data.unsqueeze(0)
>>> tensor_bad.shape
torch.Size([1, 1, 181, 217, 181])
>>> net(tensor_bad).shape
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/fernando/miniconda3/envs/resseg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/fernando/miniconda3/envs/resseg/lib/python3.7/site-packages/unet/unet.py", line 122, in forward
x = self.decoder(skip_connections, encoding)
File "/home/fernando/miniconda3/envs/resseg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/fernando/miniconda3/envs/resseg/lib/python3.7/site-packages/unet/decoding.py", line 61, in forward
x = decoding_block(skip_connection, x)
File "/home/fernando/miniconda3/envs/resseg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/fernando/miniconda3/envs/resseg/lib/python3.7/site-packages/unet/decoding.py", line 131, in forward
x = torch.cat((skip_connection, x), dim=CHANNELS_DIMENSION)
RuntimeError: Sizes of tensors must match except in dimension 1. Got 45 and 44 in dimension 2 (The offending index is 1)
>>> num_poolings = 3
>>> fix_shape_unet = tio.EnsureShapeMultiple(2**num_poolings)
>>> t1_fixed = fix_shape_unet(t1)
>>> tensor_ok = t1_fixed.data.unsqueeze(0)
>>> tensor_ok.shape
torch.Size([1, 1, 184, 224, 184]) # as expected
Args:
target_multiple: Tuple :math:`(n_w, n_h, n_d)`, so that the size of the
output along axis :math:`i` is a multiple of :math:`n_i`. If a
single value :math:`n` is provided, then
:math:`n_w = n_h = n_d = n`.
method: Either ``'crop'`` or ``'pad'``.
**kwargs: See :class:`~torchio.transforms.Transform` for additional
keyword arguments.
Example:
>>> import torchio as tio
>>> image = tio.datasets.Colin27().t1
>>> image.shape
(1, 181, 217, 181)
>>> transform = tio.EnsureShapeMultiple(8, method='pad')
>>> transformed = transform(image)
>>> transformed.shape
(1, 184, 224, 184)
>>> transform = tio.EnsureShapeMultiple(8, method='crop')
>>> transformed = transform(image)
>>> transformed.shape
(1, 176, 216, 176)
>>> image_2d = image.data[..., :1]
>>> image_2d.shape
torch.Size([1, 181, 217, 1])
>>> transformed = transform(image_2d)
>>> transformed.shape
torch.Size([1, 176, 216, 1])
"""
def __init__(
self,
target_multiple: int | TypeTripletInt,
*,
method: str = 'pad',
**kwargs,
):
super().__init__(**kwargs)
self.target_multiple = np.array(to_tuple(target_multiple, 3))
if method not in ('crop', 'pad'):
raise ValueError('Method must be "crop" or "pad"')
self.method = method
def apply_transform(self, subject: Subject) -> Subject:
source_shape = np.array(subject.spatial_shape, np.uint16)
function: Callable = np.floor if self.method == 'crop' else np.ceil # type: ignore[assignment]
integer_ratio = function(source_shape / self.target_multiple)
target_shape = integer_ratio * self.target_multiple
target_shape = np.maximum(target_shape, 1)
transform = CropOrPad(target_shape.astype(int), **self.get_base_args())
subject = transform(subject) # type: ignore[assignment]
return subject