aamrinder commited on
Commit
9bd1f77
·
verified ·
1 Parent(s): 4cdc991

sync Colab notebook with current train_grpo.py

Browse files
Files changed (1) hide show
  1. notebooks/train_grpo_colab.ipynb +168 -125
notebooks/train_grpo_colab.ipynb CHANGED
@@ -3,52 +3,24 @@
3
  {
4
  "cell_type": "markdown",
5
  "metadata": {},
6
- "source": [
7
- "# Subtext Arena — GRPO training (Colab-runnable)\n",
8
- "\n",
9
- "Re-runnable notebook for judges. Trains a Qwen2.5-3B-Instruct policy with **Unsloth + TRL `GRPOTrainer`** on the Subtext Arena task.\n",
10
- "\n",
11
- "**Architecture (Option A — single-step CoT classification)**\n",
12
- "\n",
13
- "Each training rollout:\n",
14
- " 1. We build ONE prompt for one MUStARD clip — system + transcript + prosody features + pitch contour, all in the user message.\n",
15
- " 2. The model emits ONE completion: `<think>...</think><final>{\"label\":\"sarcastic\"|\"sincere\",\"confidence\":0..1}</final>`\n",
16
- " 3. Reward = 0.70 · correctness (confidence-weighted) + 0.15 · reasoning_length + 0.15 · format.\n",
17
- " 4. GRPO updates LoRA weights from the group-relative advantage.\n",
18
- "\n",
19
- "The Subtext Arena env still supports multi-step tool calling at inference time — that's our HF Space demo. But for *training* we sidestep TRL's single-shot generate-then-score constraint by pre-rendering the tool outputs into the prompt. This is the same pattern as the deck's Wordle / Sudoku notebooks.\n",
20
- "\n",
21
- "**Stack** (deck-named, requirement #2): Unsloth + TRL. T4-medium fits.\n",
22
- "**Estimated runtime**: ~12 hours for 200 GRPO steps on T4-medium ($0.60/hr × 12 ≈ $8)."
23
- ]
24
  },
25
  {
26
  "cell_type": "markdown",
27
  "metadata": {},
28
- "source": [
29
- "## 1. Install dependencies\n",
30
- "\n",
31
- "Replace `aamrinder` with your HF username after pushing the env to a Space."
32
- ]
33
  },
34
  {
35
  "cell_type": "code",
36
  "execution_count": null,
37
  "metadata": {},
38
  "outputs": [],
39
- "source": [
40
- "!pip install -q --upgrade unsloth \"trl>=0.11\" \"transformers>=4.46\" peft datasets accelerate matplotlib\n",
41
- "!pip install -q git+https://huggingface.co/spaces/aamrinder/subtext-arena\n",
42
- "import torch\n",
43
- "print('CUDA:', torch.cuda.is_available(), '|', torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU only')"
44
- ]
45
  },
46
  {
47
  "cell_type": "markdown",
48
  "metadata": {},
49
- "source": [
50
- "## 2. Load Qwen2.5-3B-Instruct with Unsloth (4-bit + LoRA)"
51
- ]
52
  },
53
  {
54
  "cell_type": "code",
@@ -56,29 +28,25 @@
56
  "metadata": {},
57
  "outputs": [],
58
  "source": [
59
- "from unsloth import FastLanguageModel\n",
60
- "\n",
61
- "model, tokenizer = FastLanguageModel.from_pretrained(\n",
62
- " model_name='unsloth/Qwen2.5-3B-Instruct',\n",
63
- " max_seq_length=4096,\n",
64
- " load_in_4bit=True,\n",
65
- ")\n",
66
- "model = FastLanguageModel.get_peft_model(\n",
67
- " model,\n",
68
- " r=16, lora_alpha=16,\n",
69
- " target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj'],\n",
70
- " use_gradient_checkpointing='unsloth',\n",
71
- ")"
72
  ]
73
  },
