Spaces:
Sleeping
Sleeping
| """ | |
| ECHO ULTIMATE β Main Gymnasium Environment. | |
| Each episode = 1 question β 1 answer β 1 reward. | |
| State includes running calibration metrics across all 7 domains. | |
| """ | |
| import logging | |
| from typing import Any, Callable, Optional | |
| import gymnasium as gym | |
| import numpy as np | |
| from gymnasium import spaces | |
| from config import cfg | |
| from env.parser import parse_response, format_prompt, ParseResult | |
| from env.reward import compute_reward, RewardHistory, RewardBreakdown | |
| from env.task_bank import TaskBank | |
| logger = logging.getLogger(__name__) | |
| _DOMAIN_INDEX = {d: i for i, d in enumerate(cfg.DOMAINS)} | |
| class EchoEnv(gym.Env): | |
| """ | |
| ECHO ULTIMATE Gymnasium environment. | |
| Observation: dict with task info + running calibration metrics. | |
| Action: text string in <confidence>N</confidence><answer>X</answer> format. | |
| Reward: weighted accuracy + Brier calibration + overconfidence penalties. | |
| Each episode terminates after exactly one step. | |
| """ | |
| metadata = {"render_modes": ["human", "ansi"]} | |
| def __init__( | |
| self, | |
| task_bank: Optional[TaskBank] = None, | |
| reward_history: Optional[RewardHistory] = None, | |
| phase: int = 1, | |
| self_consistency: bool = False, | |
| generate_fn: Optional[Callable[[str], str]] = None, | |
| render_mode: Optional[str] = None, | |
| ) -> None: | |
| super().__init__() | |
| self.task_bank = task_bank or TaskBank() | |
| self.task_bank.ensure_loaded() | |
| self.reward_history = reward_history or RewardHistory() | |
| self.phase = phase | |
| self.self_consistency = self_consistency | |
| self.generate_fn = generate_fn | |
| self.render_mode = render_mode | |
| self._current_task: Optional[dict] = None | |
| self._last_result: Optional[RewardBreakdown] = None | |
| self._last_parsed: Optional[ParseResult] = None | |
| self._episode_step: int = 0 | |
| self._episode_reward: float = 0.0 | |
| # Gymnasium spaces (informational for text-based env) | |
| self.action_space = spaces.Text(min_length=1, max_length=1024) | |
| self.observation_space = spaces.Dict({ | |
| "task_id": spaces.Text(min_length=1, max_length=128), | |
| "domain": spaces.Text(min_length=1, max_length=32), | |
| "difficulty": spaces.Text(min_length=1, max_length=16), | |
| "question": spaces.Text(min_length=1, max_length=4096), | |
| "phase": spaces.Discrete(4), | |
| "episode_step": spaces.Discrete(3), | |
| "running_ece": spaces.Box(0, 1, shape=(1,), dtype=np.float32), | |
| "running_accuracy": spaces.Box(0, 1, shape=(1,), dtype=np.float32), | |
| "running_mean_confidence": spaces.Box(0, 100, shape=(1,), dtype=np.float32), | |
| "domain_ece": spaces.Box(0, 1, shape=(len(cfg.DOMAINS),), dtype=np.float32), | |
| }) | |
| # ββ Gymnasium API βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def reset( | |
| self, | |
| seed: Optional[int] = None, | |
| options: Optional[dict] = None, | |
| ) -> tuple[dict, dict]: | |
| super().reset(seed=seed) | |
| task_id = (options or {}).get("task_id") | |
| if task_id: | |
| task = self.task_bank.get_task_by_id(task_id) or \ | |
| self.task_bank.get_batch(1, self.phase)[0] | |
| elif (options or {}).get("adversarial"): | |
| task = self.task_bank.get_adversarial_batch(1)[0] | |
| else: | |
| task = self.task_bank.get_batch(1, self.phase)[0] | |
| self._current_task = task | |
| self._episode_step = 0 | |
| self._episode_reward = 0.0 | |
| self._last_result = None | |
| self._last_parsed = None | |
| prompt = format_prompt( | |
| task["question"], task["domain"], task["difficulty"], | |
| show_difficulty=(self.phase == 1), | |
| ) | |
| obs = self._build_obs() | |
| info = {"task": task, "formatted_prompt": prompt} | |
| return obs, info | |
| def step(self, action: str) -> tuple[dict, float, bool, bool, dict]: | |
| if self._current_task is None: | |
| logger.warning("step() called before reset() β auto-resetting") | |
| self.reset() | |
| task = self._current_task | |
| # Self-consistency check (demo mode only) | |
| if self.self_consistency and self.generate_fn is not None: | |
| from env.self_consistency import SelfConsistencyChecker | |
| checker = SelfConsistencyChecker() | |
| prompt = format_prompt(task["question"], task["domain"], task["difficulty"]) | |
| result = checker.check(prompt, self.generate_fn) | |
| # Override confidence from consistency check | |
| action = cfg.CONFIDENCE_FORMAT.format( | |
| conf=result.final_confidence, ans=result.final_answer | |
| ) | |
| parsed = parse_response(action) | |
| rb = compute_reward( | |
| confidence=parsed.confidence, | |
| predicted=parsed.answer, | |
| ground_truth=task["answer"], | |
| aliases=task.get("answer_aliases", []), | |
| domain=task["domain"], | |
| ) | |
| self.reward_history.append( | |
| confidence=parsed.confidence, | |
| was_correct=rb.was_correct, | |
| domain=task["domain"], | |
| difficulty=task["difficulty"], | |
| reward=rb.total, | |
| is_abstention=parsed.is_abstention, | |
| ) | |
| self._last_result = rb | |
| self._last_parsed = parsed | |
| self._episode_step = 1 | |
| self._episode_reward = rb.total | |
| obs = self._build_obs() | |
| info = { | |
| "accuracy": rb.accuracy_score, | |
| "brier_reward": rb.brier_reward_val, | |
| "overconfidence_penalty": rb.overconfidence_penalty_val, | |
| "underconfidence_penalty": rb.underconfidence_penalty_val, | |
| "parsed_confidence": parsed.confidence, | |
| "parsed_answer": parsed.answer, | |
| "true_answer": task["answer"], | |
| "was_correct": rb.was_correct, | |
| "parse_success": parsed.parse_success, | |
| "is_abstention": parsed.is_abstention, | |
| "task_id": task["id"], | |
| "domain": task["domain"], | |
| "difficulty": task["difficulty"], | |
| "breakdown": rb.breakdown_str, | |
| } | |
| if self.render_mode == "human": | |
| self.render() | |
| return obs, rb.total, True, False, info # terminated=True (single step) | |
| def render(self) -> None: | |
| if self._current_task is None: | |
| print("[EchoEnv] No active episode.") | |
| return | |
| task = self._current_task | |
| rb = self._last_result | |
| p = self._last_parsed | |
| snap = self.reward_history.get_training_snapshot(last_n=100) | |
| icon = "β " if (rb and rb.was_correct) else "β" | |
| conf = p.confidence if p else "β" | |
| ans = p.answer[:40] if p else "β" | |
| rew = f"{rb.total:+.3f}" if rb else "β" | |
| ece = f"{snap['ece']:.3f}" | |
| print(f"\nβ{'β'*37}β") | |
| print(f"β {'ECHO Episode Summary':<35} β") | |
| print(f"β{'β'*37}β€") | |
| print(f"β {'Domain:':<12} {task['domain']} ({task['difficulty']}){'':<10}β"[:40]) | |
| print(f"β {'Q:':<5} {task['question'][:30]+'β¦':<32} β") | |
| print(f"β {'Confidence:':<12} {conf}%{'':<22}β"[:40]) | |
| print(f"β {'Answer:':<12} {ans:<25} β"[:40]) | |
| print(f"β {'Correct:':<12} {icon:<25} β"[:40]) | |
| print(f"β {'Reward:':<12} {rew:<25} β"[:40]) | |
| print(f"β {'ECE (100ep):':<12} {ece:<25} β"[:40]) | |
| print(f"β{'β'*37}β") | |
| # ββ Metrics helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_metrics(self, domain: Optional[str] = None): | |
| return self.reward_history.get_calibration_report(domain=domain) | |
| def set_phase(self, phase: int) -> None: | |
| self.phase = max(1, min(3, phase)) | |
| def get_formatted_prompt(self) -> str: | |
| if self._current_task is None: | |
| return "" | |
| t = self._current_task | |
| return format_prompt(t["question"], t["domain"], t["difficulty"], | |
| show_difficulty=(self.phase == 1)) | |
| # ββ Internal ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _build_obs(self) -> dict: | |
| task = self._current_task or {} | |
| snap = self.reward_history.get_training_snapshot(last_n=100) | |
| profiles = self.reward_history.get_domain_profiles() | |
| domain_ece = np.array( | |
| [profiles.get(d).ece if profiles.get(d) and profiles[d].n_samples > 0 else 0.5 | |
| for d in cfg.DOMAINS], | |
| dtype=np.float32, | |
| ) | |
| return { | |
| "task_id": task.get("id", ""), | |
| "domain": task.get("domain", ""), | |
| "difficulty": task.get("difficulty", ""), | |
| "question": task.get("question", ""), | |
| "phase": self.phase, | |
| "episode_step": self._episode_step, | |
| "running_ece": float(snap["ece"]), | |
| "running_accuracy": float(snap["accuracy"]), | |
| "running_mean_confidence": float(snap["mean_confidence"]), | |
| "domain_ece": [float(x) for x in domain_ece], | |
| } | |