Spaces:
Sleeping
Sleeping
| """Subtext Arena environment. | |
| Episode flow: | |
| 1. reset() picks a random MUStARD clip (Pivot Set oversampled 3x). | |
| Returns: clip_id + speaker + duration, no transcript yet — the agent | |
| must call get_transcript and/or audio tools to investigate. | |
| 2. step(SubtextArenaAction) executes one tool call: | |
| - get_transcript -> literal text + conversational context | |
| - get_prosody_features -> pitch, energy, pause text summary | |
| - get_pitch_contour -> ASCII contour | |
| - submit_belief -> terminates episode with label + confidence | |
| Reward = per-step delta (small + for tool use, penalties for malformed | |
| actions) + final composite reward when submit_belief fires. | |
| 3. After max_steps (default 6) without a submission, the episode is force- | |
| terminated with the no_submission penalty. | |
| The trained policy is a TEXT LLM (Path A). Audio is processed by the env's | |
| frozen prosody-feature pipeline; the agent only ever sees text. Audio is | |
| load-bearing because the Pivot Set explicitly contains clips where the literal | |
| transcript alone leads to the wrong answer — the agent must consult prosody | |
| to score on those. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import random | |
| from typing import Optional | |
| from uuid import uuid4 | |
| from openenv.core.env_server.interfaces import Environment | |
| from openenv.core.env_server.types import State | |
| try: | |
| from ..models import SubtextArenaAction, SubtextArenaObservation | |
| except ImportError: | |
| from models import SubtextArenaAction, SubtextArenaObservation # type: ignore[no-redef] | |
| try: | |
| from .scenarios import load_scenarios, sample_clip | |
| from .audio_tools import ( | |
| render_transcript, | |
| render_prosody_features, | |
| render_pitch_contour, | |
| ) | |
| from .grader import step_reward, final_reward | |
| except ImportError: | |
| from server.scenarios import load_scenarios, sample_clip # type: ignore[no-redef] | |
| from server.audio_tools import ( # type: ignore[no-redef] | |
| render_transcript, | |
| render_prosody_features, | |
| render_pitch_contour, | |
| ) | |
| from server.grader import step_reward, final_reward # type: ignore[no-redef] | |
| VALID_TOOLS = { | |
| "get_transcript", | |
| "get_prosody_features", | |
| "get_pitch_contour", | |
| "submit_belief", | |
| } | |
| AUDIO_TOOLS = {"get_prosody_features", "get_pitch_contour"} | |
| class SubtextArenaEnvironment(Environment): | |
| """OpenEnv environment for sarcasm-vs-sincere classification on MUStARD.""" | |
| SUPPORTS_CONCURRENT_SESSIONS: bool = True | |
| def __init__(self, max_steps: int = 6, seed: Optional[int] = None): | |
| self._scenarios = load_scenarios() | |
| self._max_steps = max_steps | |
| self._rng = random.Random(seed if seed is not None else os.urandom(4)) | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| self._current_clip_id: Optional[str] = None | |
| self._n_audio_calls = 0 | |
| self._n_total_calls = 0 | |
| self._terminated = False | |
| # When set (via FORCE_CLIP_ID env var), reset() picks this clip instead | |
| # of sampling. Used by eval_pivot_set.py to walk specific clips. | |
| self._force_next_clip_id: Optional[str] = None | |
| def force_next_reset(self, clip_id: str) -> None: | |
| """Force the next reset() to pick the given clip ID. | |
| Called by eval scripts that need to evaluate on specific clips | |
| (e.g. all 50 Prosody-Pivot clips) rather than random sampling. | |
| Auto-clears after one reset. | |
| """ | |
| if clip_id not in self._scenarios: | |
| raise ValueError( | |
| f"Unknown clip_id {clip_id!r}; not in MUStARD scenarios." | |
| ) | |
| self._force_next_clip_id = clip_id | |
| # ------------------------------------------------------------------ | |
| # Reset | |
| # ------------------------------------------------------------------ | |
| def reset(self) -> SubtextArenaObservation: | |
| if self._force_next_clip_id is not None: | |
| clip_id = self._force_next_clip_id | |
| self._force_next_clip_id = None | |
| else: | |
| # Honor FORCE_CLIP_ID env var as a fallback (works through HTTP too) | |
| forced = os.environ.get("FORCE_CLIP_ID", "").strip() | |
| if forced and forced in self._scenarios: | |
| clip_id = forced | |
| else: | |
| clip_id = sample_clip(self._scenarios, self._rng) | |
| clip = self._scenarios[clip_id] | |
| prosody = clip.get("prosody") or {} | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| self._current_clip_id = clip_id | |
| self._n_audio_calls = 0 | |
| self._n_total_calls = 0 | |
| self._terminated = False | |
| return SubtextArenaObservation( | |
| clip_id=clip_id, | |
| speaker=clip.get("speaker", ""), | |
| duration_s=float(prosody.get("duration_s", 0.0)), | |
| is_pivot=bool(clip.get("is_pivot", False)), | |
| tool_used="reset", | |
| tool_output=( | |
| f"Episode started. Clip {clip_id}, speaker {clip.get('speaker', '?')}, " | |
| f"duration {prosody.get('duration_s', 0.0):.2f}s. " | |
| f"You have {self._max_steps} tool calls before forced submission. " | |
| f"Available tools: get_transcript, get_prosody_features, get_pitch_contour, submit_belief." | |
| ), | |
| step=0, | |
| max_steps=self._max_steps, | |
| audio_calls_so_far=0, | |
| done=False, | |
| reward=0.0, | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Step | |
| # ------------------------------------------------------------------ | |
| def step(self, action: SubtextArenaAction) -> SubtextArenaObservation: # type: ignore[override] | |
| if self._current_clip_id is None or self._terminated: | |
| # Episode ended — return done=True with no reward | |
| return SubtextArenaObservation( | |
| clip_id=self._current_clip_id or "", | |
| tool_used="", | |
| tool_output="Episode terminated. Call reset() to start a new episode.", | |
| step=self._state.step_count, | |
| max_steps=self._max_steps, | |
| audio_calls_so_far=self._n_audio_calls, | |
| done=True, | |
| reward=0.0, | |
| error="episode_terminated", | |
| ) | |
| clip = self._scenarios[self._current_clip_id] | |
| prosody = clip.get("prosody") or {} | |
| self._state.step_count += 1 | |
| self._n_total_calls += 1 | |
| tool = (action.tool or "").strip() | |
| args = action.tool_args or {} | |
| error: Optional[str] = None | |
| tool_output: str = "" | |
| if tool not in VALID_TOOLS: | |
| error = f"unknown tool '{tool}'. Valid: {sorted(VALID_TOOLS)}" | |
| tool_output = f"[error] {error}" | |
| reward = step_reward(tool, error) | |
| return SubtextArenaObservation( | |
| clip_id=self._current_clip_id, | |
| speaker=clip.get("speaker", ""), | |
| duration_s=float(prosody.get("duration_s", 0.0)), | |
| is_pivot=bool(clip.get("is_pivot", False)), | |
| tool_used=tool, | |
| tool_output=tool_output, | |
| step=self._state.step_count, | |
| max_steps=self._max_steps, | |
| audio_calls_so_far=self._n_audio_calls, | |
| done=False, | |
| reward=reward, | |
| error=error, | |
| ) | |
| if tool == "get_transcript": | |
| tool_output = render_transcript(self._current_clip_id, self._scenarios) | |
| elif tool == "get_prosody_features": | |
| tool_output = render_prosody_features(self._current_clip_id, prosody, args) | |
| self._n_audio_calls += 1 | |
| elif tool == "get_pitch_contour": | |
| tool_output = render_pitch_contour(self._current_clip_id, prosody, args) | |
| self._n_audio_calls += 1 | |
| elif tool == "submit_belief": | |
| return self._submit_and_terminate(args, clip, prosody) | |
| # Per-step delta for non-terminal actions | |
| per_step = step_reward(tool, error) | |
| forced_terminate = self._state.step_count >= self._max_steps | |
| if forced_terminate: | |
| # Force a submission with no label -> apply no_submission penalty | |
| return self._submit_and_terminate( | |
| {"label": None, "confidence": 0.0}, | |
| clip, | |
| prosody, | |
| forced=True, | |
| preceding_reward=per_step, | |
| ) | |
| return SubtextArenaObservation( | |
| clip_id=self._current_clip_id, | |
| speaker=clip.get("speaker", ""), | |
| duration_s=float(prosody.get("duration_s", 0.0)), | |
| is_pivot=bool(clip.get("is_pivot", False)), | |
| tool_used=tool, | |
| tool_output=tool_output, | |
| step=self._state.step_count, | |
| max_steps=self._max_steps, | |
| audio_calls_so_far=self._n_audio_calls, | |
| done=False, | |
| reward=per_step, | |
| error=error, | |
| ) | |
| def _submit_and_terminate( | |
| self, | |
| args: dict, | |
| clip: dict, | |
| prosody: dict, | |
| forced: bool = False, | |
| preceding_reward: float = 0.0, | |
| ) -> SubtextArenaObservation: | |
| label = args.get("label") | |
| if isinstance(label, str): | |
| label = label.strip().lower() | |
| if label not in {"sarcastic", "sincere"}: | |
| label = None | |
| else: | |
| label = None | |
| confidence = float(args.get("confidence", 0.5) or 0.5) | |
| gold = "sarcastic" if clip.get("sarcasm") else "sincere" | |
| components = final_reward( | |
| submitted_label=label, | |
| submitted_confidence=confidence, | |
| gold_label=gold, | |
| is_pivot=bool(clip.get("is_pivot", False)), | |
| n_audio_calls=self._n_audio_calls, | |
| n_total_calls=self._n_total_calls, | |
| ) | |
| total_reward = components["_total"] + preceding_reward | |
| if forced: | |
| tool_output = ( | |
| f"[forced termination after {self._max_steps} steps without submit_belief]\n" | |
| f"Gold label: {gold}. Reward components: {components}" | |
| ) | |
| else: | |
| verdict = "CORRECT" if (label == gold) else "WRONG" | |
| tool_output = ( | |
| f"Submitted: label={label}, confidence={confidence:.2f}. " | |
| f"Gold: {gold}. {verdict}. Reward components: {components}" | |
| ) | |
| self._terminated = True | |
| return SubtextArenaObservation( | |
| clip_id=self._current_clip_id or "", | |
| speaker=clip.get("speaker", ""), | |
| duration_s=float(prosody.get("duration_s", 0.0)), | |
| is_pivot=bool(clip.get("is_pivot", False)), | |
| tool_used="submit_belief", | |
| tool_output=tool_output, | |
| step=self._state.step_count, | |
| max_steps=self._max_steps, | |
| audio_calls_so_far=self._n_audio_calls, | |
| done=True, | |
| reward=round(total_reward, 4), | |
| metadata={ | |
| "gold": gold, | |
| "submitted_label": label, | |
| "submitted_confidence": confidence, | |
| "n_audio_calls": self._n_audio_calls, | |
| "n_total_calls": self._n_total_calls, | |
| "is_pivot": bool(clip.get("is_pivot", False)), | |
| "reward_components": components, | |
| }, | |
| ) | |
| # ------------------------------------------------------------------ | |
| # State | |
| # ------------------------------------------------------------------ | |
| def state(self) -> State: | |
| return self._state | |