74
  {
75
  "cell_type": "markdown",
76
  "metadata": {},
77
- "source": [
78
- "## 3. Build the training dataset\n",
79
- "\n",
80
- "Each row is one MUStARD clip's full briefing (transcript + prosody summary + pitch contour) wrapped as a chat prompt. The Pivot Set is oversampled 3×."
81
- ]
82
  },
83
  {
84
  "cell_type": "code",
@@ -86,36 +54,25 @@
86
  "metadata": {},
87
  "outputs": [],
88
  "source": [
89
- "# Reuses the env's audio_tools + scenarios same prompt format that an\n",
90
- "# interactive agent would see if it called all the tools in sequence.\n",
91
- "from train.train_grpo import (\n",
92
- " SYSTEM_PROMPT, build_full_observation, build_dataset,\n",
93
- " parse_final, reasoning_length_score, make_reward_fn,\n",
94
  ")\n",
95
- "from server.scenarios import load_scenarios\n",
96
  "\n",
97
  "scenarios = load_scenarios()\n",
98
- "n_pivot = sum(1 for s in scenarios.values() if s.get('is_pivot'))\n",
99
- "print(f'Loaded {len(scenarios)} clips ({n_pivot} marked Pivot Set)')\n",
100
  "\n",
101
- "ds = build_dataset(scenarios, n_rows=600, seed=0)\n",
102
- "print(f'Built {len(ds)} training prompts. Sample row:')\n",
103
- "print(ds[0])"
 
104
  ]
105
  },
106
  {
107
  "cell_type": "markdown",
108
  "metadata": {},
109
- "source": [
110
- "## 4. Reward function — single scalar per completion\n",
111
- "\n",
112
- "Parses `<final>{label, confidence}</final>` from the completion, scores against the gold label from the dataset row.\n",
113
- "\n",
114
- "Reward components (all in [0, 1]):\n",
115
- "- **correctness** (weight 0.70): `0.5 + 0.5 × confidence` if label matches gold, `0.5 - 0.5 × confidence` if wrong, `0.0` if no valid `<final>` tag.\n",
116
- "- **reasoning_length** (weight 0.15): incentivizes 50-150-word `<think>` blocks; penalizes <30 (lazy) and >300 (rambling).\n",
117
- "- **format** (weight 0.15): 1.0 if `<final>` tag has parseable JSON with valid label, else 0."
118
- ]
119
  },
120
  {
121
  "cell_type": "code",
@@ -124,22 +81,18 @@
124
  "outputs": [],
125
  "source": [
126
  "reward_fn = make_reward_fn()\n",
127
- "\n",
128
- "# Sanity check on synthetic completions\n",
129
- "fake_completions = [\n",
130
- " [{'role':'assistant','content': '<think>Pitch HIGH, pre-pause 320ms, positive lexical content with exaggerated melody — classic sarcasm signature.</think><final>{\"label\":\"sarcastic\",\"confidence\":0.85}</final>'}],\n",
131
- " [{'role':'assistant','content': '<think>Flat affect on neutral content, low pitch variability.</think><final>{\"label\":\"sincere\",\"confidence\":0.65}</final>'}],\n",
132
  "]\n",
133
- "rewards = reward_fn(prompts=None, completions=fake_completions, gold=['sarcastic','sincere'])\n",
134
- "print('Synthetic rewards:', rewards) # should be high for both"
135
  ]
136
  },
137
  {
138
  "cell_type": "markdown",
139
  "metadata": {},
140
- "source": [
141
- "## 5. Run GRPO training (200 steps, ~12 h on T4-medium)"
142
- ]
143
  },
