"""Loads MUStARD ground truth + cached prosody features for all 690 clips.""" from __future__ import annotations import json import random from pathlib import Path from typing import Any, Dict, List, Optional # Default data directory inside the env package; overrideable for local dev DEFAULT_DATA_ROOT = Path(__file__).resolve().parent.parent / "data" def load_scenarios(data_root: Optional[Path] = None) -> Dict[str, Dict[str, Any]]: root = data_root or DEFAULT_DATA_ROOT sarcasm_path = root / "sarcasm_data.json" prosody_dir = root / "prosody_cache" / "utterances" if not sarcasm_path.exists(): raise FileNotFoundError( f"Missing sarcasm ground truth at {sarcasm_path}. " "Copy from MUStARD/data/sarcasm_data.json before launching env." ) with sarcasm_path.open() as f: sarcasm_data = json.load(f) # Try to load curated pivot set; if absent, no clips are marked pivot yet pivot_path = root / "pivot_set.json" pivot_ids = set() if pivot_path.exists(): with pivot_path.open() as f: pivot_ids = set(json.load(f).get("clip_ids", [])) scenarios: Dict[str, Dict[str, Any]] = {} for clip_id, entry in sarcasm_data.items(): prosody_path = prosody_dir / f"{clip_id}.json" prosody = None if prosody_path.exists(): with prosody_path.open() as f: prosody = json.load(f) scenarios[clip_id] = { "utterance": entry.get("utterance", ""), "speaker": entry.get("speaker", ""), "context": entry.get("context", []), "context_speakers": entry.get("context_speakers", []), "sarcasm": bool(entry.get("sarcasm", False)), "show": entry.get("show", ""), "prosody": prosody, "is_pivot": clip_id in pivot_ids, } return scenarios def sample_clip(scenarios, rng, pivot_oversample_factor=3): pool = [] for clip_id, entry in scenarios.items(): weight = pivot_oversample_factor if entry["is_pivot"] else 1 pool.extend([clip_id] * weight) return rng.choice(pool)