| import os |
| import glob |
| import re |
| import numpy as np |
| import pandas as pd |
| import torch |
| from torch.utils.data import Dataset |
| from typing import List, Tuple, Union, Literal |
| import torch.nn.functional as F |
| from .pretrain_dataset import fMRIDataset |
| import io |
| import nibabel as nib |
|
|
| class fMRITaskDataset(fMRIDataset): |
|
|
| def __init__( |
| self, |
| data_root: str, |
| datasets: List[str], |
| split_suffixes: List[str], |
| crop_length: int, |
| label_csv_path: str, |
| task_type: Literal['classification', 'regression'] = 'classification', |
| downstream=True, |
| ): |
| super().__init__(data_root, datasets, split_suffixes, crop_length, downstream) |
| |
| self.task_type = task_type |
| self.labels_map = self._load_and_process_labels(label_csv_path) |
|
|
| initial_file_count = len(self.file_paths) |
| self.file_paths = [ |
| path for path in self.file_paths |
| if self._extract_subject_id(path) in self.labels_map |
| ] |
| |
| if len(self.file_paths) < initial_file_count: |
| print(f"Warning: Dropped {initial_file_count - len(self.file_paths)} files due to missing labels in CSV.") |
| |
| print(f"Task Dataset ready for {self.task_type}. Usable files: {len(self.file_paths)}") |
|
|
|
|
| def _extract_subject_id(self, file_path: str) -> str: |
|
|
|
|
| |
| |
|
|
| match = re.search(r'(\d{6})', os.path.basename(file_path)) |
| |
| if match: |
| subject_id_with_zeros = match.group(1) |
| subject_id = subject_id_with_zeros.lstrip('0') |
| |
| return subject_id |
| |
| return "" |
|
|
| def _load_and_process_labels(self, csv_path: str) -> dict: |
|
|
| if not os.path.exists(csv_path): |
| raise FileNotFoundError(f"Label CSV file not found at: {csv_path}") |
| |
| print(f"Loading labels from {csv_path}...") |
| df = pd.read_csv(csv_path) |
| |
| df['Subject'] = df['Subject'].astype(str) |
| df.dropna(subset=['Subject'], inplace=True) |
|
|
| labels_map = {} |
| |
| if self.task_type == 'classification': |
| label_col = None |
| if 'Gender' in df.columns: |
| label_col = 'Gender' |
| elif 'gender' in df.columns: |
| label_col = 'gender' |
| elif 'age_group' in df.columns: |
| label_col = 'age_group' |
| |
| if label_col is None: |
| raise ValueError("CSV must contain 'sex', 'gender' or 'age_group' column for classification.") |
|
|
| print(f"Using column '{label_col}' as label.") |
| |
| |
|
|
| sex_mapping = {'F': 0, 'M': 1, 'f': 0, 'm': 1} |
| |
| if df[label_col].dtype == object and df[label_col].astype(str).iloc[0].upper() in ['F', 'M']: |
| print(f"Encoding {label_col} (F/M) to Integers (0/1)...") |
| df = df[df[label_col].isin(sex_mapping.keys())] |
| df[label_col] = df[label_col].map(sex_mapping) |
| else: |
| df[label_col] = pd.to_numeric(df[label_col], errors='coerce').astype(int) |
| |
| for _, row in df.iterrows(): |
| subject_id = row['Subject'] |
| labels_map[subject_id] = torch.tensor(row[label_col], dtype=torch.long) |
|
|
| elif self.task_type == 'regression': |
| label_col = 'age' |
| if label_col not in df.columns: |
| raise ValueError(f"Regression task requires '{label_col}' column.") |
| df[label_col] = pd.to_numeric(df[label_col], errors='coerce') |
| df.dropna(subset=[label_col], inplace=True) |
| |
| for _, row in df.iterrows(): |
| subject_id = row['Subject'] |
| labels_map[subject_id] = torch.tensor(row[label_col], dtype=torch.float32).view(1) |
|
|
| else: |
| raise ValueError(f"Unsupported task_type: {self.task_type}") |
|
|
| print(f"Successfully loaded {len(labels_map)} subjects' labels.") |
| return labels_map |
|
|
| def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
| retries = 0 |
| max_retries = 100 |
| while retries < max_retries: |
| try: |
| data_tensor = super().__getitem__(idx) |
|
|
| if data_tensor is None: |
| raise ValueError(f"Failed to load data at index {idx} (super returned None)") |
|
|
| file_path = self.file_paths[idx] |
| |
| subject_id = self._extract_subject_id(file_path) |
|
|
| data_tensor = data_tensor.unsqueeze(0) |
| |
| if subject_id in self.labels_map: |
| label_tensor = self.labels_map[subject_id] |
|
|
| return data_tensor, label_tensor |
| else: |
| raise KeyError(f"Label not found for subject ID: {subject_id}") |
|
|
| except Exception as e: |
| |
| |
| idx = np.random.randint(0, len(self)) |
| retries += 1 |
| |
| raise RuntimeError(f"Failed to load any valid data after {max_retries} retries.") |
| |
| return data_tensor, label_tensor |
|
|