144
  {
145
  "cell_type": "code",
@@ -147,41 +100,48 @@
147
  "metadata": {},
148
  "outputs": [],
149
  "source": [
150
- "from trl import GRPOTrainer, GRPOConfig\n",
 
 
151
  "\n",
152
- "trainer = GRPOTrainer(\n",
153
- " model=model,\n",
154
- " reward_funcs=reward_fn,\n",
155
- " args=GRPOConfig(\n",
156
- " output_dir='./checkpoints/run1',\n",
157
- " num_generations=4,\n",
158
- " max_completion_length=768,\n",
159
- " per_device_train_batch_size=1,\n",
160
- " learning_rate=5e-6,\n",
161
- " max_steps=200,\n",
162
- " logging_steps=1,\n",
163
- " save_steps=50,\n",
164
- " save_total_limit=4,\n",
165
- " bf16=True,\n",
166
- " report_to='none',\n",
167
- " gradient_checkpointing=True,\n",
168
- " ),\n",
169
- " train_dataset=ds,\n",
170
- " processing_class=tokenizer,\n",
171
  ")\n",
172
- "trainer.train()\n",
173
- "trainer.save_model('./checkpoints/run1')\n",
174
- "print('checkpoint saved')"
 
 
 
 
 
175
  ]
176
  },
177
  {
178
  "cell_type": "markdown",
179
  "metadata": {},
180
- "source": [
181
- "## 6. Eval on the Prosody-Pivot Set\n",
182
- "\n",
183
- "Headline number: `X / 50` clips correct (per-clip majority across 3 seeds). Run BOTH the trained checkpoint and the base model — the delta is your story."
184
- ]
 
 
 
 
 
 
 
 
185
  },
186
  {
187
  "cell_type": "code",
@@ -189,41 +149,124 @@
189
  "metadata": {},
190
  "outputs": [],
191
  "source": [
192
- "# After training:\n",
193
- "# !python train/eval_pivot_set.py --checkpoint baseline-only --pivot data/pivot_set.json --out docs/plots/pivot_baseline.json\n",
194
- "# !python train/eval_pivot_set.py --checkpoint ./checkpoints/run1 --pivot data/pivot_set.json --out docs/plots/pivot_trained.json\n",
 
195
  "\n",
196
- "import json\n",
197
- "pivot = json.load(open('subtext_arena/data/pivot_set.json'))\n",
198
- "print(f'Pivot Set size: {len(pivot[\"clip_ids\"])} clips')\n",
199
- "print('Method:', pivot.get('method'))"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  ]
201
  },
202
  {
203
  "cell_type": "markdown",
204
  "metadata": {},
 
 
 
 
 
 
 
205
  "source": [
206
- "## 7. Plot the reward decomposition\n",
207
- "\n",
208
- "The killer chart: 3 colored lines (correctness, reasoning_length, format) climbing at different rates over training steps. This is the visual proof judges look for."
 
 
 
 
 
 
 
 
 
 
 
 
209
  ]
210
  },
 
 
 
 
 
211
  {
212
  "cell_type": "code",
213
  "execution_count": null,
214
  "metadata": {},
215
  "outputs": [],
216
  "source": [
217
- "# !python train/plot_reward_decomp.py --log-jsonl ./checkpoints/run1/trainer_state.json --out docs/plots/reward_decomposition.png\n",
218
- "from IPython.display import Image\n",
219
- "# Image('docs/plots/reward_decomposition.png')"
 
 
 
 
 
 
 
 
220
  ]
221
  }
222
  ],
223
  "metadata": {
224
- "kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"},
225
- "language_info": {"name": "python", "version": "3.11"}
 
 
 
 
 
 
 
226
  },
227
  "nbformat": 4,
228
  "nbformat_minor": 5
229
- }
 
3
  {
4
  "cell_type": "markdown",
5
  "metadata": {},
6
+ "source": "# Subtext Arena — GRPO training\n\nThis is the actual training that produced the README numbers — reward 0.33 → 0.97 on training, 51% on the broad held-out set, 5/6 on the Pivot Set.\n\nThe notebook imports functions straight from `subtext_arena.train.train_grpo` so it stays in sync with the script. Set the config below, run all cells. Around 2 hours on a Colab L4, ~$1.60."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  },
8
  {
9
  "cell_type": "markdown",
10
  "metadata": {},
11
+ "source": "## Install"
 
 
 
 
12
  },
