Source code for torchio.datasets.rsna_spine_fracture

from pathlib import Path
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Union
from types import ModuleType

from .. import LabelMap
from .. import ScalarImage
from .. import Subject
from .. import SubjectsDataset
from ..typing import TypePath
from ..utils import normalize_path


TypeBoxes = List[Dict[str, Union[str, float, int]]]


[docs] class RSNACervicalSpineFracture(SubjectsDataset): """RSNA 2022 Cervical Spine Fracture Detection dataset. This is a helper class for the dataset used in the `RSNA 2022 Cervical Spine Fracture Detection`_ hosted on `kaggle <https://www.kaggle.com/>`_. The dataset must be downloaded before instantiating this class. .. _RSNA 2022 Cervical Spine Fracture Detection: https://www.kaggle.com/competitions/rsna-2022-cervical-spine-fracture-detection/overview/evaluation """ # noqa: B950 UID = 'StudyInstanceUID' def __init__( self, root_dir: TypePath, add_segmentations: bool = False, add_bounding_boxes: bool = False, **kwargs, ): self.root_dir = normalize_path(root_dir) subjects = self._get_subjects( add_segmentations, add_bounding_boxes, ) super().__init__(subjects, **kwargs) @staticmethod def _get_image_dirs_dict(images_dir: Path) -> Dict[str, Path]: dirs_dict = {} for dicom_dir in sorted(images_dir.iterdir()): dirs_dict[dicom_dir.name] = dicom_dir return dirs_dict @staticmethod def _get_segs_paths_dict(segs_dir: Path) -> Dict[str, Path]: paths_dict = {} for image_path in sorted(segs_dir.iterdir()): key = image_path.name.replace('.gz', '').replace('.nii', '') paths_dict[key] = image_path return paths_dict def _get_subjects( self, add_segmentations: bool, add_bounding_boxes: bool, ) -> List[Subject]: subjects = [] pd = get_pandas() from tqdm.auto import tqdm split_name = 'train' images_dirname = f'{split_name}_images' images_dir = self.root_dir / images_dirname image_dirs_dict = self._get_image_dirs_dict(images_dir) segmentations_dir = self.root_dir / 'segmentations' seg_paths_dict = self._get_segs_paths_dict(segmentations_dir) bboxes_path = self.root_dir / 'train_bounding_boxes.csv' bounding_boxes_df = pd.read_csv(bboxes_path) grouped_boxes = bounding_boxes_df.groupby(self.UID) df = pd.read_csv(self.root_dir / f'{split_name}.csv') for _, row in tqdm(list(df.iterrows())): uid = row[self.UID] image_dir = image_dirs_dict[uid] seg_path = None if add_segmentations: seg_path = seg_paths_dict.get(uid, None) boxes = [] if add_bounding_boxes: try: boxes_df = grouped_boxes.get_group(uid) boxes = [dict(row) for _, row in boxes_df.iterrows()] except KeyError: pass subject = self._get_subject( dict(row), image_dir, seg_path, boxes, ) subjects.append(subject) return subjects @staticmethod def _filter_list(iterable: List[Path], target: str): def _filter(path: Path): if path.is_dir(): return target == path.name else: name = path.name.replace('.gz', '').replace('.nii', '') return target == name found = list(filter(_filter, iterable)) if found: assert len(found) == 1 result = found[0] else: result = None return result def _get_subject( self, csv_row_dict: Dict[str, Union[str, int]], image_dir: Path, seg_path: Optional[Path], boxes: TypeBoxes, ) -> Subject: subject_dict: Dict[str, Any] = {} subject_dict.update(csv_row_dict) subject_dict['ct'] = ScalarImage(image_dir) if seg_path is not None: subject_dict['seg'] = LabelMap(seg_path) if boxes: subject_dict['boxes'] = boxes return Subject(**subject_dict)
def get_pandas() -> ModuleType: try: import pandas return pandas except ImportError as e: message = ( 'Pandas is required for this operation.' ' Install pandas with "pip install pandas" and try again' ) raise ImportError(message) from e