subtext-arena / server /subtext_arena_environment.py
aamrinder's picture
Upload folder using huggingface_hub
225e725 verified
"""Subtext Arena environment.
Episode flow:
1. reset() picks a random MUStARD clip (Pivot Set oversampled 3x).
Returns: clip_id + speaker + duration, no transcript yet — the agent
must call get_transcript and/or audio tools to investigate.
2. step(SubtextArenaAction) executes one tool call:
- get_transcript -> literal text + conversational context
- get_prosody_features -> pitch, energy, pause text summary
- get_pitch_contour -> ASCII contour
- submit_belief -> terminates episode with label + confidence
Reward = per-step delta (small + for tool use, penalties for malformed
actions) + final composite reward when submit_belief fires.
3. After max_steps (default 6) without a submission, the episode is force-
terminated with the no_submission penalty.
The trained policy is a TEXT LLM (Path A). Audio is processed by the env's
frozen prosody-feature pipeline; the agent only ever sees text. Audio is
load-bearing because the Pivot Set explicitly contains clips where the literal
transcript alone leads to the wrong answer — the agent must consult prosody
to score on those.
"""
from __future__ import annotations
import os
import random
from typing import Optional
from uuid import uuid4
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
try:
from ..models import SubtextArenaAction, SubtextArenaObservation
except ImportError:
from models import SubtextArenaAction, SubtextArenaObservation # type: ignore[no-redef]
try:
from .scenarios import load_scenarios, sample_clip
from .audio_tools import (
render_transcript,
render_prosody_features,
render_pitch_contour,
)
from .grader import step_reward, final_reward
except ImportError:
from server.scenarios import load_scenarios, sample_clip # type: ignore[no-redef]
from server.audio_tools import ( # type: ignore[no-redef]
render_transcript,
render_prosody_features,
render_pitch_contour,
)
from server.grader import step_reward, final_reward # type: ignore[no-redef]
VALID_TOOLS = {
"get_transcript",
"get_prosody_features",
"get_pitch_contour",
"submit_belief",
}
AUDIO_TOOLS = {"get_prosody_features", "get_pitch_contour"}
class SubtextArenaEnvironment(Environment):
"""OpenEnv environment for sarcasm-vs-sincere classification on MUStARD."""
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(self, max_steps: int = 6, seed: Optional[int] = None):
self._scenarios = load_scenarios()
self._max_steps = max_steps
self._rng = random.Random(seed if seed is not None else os.urandom(4))
self._state = State(episode_id=str(uuid4()), step_count=0)
self._current_clip_id: Optional[str] = None
self._n_audio_calls = 0
self._n_total_calls = 0
self._terminated = False
# When set (via FORCE_CLIP_ID env var), reset() picks this clip instead
# of sampling. Used by eval_pivot_set.py to walk specific clips.
self._force_next_clip_id: Optional[str] = None
def force_next_reset(self, clip_id: str) -> None:
"""Force the next reset() to pick the given clip ID.
Called by eval scripts that need to evaluate on specific clips
(e.g. all 50 Prosody-Pivot clips) rather than random sampling.
Auto-clears after one reset.
"""
if clip_id not in self._scenarios:
raise ValueError(
f"Unknown clip_id {clip_id!r}; not in MUStARD scenarios."
)
self._force_next_clip_id = clip_id
# ------------------------------------------------------------------
# Reset
# ------------------------------------------------------------------
def reset(self) -> SubtextArenaObservation:
if self._force_next_clip_id is not None:
clip_id = self._force_next_clip_id
self._force_next_clip_id = None
else:
# Honor FORCE_CLIP_ID env var as a fallback (works through HTTP too)
forced = os.environ.get("FORCE_CLIP_ID", "").strip()
if forced and forced in self._scenarios:
clip_id = forced
else:
clip_id = sample_clip(self._scenarios, self._rng)
clip = self._scenarios[clip_id]
prosody = clip.get("prosody") or {}
self._state = State(episode_id=str(uuid4()), step_count=0)
self._current_clip_id = clip_id
self._n_audio_calls = 0
self._n_total_calls = 0
self._terminated = False
return SubtextArenaObservation(
clip_id=clip_id,
speaker=clip.get("speaker", ""),
duration_s=float(prosody.get("duration_s", 0.0)),
is_pivot=bool(clip.get("is_pivot", False)),
tool_used="reset",
tool_output=(
f"Episode started. Clip {clip_id}, speaker {clip.get('speaker', '?')}, "
f"duration {prosody.get('duration_s', 0.0):.2f}s. "
f"You have {self._max_steps} tool calls before forced submission. "
f"Available tools: get_transcript, get_prosody_features, get_pitch_contour, submit_belief."
),
step=0,
max_steps=self._max_steps,
audio_calls_so_far=0,
done=False,
reward=0.0,
)
# ------------------------------------------------------------------
# Step
# ------------------------------------------------------------------
def step(self, action: SubtextArenaAction) -> SubtextArenaObservation: # type: ignore[override]
if self._current_clip_id is None or self._terminated:
# Episode ended — return done=True with no reward
return SubtextArenaObservation(
clip_id=self._current_clip_id or "",
tool_used="",
tool_output="Episode terminated. Call reset() to start a new episode.",
step=self._state.step_count,
max_steps=self._max_steps,
audio_calls_so_far=self._n_audio_calls,
done=True,
reward=0.0,
error="episode_terminated",
)
clip = self._scenarios[self._current_clip_id]
prosody = clip.get("prosody") or {}
self._state.step_count += 1
self._n_total_calls += 1
tool = (action.tool or "").strip()
args = action.tool_args or {}
error: Optional[str] = None
tool_output: str = ""
if tool not in VALID_TOOLS:
error = f"unknown tool '{tool}'. Valid: {sorted(VALID_TOOLS)}"
tool_output = f"[error] {error}"
reward = step_reward(tool, error)
return SubtextArenaObservation(
clip_id=self._current_clip_id,
speaker=clip.get("speaker", ""),
duration_s=float(prosody.get("duration_s", 0.0)),
is_pivot=bool(clip.get("is_pivot", False)),
tool_used=tool,
tool_output=tool_output,
step=self._state.step_count,
max_steps=self._max_steps,
audio_calls_so_far=self._n_audio_calls,
done=False,
reward=reward,
error=error,
)
if tool == "get_transcript":
tool_output = render_transcript(self._current_clip_id, self._scenarios)
elif tool == "get_prosody_features":
tool_output = render_prosody_features(self._current_clip_id, prosody, args)
self._n_audio_calls += 1
elif tool == "get_pitch_contour":
tool_output = render_pitch_contour(self._current_clip_id, prosody, args)
self._n_audio_calls += 1
elif tool == "submit_belief":
return self._submit_and_terminate(args, clip, prosody)
# Per-step delta for non-terminal actions
per_step = step_reward(tool, error)
forced_terminate = self._state.step_count >= self._max_steps
if forced_terminate:
# Force a submission with no label -> apply no_submission penalty
return self._submit_and_terminate(
{"label": None, "confidence": 0.0},
clip,
prosody,
forced=True,
preceding_reward=per_step,
)
return SubtextArenaObservation(
clip_id=self._current_clip_id,
speaker=clip.get("speaker", ""),
duration_s=float(prosody.get("duration_s", 0.0)),
is_pivot=bool(clip.get("is_pivot", False)),
tool_used=tool,
tool_output=tool_output,
step=self._state.step_count,
max_steps=self._max_steps,
audio_calls_so_far=self._n_audio_calls,
done=False,
reward=per_step,
error=error,
)
def _submit_and_terminate(
self,
args: dict,
clip: dict,
prosody: dict,
forced: bool = False,
preceding_reward: float = 0.0,
) -> SubtextArenaObservation:
label = args.get("label")
if isinstance(label, str):
label = label.strip().lower()
if label not in {"sarcastic", "sincere"}:
label = None
else:
label = None
confidence = float(args.get("confidence", 0.5) or 0.5)
gold = "sarcastic" if clip.get("sarcasm") else "sincere"
components = final_reward(
submitted_label=label,
submitted_confidence=confidence,
gold_label=gold,
is_pivot=bool(clip.get("is_pivot", False)),
n_audio_calls=self._n_audio_calls,
n_total_calls=self._n_total_calls,
)
total_reward = components["_total"] + preceding_reward
if forced:
tool_output = (
f"[forced termination after {self._max_steps} steps without submit_belief]\n"
f"Gold label: {gold}. Reward components: {components}"
)
else:
verdict = "CORRECT" if (label == gold) else "WRONG"
tool_output = (
f"Submitted: label={label}, confidence={confidence:.2f}. "
f"Gold: {gold}. {verdict}. Reward components: {components}"
)
self._terminated = True
return SubtextArenaObservation(
clip_id=self._current_clip_id or "",
speaker=clip.get("speaker", ""),
duration_s=float(prosody.get("duration_s", 0.0)),
is_pivot=bool(clip.get("is_pivot", False)),
tool_used="submit_belief",
tool_output=tool_output,
step=self._state.step_count,
max_steps=self._max_steps,
audio_calls_so_far=self._n_audio_calls,
done=True,
reward=round(total_reward, 4),
metadata={
"gold": gold,
"submitted_label": label,
"submitted_confidence": confidence,
"n_audio_calls": self._n_audio_calls,
"n_total_calls": self._n_total_calls,
"is_pivot": bool(clip.get("is_pivot", False)),
"reward_components": components,
},
)
# ------------------------------------------------------------------
# State
# ------------------------------------------------------------------
@property
def state(self) -> State:
return self._state