13
  {
14
  "cell_type": "code",
15
  "execution_count": null,
16
  "metadata": {},
17
  "outputs": [],
18
+ "source": "!pip install -q \"trl>=0.11\" \"transformers>=4.46\" peft datasets accelerate bitsandbytes huggingface_hub\n!pip install -q git+https://huggingface.co/spaces/aamrinder/subtext-arena\nimport torch\nprint('CUDA:', torch.cuda.is_available(), '|', torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU only')"
 
 
 
 
 
19
  },
20
  {
21
  "cell_type": "markdown",
22
  "metadata": {},
23
+ "source": "## Config\n\nSame args as the CLI script. Tweak `MAX_STEPS` if you're on a smaller GPU."
 
 
24
  },
25
  {
26
  "cell_type": "code",
 
28
  "metadata": {},
29
  "outputs": [],
30
  "source": [
31
+ "MODEL = 'Qwen/Qwen2.5-3B-Instruct'\n",
32
+ "OUTPUT_DIR = './checkpoints/run1'\n",
33
+ "MAX_STEPS = 200\n",
34
+ "NUM_GENERATIONS = 4\n",
35
+ "PER_DEVICE_BATCH = 4 # must be divisible by NUM_GENERATIONS\n",
36
+ "LEARNING_RATE = 5e-6\n",
37
+ "MAX_COMPLETION_LEN = 768\n",
38
+ "LORA_R = 16\n",
39
+ "LORA_DROPOUT = 0.05\n",
40
+ "N_TRAIN_ROWS = 600\n",
41
+ "EVAL_RATIO = 0.2\n",
42
+ "N_EVAL_CLIPS = 80\n",
43
+ "PUSH_TO_HUB = None # e.g. 'your-username/subtext-arena-grpo' (or None to skip)"
44
  ]
45
  },
46
  {
47
  "cell_type": "markdown",
48
  "metadata": {},
49
+ "source": "## Load data + train/eval split\n\nEval clips never appear in training. Split is seeded so it's reproducible."
 
 
 
 
50
  },
51
  {
52
  "cell_type": "code",
 
54
  "metadata": {},
55
  "outputs": [],
56
  "source": [
57
+ "from subtext_arena.server.scenarios import load_scenarios\n",
58
+ "from subtext_arena.train.train_grpo import (\n",
59
+ " SYSTEM_PROMPT, build_full_observation, split_clip_ids, build_dataset,\n",
60
+ " parse_final, reasoning_length_score, make_reward_fn, reward_decomposition,\n",
 
61
  ")\n",
 
62
  "\n",
63
  "scenarios = load_scenarios()\n",
64
+ "print(f'Loaded {len(scenarios)} clips ({sum(1 for s in scenarios.values() if s.get(\"is_pivot\"))} marked Pivot)')\n",
 
65
  "\n",
66
+ "train_ids, eval_ids = split_clip_ids(scenarios, eval_ratio=EVAL_RATIO, seed=42)\n",
67
+ "dataset = build_dataset(scenarios, n_rows=N_TRAIN_ROWS, allowed_clip_ids=train_ids)\n",
68
+ "print(f'{len(dataset)} train prompt rows from {len(train_ids)} unique train clips')\n",
69
+ "print('Sample prompt:', dataset[0]['prompt'][1]['content'][:300], '...')"
70
  ]
71
  },
72
  {
73
  "cell_type": "markdown",
74
  "metadata": {},
75
+ "source": "## Reward sanity check\n\nQuick check that the reward function scores synthetic completions correctly before burning GPU time."
 
 
 
 
 
 
 
 
 
76
  },
