Source code for torchio.datasets.rsna_spine_fracture

from __future__ import annotations

from pathlib import Path
from types import ModuleType
from typing import Any
from typing import Union

from ..data import LabelMap
from ..data import ScalarImage
from ..data import Subject
from ..data import SubjectsDataset
from ..types import TypePath
from ..utils import normalize_path

TypeBoxes = list[dict[str, Union[str, float, int]]]


[docs] class RSNACervicalSpineFracture(SubjectsDataset): """RSNA 2022 Cervical Spine Fracture Detection dataset. This is a helper class for the dataset used in the `RSNA 2022 Cervical Spine Fracture Detection`_ hosted on `kaggle <https://www.kaggle.com/>`_. The dataset must be downloaded before instantiating this class. .. _RSNA 2022 Cervical Spine Fracture Detection: https://www.kaggle.com/competitions/rsna-2022-cervical-spine-fracture-detection/overview/evaluation """ UID = 'StudyInstanceUID' def __init__( self, root_dir: TypePath, add_segmentations: bool = False, add_bounding_boxes: bool = False, **kwargs, ): self.root_dir = normalize_path(root_dir) subjects = self._get_subjects( add_segmentations, add_bounding_boxes, ) super().__init__(subjects, **kwargs) @staticmethod def _get_image_dirs_dict(images_dir: Path) -> dict[str, Path]: dirs_dict = {} for dicom_dir in sorted(images_dir.iterdir()): dirs_dict[dicom_dir.name] = dicom_dir return dirs_dict @staticmethod def _get_segs_paths_dict(segs_dir: Path) -> dict[str, Path]: paths_dict = {} for image_path in sorted(segs_dir.iterdir()): key = image_path.name.replace('.gz', '').replace('.nii', '') paths_dict[key] = image_path return paths_dict def _get_subjects( self, add_segmentations: bool, add_bounding_boxes: bool, ) -> list[Subject]: subjects = [] pd = get_pandas() from tqdm.auto import tqdm split_name = 'train' images_dirname = f'{split_name}_images' images_dir = self.root_dir / images_dirname image_dirs_dict = self._get_image_dirs_dict(images_dir) segmentations_dir = self.root_dir / 'segmentations' seg_paths_dict = self._get_segs_paths_dict(segmentations_dir) bboxes_path = self.root_dir / 'train_bounding_boxes.csv' bounding_boxes_df = pd.read_csv(bboxes_path) grouped_boxes = bounding_boxes_df.groupby(self.UID) df = pd.read_csv(self.root_dir / f'{split_name}.csv') for _, row in tqdm(list(df.iterrows())): uid = row[self.UID] image_dir = image_dirs_dict[uid] seg_path = None if add_segmentations: seg_path = seg_paths_dict.get(uid, None) boxes = [] if add_bounding_boxes: try: boxes_df = grouped_boxes.get_group(uid) boxes = [dict(row) for _, row in boxes_df.iterrows()] except KeyError: pass subject = self._get_subject( dict(row), image_dir, seg_path, boxes, ) subjects.append(subject) return subjects @staticmethod def _filter_list(iterable: list[Path], target: str): def _filter(path: Path): if path.is_dir(): return target == path.name else: name = path.name.replace('.gz', '').replace('.nii', '') return target == name found = list(filter(_filter, iterable)) if found: assert len(found) == 1 result = found[0] else: result = None return result def _get_subject( self, csv_row_dict: dict[str, str | int], image_dir: Path, seg_path: Path | None, boxes: TypeBoxes, ) -> Subject: subject_dict: dict[str, Any] = {} subject_dict.update(csv_row_dict) subject_dict['ct'] = ScalarImage(image_dir) if seg_path is not None: subject_dict['seg'] = LabelMap(seg_path) if boxes: subject_dict['boxes'] = boxes return Subject(**subject_dict)
def get_pandas() -> ModuleType: try: import pandas return pandas except ImportError as e: message = ( 'Pandas is required for this operation.' ' Install pandas with "pip install pandas" and try again' ) raise ImportError(message) from e