Source code for torchio.datasets.ixi

"""The `Information eXtraction from Images (IXI) <>`_
dataset contains "nearly 600 MR images from normal, healthy subjects",
including "T1, T2 and PD-weighted images, MRA images and Diffusion-weighted
images (15 directions)".

.. note ::
    This data is made available under the
    Creative Commons CC BY-SA 3.0 license.
    If you use it please acknowledge the source of the IXI data, e.g.
    `the IXI website <>`_.

# Adapted from
import shutil
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Optional
from typing import Sequence

from .. import LabelMap
from .. import ScalarImage
from .. import Subject
from .. import SubjectsDataset
from import download_and_extract_archive
from ..transforms import Transform
from ..typing import TypePath

[docs] class IXI(SubjectsDataset): """Full IXI dataset. Args: root: Root directory to which the dataset will be downloaded. transform: An instance of :class:`~torchio.transforms.transform.Transform`. download: If set to ``True``, will download the data into :attr:`root`. modalities: List of modalities to be downloaded. They must be in ``('T1', 'T2', 'PD', 'MRA', 'DTI')``. .. warning:: The size of this dataset is multiple GB. If you set :attr:`download` to ``True``, it will take some time to be downloaded if it is not already present. Example: >>> import torchio as tio >>> transforms = [ ... tio.ToCanonical(), # to RAS ... tio.Resample((1, 1, 1)), # to 1 mm iso ... ] >>> ixi_dataset = tio.datasets.IXI( ... 'path/to/ixi_root/', ... modalities=('T1', 'T2'), ... transform=tio.Compose(transforms), ... download=True, ... ) >>> print('Number of subjects in dataset:', len(ixi_dataset)) # 577 >>> sample_subject = ixi_dataset[0] >>> print('Keys in subject:', tuple(sample_subject.keys())) # ('T1', 'T2') >>> print('Shape of T1 data:', sample_subject['T1'].shape) # [1, 180, 268, 268] >>> print('Shape of T2 data:', sample_subject['T2'].shape) # [1, 241, 257, 188] """ # noqa: B950 base_url = '{modality}.tar' # noqa: FS003,B950 md5_dict = { 'T1': '34901a0593b41dd19c1a1f746eac2d58', 'T2': 'e3140d78730ecdd32ba92da48c0a9aaa', 'PD': '88ecd9d1fa33cb4a2278183b42ffd749', 'MRA': '29be7d2fee3998f978a55a9bdaf3407e', 'DTI': '636573825b1c8b9e8c78f1877df3ee66', } def __init__( self, root: TypePath, transform: Optional[Transform] = None, download: bool = False, modalities: Sequence[str] = ('T1', 'T2'), **kwargs, ): root = Path(root) for modality in modalities: if modality not in self.md5_dict: message = ( f'Modality "{modality}" must be' f' one of {tuple(self.md5_dict.keys())}' ) raise ValueError(message) if download: self._download(root, modalities) if not self._check_exists(root, modalities): message = 'Dataset not found. You can use download=True to download it' raise RuntimeError(message) subjects_list = self._get_subjects_list(root, modalities) super().__init__(subjects_list, transform=transform, **kwargs) @staticmethod def _check_exists(root, modalities): for modality in modalities: modality_dir = root / modality if not modality_dir.is_dir(): exists = False break else: exists = True return exists @staticmethod def _get_subjects_list(root, modalities): # The number of files for each modality is not the same # E.g. 581 for T1, 578 for T2 # Let's just use the first modality as reference for now # I.e. only subjects with all modalities will be included one_modality = modalities[0] paths = sglob(root / one_modality, '*.nii.gz') subjects = [] for filepath in paths: subject_id = get_subject_id(filepath) images_dict = {'subject_id': subject_id} images_dict[one_modality] = ScalarImage(filepath) for modality in modalities[1:]: globbed = sglob( root / modality, f'{subject_id}-{modality}.nii.gz', ) if globbed: assert len(globbed) == 1 images_dict[modality] = ScalarImage(globbed[0]) else: skip_subject = True break else: skip_subject = False if skip_subject: continue subjects.append(Subject(**images_dict)) return subjects def _download(self, root, modalities): """Download the IXI data if it does not exist already.""" for modality in modalities: modality_dir = root / modality if modality_dir.is_dir(): continue modality_dir.mkdir(exist_ok=True, parents=True) # download files url = self.base_url.format(modality=modality) md5 = self.md5_dict[modality] with NamedTemporaryFile(suffix='.tar', delete=False) as f: download_and_extract_archive( url, download_root=modality_dir,, md5=md5, )
[docs] class IXITiny(SubjectsDataset): r"""This is the dataset used in the main `notebook`_. It is a tiny version of IXI, containing 566 :math:`T_1`-weighted brain MR images and their corresponding brain segmentations, all with size :math:`83 \times 44 \times 55`. It can be used as a medical image MNIST. Args: root: Root directory to which the dataset will be downloaded. transform: An instance of :class:`~torchio.transforms.transform.Transform`. download: If set to ``True``, will download the data into :attr:`root`. .. _notebook: """ # noqa: B950 url = '' md5 = 'bfb60f4074283d78622760230bfa1f98' def __init__( self, root: TypePath, transform: Optional[Transform] = None, download: bool = False, **kwargs, ): root = Path(root) if download: self._download(root) if not root.is_dir(): message = 'Dataset not found. You can use download=True to download it' raise RuntimeError(message) subjects_list = self._get_subjects_list(root) super().__init__(subjects_list, transform=transform, **kwargs) @staticmethod def _get_subjects_list(root): image_paths = sglob(root / 'image', '*.nii.gz') label_paths = sglob(root / 'label', '*.nii.gz') if not (image_paths and label_paths): message = ( f'Images not found. Remove the root directory ({root}) and try again' ) raise FileNotFoundError(message) subjects = [] for image_path, label_path in zip(image_paths, label_paths): subject_id = get_subject_id(image_path) subject_dict = {} subject_dict['image'] = ScalarImage(image_path) subject_dict['label'] = LabelMap(label_path) subject_dict['subject_id'] = subject_id subjects.append(Subject(**subject_dict)) return subjects def _download(self, root): """Download the tiny IXI data if it doesn't exist already.""" if root.is_dir(): # assume it's been downloaded print('Root directory for IXITiny found:', root) # noqa: T201 return print('Root directory for IXITiny not found:', root) # noqa: T201 print('Downloading...') # noqa: T201 with NamedTemporaryFile(suffix='.zip', delete=False) as f: download_and_extract_archive( self.url, download_root=root,, md5=self.md5, ) ixi_tiny_dir = root / 'ixi_tiny' (ixi_tiny_dir / 'image').rename(root / 'image') (ixi_tiny_dir / 'label').rename(root / 'label') shutil.rmtree(ixi_tiny_dir)
def sglob(directory, pattern): return sorted(Path(directory).glob(pattern)) def get_subject_id(path): return '-'.join('-')[:-1])