Spaces:
Sleeping
Sleeping
| """Loads MUStARD ground truth + cached prosody features for all 690 clips.""" | |
| from __future__ import annotations | |
| import json | |
| import random | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional | |
| # Default data directory inside the env package; overrideable for local dev | |
| DEFAULT_DATA_ROOT = Path(__file__).resolve().parent.parent / "data" | |
| def load_scenarios(data_root: Optional[Path] = None) -> Dict[str, Dict[str, Any]]: | |
| root = data_root or DEFAULT_DATA_ROOT | |
| sarcasm_path = root / "sarcasm_data.json" | |
| prosody_dir = root / "prosody_cache" / "utterances" | |
| if not sarcasm_path.exists(): | |
| raise FileNotFoundError( | |
| f"Missing sarcasm ground truth at {sarcasm_path}. " | |
| "Copy from MUStARD/data/sarcasm_data.json before launching env." | |
| ) | |
| with sarcasm_path.open() as f: | |
| sarcasm_data = json.load(f) | |
| # Try to load curated pivot set; if absent, no clips are marked pivot yet | |
| pivot_path = root / "pivot_set.json" | |
| pivot_ids = set() | |
| if pivot_path.exists(): | |
| with pivot_path.open() as f: | |
| pivot_ids = set(json.load(f).get("clip_ids", [])) | |
| scenarios: Dict[str, Dict[str, Any]] = {} | |
| for clip_id, entry in sarcasm_data.items(): | |
| prosody_path = prosody_dir / f"{clip_id}.json" | |
| prosody = None | |
| if prosody_path.exists(): | |
| with prosody_path.open() as f: | |
| prosody = json.load(f) | |
| scenarios[clip_id] = { | |
| "utterance": entry.get("utterance", ""), | |
| "speaker": entry.get("speaker", ""), | |
| "context": entry.get("context", []), | |
| "context_speakers": entry.get("context_speakers", []), | |
| "sarcasm": bool(entry.get("sarcasm", False)), | |
| "show": entry.get("show", ""), | |
| "prosody": prosody, | |
| "is_pivot": clip_id in pivot_ids, | |
| } | |
| return scenarios | |
| def sample_clip(scenarios, rng, pivot_oversample_factor=3): | |
| pool = [] | |
| for clip_id, entry in scenarios.items(): | |
| weight = pivot_oversample_factor if entry["is_pivot"] else 1 | |
| pool.extend([clip_id] * weight) | |
| return rng.choice(pool) | |