77
  {
78
  "cell_type": "code",
 
81
  "outputs": [],
82
  "source": [
83
  "reward_fn = make_reward_fn()\n",
84
+ "fake = [\n",
85
+ " [{'role':'assistant','content': '<think>Pitch range is 180Hz — wide for a single line. Pre-utterance silence is 320ms which suggests deliberate emphasis. The literal words are positive but the prosodic delivery is exaggerated, classic sarcasm signature in TV dialogue.</think><final>{\"label\":\"sarcastic\",\"confidence\":0.85}</final>'}],\n",
86
+ " [{'role':'assistant','content': '<think>Flat affect, narrow pitch range, no internal pauses. Content is neutral and matches delivery. No prosodic-lexical mismatch.</think><final>{\"label\":\"sincere\",\"confidence\":0.65}</final>'}],\n",
 
 
87
  "]\n",
88
+ "rewards = reward_fn(prompts=None, completions=fake, gold=['sarcastic','sincere'])\n",
89
+ "print('Synthetic rewards:', rewards) # both should be > 0.85"
90
  ]
91
  },
92
  {
93
  "cell_type": "markdown",
94
  "metadata": {},
95
+ "source": "## Load Qwen2.5-3B in 4-bit + LoRA"
 
 
96
  },
97
  {
98
  "cell_type": "code",
 
100
  "metadata": {},
101
  "outputs": [],
102
  "source": [
103
+ "import torch as _t\n",
104
+ "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig\n",
105
+ "from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training\n",
106
  "\n",
107
+ "bnb = BitsAndBytesConfig(\n",
108
+ " load_in_4bit=True,\n",
109
+ " bnb_4bit_compute_dtype=_t.bfloat16,\n",
110
+ " bnb_4bit_quant_type='nf4',\n",
111
+ " bnb_4bit_use_double_quant=True,\n",
112
+ ")\n",
113
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL)\n",
114
+ "if tokenizer.pad_token is None:\n",
115
+ " tokenizer.pad_token = tokenizer.eos_token\n",
116
+ "base = AutoModelForCausalLM.from_pretrained(\n",
117
+ " MODEL, quantization_config=bnb, dtype=_t.bfloat16, device_map='auto',\n",
 
 
 
 
 
 
 
 
118
  ")\n",
119
+ "base = prepare_model_for_kbit_training(base, use_gradient_checkpointing=True)\n",
120
+ "peft_config = LoraConfig(\n",
121
+ " r=LORA_R, lora_alpha=LORA_R, lora_dropout=LORA_DROPOUT, bias='none',\n",
122
+ " target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj'],\n",
123
+ " task_type='CAUSAL_LM',\n",
124
+ ")\n",
125
+ "model = get_peft_model(base, peft_config)\n",
126
+ "model.print_trainable_parameters()"
127
  ]
128
  },
129
  {
130
  "cell_type": "markdown",
131
  "metadata": {},
132
+ "source": "## Train"
133
+ },
134
+ {
135
+ "cell_type": "code",
136
+ "execution_count": null,
137
+ "metadata": {},
138
+ "outputs": [],
139
+ "source": "import os, json\nfrom pathlib import Path\nfrom trl import GRPOTrainer, GRPOConfig\n\nconfig = GRPOConfig(\n output_dir=OUTPUT_DIR,\n num_generations=NUM_GENERATIONS,\n max_completion_length=MAX_COMPLETION_LEN,\n per_device_train_batch_size=PER_DEVICE_BATCH,\n learning_rate=LEARNING_RATE,\n max_steps=MAX_STEPS,\n logging_steps=1,\n save_steps=50,\n save_total_limit=4,\n bf16=True,\n report_to=('wandb' if os.environ.get('WANDB_API_KEY') else 'none'),\n gradient_checkpointing=True,\n)\ntrainer = GRPOTrainer(\n model=model,\n reward_funcs=make_reward_fn(),\n args=config,\n train_dataset=dataset,\n processing_class=tokenizer,\n)\ntrainer.train()\n\ntrainer.save_state()\ntrainer.save_model(OUTPUT_DIR)\nPath(OUTPUT_DIR, 'log_history.json').write_text(json.dumps(trainer.state.log_history, indent=2))\nprint(f'saved to {OUTPUT_DIR}')"
140
+ },
141
+ {
142
+ "cell_type": "markdown",
143
+ "metadata": {},
144
+ "source": "## Held-out eval\n\nGreedy decoding on 80 unseen clips. Run #3 landed at 51% broad accuracy and 5/6 on the Pivot subset."
145
  },
