| """Adapter catalog for old task labels over the canonical PolyGuard env.""" |
|
|
| from __future__ import annotations |
|
|
| from typing import Any |
|
|
| from app.common.constants import REWARD_MAX, REWARD_MIN, REWARD_PRECISION |
| from app.common.enums import Difficulty, SubEnvironment |
|
|
| TASK_PRESETS: tuple[dict[str, str], ...] = ( |
| { |
| "id": "easy_screening", |
| "label": "Easy Screening", |
| "difficulty": Difficulty.EASY.value, |
| "sub_environment": SubEnvironment.DDI.value, |
| }, |
| { |
| "id": "budgeted_screening", |
| "label": "Budgeted Screening", |
| "difficulty": Difficulty.MEDIUM.value, |
| "sub_environment": SubEnvironment.REGIMEN_RISK.value, |
| }, |
| { |
| "id": "complex_tradeoff", |
| "label": "Complex Tradeoff", |
| "difficulty": Difficulty.HARD.value, |
| "sub_environment": SubEnvironment.REGIMEN_RISK.value, |
| }, |
| { |
| "id": "bandit_mining", |
| "label": "Bandit Mining", |
| "difficulty": Difficulty.HARD.value, |
| "sub_environment": SubEnvironment.BANDIT_MINING.value, |
| }, |
| ) |
|
|
| TASK_PRESET_BY_ID = {item["id"]: item for item in TASK_PRESETS} |
|
|
|
|
| def env_catalog() -> dict[str, Any]: |
| return { |
| "reward_range": [REWARD_MIN, REWARD_MAX], |
| "reward_precision": REWARD_PRECISION, |
| "task_presets": [dict(item) for item in TASK_PRESETS], |
| "sub_environments": [item.value for item in SubEnvironment], |
| } |
|
|
|
|
| def apply_task_preset(payload: dict[str, Any]) -> dict[str, Any]: |
| """Expand an old task label into canonical difficulty/sub-environment fields.""" |
| task_id = payload.pop("task_id", None) |
| if not task_id: |
| return payload |
| preset = TASK_PRESET_BY_ID.get(str(task_id)) |
| if preset is None: |
| raise ValueError(f"Unknown task_id {task_id!r}") |
| if payload.get("difficulty") is None: |
| payload["difficulty"] = preset["difficulty"] |
| if payload.get("sub_environment") is None: |
| payload["sub_environment"] = preset["sub_environment"] |
| return payload |
|
|