"""Subtext Arena — Unsloth + TRL GRPO training (Option A: single-step CoT classification). Architecture: The training task is single-step. For each rollout: 1. The training script builds a FULL prompt for one MUStARD clip: system prompt + transcript + prosody features + pitch contour 2. Model generates ONE completion ending in {label, confidence} 3. Reward function parses , scores against the gold label 4. GRPO advantage update The env (server/) still supports multi-step tool calling — that's our inference-time interface and what judges interact with on the HF Space. But for TRAINING, we sidestep TRL's single-shot completion constraint by pre-rendering all tool outputs into the prompt. The model learns the conditional reasoning step (prosody features -> sarcasm label) that the one-shot rollout architecture would otherwise fail to teach. Stack (deck-named in requirement #2): Unsloth FastLanguageModel + TRL GRPOTrainer + LoRA r=16 + 4-bit quant. Fits a T4-medium (16 GB) on HF Jobs at $0.60/hr. Usage: python train/train_grpo.py \\ --model unsloth/Qwen2.5-3B-Instruct \\ --max-steps 200 \\ --output-dir ./checkpoints/run1 """ from __future__ import annotations import argparse import json import os import random import re import sys from pathlib import Path from typing import Any, Dict, List, Optional # Dual import path: works whether this script is run locally (with # subtext_arena/ on sys.path) or after `pip install` (subtext_arena.* package). try: from subtext_arena.server.scenarios import load_scenarios from subtext_arena.server.audio_tools import ( render_transcript, render_prosody_features, render_pitch_contour, ) except ImportError: ROOT = Path(__file__).resolve().parent.parent if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) from server.scenarios import load_scenarios # type: ignore[no-redef] from server.audio_tools import ( # type: ignore[no-redef] render_transcript, render_prosody_features, render_pitch_contour, ) # --------------------------------------------------------------------------- # System prompt (defines the answer grammar) # --------------------------------------------------------------------------- SYSTEM_PROMPT = """You are an expert at detecting sarcasm in spoken dialogue. You will be given: - the literal transcript of a line from a TV show, plus its conversational context - acoustic prosody features (pitch, energy, pause patterns) - the pitch contour over the line Your job is to decide whether the line is SARCASTIC or SINCERE, by reading between the lines: the prosodic delivery often flips the meaning of the literal words. Format your response EXACTLY like this: your reasoning over the prosody and lexical cues, 2-6 sentences {"label":"sarcastic"|"sincere","confidence":0.0..1.0} Rules: 1. Use BOTH the transcript and the prosody features in your reasoning. 2. High pitch variability + emphasis pause + positive words often = sarcastic. 3. Flat affect + no internal pauses + neutral content often = sincere. 4. Confidence should reflect how strongly the cues agree. """ # --------------------------------------------------------------------------- # Prompt construction (the env's tool outputs, all bundled) # --------------------------------------------------------------------------- def build_full_observation(clip_id: str, scenarios: Dict[str, Dict[str, Any]]) -> str: """Build the same observation an interactive agent would see if it called get_transcript + get_prosody_features + get_pitch_contour in sequence. Returns the whole thing as one user-message body. """ clip = scenarios[clip_id] prosody = clip.get("prosody") or {} transcript_block = render_transcript(clip_id, scenarios) prosody_block = render_prosody_features(clip_id, prosody) contour_block = render_pitch_contour(clip_id, prosody) return ( f"Clip {clip_id} (speaker {clip.get('speaker', '?')}, " f"duration {prosody.get('duration_s', 0.0):.2f}s)\n\n" f"=== Transcript ===\n{transcript_block}\n\n" f"=== Prosody features ===\n{prosody_block}\n\n" f"=== Pitch contour ===\n{contour_block}\n\n" f"Decide: sarcastic or sincere?" ) def split_clip_ids(scenarios: Dict[str, Dict[str, Any]], eval_ratio: float = 0.2, seed: int = 42): """Deterministic train/eval split at the clip level. Important: clips in eval split are NEVER shown to the model during training. This is what lets us claim "the trained model generalizes, not memorizes." Pivot Set clips are split too (we keep training-pivots for oversampling, eval-pivots for the held-out audio-mattering test). """ rng = random.Random(seed) all_ids = sorted(scenarios.keys()) # deterministic order rng.shuffle(all_ids) n_eval = int(len(all_ids) * eval_ratio) eval_ids = set(all_ids[:n_eval]) train_ids = set(all_ids[n_eval:]) n_train_pivot = sum(1 for cid in train_ids if scenarios[cid].get("is_pivot")) n_eval_pivot = sum(1 for cid in eval_ids if scenarios[cid].get("is_pivot")) print(f"[split] train={len(train_ids)} (pivot={n_train_pivot}) | " f"eval={len(eval_ids)} (pivot={n_eval_pivot})") return train_ids, eval_ids def build_dataset(scenarios: Dict[str, Dict[str, Any]], n_rows: int, seed: int = 0, allowed_clip_ids=None, pivot_oversample: int = 2, class_balance: bool = True): """Build a HF Dataset of prompts with strict per-class balance. Why class-balance: Run #2 showed extreme prediction skew (16% sarcastic, 81% sincere) on held-out, because Pivot Set is 25 sinc + 7 sarc and the model learned to default to "sincere" when uncertain. We fix this by: - Building separate sarc and sinc pools, each weighted by Pivot oversample - Interleaving them 50/50 in the final dataset - Reducing pivot_oversample from 3 -> 2 (still emphasized, less dominant) If `allowed_clip_ids` is set, only those clips are used (lets caller restrict to the train split — eval clips never touch training). """ from datasets import Dataset rng = random.Random(seed) sarc_pool, sinc_pool = [], [] for cid, entry in scenarios.items(): if allowed_clip_ids is not None and cid not in allowed_clip_ids: continue weight = pivot_oversample if entry.get("is_pivot") else 1 target = sarc_pool if entry.get("sarcasm") else sinc_pool target.extend([cid] * weight) if not sarc_pool or not sinc_pool: raise ValueError( f"Need both sarcastic and sincere clips. " f"got sarc={len(sarc_pool)}, sinc={len(sinc_pool)}" ) rng.shuffle(sarc_pool); rng.shuffle(sinc_pool) print(f"[build_dataset] balanced pool: sarc={len(sarc_pool)}, sinc={len(sinc_pool)}, " f"pivot_oversample={pivot_oversample}, class_balance={class_balance}") # Build the dataset by INTERLEAVING sarc and sinc 50/50. # That way the model sees one of each per gradient step (with batch>=2). rows = [] sarc_idx = sinc_idx = 0 while len(rows) < n_rows: if class_balance: # alternate sarc, sinc, sarc, sinc... sequence = [ sarc_pool[sarc_idx % len(sarc_pool)], sinc_pool[sinc_idx % len(sinc_pool)], ] sarc_idx += 1; sinc_idx += 1 else: sequence = sarc_pool + sinc_pool for cid in sequence: if len(rows) >= n_rows: break user_text = build_full_observation(cid, scenarios) rows.append({ "prompt": [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_text}, ], "clip_id": cid, "gold": "sarcastic" if scenarios[cid]["sarcasm"] else "sincere", "is_pivot": bool(scenarios[cid].get("is_pivot", False)), }) return Dataset.from_list(rows) # --------------------------------------------------------------------------- # Reward function (single-step: parse , score against gold) # --------------------------------------------------------------------------- FINAL_RE = re.compile(r"\s*(\{.*?\})\s*", re.S) THINK_RE = re.compile(r"\s*(.*?)\s*", re.S) def parse_final(text: str): """Return (label, confidence, has_well_formed_final).""" m = FINAL_RE.search(text) if not m: return None, 0.5, False try: payload = json.loads(m.group(1)) except json.JSONDecodeError: return None, 0.5, False label = str(payload.get("label", "")).strip().lower() if label not in {"sarcastic", "sincere"}: return None, 0.5, False try: conf = float(payload.get("confidence", 0.5)) except (TypeError, ValueError): conf = 0.5 conf = max(0.0, min(1.0, conf)) return label, conf, True def reasoning_length_score(text: str) -> float: """Reward reasoning length: penalize <50 words (lazy stub reasoning) and >300 words (rambling). Sweet spot: 50-150 words. Forces the model to actually reason about prosody, not just emit a token-stub . """ m = THINK_RE.search(text) if not m: return 0.0 words = len(m.group(1).split()) if words < 50: return max(0.0, words / 50.0) if words <= 150: return 1.0 if words <= 300: return 1.0 - (words - 150) / 300.0 return 0.5 def make_reward_fn(): """Reward = 0.70 * correctness + 0.15 * reasoning_length + 0.15 * format. correctness component: - confidence-weighted: 0.5 + 0.5*conf if correct, 0.5 - 0.5*conf if wrong - 0.0 if no valid tag was emitted """ def reward_fn(prompts, completions, **kwargs) -> List[float]: # TRL passes any extra dataset columns as kwargs (lists aligned with completions) gold_labels = kwargs.get("gold") if gold_labels is None: raise ValueError("Reward fn requires a 'gold' column in the dataset") rewards: List[float] = [] for completion, gold in zip(completions, gold_labels): text = completion[0]["content"] if isinstance(completion, list) else str(completion) label, conf, well_formed = parse_final(text) if not well_formed: correctness = 0.0 else: correct = (label == gold.lower()) correctness = (0.5 + 0.5 * conf) if correct else (0.5 - 0.5 * conf) r_reasoning = reasoning_length_score(text) r_format = 1.0 if well_formed else 0.0 total = 0.70 * correctness + 0.15 * r_reasoning + 0.15 * r_format rewards.append(float(total)) return rewards return reward_fn def reward_decomposition(text: str, gold: str) -> Dict[str, float]: """Same logic as reward_fn but returns the per-component values for logging.""" label, conf, well_formed = parse_final(text) if not well_formed: correctness = 0.0 correct = False else: correct = (label == gold.lower()) correctness = (0.5 + 0.5 * conf) if correct else (0.5 - 0.5 * conf) r_reasoning = reasoning_length_score(text) r_format = 1.0 if well_formed else 0.0 total = 0.70 * correctness + 0.15 * r_reasoning + 0.15 * r_format return { "correctness": correctness, "reasoning_length": r_reasoning, "format": r_format, "_total": total, "_correct": bool(correct), "_well_formed": well_formed, "_predicted": label, "_confidence": conf, } # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): parser = argparse.ArgumentParser() parser.add_argument("--model", default="Qwen/Qwen2.5-3B-Instruct") parser.add_argument("--output-dir", default="./checkpoints/run1") parser.add_argument("--max-steps", type=int, default=200) parser.add_argument("--num-generations", type=int, default=4) parser.add_argument("--per-device-batch-size", type=int, default=1) parser.add_argument("--learning-rate", type=float, default=5e-6) parser.add_argument("--max-completion-length", type=int, default=768) parser.add_argument("--lora-r", type=int, default=16) parser.add_argument("--seq-length", type=int, default=4096) parser.add_argument("--n-train-rows", type=int, default=600) parser.add_argument("--data-root", default=None, help="Override path to data/ (defaults to subtext_arena/data/)") parser.add_argument("--push-to-hub", default=None, help="If set, e.g. 'aamrinder/subtext-arena-grpo', push the trained LoRA there at the end") parser.add_argument("--save-trainer-state-to-hub-space", default=None, help="If set, e.g. 'aamrinder/subtext-arena', upload trainer_state.json + eval_results.json to that Space's data/ dir") parser.add_argument("--eval-ratio", type=float, default=0.2, help="Fraction of MUStARD clips held out for generalization eval") parser.add_argument("--n-eval-clips", type=int, default=80, help="How many held-out clips to evaluate at the end") parser.add_argument("--lora-dropout", type=float, default=0.05, help="LoRA dropout for regularization (helps prevent memorization)") args = parser.parse_args() # Load scenarios up front so any data issues fail fast (before model load) data_root = Path(args.data_root) if args.data_root else None scenarios = load_scenarios(data_root) print(f"[data] {len(scenarios)} MUStARD clips loaded; " f"{sum(1 for s in scenarios.values() if s.get('is_pivot'))} marked Pivot") # CRITICAL: train/eval split so we can prove generalization (not memorization). # Eval clips are NEVER seen by the model during training. train_ids, eval_ids = split_clip_ids(scenarios, eval_ratio=args.eval_ratio, seed=42) dataset = build_dataset(scenarios, n_rows=args.n_train_rows, allowed_clip_ids=train_ids) print(f"[data] {len(dataset)} train prompt rows from {len(train_ids)} unique train clips") # Model load via plain transformers + PEFT (deck-compliant: training uses HF TRL). # We dropped Unsloth because their fast_lora kernel has a Half/Float dtype # mismatch on Qwen2.5-3B + 4-bit + bf16 in v2026.4.8 (verified via failed # smoke runs on L4). Plain transformers+peft+trl is slower but reliable. import torch as _t from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training from trl import GRPOTrainer, GRPOConfig print(f"[load] {args.model}, 4-bit, max_seq_length={args.seq_length}") bnb = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=_t.bfloat16, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, ) # Strip the "unsloth/" prefix if the user passed an Unsloth-prefixed name — # we now load directly from the upstream Qwen repo. model_name = args.model.replace("unsloth/", "Qwen/").replace("-Instruct-bnb-4bit", "-Instruct") tokenizer = AutoTokenizer.from_pretrained(model_name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token base = AutoModelForCausalLM.from_pretrained( model_name, quantization_config=bnb, dtype=_t.bfloat16, device_map="auto", ) base = prepare_model_for_kbit_training(base, use_gradient_checkpointing=True) peft_config = LoraConfig( r=args.lora_r, lora_alpha=args.lora_r, lora_dropout=args.lora_dropout, bias="none", target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM", ) model = get_peft_model(base, peft_config) config = GRPOConfig( output_dir=args.output_dir, num_generations=args.num_generations, max_completion_length=args.max_completion_length, per_device_train_batch_size=args.per_device_batch_size, learning_rate=args.learning_rate, max_steps=args.max_steps, logging_steps=1, save_steps=50, save_total_limit=4, bf16=True, report_to=("wandb" if os.environ.get("WANDB_API_KEY") else "none"), gradient_checkpointing=True, ) trainer = GRPOTrainer( model=model, reward_funcs=make_reward_fn(), args=config, train_dataset=dataset, processing_class=tokenizer, ) trainer.train() # CRITICAL: save_state() writes trainer_state.json to output_dir. # save_model() alone only saves the adapter weights, NOT the per-step log. # In Run #2, we missed save_state() and lost the reward history that drives the plot. trainer.save_state() trainer.save_model(args.output_dir) # Also explicitly write the log_history to a JSON we know we can find. try: log_path = Path(args.output_dir) / "log_history.json" log_path.write_text(json.dumps(trainer.state.log_history, indent=2)) print(f"[done] log_history saved to {log_path} ({len(trainer.state.log_history)} entries)") except Exception as e: print(f"[warn] couldn't write log_history.json: {e}") print(f"[done] checkpoint saved to {args.output_dir}") # ----- HELD-OUT GENERALIZATION EVAL ----- # Run trained model on `n_eval_clips` held-out clips that were NEVER in training. # If reward on these is much lower than training reward → memorization. # If reward is comparable → real learning. print(f"\n[held-out-eval] running trained model on {min(args.n_eval_clips, len(eval_ids))} held-out clips") eval_clip_ids = sorted(eval_ids)[: args.n_eval_clips] held_out_results = [] eval_failed = False model.eval() if hasattr(model, "gradient_checkpointing_disable"): try: model.gradient_checkpointing_disable() except Exception: pass n_eval_correct = 0 n_eval_well_formed = 0 eval_rewards = [] try: for i, cid in enumerate(eval_clip_ids): sc = scenarios[cid] gold = "sarcastic" if sc["sarcasm"] else "sincere" messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": build_full_observation(cid, scenarios)}, ] encoded = tokenizer.apply_chat_template( messages, return_tensors="pt", add_generation_prompt=True, ) input_ids = encoded.input_ids if hasattr(encoded, "input_ids") else encoded input_ids = input_ids.to(model.device) prompt_len = input_ids.shape[1] with _t.no_grad(): out = model.generate( input_ids=input_ids, max_new_tokens=args.max_completion_length, do_sample=False, pad_token_id=tokenizer.eos_token_id, use_cache=True, ) text = tokenizer.decode(out[0][prompt_len:], skip_special_tokens=True) decomp = reward_decomposition(text, gold) held_out_results.append({ "clip_id": cid, "gold": gold, "is_pivot": bool(sc.get("is_pivot")), "predicted": decomp["_predicted"], "confidence": decomp["_confidence"], "correct": decomp["_correct"], "well_formed": decomp["_well_formed"], "reward_total": decomp["_total"], "completion_text": text[:1500], }) eval_rewards.append(decomp["_total"]) if decomp["_correct"]: n_eval_correct += 1 if decomp["_well_formed"]: n_eval_well_formed += 1 if (i + 1) % 20 == 0: print(f" [{i+1}/{len(eval_clip_ids)}] running mean reward = {sum(eval_rewards)/len(eval_rewards):.3f}, " f"correct so far = {n_eval_correct}/{i+1}", flush=True) except Exception as e: print(f"[error] held-out eval crashed at clip {i}: {e}") eval_failed = True eval_summary = { "n_eval_clips": len(eval_clip_ids), "mean_reward": sum(eval_rewards) / max(1, len(eval_rewards)), "well_formed_rate": n_eval_well_formed / max(1, len(eval_clip_ids)), "accuracy": n_eval_correct / max(1, len(eval_clip_ids)), "pivot_in_eval": sum(1 for r in held_out_results if r["is_pivot"]), "pivot_correct": sum(1 for r in held_out_results if r["is_pivot"] and r["correct"]), "results": held_out_results, } print(f"\n[HELD-OUT EVAL] mean_reward={eval_summary['mean_reward']:.3f}, " f"accuracy={eval_summary['accuracy']:.2%} ({n_eval_correct}/{len(eval_clip_ids)}), " f"well_formed={eval_summary['well_formed_rate']:.2%}") if eval_summary["pivot_in_eval"] > 0: print(f"[HELD-OUT EVAL] pivot accuracy: {eval_summary['pivot_correct']}/{eval_summary['pivot_in_eval']}") # Save eval results to disk so they get pushed with the rest eval_path = Path(args.output_dir) / "held_out_eval.json" eval_path.write_text(json.dumps(eval_summary, indent=2)) print(f"[done] held-out eval saved to {eval_path}") # Push the LoRA adapter to HF Hub so it survives the ephemeral container if args.push_to_hub: try: from huggingface_hub import HfApi api = HfApi() api.create_repo(repo_id=args.push_to_hub, repo_type="model", exist_ok=True) api.upload_folder( folder_path=args.output_dir, repo_id=args.push_to_hub, repo_type="model", commit_message=f"GRPO Run #1 ({args.max_steps} steps, lr={args.learning_rate})", ) print(f"[done] LoRA adapter pushed to https://huggingface.co/{args.push_to_hub}") except Exception as e: print(f"[error] push_to_hub failed: {e}") # Push trainer_state.json + log_history.json + held_out_eval.json to HF Space. # Each upload is wrapped individually so a partial network failure doesn't # kill the whole script. We need at least the held_out_eval JSON to land. if args.save_trainer_state_to_hub_space: from huggingface_hub import HfApi from pathlib import Path as _P repo_id = args.save_trainer_state_to_hub_space run_tag = _P(args.output_dir).name # e.g. "run3" api = HfApi() for local_name, hub_name, label in [ ("trainer_state.json", f"data/trainer_state_{run_tag}.json", "trainer_state"), ("log_history.json", f"data/log_history_{run_tag}.json", "log_history"), ("held_out_eval.json", f"data/held_out_eval_{run_tag}.json", "held_out_eval"), ]: path = _P(args.output_dir) / local_name if not path.exists(): print(f"[warn] {local_name} not found at {path}, skipping upload") continue try: api.upload_file( path_or_fileobj=str(path), path_in_repo=hub_name, repo_id=repo_id, repo_type="space", commit_message=f"GRPO {run_tag} {label} ({args.max_steps} steps)", ) print(f"[done] {label} pushed to {repo_id}/{hub_name}") except Exception as e: print(f"[error] upload {label} failed: {e}") print(f"\n[main] subtext-arena GRPO run finished cleanly.") sys.exit(0) if __name__ == "__main__": main()