importnumpyasnpimporttorchfrom..importScalarImagefrom..importSubjectfrom..importSubjectsDatasetfrom..downloadimportdownload_urlfrom..utilsimportget_torchio_cache_dirclassMedMNIST(SubjectsDataset):"""3D MedMNIST v2 datasets. Datasets from `MedMNIST v2: A Large-Scale Lightweight Benchmark for 2D and 3D Biomedical Image Classification <https://arxiv.org/abs/2110.14795>`_. Please check the `MedMNIST website <https://medmnist.com/>`_ for more information, inclusing the license. Args: split: Dataset split. Should be ``'train'``, ``'val'`` or ``'test'``. """BASE_URL='https://zenodo.org/record/5208230/files'SPLITS='train','training','val','validation','test','testing'def__init__(self,split,**kwargs):ifsplitnotinself.SPLITS:raiseValueError(f'The split must be one of {self.SPLITS}')split='train'ifsplit=='training'elsesplitsplit='val'ifsplit=='validation'elsesplitsplit='test'ifsplit=='testing'elsespliturl=f'{self.BASE_URL}/{self.filename}?download=1'download_root=get_torchio_cache_dir()/'MedMNIST'download_url(url,download_root,filename=self.filename,)path=download_root/self.filenamenpz_file=np.load(path)images=npz_file[f'{split}_images']labels=npz_file[f'{split}_labels']subjects=[]forimage,labelinzip(images,labels):image=ScalarImage(tensor=image[np.newaxis])subject=Subject(image=image,labels=torch.from_numpy(label))subjects.append(subject)super().__init__(subjects,**kwargs)@propertydeffilename(self):returnf'{self.__class__.__name__.lower()}.npz'