from __future__ import annotations
import copy
from typing import Callable
from typing import Iterable
from typing import List
from typing import Optional
from typing import Sequence
from torch.utils.data import Dataset
from ..utils import get_subjects_from_batch
from .subject import Subject
[docs]
class SubjectsDataset(Dataset):
"""Base TorchIO dataset.
Reader of 3D medical images that directly inherits from the PyTorch
:class:`~torch.utils.data.Dataset`. It can be used with a PyTorch
:class:`~torch.utils.data.DataLoader` for efficient loading and
augmentation. It receives a list of instances of :class:`~torchio.Subject`
and an optional transform applied to the volumes after loading.
Args:
subjects: List of instances of :class:`~torchio.Subject`.
transform: An instance of :class:`~torchio.transforms.Transform`
that will be applied to each subject.
load_getitem: Load all subject images before returning it in
:meth:`__getitem__`. Set it to ``False`` if some of the images will
not be needed during training.
Example:
>>> import torchio as tio
>>> subject_a = tio.Subject(
... t1=tio.ScalarImage('t1.nrrd',),
... t2=tio.ScalarImage('t2.mha',),
... label=tio.LabelMap('t1_seg.nii.gz'),
... age=31,
... name='Fernando Perez',
... )
>>> subject_b = tio.Subject(
... t1=tio.ScalarImage('colin27_t1_tal_lin.minc',),
... t2=tio.ScalarImage('colin27_t2_tal_lin_dicom',),
... label=tio.LabelMap('colin27_seg1.nii.gz'),
... age=56,
... name='Colin Holmes',
... )
>>> subjects_list = [subject_a, subject_b]
>>> transforms = [
... tio.RescaleIntensity(out_min_max=(0, 1)),
... tio.RandomAffine(),
... ]
>>> transform = tio.Compose(transforms)
>>> subjects_dataset = tio.SubjectsDataset(subjects_list, transform=transform)
>>> subject = subjects_dataset[0]
.. _NiBabel: https://nipy.org/nibabel/#nibabel
.. _SimpleITK: https://itk.org/Wiki/ITK/FAQ#What_3D_file_formats_can_ITK_import_and_export.3F
.. _DICOM: https://www.dicomstandard.org/
.. _affine matrix: https://nipy.org/nibabel/coordinate_systems.html
.. tip:: To quickly iterate over the subjects without loading the images,
use :meth:`dry_iter()`.
""" # noqa: B950
def __init__(
self,
subjects: Sequence[Subject],
transform: Optional[Callable] = None,
load_getitem: bool = True,
):
self._parse_subjects_list(subjects)
self._subjects = subjects
self._transform: Optional[Callable]
self.set_transform(transform)
self.load_getitem = load_getitem
def __len__(self):
return len(self._subjects)
def __getitem__(self, index: int) -> Subject:
try:
index = int(index)
except (RuntimeError, TypeError):
message = (
f'Index "{index}" must be int or compatible dtype,'
f' but an object of type "{type(index)}" was passed'
)
raise ValueError(message)
subject = self._subjects[index]
subject = copy.deepcopy(subject) # cheap since images not loaded yet
if self.load_getitem:
subject.load()
# Apply transform (this is usually the bottleneck)
if self._transform is not None:
subject = self._transform(subject)
return subject
[docs]
@classmethod
def from_batch(cls, batch: dict) -> SubjectsDataset:
"""Instantiate a dataset from a batch generated by a data loader.
Args:
batch: Dictionary generated by a data loader, containing data that
can be converted to instances of :class:`~.torchio.Subject`.
"""
subjects: List[Subject] = get_subjects_from_batch(batch)
return cls(subjects)
[docs]
def dry_iter(self):
"""Return the internal list of subjects.
This can be used to iterate over the subjects without loading the data
and applying any transforms::
>>> names = [subject.name for subject in dataset.dry_iter()]
"""
return self._subjects
@staticmethod
def _parse_subjects_list(subjects_list: Iterable[Subject]) -> None:
# Check that it's an iterable
try:
iter(subjects_list)
except TypeError as e:
message = f'Subject list must be an iterable, not {type(subjects_list)}'
raise TypeError(message) from e
# Check that it's not empty
if not subjects_list:
raise ValueError('Subjects list is empty')
# Check each element
for subject in subjects_list:
if not isinstance(subject, Subject):
message = (
'Subjects list must contain instances of torchio.Subject,'
f' not "{type(subject)}"'
)
raise TypeError(message)