Spaces:
Sleeping
Sleeping
| """Subtext Arena — Unsloth + TRL GRPO training (Option A: single-step CoT classification). | |
| Architecture: | |
| The training task is single-step. For each rollout: | |
| 1. The training script builds a FULL prompt for one MUStARD clip: | |
| system prompt + transcript + prosody features + pitch contour | |
| 2. Model generates ONE completion ending in <final>{label, confidence}</final> | |
| 3. Reward function parses <final>, scores against the gold label | |
| 4. GRPO advantage update | |
| The env (server/) still supports multi-step tool calling — that's our | |
| inference-time interface and what judges interact with on the HF Space. | |
| But for TRAINING, we sidestep TRL's single-shot completion constraint by | |
| pre-rendering all tool outputs into the prompt. The model learns the | |
| conditional reasoning step (prosody features -> sarcasm label) that the | |
| one-shot rollout architecture would otherwise fail to teach. | |
| Stack (deck-named in requirement #2): | |
| Unsloth FastLanguageModel + TRL GRPOTrainer + LoRA r=16 + 4-bit quant. | |
| Fits a T4-medium (16 GB) on HF Jobs at $0.60/hr. | |
| Usage: | |
| python train/train_grpo.py \\ | |
| --model unsloth/Qwen2.5-3B-Instruct \\ | |
| --max-steps 200 \\ | |
| --output-dir ./checkpoints/run1 | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| import random | |
| import re | |
| import sys | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional | |
| # Dual import path: works whether this script is run locally (with | |
| # subtext_arena/ on sys.path) or after `pip install` (subtext_arena.* package). | |
| try: | |
| from subtext_arena.server.scenarios import load_scenarios | |
| from subtext_arena.server.audio_tools import ( | |
| render_transcript, | |
| render_prosody_features, | |
| render_pitch_contour, | |
| ) | |
| except ImportError: | |
| ROOT = Path(__file__).resolve().parent.parent | |
| if str(ROOT) not in sys.path: | |
| sys.path.insert(0, str(ROOT)) | |
| from server.scenarios import load_scenarios # type: ignore[no-redef] | |
| from server.audio_tools import ( # type: ignore[no-redef] | |
| render_transcript, | |
| render_prosody_features, | |
| render_pitch_contour, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # System prompt (defines the answer grammar) | |
| # --------------------------------------------------------------------------- | |
| SYSTEM_PROMPT = """You are an expert at detecting sarcasm in spoken dialogue. | |
| You will be given: | |
| - the literal transcript of a line from a TV show, plus its conversational context | |
| - acoustic prosody features (pitch, energy, pause patterns) | |
| - the pitch contour over the line | |
| Your job is to decide whether the line is SARCASTIC or SINCERE, by reading | |
| between the lines: the prosodic delivery often flips the meaning of the | |
| literal words. | |
| Format your response EXACTLY like this: | |
| <think> | |
| your reasoning over the prosody and lexical cues, 2-6 sentences | |
| </think> | |
| <final>{"label":"sarcastic"|"sincere","confidence":0.0..1.0}</final> | |
| Rules: | |
| 1. Use BOTH the transcript and the prosody features in your reasoning. | |
| 2. High pitch variability + emphasis pause + positive words often = sarcastic. | |
| 3. Flat affect + no internal pauses + neutral content often = sincere. | |
| 4. Confidence should reflect how strongly the cues agree. | |
| """ | |
| # --------------------------------------------------------------------------- | |
| # Prompt construction (the env's tool outputs, all bundled) | |
| # --------------------------------------------------------------------------- | |
| def build_full_observation(clip_id: str, scenarios: Dict[str, Dict[str, Any]]) -> str: | |
| """Build the same observation an interactive agent would see if it called | |
| get_transcript + get_prosody_features + get_pitch_contour in sequence. | |
| Returns the whole thing as one user-message body. | |
| """ | |
| clip = scenarios[clip_id] | |
| prosody = clip.get("prosody") or {} | |
| transcript_block = render_transcript(clip_id, scenarios) | |
| prosody_block = render_prosody_features(clip_id, prosody) | |
| contour_block = render_pitch_contour(clip_id, prosody) | |
| return ( | |
| f"Clip {clip_id} (speaker {clip.get('speaker', '?')}, " | |
| f"duration {prosody.get('duration_s', 0.0):.2f}s)\n\n" | |
| f"=== Transcript ===\n{transcript_block}\n\n" | |
| f"=== Prosody features ===\n{prosody_block}\n\n" | |
| f"=== Pitch contour ===\n{contour_block}\n\n" | |
| f"Decide: sarcastic or sincere?" | |
| ) | |
| def split_clip_ids(scenarios: Dict[str, Dict[str, Any]], eval_ratio: float = 0.2, seed: int = 42): | |
| """Deterministic train/eval split at the clip level. | |
| Important: clips in eval split are NEVER shown to the model during | |
| training. This is what lets us claim "the trained model generalizes, | |
| not memorizes." Pivot Set clips are split too (we keep training-pivots | |
| for oversampling, eval-pivots for the held-out audio-mattering test). | |
| """ | |
| rng = random.Random(seed) | |
| all_ids = sorted(scenarios.keys()) # deterministic order | |
| rng.shuffle(all_ids) | |
| n_eval = int(len(all_ids) * eval_ratio) | |
| eval_ids = set(all_ids[:n_eval]) | |
| train_ids = set(all_ids[n_eval:]) | |
| n_train_pivot = sum(1 for cid in train_ids if scenarios[cid].get("is_pivot")) | |
| n_eval_pivot = sum(1 for cid in eval_ids if scenarios[cid].get("is_pivot")) | |
| print(f"[split] train={len(train_ids)} (pivot={n_train_pivot}) | " | |
| f"eval={len(eval_ids)} (pivot={n_eval_pivot})") | |
| return train_ids, eval_ids | |
| def build_dataset(scenarios: Dict[str, Dict[str, Any]], n_rows: int, seed: int = 0, | |
| allowed_clip_ids=None, pivot_oversample: int = 2, | |
| class_balance: bool = True): | |
| """Build a HF Dataset of prompts with strict per-class balance. | |
| Why class-balance: Run #2 showed extreme prediction skew (16% sarcastic, | |
| 81% sincere) on held-out, because Pivot Set is 25 sinc + 7 sarc and the | |
| model learned to default to "sincere" when uncertain. We fix this by: | |
| - Building separate sarc and sinc pools, each weighted by Pivot oversample | |
| - Interleaving them 50/50 in the final dataset | |
| - Reducing pivot_oversample from 3 -> 2 (still emphasized, less dominant) | |
| If `allowed_clip_ids` is set, only those clips are used (lets caller | |
| restrict to the train split — eval clips never touch training). | |
| """ | |
| from datasets import Dataset | |
| rng = random.Random(seed) | |
| sarc_pool, sinc_pool = [], [] | |
| for cid, entry in scenarios.items(): | |
| if allowed_clip_ids is not None and cid not in allowed_clip_ids: | |
| continue | |
| weight = pivot_oversample if entry.get("is_pivot") else 1 | |
| target = sarc_pool if entry.get("sarcasm") else sinc_pool | |
| target.extend([cid] * weight) | |
| if not sarc_pool or not sinc_pool: | |
| raise ValueError( | |
| f"Need both sarcastic and sincere clips. " | |
| f"got sarc={len(sarc_pool)}, sinc={len(sinc_pool)}" | |
| ) | |
| rng.shuffle(sarc_pool); rng.shuffle(sinc_pool) | |
| print(f"[build_dataset] balanced pool: sarc={len(sarc_pool)}, sinc={len(sinc_pool)}, " | |
| f"pivot_oversample={pivot_oversample}, class_balance={class_balance}") | |
| # Build the dataset by INTERLEAVING sarc and sinc 50/50. | |
| # That way the model sees one of each per gradient step (with batch>=2). | |
| rows = [] | |
| sarc_idx = sinc_idx = 0 | |
| while len(rows) < n_rows: | |
| if class_balance: | |
| # alternate sarc, sinc, sarc, sinc... | |
| sequence = [ | |
| sarc_pool[sarc_idx % len(sarc_pool)], | |
| sinc_pool[sinc_idx % len(sinc_pool)], | |
| ] | |
| sarc_idx += 1; sinc_idx += 1 | |
| else: | |
| sequence = sarc_pool + sinc_pool | |
| for cid in sequence: | |
| if len(rows) >= n_rows: | |
| break | |
| user_text = build_full_observation(cid, scenarios) | |
| rows.append({ | |
| "prompt": [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_text}, | |
| ], | |
| "clip_id": cid, | |
| "gold": "sarcastic" if scenarios[cid]["sarcasm"] else "sincere", | |
| "is_pivot": bool(scenarios[cid].get("is_pivot", False)), | |
| }) | |
| return Dataset.from_list(rows) | |
| # --------------------------------------------------------------------------- | |
| # Reward function (single-step: parse <final>, score against gold) | |
| # --------------------------------------------------------------------------- | |
| FINAL_RE = re.compile(r"<final>\s*(\{.*?\})\s*</final>", re.S) | |
| THINK_RE = re.compile(r"<think>\s*(.*?)\s*</think>", re.S) | |
| def parse_final(text: str): | |
| """Return (label, confidence, has_well_formed_final).""" | |
| m = FINAL_RE.search(text) | |
| if not m: | |
| return None, 0.5, False | |
| try: | |
| payload = json.loads(m.group(1)) | |
| except json.JSONDecodeError: | |
| return None, 0.5, False | |
| label = str(payload.get("label", "")).strip().lower() | |
| if label not in {"sarcastic", "sincere"}: | |
| return None, 0.5, False | |
| try: | |
| conf = float(payload.get("confidence", 0.5)) | |
| except (TypeError, ValueError): | |
| conf = 0.5 | |
| conf = max(0.0, min(1.0, conf)) | |
| return label, conf, True | |
| def reasoning_length_score(text: str) -> float: | |
| """Reward reasoning length: penalize <50 words (lazy stub reasoning) | |
| and >300 words (rambling). Sweet spot: 50-150 words. Forces the model | |
| to actually reason about prosody, not just emit a token-stub <think>. | |
| """ | |
| m = THINK_RE.search(text) | |
| if not m: | |
| return 0.0 | |
| words = len(m.group(1).split()) | |
| if words < 50: | |
| return max(0.0, words / 50.0) | |
| if words <= 150: | |
| return 1.0 | |
| if words <= 300: | |
| return 1.0 - (words - 150) / 300.0 | |
| return 0.5 | |
| def make_reward_fn(): | |
| """Reward = 0.70 * correctness + 0.15 * reasoning_length + 0.15 * format. | |
| correctness component: | |
| - confidence-weighted: 0.5 + 0.5*conf if correct, 0.5 - 0.5*conf if wrong | |
| - 0.0 if no valid <final> tag was emitted | |
| """ | |
| def reward_fn(prompts, completions, **kwargs) -> List[float]: | |
| # TRL passes any extra dataset columns as kwargs (lists aligned with completions) | |
| gold_labels = kwargs.get("gold") | |
| if gold_labels is None: | |
| raise ValueError("Reward fn requires a 'gold' column in the dataset") | |
| rewards: List[float] = [] | |
| for completion, gold in zip(completions, gold_labels): | |
| text = completion[0]["content"] if isinstance(completion, list) else str(completion) | |
| label, conf, well_formed = parse_final(text) | |
| if not well_formed: | |
| correctness = 0.0 | |
| else: | |
| correct = (label == gold.lower()) | |
| correctness = (0.5 + 0.5 * conf) if correct else (0.5 - 0.5 * conf) | |
| r_reasoning = reasoning_length_score(text) | |
| r_format = 1.0 if well_formed else 0.0 | |
| total = 0.70 * correctness + 0.15 * r_reasoning + 0.15 * r_format | |
| rewards.append(float(total)) | |
| return rewards | |
| return reward_fn | |
| def reward_decomposition(text: str, gold: str) -> Dict[str, float]: | |
| """Same logic as reward_fn but returns the per-component values for logging.""" | |
| label, conf, well_formed = parse_final(text) | |
| if not well_formed: | |
| correctness = 0.0 | |
| correct = False | |
| else: | |
| correct = (label == gold.lower()) | |
| correctness = (0.5 + 0.5 * conf) if correct else (0.5 - 0.5 * conf) | |
| r_reasoning = reasoning_length_score(text) | |
| r_format = 1.0 if well_formed else 0.0 | |
| total = 0.70 * correctness + 0.15 * r_reasoning + 0.15 * r_format | |
| return { | |
| "correctness": correctness, | |
| "reasoning_length": r_reasoning, | |
| "format": r_format, | |
| "_total": total, | |
| "_correct": bool(correct), | |
| "_well_formed": well_formed, | |
| "_predicted": label, | |
| "_confidence": conf, | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Main | |
| # --------------------------------------------------------------------------- | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--model", default="Qwen/Qwen2.5-3B-Instruct") | |
| parser.add_argument("--output-dir", default="./checkpoints/run1") | |
| parser.add_argument("--max-steps", type=int, default=200) | |
| parser.add_argument("--num-generations", type=int, default=4) | |
| parser.add_argument("--per-device-batch-size", type=int, default=1) | |
| parser.add_argument("--learning-rate", type=float, default=5e-6) | |
| parser.add_argument("--max-completion-length", type=int, default=768) | |
| parser.add_argument("--lora-r", type=int, default=16) | |
| parser.add_argument("--seq-length", type=int, default=4096) | |
| parser.add_argument("--n-train-rows", type=int, default=600) | |
| parser.add_argument("--data-root", default=None, | |
| help="Override path to data/ (defaults to subtext_arena/data/)") | |
| parser.add_argument("--push-to-hub", default=None, | |
| help="If set, e.g. 'aamrinder/subtext-arena-grpo', push the trained LoRA there at the end") | |
| parser.add_argument("--save-trainer-state-to-hub-space", default=None, | |
| help="If set, e.g. 'aamrinder/subtext-arena', upload trainer_state.json + eval_results.json to that Space's data/ dir") | |
| parser.add_argument("--eval-ratio", type=float, default=0.2, | |
| help="Fraction of MUStARD clips held out for generalization eval") | |
| parser.add_argument("--n-eval-clips", type=int, default=80, | |
| help="How many held-out clips to evaluate at the end") | |
| parser.add_argument("--lora-dropout", type=float, default=0.05, | |
| help="LoRA dropout for regularization (helps prevent memorization)") | |
| args = parser.parse_args() | |
| # Load scenarios up front so any data issues fail fast (before model load) | |
| data_root = Path(args.data_root) if args.data_root else None | |
| scenarios = load_scenarios(data_root) | |
| print(f"[data] {len(scenarios)} MUStARD clips loaded; " | |
| f"{sum(1 for s in scenarios.values() if s.get('is_pivot'))} marked Pivot") | |
| # CRITICAL: train/eval split so we can prove generalization (not memorization). | |
| # Eval clips are NEVER seen by the model during training. | |
| train_ids, eval_ids = split_clip_ids(scenarios, eval_ratio=args.eval_ratio, seed=42) | |
| dataset = build_dataset(scenarios, n_rows=args.n_train_rows, allowed_clip_ids=train_ids) | |
| print(f"[data] {len(dataset)} train prompt rows from {len(train_ids)} unique train clips") | |
| # Model load via plain transformers + PEFT (deck-compliant: training uses HF TRL). | |
| # We dropped Unsloth because their fast_lora kernel has a Half/Float dtype | |
| # mismatch on Qwen2.5-3B + 4-bit + bf16 in v2026.4.8 (verified via failed | |
| # smoke runs on L4). Plain transformers+peft+trl is slower but reliable. | |
| import torch as _t | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training | |
| from trl import GRPOTrainer, GRPOConfig | |
| print(f"[load] {args.model}, 4-bit, max_seq_length={args.seq_length}") | |
| bnb = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=_t.bfloat16, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| # Strip the "unsloth/" prefix if the user passed an Unsloth-prefixed name — | |
| # we now load directly from the upstream Qwen repo. | |
| model_name = args.model.replace("unsloth/", "Qwen/").replace("-Instruct-bnb-4bit", "-Instruct") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| base = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| quantization_config=bnb, | |
| dtype=_t.bfloat16, | |
| device_map="auto", | |
| ) | |
| base = prepare_model_for_kbit_training(base, use_gradient_checkpointing=True) | |
| peft_config = LoraConfig( | |
| r=args.lora_r, lora_alpha=args.lora_r, lora_dropout=args.lora_dropout, bias="none", | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], | |
| task_type="CAUSAL_LM", | |
| ) | |
| model = get_peft_model(base, peft_config) | |
| config = GRPOConfig( | |
| output_dir=args.output_dir, | |
| num_generations=args.num_generations, | |
| max_completion_length=args.max_completion_length, | |
| per_device_train_batch_size=args.per_device_batch_size, | |
| learning_rate=args.learning_rate, | |
| max_steps=args.max_steps, | |
| logging_steps=1, | |
| save_steps=50, | |
| save_total_limit=4, | |
| bf16=True, | |
| report_to=("wandb" if os.environ.get("WANDB_API_KEY") else "none"), | |
| gradient_checkpointing=True, | |
| ) | |
| trainer = GRPOTrainer( | |
| model=model, | |
| reward_funcs=make_reward_fn(), | |
| args=config, | |
| train_dataset=dataset, | |
| processing_class=tokenizer, | |
| ) | |
| trainer.train() | |
| # CRITICAL: save_state() writes trainer_state.json to output_dir. | |
| # save_model() alone only saves the adapter weights, NOT the per-step log. | |
| # In Run #2, we missed save_state() and lost the reward history that drives the plot. | |
| trainer.save_state() | |
| trainer.save_model(args.output_dir) | |
| # Also explicitly write the log_history to a JSON we know we can find. | |
| try: | |
| log_path = Path(args.output_dir) / "log_history.json" | |
| log_path.write_text(json.dumps(trainer.state.log_history, indent=2)) | |
| print(f"[done] log_history saved to {log_path} ({len(trainer.state.log_history)} entries)") | |
| except Exception as e: | |
| print(f"[warn] couldn't write log_history.json: {e}") | |
| print(f"[done] checkpoint saved to {args.output_dir}") | |
| # ----- HELD-OUT GENERALIZATION EVAL ----- | |
| # Run trained model on `n_eval_clips` held-out clips that were NEVER in training. | |
| # If reward on these is much lower than training reward → memorization. | |
| # If reward is comparable → real learning. | |
| print(f"\n[held-out-eval] running trained model on {min(args.n_eval_clips, len(eval_ids))} held-out clips") | |
| eval_clip_ids = sorted(eval_ids)[: args.n_eval_clips] | |
| held_out_results = [] | |
| eval_failed = False | |
| model.eval() | |
| if hasattr(model, "gradient_checkpointing_disable"): | |
| try: model.gradient_checkpointing_disable() | |
| except Exception: pass | |
| n_eval_correct = 0 | |
| n_eval_well_formed = 0 | |
| eval_rewards = [] | |
| try: | |
| for i, cid in enumerate(eval_clip_ids): | |
| sc = scenarios[cid] | |
| gold = "sarcastic" if sc["sarcasm"] else "sincere" | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": build_full_observation(cid, scenarios)}, | |
| ] | |
| encoded = tokenizer.apply_chat_template( | |
| messages, return_tensors="pt", add_generation_prompt=True, | |
| ) | |
| input_ids = encoded.input_ids if hasattr(encoded, "input_ids") else encoded | |
| input_ids = input_ids.to(model.device) | |
| prompt_len = input_ids.shape[1] | |
| with _t.no_grad(): | |
| out = model.generate( | |
| input_ids=input_ids, | |
| max_new_tokens=args.max_completion_length, | |
| do_sample=False, | |
| pad_token_id=tokenizer.eos_token_id, | |
| use_cache=True, | |
| ) | |
| text = tokenizer.decode(out[0][prompt_len:], skip_special_tokens=True) | |
| decomp = reward_decomposition(text, gold) | |
| held_out_results.append({ | |
| "clip_id": cid, | |
| "gold": gold, | |
| "is_pivot": bool(sc.get("is_pivot")), | |
| "predicted": decomp["_predicted"], | |
| "confidence": decomp["_confidence"], | |
| "correct": decomp["_correct"], | |
| "well_formed": decomp["_well_formed"], | |
| "reward_total": decomp["_total"], | |
| "completion_text": text[:1500], | |
| }) | |
| eval_rewards.append(decomp["_total"]) | |
| if decomp["_correct"]: n_eval_correct += 1 | |
| if decomp["_well_formed"]: n_eval_well_formed += 1 | |
| if (i + 1) % 20 == 0: | |
| print(f" [{i+1}/{len(eval_clip_ids)}] running mean reward = {sum(eval_rewards)/len(eval_rewards):.3f}, " | |
| f"correct so far = {n_eval_correct}/{i+1}", flush=True) | |
| except Exception as e: | |
| print(f"[error] held-out eval crashed at clip {i}: {e}") | |
| eval_failed = True | |
| eval_summary = { | |
| "n_eval_clips": len(eval_clip_ids), | |
| "mean_reward": sum(eval_rewards) / max(1, len(eval_rewards)), | |
| "well_formed_rate": n_eval_well_formed / max(1, len(eval_clip_ids)), | |
| "accuracy": n_eval_correct / max(1, len(eval_clip_ids)), | |
| "pivot_in_eval": sum(1 for r in held_out_results if r["is_pivot"]), | |
| "pivot_correct": sum(1 for r in held_out_results if r["is_pivot"] and r["correct"]), | |
| "results": held_out_results, | |
| } | |
| print(f"\n[HELD-OUT EVAL] mean_reward={eval_summary['mean_reward']:.3f}, " | |
| f"accuracy={eval_summary['accuracy']:.2%} ({n_eval_correct}/{len(eval_clip_ids)}), " | |
| f"well_formed={eval_summary['well_formed_rate']:.2%}") | |
| if eval_summary["pivot_in_eval"] > 0: | |
| print(f"[HELD-OUT EVAL] pivot accuracy: {eval_summary['pivot_correct']}/{eval_summary['pivot_in_eval']}") | |
| # Save eval results to disk so they get pushed with the rest | |
| eval_path = Path(args.output_dir) / "held_out_eval.json" | |
| eval_path.write_text(json.dumps(eval_summary, indent=2)) | |
| print(f"[done] held-out eval saved to {eval_path}") | |
| # Push the LoRA adapter to HF Hub so it survives the ephemeral container | |
| if args.push_to_hub: | |
| try: | |
| from huggingface_hub import HfApi | |
| api = HfApi() | |
| api.create_repo(repo_id=args.push_to_hub, repo_type="model", exist_ok=True) | |
| api.upload_folder( | |
| folder_path=args.output_dir, | |
| repo_id=args.push_to_hub, | |
| repo_type="model", | |
| commit_message=f"GRPO Run #1 ({args.max_steps} steps, lr={args.learning_rate})", | |
| ) | |
| print(f"[done] LoRA adapter pushed to https://huggingface.co/{args.push_to_hub}") | |
| except Exception as e: | |
| print(f"[error] push_to_hub failed: {e}") | |
| # Push trainer_state.json + log_history.json + held_out_eval.json to HF Space. | |
| # Each upload is wrapped individually so a partial network failure doesn't | |
| # kill the whole script. We need at least the held_out_eval JSON to land. | |
| if args.save_trainer_state_to_hub_space: | |
| from huggingface_hub import HfApi | |
| from pathlib import Path as _P | |
| repo_id = args.save_trainer_state_to_hub_space | |
| run_tag = _P(args.output_dir).name # e.g. "run3" | |
| api = HfApi() | |
| for local_name, hub_name, label in [ | |
| ("trainer_state.json", f"data/trainer_state_{run_tag}.json", "trainer_state"), | |
| ("log_history.json", f"data/log_history_{run_tag}.json", "log_history"), | |
| ("held_out_eval.json", f"data/held_out_eval_{run_tag}.json", "held_out_eval"), | |
| ]: | |
| path = _P(args.output_dir) / local_name | |
| if not path.exists(): | |
| print(f"[warn] {local_name} not found at {path}, skipping upload") | |
| continue | |
| try: | |
| api.upload_file( | |
| path_or_fileobj=str(path), | |
| path_in_repo=hub_name, | |
| repo_id=repo_id, | |
| repo_type="space", | |
| commit_message=f"GRPO {run_tag} {label} ({args.max_steps} steps)", | |
| ) | |
| print(f"[done] {label} pushed to {repo_id}/{hub_name}") | |
| except Exception as e: | |
| print(f"[error] upload {label} failed: {e}") | |
| print(f"\n[main] subtext-arena GRPO run finished cleanly.") | |
| sys.exit(0) | |
| if __name__ == "__main__": | |
| main() | |