Spaces:
Sleeping
Sleeping
| """Task setup utilities: select episodes and configure budgets per difficulty.""" | |
| from __future__ import annotations | |
| import random | |
| from typing import Optional | |
| from .config import DEFAULT_TASK, TASK_CONFIGS, TaskConfig | |
| from .data_loader import PatientEpisode, load_patients | |
| # Map OpenEnv difficulty labels to the CSV difficulty tags | |
| _DIFFICULTY_MAP = { | |
| "easy_screening": "easy", | |
| "budgeted_screening": "medium", | |
| "complex_tradeoff": "hard", | |
| } | |
| def get_task_config(task_id: Optional[str] = None) -> TaskConfig: | |
| tid = task_id or DEFAULT_TASK | |
| cfg = TASK_CONFIGS.get(tid) | |
| if cfg is None: | |
| raise ValueError(f"Unknown task_id {tid!r}. Choose from {list(TASK_CONFIGS)}") | |
| return cfg | |
| def sample_episode( | |
| task_id: Optional[str] = None, | |
| seed: Optional[int] = None, | |
| episode_id: Optional[str] = None, | |
| ) -> PatientEpisode: | |
| """Return a single patient episode appropriate for *task_id*.""" | |
| tid = task_id or DEFAULT_TASK | |
| difficulty = _DIFFICULTY_MAP.get(tid, "medium") | |
| episodes = load_patients(difficulty=difficulty) | |
| if not episodes: | |
| raise RuntimeError(f"No episodes found for difficulty={difficulty!r}") | |
| if episode_id: | |
| for ep in episodes: | |
| if ep.episode_id == episode_id: | |
| return ep | |
| raise ValueError(f"Episode {episode_id!r} not found for difficulty={difficulty!r}") | |
| rng = random.Random(seed) | |
| return rng.choice(episodes) | |