146
  {
147
  "cell_type": "code",
 
149
  "metadata": {},
150
  "outputs": [],
151
  "source": [
152
+ "model.eval()\n",
153
+ "if hasattr(model, 'gradient_checkpointing_disable'):\n",
154
+ " try: model.gradient_checkpointing_disable()\n",
155
+ " except Exception: pass\n",
156
  "\n",
157
+ "eval_clip_ids = sorted(eval_ids)[:N_EVAL_CLIPS]\n",
158
+ "results, eval_rewards = [], []\n",
159
+ "n_correct = n_well_formed = 0\n",
160
+ "for i, cid in enumerate(eval_clip_ids):\n",
161
+ " sc = scenarios[cid]\n",
162
+ " gold = 'sarcastic' if sc['sarcasm'] else 'sincere'\n",
163
+ " messages = [\n",
164
+ " {'role': 'system', 'content': SYSTEM_PROMPT},\n",
165
+ " {'role': 'user', 'content': build_full_observation(cid, scenarios)},\n",
166
+ " ]\n",
167
+ " encoded = tokenizer.apply_chat_template(messages, return_tensors='pt', add_generation_prompt=True)\n",
168
+ " input_ids = (encoded.input_ids if hasattr(encoded, 'input_ids') else encoded).to(model.device)\n",
169
+ " prompt_len = input_ids.shape[1]\n",
170
+ " with _t.no_grad():\n",
171
+ " out = model.generate(\n",
172
+ " input_ids=input_ids, max_new_tokens=MAX_COMPLETION_LEN,\n",
173
+ " do_sample=False, pad_token_id=tokenizer.eos_token_id, use_cache=True,\n",
174
+ " )\n",
175
+ " text = tokenizer.decode(out[0][prompt_len:], skip_special_tokens=True)\n",
176
+ " decomp = reward_decomposition(text, gold)\n",
177
+ " results.append({\n",
178
+ " 'clip_id': cid, 'gold': gold, 'is_pivot': bool(sc.get('is_pivot')),\n",
179
+ " 'predicted': decomp['_predicted'], 'confidence': decomp['_confidence'],\n",
180
+ " 'correct': decomp['_correct'], 'well_formed': decomp['_well_formed'],\n",
181
+ " 'reward_total': decomp['_total'], 'completion_text': text[:1500],\n",
182
+ " })\n",
183
+ " eval_rewards.append(decomp['_total'])\n",
184
+ " if decomp['_correct']: n_correct += 1\n",
185
+ " if decomp['_well_formed']: n_well_formed += 1\n",
186
+ " if (i + 1) % 20 == 0:\n",
187
+ " print(f' [{i+1}/{len(eval_clip_ids)}] running mean reward = {sum(eval_rewards)/len(eval_rewards):.3f}, '\n",
188
+ " f'correct so far = {n_correct}/{i+1}', flush=True)\n",
189
+ "\n",
190
+ "n_pivot = sum(1 for r in results if r['is_pivot'])\n",
191
+ "n_pivot_correct = sum(1 for r in results if r['is_pivot'] and r['correct'])\n",
192
+ "summary = {\n",
193
+ " 'n_eval_clips': len(eval_clip_ids),\n",
194
+ " 'mean_reward': sum(eval_rewards) / max(1, len(eval_rewards)),\n",
195
+ " 'well_formed_rate': n_well_formed / max(1, len(eval_clip_ids)),\n",
196
+ " 'accuracy': n_correct / max(1, len(eval_clip_ids)),\n",
197
+ " 'pivot_in_eval': n_pivot,\n",
198
+ " 'pivot_correct': n_pivot_correct,\n",
199
+ " 'results': results,\n",
200
+ "}\n",
201
+ "Path(OUTPUT_DIR, 'held_out_eval.json').write_text(json.dumps(summary, indent=2))\n",
202
+ "print(f\"\\nHELD-OUT: mean_reward={summary['mean_reward']:.3f}, accuracy={summary['accuracy']:.2%} ({n_correct}/{len(eval_clip_ids)})\")\n",
203
+ "print(f\" pivot accuracy: {n_pivot_correct}/{n_pivot}\")"
204
  ]
205
  },
