File size: 1,441 Bytes
2043afa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
"""Task setup utilities: select episodes and configure budgets per difficulty."""

from __future__ import annotations

import random
from typing import Optional

from .config import DEFAULT_TASK, TASK_CONFIGS, TaskConfig
from .data_loader import PatientEpisode, load_patients


# Map OpenEnv difficulty labels to the CSV difficulty tags
_DIFFICULTY_MAP = {
    "easy_screening": "easy",
    "budgeted_screening": "medium",
    "complex_tradeoff": "hard",
}


def get_task_config(task_id: Optional[str] = None) -> TaskConfig:
    tid = task_id or DEFAULT_TASK
    cfg = TASK_CONFIGS.get(tid)
    if cfg is None:
        raise ValueError(f"Unknown task_id {tid!r}. Choose from {list(TASK_CONFIGS)}")
    return cfg


def sample_episode(
    task_id: Optional[str] = None,
    seed: Optional[int] = None,
    episode_id: Optional[str] = None,
) -> PatientEpisode:
    """Return a single patient episode appropriate for *task_id*."""
    tid = task_id or DEFAULT_TASK
    difficulty = _DIFFICULTY_MAP.get(tid, "medium")
    episodes = load_patients(difficulty=difficulty)
    if not episodes:
        raise RuntimeError(f"No episodes found for difficulty={difficulty!r}")

    if episode_id:
        for ep in episodes:
            if ep.episode_id == episode_id:
                return ep
        raise ValueError(f"Episode {episode_id!r} not found for difficulty={difficulty!r}")

    rng = random.Random(seed)
    return rng.choice(episodes)