Spaces:
Sleeping
Sleeping
sync Colab notebook with current train_grpo.py
Browse files- notebooks/train_grpo_colab.ipynb +168 -125
notebooks/train_grpo_colab.ipynb
CHANGED
|
@@ -3,52 +3,24 @@
|
|
| 3 |
{
|
| 4 |
"cell_type": "markdown",
|
| 5 |
"metadata": {},
|
| 6 |
-
"source":
|
| 7 |
-
"# Subtext Arena — GRPO training (Colab-runnable)\n",
|
| 8 |
-
"\n",
|
| 9 |
-
"Re-runnable notebook for judges. Trains a Qwen2.5-3B-Instruct policy with **Unsloth + TRL `GRPOTrainer`** on the Subtext Arena task.\n",
|
| 10 |
-
"\n",
|
| 11 |
-
"**Architecture (Option A — single-step CoT classification)**\n",
|
| 12 |
-
"\n",
|
| 13 |
-
"Each training rollout:\n",
|
| 14 |
-
" 1. We build ONE prompt for one MUStARD clip — system + transcript + prosody features + pitch contour, all in the user message.\n",
|
| 15 |
-
" 2. The model emits ONE completion: `<think>...</think><final>{\"label\":\"sarcastic\"|\"sincere\",\"confidence\":0..1}</final>`\n",
|
| 16 |
-
" 3. Reward = 0.70 · correctness (confidence-weighted) + 0.15 · reasoning_length + 0.15 · format.\n",
|
| 17 |
-
" 4. GRPO updates LoRA weights from the group-relative advantage.\n",
|
| 18 |
-
"\n",
|
| 19 |
-
"The Subtext Arena env still supports multi-step tool calling at inference time — that's our HF Space demo. But for *training* we sidestep TRL's single-shot generate-then-score constraint by pre-rendering the tool outputs into the prompt. This is the same pattern as the deck's Wordle / Sudoku notebooks.\n",
|
| 20 |
-
"\n",
|
| 21 |
-
"**Stack** (deck-named, requirement #2): Unsloth + TRL. T4-medium fits.\n",
|
| 22 |
-
"**Estimated runtime**: ~12 hours for 200 GRPO steps on T4-medium ($0.60/hr × 12 ≈ $8)."
|
| 23 |
-
]
|
| 24 |
},
|
| 25 |
{
|
| 26 |
"cell_type": "markdown",
|
| 27 |
"metadata": {},
|
| 28 |
-
"source":
|
| 29 |
-
"## 1. Install dependencies\n",
|
| 30 |
-
"\n",
|
| 31 |
-
"Replace `aamrinder` with your HF username after pushing the env to a Space."
|
| 32 |
-
]
|
| 33 |
},
|
| 34 |
{
|
| 35 |
"cell_type": "code",
|
| 36 |
"execution_count": null,
|
| 37 |
"metadata": {},
|
| 38 |
"outputs": [],
|
| 39 |
-
"source":
|
| 40 |
-
"!pip install -q --upgrade unsloth \"trl>=0.11\" \"transformers>=4.46\" peft datasets accelerate matplotlib\n",
|
| 41 |
-
"!pip install -q git+https://huggingface.co/spaces/aamrinder/subtext-arena\n",
|
| 42 |
-
"import torch\n",
|
| 43 |
-
"print('CUDA:', torch.cuda.is_available(), '|', torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU only')"
|
| 44 |
-
]
|
| 45 |
},
|
| 46 |
{
|
| 47 |
"cell_type": "markdown",
|
| 48 |
"metadata": {},
|
| 49 |
-
"source":
|
| 50 |
-
"## 2. Load Qwen2.5-3B-Instruct with Unsloth (4-bit + LoRA)"
|
| 51 |
-
]
|
| 52 |
},
|
| 53 |
{
|
| 54 |
"cell_type": "code",
|
|
@@ -56,29 +28,25 @@
|
|
| 56 |
"metadata": {},
|
| 57 |
"outputs": [],
|
| 58 |
"source": [
|
| 59 |
-
"
|
| 60 |
-
"\n",
|
| 61 |
-
"
|
| 62 |
-
"
|
| 63 |
-
"
|
| 64 |
-
"
|
| 65 |
-
"
|
| 66 |
-
"
|
| 67 |
-
"
|
| 68 |
-
"
|
| 69 |
-
"
|
| 70 |
-
"
|
| 71 |
-
")"
|
| 72 |
]
|
| 73 |
},
|
| 74 |
{
|
| 75 |
"cell_type": "markdown",
|
| 76 |
"metadata": {},
|
| 77 |
-
"source":
|
| 78 |
-
"## 3. Build the training dataset\n",
|
| 79 |
-
"\n",
|
| 80 |
-
"Each row is one MUStARD clip's full briefing (transcript + prosody summary + pitch contour) wrapped as a chat prompt. The Pivot Set is oversampled 3×."
|
| 81 |
-
]
|
| 82 |
},
|
| 83 |
{
|
| 84 |
"cell_type": "code",
|
|
@@ -86,36 +54,25 @@
|
|
| 86 |
"metadata": {},
|
| 87 |
"outputs": [],
|
| 88 |
"source": [
|
| 89 |
-
"
|
| 90 |
-
"
|
| 91 |
-
"
|
| 92 |
-
"
|
| 93 |
-
" parse_final, reasoning_length_score, make_reward_fn,\n",
|
| 94 |
")\n",
|
| 95 |
-
"from server.scenarios import load_scenarios\n",
|
| 96 |
"\n",
|
| 97 |
"scenarios = load_scenarios()\n",
|
| 98 |
-
"
|
| 99 |
-
"print(f'Loaded {len(scenarios)} clips ({n_pivot} marked Pivot Set)')\n",
|
| 100 |
"\n",
|
| 101 |
-
"
|
| 102 |
-
"
|
| 103 |
-
"print(
|
|
|
|
| 104 |
]
|
| 105 |
},
|
| 106 |
{
|
| 107 |
"cell_type": "markdown",
|
| 108 |
"metadata": {},
|
| 109 |
-
"source":
|
| 110 |
-
"## 4. Reward function — single scalar per completion\n",
|
| 111 |
-
"\n",
|
| 112 |
-
"Parses `<final>{label, confidence}</final>` from the completion, scores against the gold label from the dataset row.\n",
|
| 113 |
-
"\n",
|
| 114 |
-
"Reward components (all in [0, 1]):\n",
|
| 115 |
-
"- **correctness** (weight 0.70): `0.5 + 0.5 × confidence` if label matches gold, `0.5 - 0.5 × confidence` if wrong, `0.0` if no valid `<final>` tag.\n",
|
| 116 |
-
"- **reasoning_length** (weight 0.15): incentivizes 50-150-word `<think>` blocks; penalizes <30 (lazy) and >300 (rambling).\n",
|
| 117 |
-
"- **format** (weight 0.15): 1.0 if `<final>` tag has parseable JSON with valid label, else 0."
|
| 118 |
-
]
|
| 119 |
},
|
| 120 |
{
|
| 121 |
"cell_type": "code",
|
|
@@ -124,22 +81,18 @@
|
|
| 124 |
"outputs": [],
|
| 125 |
"source": [
|
| 126 |
"reward_fn = make_reward_fn()\n",
|
| 127 |
-
"\n",
|
| 128 |
-
"
|
| 129 |
-
"
|
| 130 |
-
" [{'role':'assistant','content': '<think>Pitch HIGH, pre-pause 320ms, positive lexical content with exaggerated melody — classic sarcasm signature.</think><final>{\"label\":\"sarcastic\",\"confidence\":0.85}</final>'}],\n",
|
| 131 |
-
" [{'role':'assistant','content': '<think>Flat affect on neutral content, low pitch variability.</think><final>{\"label\":\"sincere\",\"confidence\":0.65}</final>'}],\n",
|
| 132 |
"]\n",
|
| 133 |
-
"rewards = reward_fn(prompts=None, completions=
|
| 134 |
-
"print('Synthetic rewards:', rewards) # should be
|
| 135 |
]
|
| 136 |
},
|
| 137 |
{
|
| 138 |
"cell_type": "markdown",
|
| 139 |
"metadata": {},
|
| 140 |
-
"source":
|
| 141 |
-
"## 5. Run GRPO training (200 steps, ~12 h on T4-medium)"
|
| 142 |
-
]
|
| 143 |
},
|
| 144 |
{
|
| 145 |
"cell_type": "code",
|
|
@@ -147,41 +100,48 @@
|
|
| 147 |
"metadata": {},
|
| 148 |
"outputs": [],
|
| 149 |
"source": [
|
| 150 |
-
"
|
|
|
|
|
|
|
| 151 |
"\n",
|
| 152 |
-
"
|
| 153 |
-
"
|
| 154 |
-
"
|
| 155 |
-
"
|
| 156 |
-
"
|
| 157 |
-
"
|
| 158 |
-
"
|
| 159 |
-
"
|
| 160 |
-
"
|
| 161 |
-
"
|
| 162 |
-
"
|
| 163 |
-
" save_steps=50,\n",
|
| 164 |
-
" save_total_limit=4,\n",
|
| 165 |
-
" bf16=True,\n",
|
| 166 |
-
" report_to='none',\n",
|
| 167 |
-
" gradient_checkpointing=True,\n",
|
| 168 |
-
" ),\n",
|
| 169 |
-
" train_dataset=ds,\n",
|
| 170 |
-
" processing_class=tokenizer,\n",
|
| 171 |
")\n",
|
| 172 |
-
"
|
| 173 |
-
"
|
| 174 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
]
|
| 176 |
},
|
| 177 |
{
|
| 178 |
"cell_type": "markdown",
|
| 179 |
"metadata": {},
|
| 180 |
-
"source":
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
},
|
| 186 |
{
|
| 187 |
"cell_type": "code",
|
|
@@ -189,41 +149,124 @@
|
|
| 189 |
"metadata": {},
|
| 190 |
"outputs": [],
|
| 191 |
"source": [
|
| 192 |
-
"
|
| 193 |
-
"
|
| 194 |
-
"
|
|
|
|
| 195 |
"\n",
|
| 196 |
-
"
|
| 197 |
-
"
|
| 198 |
-
"
|
| 199 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
]
|
| 201 |
},
|
| 202 |
{
|
| 203 |
"cell_type": "markdown",
|
| 204 |
"metadata": {},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
"source": [
|
| 206 |
-
"
|
| 207 |
-
"\n",
|
| 208 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
]
|
| 210 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
{
|
| 212 |
"cell_type": "code",
|
| 213 |
"execution_count": null,
|
| 214 |
"metadata": {},
|
| 215 |
"outputs": [],
|
| 216 |
"source": [
|
| 217 |
-
"
|
| 218 |
-
"from
|
| 219 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
]
|
| 221 |
}
|
| 222 |
],
|
| 223 |
"metadata": {
|
| 224 |
-
"kernelspec": {
|
| 225 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
},
|
| 227 |
"nbformat": 4,
|
| 228 |
"nbformat_minor": 5
|
| 229 |
-
}
|
|
|
|
| 3 |
{
|
| 4 |
"cell_type": "markdown",
|
| 5 |
"metadata": {},
|
| 6 |
+
"source": "# Subtext Arena — GRPO training\n\nThis is the actual training that produced the README numbers — reward 0.33 → 0.97 on training, 51% on the broad held-out set, 5/6 on the Pivot Set.\n\nThe notebook imports functions straight from `subtext_arena.train.train_grpo` so it stays in sync with the script. Set the config below, run all cells. Around 2 hours on a Colab L4, ~$1.60."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
},
|
| 8 |
{
|
| 9 |
"cell_type": "markdown",
|
| 10 |
"metadata": {},
|
| 11 |
+
"source": "## Install"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
},
|
| 13 |
{
|
| 14 |
"cell_type": "code",
|
| 15 |
"execution_count": null,
|
| 16 |
"metadata": {},
|
| 17 |
"outputs": [],
|
| 18 |
+
"source": "!pip install -q \"trl>=0.11\" \"transformers>=4.46\" peft datasets accelerate bitsandbytes huggingface_hub\n!pip install -q git+https://huggingface.co/spaces/aamrinder/subtext-arena\nimport torch\nprint('CUDA:', torch.cuda.is_available(), '|', torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU only')"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
},
|
| 20 |
{
|
| 21 |
"cell_type": "markdown",
|
| 22 |
"metadata": {},
|
| 23 |
+
"source": "## Config\n\nSame args as the CLI script. Tweak `MAX_STEPS` if you're on a smaller GPU."
|
|
|
|
|
|
|
| 24 |
},
|
| 25 |
{
|
| 26 |
"cell_type": "code",
|
|
|
|
| 28 |
"metadata": {},
|
| 29 |
"outputs": [],
|
| 30 |
"source": [
|
| 31 |
+
"MODEL = 'Qwen/Qwen2.5-3B-Instruct'\n",
|
| 32 |
+
"OUTPUT_DIR = './checkpoints/run1'\n",
|
| 33 |
+
"MAX_STEPS = 200\n",
|
| 34 |
+
"NUM_GENERATIONS = 4\n",
|
| 35 |
+
"PER_DEVICE_BATCH = 4 # must be divisible by NUM_GENERATIONS\n",
|
| 36 |
+
"LEARNING_RATE = 5e-6\n",
|
| 37 |
+
"MAX_COMPLETION_LEN = 768\n",
|
| 38 |
+
"LORA_R = 16\n",
|
| 39 |
+
"LORA_DROPOUT = 0.05\n",
|
| 40 |
+
"N_TRAIN_ROWS = 600\n",
|
| 41 |
+
"EVAL_RATIO = 0.2\n",
|
| 42 |
+
"N_EVAL_CLIPS = 80\n",
|
| 43 |
+
"PUSH_TO_HUB = None # e.g. 'your-username/subtext-arena-grpo' (or None to skip)"
|
| 44 |
]
|
| 45 |
},
|
| 46 |
{
|
| 47 |
"cell_type": "markdown",
|
| 48 |
"metadata": {},
|
| 49 |
+
"source": "## Load data + train/eval split\n\nEval clips never appear in training. Split is seeded so it's reproducible."
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
},
|
| 51 |
{
|
| 52 |
"cell_type": "code",
|
|
|
|
| 54 |
"metadata": {},
|
| 55 |
"outputs": [],
|
| 56 |
"source": [
|
| 57 |
+
"from subtext_arena.server.scenarios import load_scenarios\n",
|
| 58 |
+
"from subtext_arena.train.train_grpo import (\n",
|
| 59 |
+
" SYSTEM_PROMPT, build_full_observation, split_clip_ids, build_dataset,\n",
|
| 60 |
+
" parse_final, reasoning_length_score, make_reward_fn, reward_decomposition,\n",
|
|
|
|
| 61 |
")\n",
|
|
|
|
| 62 |
"\n",
|
| 63 |
"scenarios = load_scenarios()\n",
|
| 64 |
+
"print(f'Loaded {len(scenarios)} clips ({sum(1 for s in scenarios.values() if s.get(\"is_pivot\"))} marked Pivot)')\n",
|
|
|
|
| 65 |
"\n",
|
| 66 |
+
"train_ids, eval_ids = split_clip_ids(scenarios, eval_ratio=EVAL_RATIO, seed=42)\n",
|
| 67 |
+
"dataset = build_dataset(scenarios, n_rows=N_TRAIN_ROWS, allowed_clip_ids=train_ids)\n",
|
| 68 |
+
"print(f'{len(dataset)} train prompt rows from {len(train_ids)} unique train clips')\n",
|
| 69 |
+
"print('Sample prompt:', dataset[0]['prompt'][1]['content'][:300], '...')"
|
| 70 |
]
|
| 71 |
},
|
| 72 |
{
|
| 73 |
"cell_type": "markdown",
|
| 74 |
"metadata": {},
|
| 75 |
+
"source": "## Reward sanity check\n\nQuick check that the reward function scores synthetic completions correctly before burning GPU time."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
},
|
| 77 |
{
|
| 78 |
"cell_type": "code",
|
|
|
|
| 81 |
"outputs": [],
|
| 82 |
"source": [
|
| 83 |
"reward_fn = make_reward_fn()\n",
|
| 84 |
+
"fake = [\n",
|
| 85 |
+
" [{'role':'assistant','content': '<think>Pitch range is 180Hz — wide for a single line. Pre-utterance silence is 320ms which suggests deliberate emphasis. The literal words are positive but the prosodic delivery is exaggerated, classic sarcasm signature in TV dialogue.</think><final>{\"label\":\"sarcastic\",\"confidence\":0.85}</final>'}],\n",
|
| 86 |
+
" [{'role':'assistant','content': '<think>Flat affect, narrow pitch range, no internal pauses. Content is neutral and matches delivery. No prosodic-lexical mismatch.</think><final>{\"label\":\"sincere\",\"confidence\":0.65}</final>'}],\n",
|
|
|
|
|
|
|
| 87 |
"]\n",
|
| 88 |
+
"rewards = reward_fn(prompts=None, completions=fake, gold=['sarcastic','sincere'])\n",
|
| 89 |
+
"print('Synthetic rewards:', rewards) # both should be > 0.85"
|
| 90 |
]
|
| 91 |
},
|
| 92 |
{
|
| 93 |
"cell_type": "markdown",
|
| 94 |
"metadata": {},
|
| 95 |
+
"source": "## Load Qwen2.5-3B in 4-bit + LoRA"
|
|
|
|
|
|
|
| 96 |
},
|
| 97 |
{
|
| 98 |
"cell_type": "code",
|
|
|
|
| 100 |
"metadata": {},
|
| 101 |
"outputs": [],
|
| 102 |
"source": [
|
| 103 |
+
"import torch as _t\n",
|
| 104 |
+
"from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig\n",
|
| 105 |
+
"from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training\n",
|
| 106 |
"\n",
|
| 107 |
+
"bnb = BitsAndBytesConfig(\n",
|
| 108 |
+
" load_in_4bit=True,\n",
|
| 109 |
+
" bnb_4bit_compute_dtype=_t.bfloat16,\n",
|
| 110 |
+
" bnb_4bit_quant_type='nf4',\n",
|
| 111 |
+
" bnb_4bit_use_double_quant=True,\n",
|
| 112 |
+
")\n",
|
| 113 |
+
"tokenizer = AutoTokenizer.from_pretrained(MODEL)\n",
|
| 114 |
+
"if tokenizer.pad_token is None:\n",
|
| 115 |
+
" tokenizer.pad_token = tokenizer.eos_token\n",
|
| 116 |
+
"base = AutoModelForCausalLM.from_pretrained(\n",
|
| 117 |
+
" MODEL, quantization_config=bnb, dtype=_t.bfloat16, device_map='auto',\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
")\n",
|
| 119 |
+
"base = prepare_model_for_kbit_training(base, use_gradient_checkpointing=True)\n",
|
| 120 |
+
"peft_config = LoraConfig(\n",
|
| 121 |
+
" r=LORA_R, lora_alpha=LORA_R, lora_dropout=LORA_DROPOUT, bias='none',\n",
|
| 122 |
+
" target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj'],\n",
|
| 123 |
+
" task_type='CAUSAL_LM',\n",
|
| 124 |
+
")\n",
|
| 125 |
+
"model = get_peft_model(base, peft_config)\n",
|
| 126 |
+
"model.print_trainable_parameters()"
|
| 127 |
]
|
| 128 |
},
|
| 129 |
{
|
| 130 |
"cell_type": "markdown",
|
| 131 |
"metadata": {},
|
| 132 |
+
"source": "## Train"
|
| 133 |
+
},
|
| 134 |
+
{
|
| 135 |
+
"cell_type": "code",
|
| 136 |
+
"execution_count": null,
|
| 137 |
+
"metadata": {},
|
| 138 |
+
"outputs": [],
|
| 139 |
+
"source": "import os, json\nfrom pathlib import Path\nfrom trl import GRPOTrainer, GRPOConfig\n\nconfig = GRPOConfig(\n output_dir=OUTPUT_DIR,\n num_generations=NUM_GENERATIONS,\n max_completion_length=MAX_COMPLETION_LEN,\n per_device_train_batch_size=PER_DEVICE_BATCH,\n learning_rate=LEARNING_RATE,\n max_steps=MAX_STEPS,\n logging_steps=1,\n save_steps=50,\n save_total_limit=4,\n bf16=True,\n report_to=('wandb' if os.environ.get('WANDB_API_KEY') else 'none'),\n gradient_checkpointing=True,\n)\ntrainer = GRPOTrainer(\n model=model,\n reward_funcs=make_reward_fn(),\n args=config,\n train_dataset=dataset,\n processing_class=tokenizer,\n)\ntrainer.train()\n\ntrainer.save_state()\ntrainer.save_model(OUTPUT_DIR)\nPath(OUTPUT_DIR, 'log_history.json').write_text(json.dumps(trainer.state.log_history, indent=2))\nprint(f'saved to {OUTPUT_DIR}')"
|
| 140 |
+
},
|
| 141 |
+
{
|
| 142 |
+
"cell_type": "markdown",
|
| 143 |
+
"metadata": {},
|
| 144 |
+
"source": "## Held-out eval\n\nGreedy decoding on 80 unseen clips. Run #3 landed at 51% broad accuracy and 5/6 on the Pivot subset."
|
| 145 |
},
|
| 146 |
{
|
| 147 |
"cell_type": "code",
|
|
|
|
| 149 |
"metadata": {},
|
| 150 |
"outputs": [],
|
| 151 |
"source": [
|
| 152 |
+
"model.eval()\n",
|
| 153 |
+
"if hasattr(model, 'gradient_checkpointing_disable'):\n",
|
| 154 |
+
" try: model.gradient_checkpointing_disable()\n",
|
| 155 |
+
" except Exception: pass\n",
|
| 156 |
"\n",
|
| 157 |
+
"eval_clip_ids = sorted(eval_ids)[:N_EVAL_CLIPS]\n",
|
| 158 |
+
"results, eval_rewards = [], []\n",
|
| 159 |
+
"n_correct = n_well_formed = 0\n",
|
| 160 |
+
"for i, cid in enumerate(eval_clip_ids):\n",
|
| 161 |
+
" sc = scenarios[cid]\n",
|
| 162 |
+
" gold = 'sarcastic' if sc['sarcasm'] else 'sincere'\n",
|
| 163 |
+
" messages = [\n",
|
| 164 |
+
" {'role': 'system', 'content': SYSTEM_PROMPT},\n",
|
| 165 |
+
" {'role': 'user', 'content': build_full_observation(cid, scenarios)},\n",
|
| 166 |
+
" ]\n",
|
| 167 |
+
" encoded = tokenizer.apply_chat_template(messages, return_tensors='pt', add_generation_prompt=True)\n",
|
| 168 |
+
" input_ids = (encoded.input_ids if hasattr(encoded, 'input_ids') else encoded).to(model.device)\n",
|
| 169 |
+
" prompt_len = input_ids.shape[1]\n",
|
| 170 |
+
" with _t.no_grad():\n",
|
| 171 |
+
" out = model.generate(\n",
|
| 172 |
+
" input_ids=input_ids, max_new_tokens=MAX_COMPLETION_LEN,\n",
|
| 173 |
+
" do_sample=False, pad_token_id=tokenizer.eos_token_id, use_cache=True,\n",
|
| 174 |
+
" )\n",
|
| 175 |
+
" text = tokenizer.decode(out[0][prompt_len:], skip_special_tokens=True)\n",
|
| 176 |
+
" decomp = reward_decomposition(text, gold)\n",
|
| 177 |
+
" results.append({\n",
|
| 178 |
+
" 'clip_id': cid, 'gold': gold, 'is_pivot': bool(sc.get('is_pivot')),\n",
|
| 179 |
+
" 'predicted': decomp['_predicted'], 'confidence': decomp['_confidence'],\n",
|
| 180 |
+
" 'correct': decomp['_correct'], 'well_formed': decomp['_well_formed'],\n",
|
| 181 |
+
" 'reward_total': decomp['_total'], 'completion_text': text[:1500],\n",
|
| 182 |
+
" })\n",
|
| 183 |
+
" eval_rewards.append(decomp['_total'])\n",
|
| 184 |
+
" if decomp['_correct']: n_correct += 1\n",
|
| 185 |
+
" if decomp['_well_formed']: n_well_formed += 1\n",
|
| 186 |
+
" if (i + 1) % 20 == 0:\n",
|
| 187 |
+
" print(f' [{i+1}/{len(eval_clip_ids)}] running mean reward = {sum(eval_rewards)/len(eval_rewards):.3f}, '\n",
|
| 188 |
+
" f'correct so far = {n_correct}/{i+1}', flush=True)\n",
|
| 189 |
+
"\n",
|
| 190 |
+
"n_pivot = sum(1 for r in results if r['is_pivot'])\n",
|
| 191 |
+
"n_pivot_correct = sum(1 for r in results if r['is_pivot'] and r['correct'])\n",
|
| 192 |
+
"summary = {\n",
|
| 193 |
+
" 'n_eval_clips': len(eval_clip_ids),\n",
|
| 194 |
+
" 'mean_reward': sum(eval_rewards) / max(1, len(eval_rewards)),\n",
|
| 195 |
+
" 'well_formed_rate': n_well_formed / max(1, len(eval_clip_ids)),\n",
|
| 196 |
+
" 'accuracy': n_correct / max(1, len(eval_clip_ids)),\n",
|
| 197 |
+
" 'pivot_in_eval': n_pivot,\n",
|
| 198 |
+
" 'pivot_correct': n_pivot_correct,\n",
|
| 199 |
+
" 'results': results,\n",
|
| 200 |
+
"}\n",
|
| 201 |
+
"Path(OUTPUT_DIR, 'held_out_eval.json').write_text(json.dumps(summary, indent=2))\n",
|
| 202 |
+
"print(f\"\\nHELD-OUT: mean_reward={summary['mean_reward']:.3f}, accuracy={summary['accuracy']:.2%} ({n_correct}/{len(eval_clip_ids)})\")\n",
|
| 203 |
+
"print(f\" pivot accuracy: {n_pivot_correct}/{n_pivot}\")"
|
| 204 |
]
|
| 205 |
},
|
| 206 |
{
|
| 207 |
"cell_type": "markdown",
|
| 208 |
"metadata": {},
|
| 209 |
+
"source": "## Reward curve"
|
| 210 |
+
},
|
| 211 |
+
{
|
| 212 |
+
"cell_type": "code",
|
| 213 |
+
"execution_count": null,
|
| 214 |
+
"metadata": {},
|
| 215 |
+
"outputs": [],
|
| 216 |
"source": [
|
| 217 |
+
"import matplotlib.pyplot as plt\n",
|
| 218 |
+
"log = json.loads(Path(OUTPUT_DIR, 'log_history.json').read_text())\n",
|
| 219 |
+
"steps = [e['step'] for e in log if 'reward' in e]\n",
|
| 220 |
+
"rewards = [e['reward'] for e in log if 'reward' in e]\n",
|
| 221 |
+
"plt.figure(figsize=(8, 4))\n",
|
| 222 |
+
"plt.plot(steps, rewards, alpha=0.4, label='per-step')\n",
|
| 223 |
+
"if len(rewards) >= 10:\n",
|
| 224 |
+
" import numpy as np\n",
|
| 225 |
+
" ema = np.array(rewards)\n",
|
| 226 |
+
" for i in range(1, len(ema)):\n",
|
| 227 |
+
" ema[i] = 0.9 * ema[i-1] + 0.1 * ema[i]\n",
|
| 228 |
+
" plt.plot(steps, ema, linewidth=2, label='EMA(0.9)')\n",
|
| 229 |
+
"plt.xlabel('GRPO step'); plt.ylabel('reward'); plt.legend(); plt.grid(alpha=0.3)\n",
|
| 230 |
+
"plt.title('Subtext Arena — GRPO training reward')\n",
|
| 231 |
+
"plt.tight_layout(); plt.show()"
|
| 232 |
]
|
| 233 |
},
|
| 234 |
+
{
|
| 235 |
+
"cell_type": "markdown",
|
| 236 |
+
"metadata": {},
|
| 237 |
+
"source": "## Push to Hub (optional)\n\nSet `PUSH_TO_HUB` at the top, then `huggingface-cli login` first or set `HF_TOKEN` in Colab secrets."
|
| 238 |
+
},
|
| 239 |
{
|
| 240 |
"cell_type": "code",
|
| 241 |
"execution_count": null,
|
| 242 |
"metadata": {},
|
| 243 |
"outputs": [],
|
| 244 |
"source": [
|
| 245 |
+
"if PUSH_TO_HUB:\n",
|
| 246 |
+
" from huggingface_hub import HfApi\n",
|
| 247 |
+
" api = HfApi()\n",
|
| 248 |
+
" api.create_repo(repo_id=PUSH_TO_HUB, repo_type='model', exist_ok=True)\n",
|
| 249 |
+
" api.upload_folder(\n",
|
| 250 |
+
" folder_path=OUTPUT_DIR, repo_id=PUSH_TO_HUB, repo_type='model',\n",
|
| 251 |
+
" commit_message=f'GRPO ({MAX_STEPS} steps, lr={LEARNING_RATE})',\n",
|
| 252 |
+
" )\n",
|
| 253 |
+
" print(f'LoRA pushed to https://huggingface.co/{PUSH_TO_HUB}')\n",
|
| 254 |
+
"else:\n",
|
| 255 |
+
" print('PUSH_TO_HUB is None — skipping')"
|
| 256 |
]
|
| 257 |
}
|
| 258 |
],
|
| 259 |
"metadata": {
|
| 260 |
+
"kernelspec": {
|
| 261 |
+
"display_name": "Python 3",
|
| 262 |
+
"language": "python",
|
| 263 |
+
"name": "python3"
|
| 264 |
+
},
|
| 265 |
+
"language_info": {
|
| 266 |
+
"name": "python",
|
| 267 |
+
"version": "3.10"
|
| 268 |
+
}
|
| 269 |
},
|
| 270 |
"nbformat": 4,
|
| 271 |
"nbformat_minor": 5
|
| 272 |
+
}
|