206
  {
207
  "cell_type": "markdown",
208
  "metadata": {},
209
+ "source": "## Reward curve"
210
+ },
211
+ {
212
+ "cell_type": "code",
213
+ "execution_count": null,
214
+ "metadata": {},
215
+ "outputs": [],
216
  "source": [
217
+ "import matplotlib.pyplot as plt\n",
218
+ "log = json.loads(Path(OUTPUT_DIR, 'log_history.json').read_text())\n",
219
+ "steps = [e['step'] for e in log if 'reward' in e]\n",
220
+ "rewards = [e['reward'] for e in log if 'reward' in e]\n",
221
+ "plt.figure(figsize=(8, 4))\n",
222
+ "plt.plot(steps, rewards, alpha=0.4, label='per-step')\n",
223
+ "if len(rewards) >= 10:\n",
224
+ " import numpy as np\n",
225
+ " ema = np.array(rewards)\n",
226
+ " for i in range(1, len(ema)):\n",
227
+ " ema[i] = 0.9 * ema[i-1] + 0.1 * ema[i]\n",
228
+ " plt.plot(steps, ema, linewidth=2, label='EMA(0.9)')\n",
229
+ "plt.xlabel('GRPO step'); plt.ylabel('reward'); plt.legend(); plt.grid(alpha=0.3)\n",
230
+ "plt.title('Subtext Arena — GRPO training reward')\n",
231
+ "plt.tight_layout(); plt.show()"
232
  ]
233
  },
234
+ {
235
+ "cell_type": "markdown",
236
+ "metadata": {},
237
+ "source": "## Push to Hub (optional)\n\nSet `PUSH_TO_HUB` at the top, then `huggingface-cli login` first or set `HF_TOKEN` in Colab secrets."
238
+ },
239
  {
240
  "cell_type": "code",
241
  "execution_count": null,
242
  "metadata": {},
243
  "outputs": [],
244
  "source": [
245
+ "if PUSH_TO_HUB:\n",
246
+ " from huggingface_hub import HfApi\n",
247
+ " api = HfApi()\n",
248
+ " api.create_repo(repo_id=PUSH_TO_HUB, repo_type='model', exist_ok=True)\n",
249
+ " api.upload_folder(\n",
250
+ " folder_path=OUTPUT_DIR, repo_id=PUSH_TO_HUB, repo_type='model',\n",
251
+ " commit_message=f'GRPO ({MAX_STEPS} steps, lr={LEARNING_RATE})',\n",
252
+ " )\n",
253
+ " print(f'LoRA pushed to https://huggingface.co/{PUSH_TO_HUB}')\n",
254
+ "else:\n",
255
+ " print('PUSH_TO_HUB is None — skipping')"
256
  ]
257
  }
258
  ],
259
  "metadata": {
260
+ "kernelspec": {
261
+ "display_name": "Python 3",
262
+ "language": "python",
263
+ "name": "python3"
264
+ },
265
+ "language_info": {
266
+ "name": "python",
267
+ "version": "3.10"
268
+ }
269
  },
270
  "nbformat": 4,
271
  "nbformat_minor": 5
272
+ }