Spaces:
Sleeping
Sleeping
File size: 14,358 Bytes
26ba066 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 | """
ECHO ULTIMATE β OpenEnv-compliant environment.
EchoOpenEnv extends BOTH openenv.core.Environment AND gymnasium.Env (via EchoEnv),
satisfying the full OpenEnv protocol:
reset(seed, episode_id, **kwargs) β EchoObservation
step(action: EchoAction, ...) β EchoObservation
state β EchoState (property)
get_metadata() β EnvironmentMetadata
Plus OpenEnv task-listing helpers:
info() β environment metadata dict
list_tasks() β all TaskSpec dicts
get_task(id) β single TaskSpec dict
Gymnasium-style callers (server, training) use the _gym_reset / _gym_step
helpers which still return (obs_dict, info) / (obs, reward, done, β¦) tuples.
"""
from __future__ import annotations
from dataclasses import dataclass, asdict
from typing import Any, Dict, Optional, List, Tuple
try:
from openenv.core import Environment
try:
from openenv.core.env import EnvironmentMetadata
except ImportError:
EnvironmentMetadata = None
except ImportError:
# Fallback: plain base class when openenv is not available
class Environment:
def __init__(self, transform=None, rubric=None, **kwargs):
pass
EnvironmentMetadata = None
from env.echo_env import EchoEnv
from env.task_bank import TaskBank
from env.reward import RewardHistory
from models import EchoAction, EchoObservation, EchoState
from core.tasks import TASKS
from config import cfg
# ββ OpenEnv task spec βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@dataclass
class TaskSpec:
id: str
name: str
description: str
pass_threshold: float
metric: str
n_episodes: int
domains: List[str]
difficulties: List[str]
def to_dict(self) -> dict:
return asdict(self)
@dataclass
class EnvInfo:
name: str
version: str
description: str
observation_format: str
action_format: str
reward_range: Tuple[float, float]
domains: List[str]
tasks: List[str]
def to_dict(self) -> dict:
return asdict(self)
# ββ Main environment ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class EchoOpenEnv(Environment[EchoAction, EchoObservation, EchoState], EchoEnv):
"""
ECHO ULTIMATE: OpenEnv-compliant RL environment for LLM calibration.
Extends openenv.core.Environment (OpenEnv protocol) AND EchoEnv (gymnasium.Env).
OpenEnv usage β stateless per-request:
env = EchoOpenEnv()
obs = env.reset() # EchoObservation
obs = env.step(EchoAction(response="...")) # EchoObservation
s = env.state # EchoState
Gymnasium usage β stateful episodes:
obs_dict, info = env._gym_reset()
obs_dict, r, done, _, info = env._gym_step("<confidence>72</confidence><answer>Paris</answer>")
Training loop:
env = EchoOpenEnv(phase=1)
for _ in range(n_steps):
obs_dict, info = env._gym_reset()
prompt = info["formatted_prompt"]
response = model.generate(prompt)
_, reward, _, _, _ = env._gym_step(response)
"""
# OpenEnv class attributes
SUPPORTS_CONCURRENT_SESSIONS: bool = False
OPENENV_PROTOCOL_VERSION: str = "1.0"
N_TASKS: int = 3
OBSERVATION_TYPE: str = "dict"
ACTION_TYPE: str = "text"
def __init__(
self,
task_id: Optional[str] = None,
task_bank: Optional[TaskBank] = None,
reward_history: Optional[RewardHistory] = None,
phase: int = 1,
self_consistency: bool = False,
generate_fn=None,
render_mode: Optional[str] = None,
) -> None:
# Init gymnasium env (EchoEnv sets up task_bank, reward_history, spaces, etc.)
EchoEnv.__init__(
self,
task_bank=task_bank,
reward_history=reward_history,
phase=phase,
self_consistency=self_consistency,
generate_fn=generate_fn,
render_mode=render_mode,
)
# Init openenv.core.Environment (sets transform=None, rubric=None)
Environment.__init__(self, transform=None, rubric=None)
self._default_task_id = task_id
# ββ OpenEnv abstract method: reset ββββββββββββββββββββββββββββββββββββββββ
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
**kwargs,
) -> EchoObservation:
"""
OpenEnv reset β returns EchoObservation.
Accepts kwargs: options={"task_id": "task_hard"} or task_id="task_easy".
"""
options = kwargs.get("options")
task_id = kwargs.get("task_id") or self._default_task_id
if options is None and task_id:
options = {"task_id": task_id}
obs_dict, _ = EchoEnv.reset(self, seed=seed, options=options)
return self._obs_from_dict(obs_dict, done=False)
# ββ OpenEnv abstract method: step βββββββββββββββββββββββββββββββββββββββββ
def step(
self,
action: EchoAction | str,
timeout_s: Optional[float] = None,
**kwargs,
) -> EchoObservation:
"""OpenEnv step β accepts EchoAction or raw string, returns EchoObservation."""
response = action.response if isinstance(action, EchoAction) else str(action)
obs_dict, reward, terminated, truncated, info = EchoEnv.step(self, response)
return self._obs_from_step(obs_dict, reward, terminated or truncated, info)
# ββ OpenEnv abstract property: state ββββββββββββββββββββββββββββββββββββββ
@property
def state(self) -> EchoState:
"""OpenEnv state property β returns full EchoState snapshot."""
task = self._current_task or {}
snap = self.reward_history.get_training_snapshot(last_n=100)
profiles = self.reward_history.get_domain_profiles()
return EchoState(
current_question=task.get("question", ""),
domain=task.get("domain", ""),
difficulty=task.get("difficulty", ""),
phase=self.phase,
step_count=self._episode_step,
total_reward=self._episode_reward,
domain_stats={
d: {"ece": round(p.ece, 3), "accuracy": round(p.accuracy, 3)}
for d, p in profiles.items()
if p.n_samples > 0
},
)
# ββ OpenEnv metadata ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def get_metadata(self):
"""OpenEnv environment metadata."""
if EnvironmentMetadata is not None:
return EnvironmentMetadata(
name="ECHO-ULTIMATE",
version="2.0.0",
description=(
"RL environment for LLM metacognitive calibration. "
"Trains models to accurately predict their own probability of "
"being correct across 7 domains via GRPO with Brier-score rewards."
),
)
return {
"name": "ECHO-ULTIMATE",
"version": "2.0.0",
"description": "OpenEnv RL environment for LLM metacognitive calibration.",
}
# ββ Gymnasium-compatible helpers (for server + training) ββββββββββββββββββ
def _gym_reset(
self,
seed: Optional[int] = None,
options: Optional[dict] = None,
) -> Tuple[dict, dict]:
"""Gymnasium-style reset returning (obs_dict, info) tuple."""
if options is None and self._default_task_id:
options = {"task_id": self._default_task_id}
return EchoEnv.reset(self, seed=seed, options=options)
def _gym_step(self, response: str) -> Tuple[dict, float, bool, bool, dict]:
"""Gymnasium-style step returning (obs, reward, terminated, truncated, info)."""
return EchoEnv.step(self, response)
# ββ Task-listing helpers (OpenEnv task bank protocol) βββββββββββββββββββββ
def info(self) -> dict:
"""Return environment metadata dict."""
return EnvInfo(
name="ECHO-ULTIMATE",
version="2.0.0",
description=(
"RL environment for LLM metacognitive calibration. "
"Teaches models to accurately predict their own probability of being correct "
"across 7 domains via GRPO with Brier-score calibration rewards."
),
observation_format=(
"EchoObservation: {question, domain, difficulty, reward, done, "
"ece, accuracy, confidence, brier_score, is_correct, feedback}"
),
action_format="EchoAction: {response='<confidence>N</confidence><answer>TEXT</answer>'}",
reward_range=(cfg.REWARD_CLIP_LOW, cfg.REWARD_CLIP_HIGH),
domains=cfg.DOMAINS,
tasks=[t.id for t in TASKS],
).to_dict()
def list_tasks(self) -> List[dict]:
"""Return all task specifications."""
return [
TaskSpec(
id=t.id,
name=t.name,
description=t.description,
pass_threshold=t.pass_threshold,
metric=t.metric,
n_episodes=t.n_episodes,
domains=cfg.DOMAINS,
difficulties=cfg.DIFFICULTIES,
).to_dict()
for t in TASKS
]
def get_task(self, task_id: str) -> Optional[dict]:
"""Return a single task spec by ID."""
for t in TASKS:
if t.id == task_id:
return TaskSpec(
id=t.id,
name=t.name,
description=t.description,
pass_threshold=t.pass_threshold,
metric=t.metric,
n_episodes=t.n_episodes,
domains=cfg.DOMAINS,
difficulties=cfg.DIFFICULTIES,
).to_dict()
return None
# ββ Evaluation helper βββββββββββββββββββββββββββββββββββββββββββββββββββββ
def evaluate(
self,
n_episodes: int = 30,
task_id: Optional[str] = None,
) -> dict:
"""Run n_episodes and return OpenEnv-style evaluation results."""
rewards = []
for _ in range(n_episodes):
obs_dict, info = self._gym_reset(
options={"task_id": task_id} if task_id else None
)
placeholder = "<confidence>50</confidence><answer>unknown</answer>"
_, reward, _, _, _ = self._gym_step(placeholder)
rewards.append(reward)
metrics = self.get_metrics()
task_spec = self.get_task(task_id) if task_id else None
threshold = task_spec["pass_threshold"] if task_spec else 0.5
score = max(0.0, 1.0 - metrics.ece) * min(1.0, metrics.accuracy / 0.55)
return {
"n_episodes": n_episodes,
"ece": round(metrics.ece, 4),
"accuracy": round(metrics.accuracy, 4),
"brier_score": round(metrics.brier, 4),
"overconfidence_rate": round(metrics.overconfidence_rate, 4),
"mean_reward": round(sum(rewards) / len(rewards), 4),
"score": round(score, 4),
"pass_threshold": threshold,
"passed": score >= threshold,
}
# ββ Internal helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _obs_from_dict(self, obs_dict: dict, done: bool = False) -> EchoObservation:
"""Convert _build_obs() dict β EchoObservation (after reset)."""
task = self._current_task or {}
return EchoObservation(
question=task.get("question", obs_dict.get("question", "")),
domain=obs_dict.get("domain", ""),
difficulty=obs_dict.get("difficulty", ""),
ece=float(obs_dict.get("running_ece", 0.0)),
accuracy=float(obs_dict.get("running_accuracy", 0.0)),
confidence=int(obs_dict.get("running_mean_confidence", 50)),
done=done,
)
def _obs_from_step(
self,
obs_dict: dict,
reward: float,
done: bool,
info: dict,
) -> EchoObservation:
"""Convert step() outputs β EchoObservation."""
return EchoObservation(
question=(self._current_task or {}).get("question", ""),
domain=info.get("domain", obs_dict.get("domain", "")),
difficulty=info.get("difficulty", obs_dict.get("difficulty", "")),
reward=float(reward),
done=done,
ece=float(obs_dict.get("running_ece", 0.0)),
accuracy=float(info.get("accuracy", 0.0)),
confidence=int(info.get("parsed_confidence", 50)),
brier_score=float(info.get("brier_reward", 0.0)),
is_correct=bool(info.get("was_correct", False)),
feedback=info.get("breakdown", ""),
)
# ββ Convenience factory βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def make_echo_env(
task_id: Optional[str] = None,
phase: int = 1,
**kwargs,
) -> EchoOpenEnv:
"""Factory function for creating an ECHO OpenEnv environment."""
return EchoOpenEnv(task_id=task_id, phase=phase, **kwargs)
|