[docs]classSubjectsDataset(Dataset):"""Base TorchIO dataset. Reader of 3D medical images that directly inherits from the PyTorch :class:`~torch.utils.data.Dataset`. It can be used with a PyTorch :class:`~torch.utils.data.DataLoader` for efficient loading and augmentation. It receives a list of instances of :class:`~torchio.Subject` and an optional transform applied to the volumes after loading. Args: subjects: List of instances of :class:`~torchio.Subject`. transform: An instance of :class:`~torchio.transforms.Transform` that will be applied to each subject. load_getitem: Load all subject images before returning it in :meth:`__getitem__`. Set it to ``False`` if some of the images will not be needed during training. Example: >>> import torchio as tio >>> subject_a = tio.Subject( ... t1=tio.ScalarImage('t1.nrrd',), ... t2=tio.ScalarImage('t2.mha',), ... label=tio.LabelMap('t1_seg.nii.gz'), ... age=31, ... name='Fernando Perez', ... ) >>> subject_b = tio.Subject( ... t1=tio.ScalarImage('colin27_t1_tal_lin.minc',), ... t2=tio.ScalarImage('colin27_t2_tal_lin_dicom',), ... label=tio.LabelMap('colin27_seg1.nii.gz'), ... age=56, ... name='Colin Holmes', ... ) >>> subjects_list = [subject_a, subject_b] >>> transforms = [ ... tio.RescaleIntensity(out_min_max=(0, 1)), ... tio.RandomAffine(), ... ] >>> transform = tio.Compose(transforms) >>> subjects_dataset = tio.SubjectsDataset(subjects_list, transform=transform) >>> subject = subjects_dataset[0] .. _NiBabel: https://nipy.org/nibabel/#nibabel .. _SimpleITK: https://itk.org/Wiki/ITK/FAQ#What_3D_file_formats_can_ITK_import_and_export.3F .. _DICOM: https://www.dicomstandard.org/ .. _affine matrix: https://nipy.org/nibabel/coordinate_systems.html .. tip:: To quickly iterate over the subjects without loading the images, use :meth:`dry_iter()`. """# noqa: B950def__init__(self,subjects:Sequence[Subject],transform:Optional[Callable]=None,load_getitem:bool=True,):self._parse_subjects_list(subjects)self._subjects=subjectsself._transform:Optional[Callable]self.set_transform(transform)self.load_getitem=load_getitemdef__len__(self):returnlen(self._subjects)def__getitem__(self,index:int)->Subject:try:index=int(index)except(RuntimeError,TypeError):message=(f'Index "{index}" must be int or compatible dtype,'f' but an object of type "{type(index)}" was passed')raiseValueError(message)subject=self._subjects[index]subject=copy.deepcopy(subject)# cheap since images not loaded yetifself.load_getitem:subject.load()# Apply transform (this is usually the bottleneck)ifself._transformisnotNone:subject=self._transform(subject)returnsubject
[docs]@classmethoddeffrom_batch(cls,batch:dict)->SubjectsDataset:"""Instantiate a dataset from a batch generated by a data loader. Args: batch: Dictionary generated by a data loader, containing data that can be converted to instances of :class:`~.torchio.Subject`. """subjects:List[Subject]=get_subjects_from_batch(batch)returncls(subjects)
[docs]defdry_iter(self):"""Return the internal list of subjects. This can be used to iterate over the subjects without loading the data and applying any transforms:: >>> names = [subject.name for subject in dataset.dry_iter()] """returnself._subjects
[docs]defset_transform(self,transform:Optional[Callable])->None:"""Set the :attr:`transform` attribute. Args: transform: Callable object, typically an subclass of :class:`torchio.transforms.Transform`. """iftransformisnotNoneandnotcallable(transform):message=('The transform must be a callable object,'f' but it has type {type(transform)}')raiseValueError(message)self._transform=transform
@staticmethoddef_parse_subjects_list(subjects_list:Iterable[Subject])->None:# Check that it's an iterabletry:iter(subjects_list)exceptTypeErrorase:message=f'Subject list must be an iterable, not {type(subjects_list)}'raiseTypeError(message)frome# Check that it's not emptyifnotsubjects_list:raiseValueError('Subjects list is empty')# Check each elementforsubjectinsubjects_list:ifnotisinstance(subject,Subject):message=('Subjects list must contain instances of torchio.Subject,'f' not "{type(subject)}"')raiseTypeError(message)