Source code for torchio.datasets.itk_snap.itk_snap

import urllib.parse

from ...data import LabelMap
from ...data import ScalarImage
from ...data.subject import _RawSubjectCopySubject
from ...download import download_and_extract_archive
from ...utils import get_torchio_cache_dir


class SubjectITKSNAP(_RawSubjectCopySubject):
    """ITK-SNAP Image Data Downloads.

    See `the ITK-SNAP website`_ for more information.

    .. _the ITK-SNAP website: http://www.itksnap.org/pmwiki/pmwiki.php?n=Downloads.Data
    """  # noqa: B950

    url_base = 'https://www.nitrc.org/frs/download.php/'

    def __init__(self, name, code):
        self.name = name
        self.url_dir = urllib.parse.urljoin(self.url_base, f'{code}/')
        self.filename = f'{self.name}.zip'
        self.url = urllib.parse.urljoin(self.url_dir, self.filename)
        self.download_root = get_torchio_cache_dir() / self.name
        if not self.download_root.is_dir():
            download_and_extract_archive(
                self.url,
                download_root=self.download_root,
                filename=self.filename,
            )
        super().__init__(**self.get_kwargs())

    def get_kwargs(self):
        raise NotImplementedError


[docs] class BrainTumor(SubjectITKSNAP): def __init__(self): super().__init__('braintumor', '6161') def get_kwargs(self): t1, t1c, t2, flair, seg = ( self.download_root / self.name / f'BRATS_HG0015_{name}.mha' for name in ('T1', 'T1C', 'T2', 'FLAIR', 'truth') ) return { 't1': ScalarImage(t1), 't1c': ScalarImage(t1c), 't2': ScalarImage(t2), 'flair': ScalarImage(flair), 'seg': LabelMap(seg), }
[docs] class T1T2(SubjectITKSNAP): def __init__(self): super().__init__('ashs_test', '10983') def get_kwargs(self): mprage = self.download_root / self.name / 'mprage_3T_bet_dr.nii' tse = self.download_root / self.name / 'tse_3t_dr.nii' return { 'mprage': ScalarImage(mprage), 'tse': ScalarImage(tse), }
[docs] class AorticValve(SubjectITKSNAP): def __init__(self): super().__init__('bav_example', '11021') def get_kwargs(self): b14, b14_seg, b25, b25_seg = ( self.download_root / self.name / f'bav_frame_{name}.nii.gz' for name in ('14', '14_manseg', '25', '25_manseg') ) return { 'b14': ScalarImage(b14), 'b14_seg': LabelMap(b14_seg), 'b25': ScalarImage(b25), 'b25_seg': LabelMap(b25_seg), }