File size: 2,010 Bytes
877add7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
"""Adapter catalog for old task labels over the canonical PolyGuard env."""

from __future__ import annotations

from typing import Any

from app.common.constants import REWARD_MAX, REWARD_MIN, REWARD_PRECISION
from app.common.enums import Difficulty, SubEnvironment

TASK_PRESETS: tuple[dict[str, str], ...] = (
    {
        "id": "easy_screening",
        "label": "Easy Screening",
        "difficulty": Difficulty.EASY.value,
        "sub_environment": SubEnvironment.DDI.value,
    },
    {
        "id": "budgeted_screening",
        "label": "Budgeted Screening",
        "difficulty": Difficulty.MEDIUM.value,
        "sub_environment": SubEnvironment.REGIMEN_RISK.value,
    },
    {
        "id": "complex_tradeoff",
        "label": "Complex Tradeoff",
        "difficulty": Difficulty.HARD.value,
        "sub_environment": SubEnvironment.REGIMEN_RISK.value,
    },
    {
        "id": "bandit_mining",
        "label": "Bandit Mining",
        "difficulty": Difficulty.HARD.value,
        "sub_environment": SubEnvironment.BANDIT_MINING.value,
    },
)

TASK_PRESET_BY_ID = {item["id"]: item for item in TASK_PRESETS}


def env_catalog() -> dict[str, Any]:
    return {
        "reward_range": [REWARD_MIN, REWARD_MAX],
        "reward_precision": REWARD_PRECISION,
        "task_presets": [dict(item) for item in TASK_PRESETS],
        "sub_environments": [item.value for item in SubEnvironment],
    }


def apply_task_preset(payload: dict[str, Any]) -> dict[str, Any]:
    """Expand an old task label into canonical difficulty/sub-environment fields."""
    task_id = payload.pop("task_id", None)
    if not task_id:
        return payload
    preset = TASK_PRESET_BY_ID.get(str(task_id))
    if preset is None:
        raise ValueError(f"Unknown task_id {task_id!r}")
    if payload.get("difficulty") is None:
        payload["difficulty"] = preset["difficulty"]
    if payload.get("sub_environment") is None:
        payload["sub_environment"] = preset["sub_environment"]
    return payload