Spaces:
Running
Running
| """Adapter catalog for old task labels over the canonical PolyGuard env.""" | |
| from __future__ import annotations | |
| from typing import Any | |
| from app.common.constants import REWARD_MAX, REWARD_MIN, REWARD_PRECISION | |
| from app.common.enums import Difficulty, SubEnvironment | |
| TASK_PRESETS: tuple[dict[str, str], ...] = ( | |
| { | |
| "id": "easy_screening", | |
| "label": "Easy Screening", | |
| "difficulty": Difficulty.EASY.value, | |
| "sub_environment": SubEnvironment.DDI.value, | |
| }, | |
| { | |
| "id": "budgeted_screening", | |
| "label": "Budgeted Screening", | |
| "difficulty": Difficulty.MEDIUM.value, | |
| "sub_environment": SubEnvironment.REGIMEN_RISK.value, | |
| }, | |
| { | |
| "id": "complex_tradeoff", | |
| "label": "Complex Tradeoff", | |
| "difficulty": Difficulty.HARD.value, | |
| "sub_environment": SubEnvironment.REGIMEN_RISK.value, | |
| }, | |
| { | |
| "id": "bandit_mining", | |
| "label": "Bandit Mining", | |
| "difficulty": Difficulty.HARD.value, | |
| "sub_environment": SubEnvironment.BANDIT_MINING.value, | |
| }, | |
| ) | |
| TASK_PRESET_BY_ID = {item["id"]: item for item in TASK_PRESETS} | |
| def env_catalog() -> dict[str, Any]: | |
| return { | |
| "reward_range": [REWARD_MIN, REWARD_MAX], | |
| "reward_precision": REWARD_PRECISION, | |
| "task_presets": [dict(item) for item in TASK_PRESETS], | |
| "sub_environments": [item.value for item in SubEnvironment], | |
| } | |
| def apply_task_preset(payload: dict[str, Any]) -> dict[str, Any]: | |
| """Expand an old task label into canonical difficulty/sub-environment fields.""" | |
| task_id = payload.pop("task_id", None) | |
| if not task_id: | |
| return payload | |
| preset = TASK_PRESET_BY_ID.get(str(task_id)) | |
| if preset is None: | |
| raise ValueError(f"Unknown task_id {task_id!r}") | |
| if payload.get("difficulty") is None: | |
| payload["difficulty"] = preset["difficulty"] | |
| if payload.get("sub_environment") is None: | |
| payload["sub_environment"] = preset["sub_environment"] | |
| return payload | |