subtext-arena / server /scenarios.py
aamrinder's picture
Upload folder using huggingface_hub
9f43137 verified
"""Loads MUStARD ground truth + cached prosody features for all 690 clips."""
from __future__ import annotations
import json
import random
from pathlib import Path
from typing import Any, Dict, List, Optional
# Default data directory inside the env package; overrideable for local dev
DEFAULT_DATA_ROOT = Path(__file__).resolve().parent.parent / "data"
def load_scenarios(data_root: Optional[Path] = None) -> Dict[str, Dict[str, Any]]:
root = data_root or DEFAULT_DATA_ROOT
sarcasm_path = root / "sarcasm_data.json"
prosody_dir = root / "prosody_cache" / "utterances"
if not sarcasm_path.exists():
raise FileNotFoundError(
f"Missing sarcasm ground truth at {sarcasm_path}. "
"Copy from MUStARD/data/sarcasm_data.json before launching env."
)
with sarcasm_path.open() as f:
sarcasm_data = json.load(f)
# Try to load curated pivot set; if absent, no clips are marked pivot yet
pivot_path = root / "pivot_set.json"
pivot_ids = set()
if pivot_path.exists():
with pivot_path.open() as f:
pivot_ids = set(json.load(f).get("clip_ids", []))
scenarios: Dict[str, Dict[str, Any]] = {}
for clip_id, entry in sarcasm_data.items():
prosody_path = prosody_dir / f"{clip_id}.json"
prosody = None
if prosody_path.exists():
with prosody_path.open() as f:
prosody = json.load(f)
scenarios[clip_id] = {
"utterance": entry.get("utterance", ""),
"speaker": entry.get("speaker", ""),
"context": entry.get("context", []),
"context_speakers": entry.get("context_speakers", []),
"sarcasm": bool(entry.get("sarcasm", False)),
"show": entry.get("show", ""),
"prosody": prosody,
"is_pivot": clip_id in pivot_ids,
}
return scenarios
def sample_clip(scenarios, rng, pivot_oversample_factor=3):
pool = []
for clip_id, entry in scenarios.items():
weight = pivot_oversample_factor if entry["is_pivot"] else 1
pool.extend([clip_id] * weight)
return rng.choice(pool)