Vikaspandey582003 commited on
Commit
26ba066
·
verified ·
1 Parent(s): ee53375

fix: remove openenv.core dependency — pure FastAPI, graceful fallback

Browse files
Files changed (1) hide show
  1. env/openenv_env.py +363 -0
env/openenv_env.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ECHO ULTIMATE — OpenEnv-compliant environment.
3
+
4
+ EchoOpenEnv extends BOTH openenv.core.Environment AND gymnasium.Env (via EchoEnv),
5
+ satisfying the full OpenEnv protocol:
6
+
7
+ reset(seed, episode_id, **kwargs) → EchoObservation
8
+ step(action: EchoAction, ...) → EchoObservation
9
+ state → EchoState (property)
10
+ get_metadata() → EnvironmentMetadata
11
+
12
+ Plus OpenEnv task-listing helpers:
13
+ info() → environment metadata dict
14
+ list_tasks() → all TaskSpec dicts
15
+ get_task(id) → single TaskSpec dict
16
+
17
+ Gymnasium-style callers (server, training) use the _gym_reset / _gym_step
18
+ helpers which still return (obs_dict, info) / (obs, reward, done, …) tuples.
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ from dataclasses import dataclass, asdict
24
+ from typing import Any, Dict, Optional, List, Tuple
25
+
26
+ try:
27
+ from openenv.core import Environment
28
+ try:
29
+ from openenv.core.env import EnvironmentMetadata
30
+ except ImportError:
31
+ EnvironmentMetadata = None
32
+ except ImportError:
33
+ # Fallback: plain base class when openenv is not available
34
+ class Environment:
35
+ def __init__(self, transform=None, rubric=None, **kwargs):
36
+ pass
37
+ EnvironmentMetadata = None
38
+
39
+ from env.echo_env import EchoEnv
40
+ from env.task_bank import TaskBank
41
+ from env.reward import RewardHistory
42
+ from models import EchoAction, EchoObservation, EchoState
43
+ from core.tasks import TASKS
44
+ from config import cfg
45
+
46
+
47
+ # ── OpenEnv task spec ─────────────────────────────────────────────────────────
48
+
49
+ @dataclass
50
+ class TaskSpec:
51
+ id: str
52
+ name: str
53
+ description: str
54
+ pass_threshold: float
55
+ metric: str
56
+ n_episodes: int
57
+ domains: List[str]
58
+ difficulties: List[str]
59
+
60
+ def to_dict(self) -> dict:
61
+ return asdict(self)
62
+
63
+
64
+ @dataclass
65
+ class EnvInfo:
66
+ name: str
67
+ version: str
68
+ description: str
69
+ observation_format: str
70
+ action_format: str
71
+ reward_range: Tuple[float, float]
72
+ domains: List[str]
73
+ tasks: List[str]
74
+
75
+ def to_dict(self) -> dict:
76
+ return asdict(self)
77
+
78
+
79
+ # ── Main environment ──────────────────────────────────────────────────────────
80
+
81
+ class EchoOpenEnv(Environment[EchoAction, EchoObservation, EchoState], EchoEnv):
82
+ """
83
+ ECHO ULTIMATE: OpenEnv-compliant RL environment for LLM calibration.
84
+
85
+ Extends openenv.core.Environment (OpenEnv protocol) AND EchoEnv (gymnasium.Env).
86
+
87
+ OpenEnv usage — stateless per-request:
88
+ env = EchoOpenEnv()
89
+ obs = env.reset() # EchoObservation
90
+ obs = env.step(EchoAction(response="...")) # EchoObservation
91
+ s = env.state # EchoState
92
+
93
+ Gymnasium usage — stateful episodes:
94
+ obs_dict, info = env._gym_reset()
95
+ obs_dict, r, done, _, info = env._gym_step("<confidence>72</confidence><answer>Paris</answer>")
96
+
97
+ Training loop:
98
+ env = EchoOpenEnv(phase=1)
99
+ for _ in range(n_steps):
100
+ obs_dict, info = env._gym_reset()
101
+ prompt = info["formatted_prompt"]
102
+ response = model.generate(prompt)
103
+ _, reward, _, _, _ = env._gym_step(response)
104
+ """
105
+
106
+ # OpenEnv class attributes
107
+ SUPPORTS_CONCURRENT_SESSIONS: bool = False
108
+ OPENENV_PROTOCOL_VERSION: str = "1.0"
109
+ N_TASKS: int = 3
110
+ OBSERVATION_TYPE: str = "dict"
111
+ ACTION_TYPE: str = "text"
112
+
113
+ def __init__(
114
+ self,
115
+ task_id: Optional[str] = None,
116
+ task_bank: Optional[TaskBank] = None,
117
+ reward_history: Optional[RewardHistory] = None,
118
+ phase: int = 1,
119
+ self_consistency: bool = False,
120
+ generate_fn=None,
121
+ render_mode: Optional[str] = None,
122
+ ) -> None:
123
+ # Init gymnasium env (EchoEnv sets up task_bank, reward_history, spaces, etc.)
124
+ EchoEnv.__init__(
125
+ self,
126
+ task_bank=task_bank,
127
+ reward_history=reward_history,
128
+ phase=phase,
129
+ self_consistency=self_consistency,
130
+ generate_fn=generate_fn,
131
+ render_mode=render_mode,
132
+ )
133
+ # Init openenv.core.Environment (sets transform=None, rubric=None)
134
+ Environment.__init__(self, transform=None, rubric=None)
135
+ self._default_task_id = task_id
136
+
137
+ # ── OpenEnv abstract method: reset ────────────────────────────────────────
138
+
139
+ def reset(
140
+ self,
141
+ seed: Optional[int] = None,
142
+ episode_id: Optional[str] = None,
143
+ **kwargs,
144
+ ) -> EchoObservation:
145
+ """
146
+ OpenEnv reset — returns EchoObservation.
147
+
148
+ Accepts kwargs: options={"task_id": "task_hard"} or task_id="task_easy".
149
+ """
150
+ options = kwargs.get("options")
151
+ task_id = kwargs.get("task_id") or self._default_task_id
152
+ if options is None and task_id:
153
+ options = {"task_id": task_id}
154
+
155
+ obs_dict, _ = EchoEnv.reset(self, seed=seed, options=options)
156
+ return self._obs_from_dict(obs_dict, done=False)
157
+
158
+ # ── OpenEnv abstract method: step ─────────────────────────────────────────
159
+
160
+ def step(
161
+ self,
162
+ action: EchoAction | str,
163
+ timeout_s: Optional[float] = None,
164
+ **kwargs,
165
+ ) -> EchoObservation:
166
+ """OpenEnv step — accepts EchoAction or raw string, returns EchoObservation."""
167
+ response = action.response if isinstance(action, EchoAction) else str(action)
168
+ obs_dict, reward, terminated, truncated, info = EchoEnv.step(self, response)
169
+ return self._obs_from_step(obs_dict, reward, terminated or truncated, info)
170
+
171
+ # ── OpenEnv abstract property: state ──────────────────────────────────────
172
+
173
+ @property
174
+ def state(self) -> EchoState:
175
+ """OpenEnv state property — returns full EchoState snapshot."""
176
+ task = self._current_task or {}
177
+ snap = self.reward_history.get_training_snapshot(last_n=100)
178
+ profiles = self.reward_history.get_domain_profiles()
179
+ return EchoState(
180
+ current_question=task.get("question", ""),
181
+ domain=task.get("domain", ""),
182
+ difficulty=task.get("difficulty", ""),
183
+ phase=self.phase,
184
+ step_count=self._episode_step,
185
+ total_reward=self._episode_reward,
186
+ domain_stats={
187
+ d: {"ece": round(p.ece, 3), "accuracy": round(p.accuracy, 3)}
188
+ for d, p in profiles.items()
189
+ if p.n_samples > 0
190
+ },
191
+ )
192
+
193
+ # ── OpenEnv metadata ──────────────────────────────────────────────────────
194
+
195
+ def get_metadata(self):
196
+ """OpenEnv environment metadata."""
197
+ if EnvironmentMetadata is not None:
198
+ return EnvironmentMetadata(
199
+ name="ECHO-ULTIMATE",
200
+ version="2.0.0",
201
+ description=(
202
+ "RL environment for LLM metacognitive calibration. "
203
+ "Trains models to accurately predict their own probability of "
204
+ "being correct across 7 domains via GRPO with Brier-score rewards."
205
+ ),
206
+ )
207
+ return {
208
+ "name": "ECHO-ULTIMATE",
209
+ "version": "2.0.0",
210
+ "description": "OpenEnv RL environment for LLM metacognitive calibration.",
211
+ }
212
+
213
+ # ── Gymnasium-compatible helpers (for server + training) ──────────────────
214
+
215
+ def _gym_reset(
216
+ self,
217
+ seed: Optional[int] = None,
218
+ options: Optional[dict] = None,
219
+ ) -> Tuple[dict, dict]:
220
+ """Gymnasium-style reset returning (obs_dict, info) tuple."""
221
+ if options is None and self._default_task_id:
222
+ options = {"task_id": self._default_task_id}
223
+ return EchoEnv.reset(self, seed=seed, options=options)
224
+
225
+ def _gym_step(self, response: str) -> Tuple[dict, float, bool, bool, dict]:
226
+ """Gymnasium-style step returning (obs, reward, terminated, truncated, info)."""
227
+ return EchoEnv.step(self, response)
228
+
229
+ # ── Task-listing helpers (OpenEnv task bank protocol) ─────────────────────
230
+
231
+ def info(self) -> dict:
232
+ """Return environment metadata dict."""
233
+ return EnvInfo(
234
+ name="ECHO-ULTIMATE",
235
+ version="2.0.0",
236
+ description=(
237
+ "RL environment for LLM metacognitive calibration. "
238
+ "Teaches models to accurately predict their own probability of being correct "
239
+ "across 7 domains via GRPO with Brier-score calibration rewards."
240
+ ),
241
+ observation_format=(
242
+ "EchoObservation: {question, domain, difficulty, reward, done, "
243
+ "ece, accuracy, confidence, brier_score, is_correct, feedback}"
244
+ ),
245
+ action_format="EchoAction: {response='<confidence>N</confidence><answer>TEXT</answer>'}",
246
+ reward_range=(cfg.REWARD_CLIP_LOW, cfg.REWARD_CLIP_HIGH),
247
+ domains=cfg.DOMAINS,
248
+ tasks=[t.id for t in TASKS],
249
+ ).to_dict()
250
+
251
+ def list_tasks(self) -> List[dict]:
252
+ """Return all task specifications."""
253
+ return [
254
+ TaskSpec(
255
+ id=t.id,
256
+ name=t.name,
257
+ description=t.description,
258
+ pass_threshold=t.pass_threshold,
259
+ metric=t.metric,
260
+ n_episodes=t.n_episodes,
261
+ domains=cfg.DOMAINS,
262
+ difficulties=cfg.DIFFICULTIES,
263
+ ).to_dict()
264
+ for t in TASKS
265
+ ]
266
+
267
+ def get_task(self, task_id: str) -> Optional[dict]:
268
+ """Return a single task spec by ID."""
269
+ for t in TASKS:
270
+ if t.id == task_id:
271
+ return TaskSpec(
272
+ id=t.id,
273
+ name=t.name,
274
+ description=t.description,
275
+ pass_threshold=t.pass_threshold,
276
+ metric=t.metric,
277
+ n_episodes=t.n_episodes,
278
+ domains=cfg.DOMAINS,
279
+ difficulties=cfg.DIFFICULTIES,
280
+ ).to_dict()
281
+ return None
282
+
283
+ # ── Evaluation helper ─────────────────────────────────────────────────────
284
+
285
+ def evaluate(
286
+ self,
287
+ n_episodes: int = 30,
288
+ task_id: Optional[str] = None,
289
+ ) -> dict:
290
+ """Run n_episodes and return OpenEnv-style evaluation results."""
291
+ rewards = []
292
+ for _ in range(n_episodes):
293
+ obs_dict, info = self._gym_reset(
294
+ options={"task_id": task_id} if task_id else None
295
+ )
296
+ placeholder = "<confidence>50</confidence><answer>unknown</answer>"
297
+ _, reward, _, _, _ = self._gym_step(placeholder)
298
+ rewards.append(reward)
299
+
300
+ metrics = self.get_metrics()
301
+ task_spec = self.get_task(task_id) if task_id else None
302
+ threshold = task_spec["pass_threshold"] if task_spec else 0.5
303
+ score = max(0.0, 1.0 - metrics.ece) * min(1.0, metrics.accuracy / 0.55)
304
+
305
+ return {
306
+ "n_episodes": n_episodes,
307
+ "ece": round(metrics.ece, 4),
308
+ "accuracy": round(metrics.accuracy, 4),
309
+ "brier_score": round(metrics.brier, 4),
310
+ "overconfidence_rate": round(metrics.overconfidence_rate, 4),
311
+ "mean_reward": round(sum(rewards) / len(rewards), 4),
312
+ "score": round(score, 4),
313
+ "pass_threshold": threshold,
314
+ "passed": score >= threshold,
315
+ }
316
+
317
+ # ── Internal helpers ──────────────────────────────────────────────────────
318
+
319
+ def _obs_from_dict(self, obs_dict: dict, done: bool = False) -> EchoObservation:
320
+ """Convert _build_obs() dict → EchoObservation (after reset)."""
321
+ task = self._current_task or {}
322
+ return EchoObservation(
323
+ question=task.get("question", obs_dict.get("question", "")),
324
+ domain=obs_dict.get("domain", ""),
325
+ difficulty=obs_dict.get("difficulty", ""),
326
+ ece=float(obs_dict.get("running_ece", 0.0)),
327
+ accuracy=float(obs_dict.get("running_accuracy", 0.0)),
328
+ confidence=int(obs_dict.get("running_mean_confidence", 50)),
329
+ done=done,
330
+ )
331
+
332
+ def _obs_from_step(
333
+ self,
334
+ obs_dict: dict,
335
+ reward: float,
336
+ done: bool,
337
+ info: dict,
338
+ ) -> EchoObservation:
339
+ """Convert step() outputs → EchoObservation."""
340
+ return EchoObservation(
341
+ question=(self._current_task or {}).get("question", ""),
342
+ domain=info.get("domain", obs_dict.get("domain", "")),
343
+ difficulty=info.get("difficulty", obs_dict.get("difficulty", "")),
344
+ reward=float(reward),
345
+ done=done,
346
+ ece=float(obs_dict.get("running_ece", 0.0)),
347
+ accuracy=float(info.get("accuracy", 0.0)),
348
+ confidence=int(info.get("parsed_confidence", 50)),
349
+ brier_score=float(info.get("brier_reward", 0.0)),
350
+ is_correct=bool(info.get("was_correct", False)),
351
+ feedback=info.get("breakdown", ""),
352
+ )
353
+
354
+
355
+ # ── Convenience factory ───────────────────────────────────────────────────────
356
+
357
+ def make_echo_env(
358
+ task_id: Optional[str] = None,
359
+ phase: int = 1,
360
+ **kwargs,
361
+ ) -> EchoOpenEnv:
362
+ """Factory function for creating an ECHO OpenEnv environment."""
363
+ return EchoOpenEnv(task_id=task_id, phase=phase, **kwargs)