adithya9903's picture
Flatten project to root for OpenEnv submission readiness.
fa51dd9
"""Task setup utilities: select episodes and configure budgets per difficulty."""
from __future__ import annotations
import random
from typing import Optional
from .config import DEFAULT_TASK, TASK_CONFIGS, TaskConfig
from .data_loader import PatientEpisode, load_patients
# Map OpenEnv difficulty labels to the CSV difficulty tags
_DIFFICULTY_MAP = {
"easy_screening": "easy",
"budgeted_screening": "medium",
"complex_tradeoff": "hard",
}
def get_task_config(task_id: Optional[str] = None) -> TaskConfig:
tid = task_id or DEFAULT_TASK
cfg = TASK_CONFIGS.get(tid)
if cfg is None:
raise ValueError(f"Unknown task_id {tid!r}. Choose from {list(TASK_CONFIGS)}")
return cfg
def sample_episode(
task_id: Optional[str] = None,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
) -> PatientEpisode:
"""Return a single patient episode appropriate for *task_id*."""
tid = task_id or DEFAULT_TASK
difficulty = _DIFFICULTY_MAP.get(tid, "medium")
episodes = load_patients(difficulty=difficulty)
if not episodes:
raise RuntimeError(f"No episodes found for difficulty={difficulty!r}")
if episode_id:
for ep in episodes:
if ep.episode_id == episode_id:
return ep
raise ValueError(f"Episode {episode_id!r} not found for difficulty={difficulty!r}")
rng = random.Random(seed)
return rng.choice(episodes)