akhiilll commited on
Commit
43372d5
·
verified ·
1 Parent(s): 2581673

wire Colab notebook to TRL GRPOTrainer (real LoRA weight updates)

Browse files
Files changed (2) hide show
  1. README.md +14 -5
  2. training/InsureClaim_Training_Colab.ipynb +139 -673
README.md CHANGED
@@ -32,7 +32,7 @@ tags:
32
  | **Live Space** | <https://huggingface.co/spaces/akhiilll/claims-env> |
33
  | **API root** | <https://akhiilll-claims-env.hf.space> · `/health` · `/api` · `/docs` |
34
  | **WebSocket** | `wss://akhiilll-claims-env.hf.space/ws` |
35
- | **Training (Colab)** | [`training/InsureClaim_Training_Colab.ipynb`](training/InsureClaim_Training_Colab.ipynb) — Unsloth + TRL |
36
  | **Training (HF Job, 4×A10G)** | [`training/train_local_hf.py`](training/train_local_hf.py) |
37
  | **Latest run artifacts** | [`runs/20260425-215059/`](runs/20260425-215059) |
38
 
@@ -202,9 +202,18 @@ print(job.url)
202
 
203
  The job streams to `runs/<timestamp>/{reward_curves.png,reward_summary.json}` automatically.
204
 
205
- ### 4.4 Train with TRL + Unsloth in Colab
206
 
207
- Open [`training/InsureClaim_Training_Colab.ipynb`](training/InsureClaim_Training_Colab.ipynb) on a free T4. The notebook loads `unsloth/Qwen2.5-1.5B-Instruct` in 4-bit with LoRA adapters, connects to **this** deployed Space over WebSocket, runs a REINFORCE-style policy-gradient loop, and saves reward-curve PNGs.
 
 
 
 
 
 
 
 
 
208
 
209
  ## 5. Repo layout
210
 
@@ -242,9 +251,9 @@ Open [`training/InsureClaim_Training_Colab.ipynb`](training/InsureClaim_Training
242
 
243
  ## 7. What we'll do next (post-deadline)
244
 
245
- * **GRPO with TRL + Unsloth** — replace the REINFORCE-style notebook with a GRPO loop so the LLM's *weights* update on the per-component rewards (currently we only do online rollouts).
246
- * **Curriculum** — start episodes only on the routine-approval cases, then unlock fraud / lapsed-policy / escalation cases as `final_avg` crosses thresholds.
247
  * **Process supervision** — reward correct *intermediate* tool selection (e.g. running `check_fraud` before approving a high-amount auto-theft claim), not just terminal verdicts.
 
248
 
249
  ## 8. Materials & links
250
 
 
32
  | **Live Space** | <https://huggingface.co/spaces/akhiilll/claims-env> |
33
  | **API root** | <https://akhiilll-claims-env.hf.space> · `/health` · `/api` · `/docs` |
34
  | **WebSocket** | `wss://akhiilll-claims-env.hf.space/ws` |
35
+ | **Training (Colab, GRPO)** | [`training/InsureClaim_Training_Colab.ipynb`](training/InsureClaim_Training_Colab.ipynb) — Unsloth + TRL `GRPOTrainer` (real LoRA weight updates) |
36
  | **Training (HF Job, 4×A10G)** | [`training/train_local_hf.py`](training/train_local_hf.py) |
37
  | **Latest run artifacts** | [`runs/20260425-215059/`](runs/20260425-215059) |
38
 
 
202
 
203
  The job streams to `runs/<timestamp>/{reward_curves.png,reward_summary.json}` automatically.
204
 
205
+ ### 4.4 Train with TRL `GRPOTrainer` + Unsloth in Colab (real weight updates)
206
 
207
+ Open [`training/InsureClaim_Training_Colab.ipynb`](training/InsureClaim_Training_Colab.ipynb) on a free Colab T4. The notebook:
208
+
209
+ 1. Clones this Space repo so the gym runs **in-process** in Colab and is fully deterministic per case (`scenario_index = 0..7`).
210
+ 2. Loads `unsloth/Qwen2.5-1.5B-Instruct` in 4-bit with LoRA `r=16, alpha=32` adapters (~12-15 M trainable params).
211
+ 3. Builds a prompt dataset where each row is pinned to one of the 8 curated cases.
212
+ 4. Defines **two independent reward functions** (anti-reward-hack pattern from the hackathon guide):
213
+ - `format_reward_fn` — was the completion parseable and did it end in a terminal verb?
214
+ - `env_reward_fn` — replays the trajectory inside the deterministic gym, returns cumulative env reward.
215
+ 5. Trains with `trl.GRPOTrainer` (`num_generations=4`, `epsilon=0.2` PPO clip, `beta=0.04` KL), logs reward / KL / completion-length.
216
+ 6. Plots curves, runs a per-case **before-vs-after rollout** so judges can see behaviour change, saves the LoRA adapter (with optional `push_to_hub`).
217
 
218
  ## 5. Repo layout
219
 
 
251
 
252
  ## 7. What we'll do next (post-deadline)
253
 
254
+ * **Curriculum** — start GRPO episodes only on the routine-approval cases, then unlock fraud / lapsed-policy / escalation cases as `final_avg` crosses thresholds.
 
255
  * **Process supervision** — reward correct *intermediate* tool selection (e.g. running `check_fraud` before approving a high-amount auto-theft claim), not just terminal verdicts.
256
+ * **Push trained adapter to the Hub** — once GRPO finishes in Colab, `push_to_hub("akhiilll/claims-grpo-qwen2.5-1.5b")` so a one-line `from_pretrained` reproduces the trained agent.
257
 
258
  ## 8. Materials & links
259
 
training/InsureClaim_Training_Colab.ipynb CHANGED
@@ -1,674 +1,140 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "metadata": {
6
- "id": "header"
7
- },
8
- "source": [
9
- "# InsureClaim AI - RL Training on the ClaimSense OpenEnv Gym\n",
10
- "\n",
11
- "> Apr 2026 OpenEnv Hackathon - Theme 3.1 (Professional Tasks) + Theme 2 (Long-Horizon Planning)\n",
12
- "\n",
13
- "This notebook trains `unsloth/Qwen2.5-1.5B-Instruct` against the live\n",
14
- "ClaimSense Space (https://huggingface.co/spaces/akhiilll/claims-env)\n",
15
- "on a free Colab T4 using Unsloth + TRL. Open in Colab, click *Run all*,\n",
16
- "and the trained reward curve drops out as `reward_curves.png`.\n",
17
- "\n",
18
- "# (legacy header below kept for reference)\n",
19
- "# InsureClaim AI - RL Training with Unsloth\n",
20
- "\n",
21
- "**OpenEnv Hackathon | Statement 3.1 + Scaler AI Labs**\n",
22
- "\n",
23
- "This notebook demonstrates training an LLM to process insurance claims using:\n",
24
- "- **Unsloth** for efficient 4-bit model loading\n",
25
- "- **TRL** for reinforcement learning\n",
26
- "- **OpenEnv** for the claims processing environment\n",
27
- "\n",
28
- "## Results Preview\n",
29
- "- Starting reward: **-5.5**\n",
30
- "- Final reward: **+11.75**\n",
31
- "- Improvement: **+17.25**\n",
32
- "- Fraud detection: **+17.4** max reward"
33
- ]
34
- },
35
- {
36
- "cell_type": "markdown",
37
- "metadata": {
38
- "id": "install_header"
39
- },
40
- "source": [
41
- "## 1️⃣ Install Dependencies"
42
- ]
43
- },
44
- {
45
- "cell_type": "code",
46
- "metadata": {
47
- "id": "install"
48
- },
49
- "source": [
50
- "%%capture\n",
51
- "# Install Unsloth (optimized for Colab)\n",
52
- "!pip install \"unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git\"\n",
53
- "!pip install --no-deps trl peft accelerate bitsandbytes\n",
54
- "\n",
55
- "# Install environment dependencies\n",
56
- "!pip install websockets nest_asyncio certifi matplotlib\n",
57
- "\n",
58
- "print(\" Dependencies installed!\")"
59
- ],
60
- "execution_count": null,
61
- "outputs": []
62
- },
63
- {
64
- "cell_type": "markdown",
65
- "metadata": {
66
- "id": "model_header"
67
- },
68
- "source": [
69
- "## 2️⃣ Load Model with Unsloth (4-bit quantization)"
70
- ]
71
- },
72
- {
73
- "cell_type": "code",
74
- "metadata": {
75
- "id": "load_model"
76
- },
77
- "source": [
78
- "from unsloth import FastLanguageModel\n",
79
- "import torch\n",
80
- "\n",
81
- "# Check GPU\n",
82
- "print(f\"GPU Available: {torch.cuda.is_available()}\")\n",
83
- "if torch.cuda.is_available():\n",
84
- " print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
85
- " print(f\"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB\")\n",
86
- "\n",
87
- "# Load model with Unsloth (4x faster, 70% less memory)\n",
88
- "model, tokenizer = FastLanguageModel.from_pretrained(\n",
89
- " model_name=\"unsloth/Qwen2.5-1.5B-Instruct\",\n",
90
- " max_seq_length=2048,\n",
91
- " load_in_4bit=True,\n",
92
- " dtype=None, # auto-detect\n",
93
- ")\n",
94
- "\n",
95
- "# Add LoRA adapters for efficient fine-tuning\n",
96
- "model = FastLanguageModel.get_peft_model(\n",
97
- " model,\n",
98
- " r=16,\n",
99
- " target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
100
- " \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
101
- " lora_alpha=16,\n",
102
- " lora_dropout=0,\n",
103
- " bias=\"none\",\n",
104
- " use_gradient_checkpointing=\"unsloth\",\n",
105
- " random_state=42,\n",
106
- ")\n",
107
- "\n",
108
- "# Ensure pad token\n",
109
- "if tokenizer.pad_token is None:\n",
110
- " tokenizer.pad_token = tokenizer.eos_token\n",
111
- "\n",
112
- "print(\"\\n✅ Model loaded with Unsloth + LoRA!\")\n",
113
- "print(f\"Trainable parameters: {model.print_trainable_parameters()}\")"
114
- ],
115
- "execution_count": null,
116
- "outputs": []
117
- },
118
- {
119
- "cell_type": "markdown",
120
- "metadata": {
121
- "id": "env_header"
122
- },
123
- "source": [
124
- "## 3️⃣ Connect to Claims Environment"
125
- ]
126
- },
127
- {
128
- "cell_type": "code",
129
- "metadata": {
130
- "id": "connect_env"
131
- },
132
- "source": [
133
- "import asyncio\n",
134
- "import websockets\n",
135
- "import json\n",
136
- "import ssl\n",
137
- "import certifi\n",
138
- "import nest_asyncio\n",
139
- "\n",
140
- "# Fix for Colab event loop\n",
141
- "nest_asyncio.apply()\n",
142
- "\n",
143
- "# Environment URLs\n",
144
- "ENV_URL = \"https://akhiilll-claims-env.hf.space\"\n",
145
- "WS_URL = \"wss://akhiilll-claims-env.hf.space/ws\"\n",
146
- "\n",
147
- "# SSL context for Colab\n",
148
- "ssl_context = ssl.create_default_context(cafile=certifi.where())\n",
149
- "\n",
150
- "# Test connection\n",
151
- "import httpx\n",
152
- "response = httpx.get(f\"{ENV_URL}/health\", timeout=30)\n",
153
- "print(f\"Health check: {response.json()}\")\n",
154
- "\n",
155
- "# Test WebSocket with one episode\n",
156
- "async def test_environment():\n",
157
- " async with websockets.connect(WS_URL, ssl=ssl_context) as ws:\n",
158
- " await ws.send('{\"type\": \"reset\", \"data\": {}}')\n",
159
- " response = json.loads(await ws.recv())\n",
160
- " obs = response[\"data\"][\"observation\"]\n",
161
- " print(f\"\\n📋 Test Claim: {obs['claim_id']}\")\n",
162
- " print(f\" Type: {obs['claim_type']}\")\n",
163
- " print(f\" Amount: ${obs['claim_amount_requested']:,.2f}\")\n",
164
- "\n",
165
- " # Quick action test\n",
166
- " await ws.send('{\"type\": \"step\", \"data\": {\"action_type\": \"query_policy\"}}')\n",
167
- " response = json.loads(await ws.recv())\n",
168
- " reward = response[\"data\"].get(\"reward\", 0)\n",
169
- " print(f\" query_policy reward: {reward}\")\n",
170
- "\n",
171
- " await ws.send('{\"type\": \"close\", \"data\": {}}')\n",
172
- " return True\n",
173
- "\n",
174
- "asyncio.get_event_loop().run_until_complete(test_environment())\n",
175
- "print(\"\\n✅ Environment connected!\")"
176
- ],
177
- "execution_count": null,
178
- "outputs": []
179
- },
180
- {
181
- "cell_type": "markdown",
182
- "metadata": {
183
- "id": "components_header"
184
- },
185
- "source": [
186
- "## 4️⃣ Define Training Components"
187
- ]
188
- },
189
- {
190
- "cell_type": "code",
191
- "metadata": {
192
- "id": "components"
193
- },
194
- "source": [
195
- "import re\n",
196
- "from dataclasses import dataclass\n",
197
- "from typing import List, Dict, Any, Tuple\n",
198
- "\n",
199
- "# System prompt for claims adjuster\n",
200
- "SYSTEM_PROMPT = \"\"\"You are an expert insurance claims adjuster. Process claims efficiently and accurately.\n",
201
- "\n",
202
- "Available actions:\n",
203
- "- query_policy: Look up policy details\n",
204
- "- check_fraud: Run fraud detection\n",
205
- "- verify_purchase: Verify via Plaid transactions\n",
206
- "- approve: Approve claim (include amount)\n",
207
- "- deny: Deny claim (include reason)\n",
208
- "- escalate: Escalate to senior adjuster\n",
209
- "\n",
210
- "Respond with just the action, e.g., 'query_policy' or 'approve 3500' or 'deny fraud detected'.\"\"\"\n",
211
- "\n",
212
- "def format_observation(obs: dict) -> str:\n",
213
- " \"\"\"Format observation for LLM.\"\"\"\n",
214
- " text = f\"\"\"Claim: {obs.get('claim_id', 'N/A')}\n",
215
- "Type: {obs.get('claim_type', 'N/A')}\n",
216
- "Amount: ${obs.get('claim_amount_requested', 0):,.2f}\n",
217
- "Description: {obs.get('description', 'N/A')}\n",
218
- "\n",
219
- "System: {obs.get('system_response', 'Ready')}\"\"\"\n",
220
- "\n",
221
- " if obs.get('revealed_info'):\n",
222
- " info = obs['revealed_info']\n",
223
- " if 'fraud_analysis' in info:\n",
224
- " fa = info['fraud_analysis']\n",
225
- " text += f\"\\n\\nFraud Risk: {fa.get('risk_score', 0):.2f}\"\n",
226
- " if fa.get('flags'):\n",
227
- " text += f\" | Flags: {', '.join(fa['flags'])}\"\n",
228
- "\n",
229
- " return text\n",
230
- "\n",
231
- "def parse_action(response: str, claim_amount: float) -> dict:\n",
232
- " \"\"\"Parse LLM response to action.\"\"\"\n",
233
- " response = response.lower().strip()\n",
234
- "\n",
235
- " # Terminal actions\n",
236
- " if \"approve\" in response:\n",
237
- " match = re.search(r'(\\d+(?:\\.\\d+)?)', response)\n",
238
- " payout = float(match.group(1)) if match else claim_amount\n",
239
- " return {\"action_type\": \"approve\", \"parameters\": {\"payout\": payout}}\n",
240
- "\n",
241
- " if \"deny\" in response:\n",
242
- " return {\"action_type\": \"deny\", \"parameters\": {\"reason\": \"Denied after review\"}}\n",
243
- "\n",
244
- " if \"escalate\" in response:\n",
245
- " return {\"action_type\": \"escalate\", \"parameters\": {\"reason\": \"Needs review\"}}\n",
246
- "\n",
247
- " # Information gathering\n",
248
- " if \"fraud\" in response:\n",
249
- " return {\"action_type\": \"check_fraud\", \"parameters\": {}}\n",
250
- " if \"policy\" in response:\n",
251
- " return {\"action_type\": \"query_policy\", \"parameters\": {}}\n",
252
- " if \"purchase\" in response or \"plaid\" in response:\n",
253
- " return {\"action_type\": \"verify_purchase\", \"parameters\": {}}\n",
254
- "\n",
255
- " # Default\n",
256
- " return {\"action_type\": \"query_policy\", \"parameters\": {}}\n",
257
- "\n",
258
- "@dataclass\n",
259
- "class Experience:\n",
260
- " \"\"\"Single step experience for training.\"\"\"\n",
261
- " prompt: str\n",
262
- " response: str\n",
263
- " reward: float\n",
264
- " action: str\n",
265
- "\n",
266
- "print(\"✅ Training components defined!\")"
267
- ],
268
- "execution_count": null,
269
- "outputs": []
270
- },
271
- {
272
- "cell_type": "markdown",
273
- "metadata": {
274
- "id": "training_header"
275
- },
276
- "source": [
277
- "## 5️⃣ Training Loop with Policy Gradient\n",
278
- "\n",
279
- "This implements a simplified REINFORCE algorithm:\n",
280
- "1. Generate actions using the model\n",
281
- "2. Collect rewards from environment\n",
282
- "3. Update model to favor high-reward actions"
283
- ]
284
- },
285
- {
286
- "cell_type": "code",
287
- "metadata": {
288
- "id": "training_loop"
289
- },
290
- "source": [
291
- "from torch.optim import AdamW\n",
292
- "import random\n",
293
- "\n",
294
- "# Training configuration\n",
295
- "NUM_EPISODES = 50\n",
296
- "MAX_STEPS = 8\n",
297
- "LEARNING_RATE = 2e-5\n",
298
- "BASELINE_REWARD = 0.0 # For variance reduction\n",
299
- "\n",
300
- "# Optimizer\n",
301
- "optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)\n",
302
- "\n",
303
- "# Metrics\n",
304
- "episode_rewards = []\n",
305
- "running_avg_rewards = []\n",
306
- "losses = []\n",
307
- "\n",
308
- "async def run_episode_with_training(episode_num: int, debug: bool = False):\n",
309
- " \"\"\"Run episode and collect experiences for training.\"\"\"\n",
310
- " global BASELINE_REWARD\n",
311
- "\n",
312
- " experiences = []\n",
313
- " episode_reward = 0\n",
314
- "\n",
315
- " try:\n",
316
- " async with websockets.connect(WS_URL, ssl=ssl_context, close_timeout=15) as ws:\n",
317
- " # Reset\n",
318
- " await ws.send(json.dumps({\"type\": \"reset\", \"data\": {}}))\n",
319
- " response = json.loads(await ws.recv())\n",
320
- " obs = response[\"data\"][\"observation\"]\n",
321
- " claim_amount = obs.get('claim_amount_requested', 0)\n",
322
- "\n",
323
- " if debug:\n",
324
- " print(f\" Claim: {obs['claim_id']} - ${claim_amount:,.0f}\")\n",
325
- "\n",
326
- " done = False\n",
327
- " step = 0\n",
328
- "\n",
329
- " while not done and step < MAX_STEPS:\n",
330
- " # Format prompt\n",
331
- " prompt = f\"{SYSTEM_PROMPT}\\n\\n{format_observation(obs)}\\n\\nAction:\"\n",
332
- "\n",
333
- " # Generate with model\n",
334
- " inputs = tokenizer(prompt, return_tensors=\"pt\", truncation=True, max_length=1024)\n",
335
- " inputs = {k: v.to(model.device) for k, v in inputs.items()}\n",
336
- "\n",
337
- " # Exploration: mix model output with random actions early on\n",
338
- " explore_rate = max(0.1, 1.0 - episode_num / 30)\n",
339
- "\n",
340
- " if random.random() < explore_rate and step < 3:\n",
341
- " # Explore: random action\n",
342
- " actions = [\"query_policy\", \"check_fraud\", \"verify_purchase\"]\n",
343
- " response_text = random.choice(actions)\n",
344
- " else:\n",
345
- " # Exploit: use model\n",
346
- " with torch.no_grad():\n",
347
- " outputs = model.generate(\n",
348
- " **inputs,\n",
349
- " max_new_tokens=20,\n",
350
- " temperature=0.7,\n",
351
- " do_sample=True,\n",
352
- " pad_token_id=tokenizer.pad_token_id,\n",
353
- " )\n",
354
- " response_text = tokenizer.decode(\n",
355
- " outputs[0][inputs['input_ids'].shape[1]:],\n",
356
- " skip_special_tokens=True\n",
357
- " )\n",
358
- "\n",
359
- " # Parse action\n",
360
- " action = parse_action(response_text, claim_amount)\n",
361
- "\n",
362
- " if debug:\n",
363
- " print(f\" Step {step}: {action['action_type']} ('{response_text[:30]}...')\")\n",
364
- "\n",
365
- " # Execute in environment\n",
366
- " await ws.send(json.dumps({\"type\": \"step\", \"data\": action}))\n",
367
- " env_response = json.loads(await ws.recv())\n",
368
- "\n",
369
- " obs = env_response[\"data\"][\"observation\"]\n",
370
- " reward = env_response[\"data\"].get(\"reward\") or 0\n",
371
- " done = env_response[\"data\"].get(\"done\", False) or obs.get('is_terminal', False)\n",
372
- "\n",
373
- " # Store experience\n",
374
- " experiences.append(Experience(\n",
375
- " prompt=prompt,\n",
376
- " response=response_text,\n",
377
- " reward=reward,\n",
378
- " action=action['action_type']\n",
379
- " ))\n",
380
- "\n",
381
- " episode_reward += reward\n",
382
- " step += 1\n",
383
- "\n",
384
- " if debug:\n",
385
- " print(f\" reward={reward:+.2f}, done={done}\")\n",
386
- "\n",
387
- " await ws.send(json.dumps({\"type\": \"close\", \"data\": {}}))\n",
388
- "\n",
389
- " except Exception as e:\n",
390
- " if debug:\n",
391
- " print(f\" Error: {e}\")\n",
392
- " return -5.0, [], 0.0\n",
393
- "\n",
394
- " # Compute advantage for policy gradient\n",
395
- " advantage = episode_reward - BASELINE_REWARD\n",
396
- "\n",
397
- " # Update baseline with moving average\n",
398
- " BASELINE_REWARD = 0.9 * BASELINE_REWARD + 0.1 * episode_reward\n",
399
- "\n",
400
- " # Return the advantage as \"loss\" for tracking\n",
401
- " return episode_reward, experiences, abs(advantage)\n",
402
- "\n",
403
- "print(\"✅ Training loop defined!\")"
404
- ],
405
- "execution_count": null,
406
- "outputs": []
407
- },
408
- {
409
- "cell_type": "markdown",
410
- "metadata": {
411
- "id": "run_header"
412
- },
413
- "source": [
414
- "## 6️⃣ Run Training"
415
- ]
416
- },
417
- {
418
- "cell_type": "code",
419
- "metadata": {
420
- "id": "run_training"
421
- },
422
- "source": [
423
- "print(\"=\" * 60)\n",
424
- "print(\"🚀 Starting Training\")\n",
425
- "print(f\" Episodes: {NUM_EPISODES}\")\n",
426
- "print(f\" Max steps: {MAX_STEPS}\")\n",
427
- "print(f\" Exploration-based learning with reward signal\")\n",
428
- "print(\"=\" * 60)\n",
429
- "\n",
430
- "# Debug first episode\n",
431
- "print(\"\\n📋 Debug Episode 1:\")\n",
432
- "reward, exps, adv = asyncio.get_event_loop().run_until_complete(\n",
433
- " run_episode_with_training(0, debug=True)\n",
434
- ")\n",
435
- "episode_rewards.append(reward)\n",
436
- "running_avg_rewards.append(reward)\n",
437
- "losses.append(adv)\n",
438
- "print(f\"\\n Episode 1: reward={reward:+.2f}, advantage={adv:.2f}\")\n",
439
- "\n",
440
- "# Training loop\n",
441
- "print(f\"\\n{'='*60}\")\n",
442
- "print(\"Training Progress:\")\n",
443
- "print(f\"{'='*60}\")\n",
444
- "\n",
445
- "for episode in range(1, NUM_EPISODES):\n",
446
- " # Run episode\n",
447
- " reward, experiences, advantage = asyncio.get_event_loop().run_until_complete(\n",
448
- " run_episode_with_training(episode, debug=False)\n",
449
- " )\n",
450
- "\n",
451
- " # Track metrics\n",
452
- " episode_rewards.append(reward)\n",
453
- " window = min(10, len(episode_rewards))\n",
454
- " running_avg = sum(episode_rewards[-window:]) / window\n",
455
- " running_avg_rewards.append(running_avg)\n",
456
- " losses.append(advantage)\n",
457
- "\n",
458
- " # Note: In a full implementation, we'd update model weights here\n",
459
- " # For this demo, the exploration rate decay serves as the \"learning\" mechanism\n",
460
- " # Early episodes explore randomly, later episodes use the model more\n",
461
- " # This demonstrates the environment produces meaningful reward signals\n",
462
- "\n",
463
- " # Log progress\n",
464
- " if (episode + 1) % 5 == 0:\n",
465
- " print(f\"Episode {episode+1:3d}/{NUM_EPISODES} | \"\n",
466
- " f\"Reward: {reward:+6.1f} | \"\n",
467
- " f\"Avg(10): {running_avg:+6.1f} | \"\n",
468
- " f\"Advantage: {advantage:.2f}\")\n",
469
- "\n",
470
- "print(f\"\\n{'='*60}\")\n",
471
- "print(\"✅ Training Complete!\")\n",
472
- "print(f\"{'='*60}\")\n",
473
- "print(f\"Final running average: {running_avg_rewards[-1]:+.2f}\")\n",
474
- "print(f\"Improvement: {running_avg_rewards[-1] - running_avg_rewards[0]:+.2f}\")\n",
475
- "print(f\"Reward range: [{min(episode_rewards):.1f}, {max(episode_rewards):.1f}]\")"
476
- ],
477
- "execution_count": null,
478
- "outputs": []
479
- },
480
- {
481
- "cell_type": "markdown",
482
- "metadata": {
483
- "id": "plot_header"
484
- },
485
- "source": [
486
- "## 7️⃣ Plot Reward Curves (Required for Judging)"
487
- ]
488
- },
489
- {
490
- "cell_type": "code",
491
- "metadata": {
492
- "id": "plot"
493
- },
494
- "source": [
495
- "import matplotlib.pyplot as plt\n",
496
- "\n",
497
- "fig, axes = plt.subplots(1, 3, figsize=(15, 4))\n",
498
- "\n",
499
- "# Plot 1: Episode Rewards\n",
500
- "ax1 = axes[0]\n",
501
- "ax1.plot(episode_rewards, alpha=0.5, label='Episode Reward', color='blue')\n",
502
- "ax1.plot(running_avg_rewards, linewidth=2, label='Running Avg (10)', color='red')\n",
503
- "ax1.axhline(y=0, color='gray', linestyle='--', alpha=0.5)\n",
504
- "ax1.set_xlabel('Episode', fontsize=12)\n",
505
- "ax1.set_ylabel('Reward', fontsize=12)\n",
506
- "ax1.set_title('Training Progress', fontsize=14)\n",
507
- "ax1.legend()\n",
508
- "ax1.grid(True, alpha=0.3)\n",
509
- "\n",
510
- "# Plot 2: Reward Distribution\n",
511
- "ax2 = axes[1]\n",
512
- "ax2.hist(episode_rewards, bins=15, edgecolor='black', alpha=0.7, color='green')\n",
513
- "ax2.axvline(x=0, color='red', linestyle='--', label='Break-even')\n",
514
- "ax2.axvline(x=sum(episode_rewards)/len(episode_rewards), color='blue',\n",
515
- " linestyle='-', linewidth=2, label=f'Mean: {sum(episode_rewards)/len(episode_rewards):.1f}')\n",
516
- "ax2.set_xlabel('Reward', fontsize=12)\n",
517
- "ax2.set_ylabel('Frequency', fontsize=12)\n",
518
- "ax2.set_title('Reward Distribution', fontsize=14)\n",
519
- "ax2.legend()\n",
520
- "ax2.grid(True, alpha=0.3)\n",
521
- "\n",
522
- "# Plot 3: Advantage (reward - baseline)\n",
523
- "ax3 = axes[2]\n",
524
- "ax3.plot(losses, alpha=0.7, color='purple')\n",
525
- "ax3.axhline(y=0, color='gray', linestyle='--', alpha=0.5)\n",
526
- "ax3.set_xlabel('Episode', fontsize=12)\n",
527
- "ax3.set_ylabel('|Advantage|', fontsize=12)\n",
528
- "ax3.set_title('Advantage Over Baseline', fontsize=14)\n",
529
- "ax3.grid(True, alpha=0.3)\n",
530
- "\n",
531
- "plt.tight_layout()\n",
532
- "plt.savefig('reward_curves.png', dpi=150, bbox_inches='tight')\n",
533
- "plt.show()\n",
534
- "\n",
535
- "print(\"\\n✅ Saved: reward_curves.png\")"
536
- ],
537
- "execution_count": null,
538
- "outputs": []
539
- },
540
- {
541
- "cell_type": "markdown",
542
- "metadata": {
543
- "id": "demo_header"
544
- },
545
- "source": [
546
- "## 8️⃣ Demo: Watch Trained Agent"
547
- ]
548
- },
549
- {
550
- "cell_type": "code",
551
- "metadata": {
552
- "id": "demo"
553
- },
554
- "source": [
555
- "async def demo_trained_agent():\n",
556
- " \"\"\"Demo the trained agent processing a claim.\"\"\"\n",
557
- " print(\"=\" * 60)\n",
558
- " print(\"🎯 DEMO: Trained Agent Processing Claim\")\n",
559
- " print(\"=\" * 60)\n",
560
- "\n",
561
- " async with websockets.connect(WS_URL, ssl=ssl_context) as ws:\n",
562
- " await ws.send(json.dumps({\"type\": \"reset\", \"data\": {}}))\n",
563
- " response = json.loads(await ws.recv())\n",
564
- " obs = response[\"data\"][\"observation\"]\n",
565
- "\n",
566
- " print(f\"\\n📋 Claim: {obs['claim_id']}\")\n",
567
- " print(f\" Type: {obs['claim_type']}\")\n",
568
- " print(f\" Amount: ${obs['claim_amount_requested']:,.2f}\")\n",
569
- " print(f\" Description: {obs['description']}\")\n",
570
- "\n",
571
- " claim_amount = obs['claim_amount_requested']\n",
572
- " done = False\n",
573
- " step = 0\n",
574
- " total_reward = 0\n",
575
- "\n",
576
- " print(\"\\n📝 Processing:\")\n",
577
- "\n",
578
- " while not done and step < 6:\n",
579
- " prompt = f\"{SYSTEM_PROMPT}\\n\\n{format_observation(obs)}\\n\\nAction:\"\n",
580
- "\n",
581
- " inputs = tokenizer(prompt, return_tensors=\"pt\", truncation=True, max_length=1024)\n",
582
- " inputs = {k: v.to(model.device) for k, v in inputs.items()}\n",
583
- "\n",
584
- " with torch.no_grad():\n",
585
- " outputs = model.generate(\n",
586
- " **inputs,\n",
587
- " max_new_tokens=20,\n",
588
- " temperature=0.3, # Lower temp for demo\n",
589
- " do_sample=True,\n",
590
- " pad_token_id=tokenizer.pad_token_id,\n",
591
- " )\n",
592
- "\n",
593
- " response_text = tokenizer.decode(\n",
594
- " outputs[0][inputs['input_ids'].shape[1]:],\n",
595
- " skip_special_tokens=True\n",
596
- " )\n",
597
- "\n",
598
- " action = parse_action(response_text, claim_amount)\n",
599
- "\n",
600
- " print(f\"\\n Step {step + 1}: {action['action_type']}\")\n",
601
- "\n",
602
- " await ws.send(json.dumps({\"type\": \"step\", \"data\": action}))\n",
603
- " env_response = json.loads(await ws.recv())\n",
604
- "\n",
605
- " obs = env_response[\"data\"][\"observation\"]\n",
606
- " reward = env_response[\"data\"].get(\"reward\") or 0\n",
607
- " done = env_response[\"data\"].get(\"done\", False) or obs.get('is_terminal', False)\n",
608
- "\n",
609
- " total_reward += reward\n",
610
- "\n",
611
- " print(f\" Response: {obs['system_response'][:80]}...\")\n",
612
- " print(f\" Reward: {reward:+.2f}\")\n",
613
- "\n",
614
- " step += 1\n",
615
- "\n",
616
- " await ws.send(json.dumps({\"type\": \"close\", \"data\": {}}))\n",
617
- "\n",
618
- " print(f\"\\n{'='*60}\")\n",
619
- " print(f\"✅ Decision: {obs.get('terminal_reason', 'N/A').upper()}\")\n",
620
- " print(f\"💰 Total Reward: {total_reward:+.2f}\")\n",
621
- " print(f\"{'='*60}\")\n",
622
- "\n",
623
- "asyncio.get_event_loop().run_until_complete(demo_trained_agent())"
624
- ],
625
- "execution_count": null,
626
- "outputs": []
627
- },
628
- {
629
- "cell_type": "markdown",
630
- "metadata": {
631
- "id": "summary"
632
- },
633
- "source": [
634
- "## 📊 Summary\n",
635
- "\n",
636
- "This notebook demonstrated:\n",
637
- "\n",
638
- "1. **Unsloth** - 4-bit model loading with LoRA adapters\n",
639
- "2. **TRL** - Policy gradient training infrastructure\n",
640
- "3. **OpenEnv** - Claims processing environment via WebSocket\n",
641
- "4. **Training** - Reward improvement over 50 episodes\n",
642
- "\n",
643
- "### Key Results\n",
644
- "- Starting reward: **-5.5**\n",
645
- "- Final reward: **+11.75**\n",
646
- "- Improvement: **+17.25**\n",
647
- "\n",
648
- "### Links\n",
649
- "- **HF Space**: https://akhiilll-claims-env.hf.space\n",
650
- "- **GitHub**: https://github.com/pramodmisra/claims-env-hackathon\n",
651
- "\n",
652
- "### Hackathon\n",
653
- "- **Problem**: 3.1 - Professional Tasks (World Modeling)\n",
654
- "- **Theme**: Scaler AI Labs - Enterprise Workflows"
655
- ]
656
- }
657
- ],
658
- "metadata": {
659
- "accelerator": "GPU",
660
- "colab": {
661
- "gpuType": "T4",
662
- "provenance": []
663
- },
664
- "kernelspec": {
665
- "display_name": "Python 3",
666
- "name": "python3"
667
- },
668
- "language_info": {
669
- "name": "python"
670
- }
671
- },
672
- "nbformat": 4,
673
- "nbformat_minor": 0
674
  }
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 5,
4
+ "metadata": {
5
+ "kernelspec": {
6
+ "display_name": "Python 3",
7
+ "language": "python",
8
+ "name": "python3"
9
+ },
10
+ "language_info": {
11
+ "name": "python",
12
+ "pygments_lexer": "ipython3"
13
+ },
14
+ "colab": {
15
+ "provenance": [],
16
+ "gpuType": "T4"
17
+ },
18
+ "accelerator": "GPU"
19
+ },
20
+ "cells": [
21
+ {
22
+ "cell_type": "markdown",
23
+ "metadata": {},
24
+ "source": "# ClaimSense GRPO Training (TRL + Unsloth)\n\n> **Apr 2026 OpenEnv Hackathon - India**\n> **Theme 3.1 - World Modeling, Professional Tasks** + **Theme 2 - Long-Horizon Planning**\n\nThis notebook performs **real GRPO weight updates** on\n`unsloth/Qwen2.5-1.5B-Instruct` against the\n[ClaimSense adjudication gym](https://huggingface.co/spaces/akhiilll/claims-env).\n\nThe training loop:\n\n1. Clones the Space repo so the gym runs **in-process** in Colab (deterministic\n per-claim resets via `scenario_index`).\n2. Loads Qwen2.5-1.5B in 4-bit with LoRA adapters via Unsloth (fits a free T4).\n3. Builds a prompt dataset where each row is pinned to a specific case\n (`scenario_index = 0..7`), so the prompt the model sees and the env we\n score against describe the *same* claim.\n4. Defines **two independent reward functions** (multiple independent rewards\n is explicitly recommended by the hackathon guide to combat reward hacking):\n - `format_reward_fn` - did the model emit at least one well-formed\n terminal verb?\n - `env_reward_fn` - cumulative reward from replaying the model's\n trajectory inside the deterministic gym.\n5. Runs `trl.GRPOTrainer.train()` with `num_generations=4` so the per-group\n advantage signal has variance to learn from.\n6. Plots reward curves, does a before/after rollout, and saves the LoRA\n adapter so it can be pushed to the Hub.\n\nRun all cells from a Colab T4. Total runtime: ~25-35 minutes for ~80\ntraining steps. Adjust `NUM_GRPO_STEPS` in the training cell to taste."
25
+ },
26
+ {
27
+ "cell_type": "markdown",
28
+ "metadata": {},
29
+ "source": "## 1. Install dependencies"
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": null,
34
+ "metadata": {},
35
+ "outputs": [],
36
+ "source": "%%capture\n%pip install -q \"unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git\"\n%pip install -q --no-deps \"trl>=0.18\" peft accelerate bitsandbytes datasets\n%pip install -q openenv-core matplotlib hf_transfer\n\nimport os\nos.environ[\"HF_HUB_ENABLE_HF_TRANSFER\"] = \"1\"\nprint(\"deps installed\")"
37
+ },
38
+ {
39
+ "cell_type": "markdown",
40
+ "metadata": {},
41
+ "source": "## 2. Clone the ClaimSense Space (gym runs locally in Colab)\n\nWe avoid network round-trips from inside the GRPO reward function by running\na fresh in-process gym per reward computation. The gym code lives on the\nSpace repo at `akhiilll/claims-env`."
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": null,
46
+ "metadata": {},
47
+ "outputs": [],
48
+ "source": "!rm -rf /content/claims-env-repo\n!git clone https://huggingface.co/spaces/akhiilll/claims-env /content/claims-env-repo\n\nimport sys\nsys.path.insert(0, \"/content/claims-env-repo\")\n\nfrom server.claims_environment import AdjudicationGym, ACTION_VOCABULARY\nfrom server.mock_systems import CASE_LIBRARY\nfrom models import AdjudicatorAction\n\nprint(f\"verbs ({len(ACTION_VOCABULARY)}):\", ACTION_VOCABULARY)\nprint(f\"cases ({len(CASE_LIBRARY)}):\")\nfor i, c in enumerate(CASE_LIBRARY):\n print(f\" [{i}] {c.claim_id:<14} {c.claim_type:<22} ${c.claim_amount:>10,.0f} ({c.complexity})\")"
49
+ },
50
+ {
51
+ "cell_type": "markdown",
52
+ "metadata": {},
53
+ "source": "## 3. Load Qwen2.5-1.5B-Instruct in 4-bit + LoRA (Unsloth)\n\nUnsloth gives ~4x faster RL training and ~70 % less memory than vanilla TRL,\nwhich is what makes GRPO fit on a free T4."
54
+ },
55
+ {
56
+ "cell_type": "code",
57
+ "execution_count": null,
58
+ "metadata": {},
59
+ "outputs": [],
60
+ "source": "from unsloth import FastLanguageModel\nimport torch\n\nprint(\"CUDA :\", torch.cuda.is_available(),\n \"|\", torch.cuda.get_device_name(0) if torch.cuda.is_available() else \"no GPU\")\n\nMAX_SEQ_LENGTH = 1024\n\nmodel, tokenizer = FastLanguageModel.from_pretrained(\n model_name=\"unsloth/Qwen2.5-1.5B-Instruct\",\n max_seq_length=MAX_SEQ_LENGTH,\n load_in_4bit=True,\n dtype=None, # auto\n)\n\nmodel = FastLanguageModel.get_peft_model(\n model,\n r=16,\n lora_alpha=32,\n lora_dropout=0.0,\n target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n \"gate_proj\", \"up_proj\", \"down_proj\"],\n use_gradient_checkpointing=\"unsloth\",\n random_state=42,\n)\n\nif tokenizer.pad_token is None:\n tokenizer.pad_token = tokenizer.eos_token\n\nn_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)\nn_total = sum(p.numel() for p in model.parameters())\nprint(f\"trainable LoRA params: {n_trainable/1e6:.1f}M / {n_total/1e9:.2f}B \"\n f\"({100*n_trainable/n_total:.2f}%)\")"
61
+ },
62
+ {
63
+ "cell_type": "markdown",
64
+ "metadata": {},
65
+ "source": "## 4. Build the GRPO prompt dataset\n\nEach row in the dataset is `(prompt, scenario_index)`. The prompt is already\ntemplated through the chat template so we feed plain strings to GRPO. The\n`scenario_index` column is *passed through* by `GRPOTrainer` to our reward\nfunctions as a kwarg, so we can replay the trajectory against the correct\ndeterministic case."
66
+ },
67
+ {
68
+ "cell_type": "code",
69
+ "execution_count": null,
70
+ "metadata": {},
71
+ "outputs": [],
72
+ "source": "from datasets import Dataset\n\nSYSTEM_PROMPT = (\n \"You are an expert insurance claims adjuster.\\n\"\n \"\\n\"\n \"Available actions (one per line, lowercase, in this order of execution):\\n\"\n \" query_policy\\n\"\n \" query_claim_history\\n\"\n \" check_fraud\\n\"\n \" request_documents\\n\"\n \" verify_coverage\\n\"\n \" verify_purchase\\n\"\n \" calculate_payout\\n\"\n \" approve <amount> (terminal)\\n\"\n \" deny <reason> (terminal)\\n\"\n \" escalate <reason> (terminal)\\n\"\n \"\\n\"\n \"Information actions cost a small fee; correct terminal verdicts pay big.\\n\"\n \"Catching fraud via deny pays even more. Output up to 6 actions, one per\\n\"\n \"line, ending with a terminal action. Do not write anything else.\"\n)\n\n\ndef claim_to_user_msg(scenario_index: int) -> str:\n env = AdjudicationGym(scenario_index=scenario_index)\n obs = env.reset()\n return (\n f\"New claim arrived:\\n\"\n f\" claim_id : {obs.claim_id}\\n\"\n f\" type : {obs.claim_type}\\n\"\n f\" amount : ${obs.claim_amount_requested:,.2f}\\n\"\n f\" claimant : {obs.claimant_name}\\n\"\n f\" incident_date: {obs.incident_date}\\n\"\n f\" description : {obs.description}\\n\"\n f\"\\nWhat is your action plan?\"\n )\n\n\ndef make_prompt(scenario_index: int) -> str:\n msgs = [\n {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n {\"role\": \"user\", \"content\": claim_to_user_msg(scenario_index)},\n ]\n return tokenizer.apply_chat_template(\n msgs, tokenize=False, add_generation_prompt=True\n )\n\n\nCASE_REPEATS = 8 # how many times each of the 8 curated cases appears\nrows = []\nfor repeat in range(CASE_REPEATS):\n for sidx in range(len(CASE_LIBRARY)):\n rows.append({\"prompt\": make_prompt(sidx), \"scenario_index\": sidx})\n\ntrain_ds = Dataset.from_list(rows).shuffle(seed=42)\nprint(f\"dataset rows: {len(train_ds)} | unique cases: {len(CASE_LIBRARY)}\")\nprint()\nprint(\"--- example prompt ---\")\nprint(train_ds[0][\"prompt\"][:900])"
73
+ },
74
+ {
75
+ "cell_type": "markdown",
76
+ "metadata": {},
77
+ "source": "## 5. Reward functions (multiple independent signals)\n\n`format_reward_fn`\n: Did the model emit at least one parseable action and end with a terminal\n verb? Cheap signal, prevents the model from outputting arbitrary text.\n\n`env_reward_fn`\n: Replays the parsed trajectory in a deterministic gym pinned to the same\n `scenario_index` as the prompt. Returns the cumulative env reward\n (between roughly -16 and +20)."
78
+ },
79
+ {
80
+ "cell_type": "code",
81
+ "execution_count": null,
82
+ "metadata": {},
83
+ "outputs": [],
84
+ "source": "import re\n\nACTIONS_SET = set(ACTION_VOCABULARY)\nTERMINALS = {\"approve\", \"deny\", \"escalate\"}\n\n\ndef _coerce_completion(c) -> str:\n if isinstance(c, list): # chat-style completions\n if not c:\n return \"\"\n return c[0].get(\"content\", \"\") if isinstance(c[0], dict) else str(c[0])\n return str(c)\n\n\ndef parse_actions(completion: str) -> list[AdjudicatorAction]:\n actions: list[AdjudicatorAction] = []\n for raw in completion.strip().splitlines():\n line = raw.strip().lstrip(\"-*0123456789. \").lower().strip()\n if not line:\n continue\n parts = line.split(maxsplit=1)\n verb = parts[0]\n if verb not in ACTIONS_SET:\n continue\n params: dict = {}\n rest = parts[1] if len(parts) > 1 else \"\"\n if verb == \"approve\":\n m = re.search(r\"\\d[\\d,\\.]*\", rest)\n if m:\n try:\n params[\"amount\"] = float(m.group().replace(\",\", \"\"))\n except ValueError:\n pass\n elif verb == \"deny\":\n params[\"reason\"] = (rest or \"policy_violation\")[:80]\n elif verb == \"escalate\":\n params[\"reason\"] = (rest or \"manager_review\")[:80]\n actions.append(AdjudicatorAction(action_type=verb, parameters=params))\n if verb in TERMINALS:\n break\n return actions\n\n\ndef replay(actions: list[AdjudicatorAction], scenario_index: int,\n max_steps: int = 8) -> tuple[float, str, int]:\n env = AdjudicationGym(scenario_index=int(scenario_index))\n env.reset()\n total = 0.0\n terminal = \"max_steps\"\n steps = 0\n for act in actions[:max_steps]:\n obs = env.step(act)\n total += float(obs.reward)\n steps += 1\n if obs.done:\n terminal = act.action_type\n break\n return total, terminal, steps\n\n\ndef format_reward_fn(prompts, completions, **kwargs) -> list[float]:\n rewards = []\n for c in completions:\n text = _coerce_completion(c)\n actions = parse_actions(text)\n if not actions:\n rewards.append(-1.0) # zero parseable actions\n continue\n ended_in_terminal = actions[-1].action_type in TERMINALS\n rewards.append(0.5 if ended_in_terminal else -0.25)\n return rewards\n\n\ndef env_reward_fn(prompts, completions, scenario_index, **kwargs) -> list[float]:\n rewards = []\n for c, sidx in zip(completions, scenario_index):\n text = _coerce_completion(c)\n actions = parse_actions(text)\n env_r, _, _ = replay(actions, int(sidx))\n rewards.append(env_r)\n return rewards\n\n\n# Sanity checks\nprint(\"=== sanity check ===\")\noptimal = \"query_policy\\ncheck_fraud\\napprove 3500\"\nr_opt, term, steps = replay(parse_actions(optimal), scenario_index=0)\nprint(f\"optimal trace on case 0 -> reward={r_opt:+.2f} terminal={term} steps={steps}\")\n\nbad = \"approve 99999\" # blind approve on case 0\nr_bad, term, steps = replay(parse_actions(bad), scenario_index=0)\nprint(f\"blind approve on case 0 -> reward={r_bad:+.2f} terminal={term} steps={steps}\")\n\nempty = \"lorem ipsum\"\nr_empty, term, steps = replay(parse_actions(empty), scenario_index=0)\nprint(f\"unparseable on case 0 -> reward={r_empty:+.2f} terminal={term} steps={steps}\")"
85
+ },
86
+ {
87
+ "cell_type": "markdown",
88
+ "metadata": {},
89
+ "source": "## 6. GRPO training (real weight updates)\n\n`num_generations=4` means the trainer samples 4 completions per prompt and\ncomputes per-group advantages. With `per_device_train_batch_size=2`, each\noptimization step uses 2 prompts x 4 completions = 8 rollouts.\n\n`NUM_GRPO_STEPS=80` with `batch_size=2 * num_generations=4 = 8` covers ~640\nrollouts. Bump it once you confirm the loop is healthy."
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": null,
94
+ "metadata": {},
95
+ "outputs": [],
96
+ "source": "from trl import GRPOConfig, GRPOTrainer\n\nNUM_GRPO_STEPS = 80\n\ntraining_args = GRPOConfig(\n output_dir=\"/content/grpo-claims\",\n learning_rate=5e-6,\n adam_beta1=0.9,\n adam_beta2=0.99,\n weight_decay=0.1,\n warmup_ratio=0.1,\n lr_scheduler_type=\"cosine\",\n optim=\"adamw_8bit\",\n logging_steps=1,\n\n per_device_train_batch_size=2,\n gradient_accumulation_steps=2,\n num_generations=4,\n max_prompt_length=512,\n max_completion_length=256,\n\n max_steps=NUM_GRPO_STEPS,\n save_steps=999_999, # we save the adapter manually at the end\n report_to=\"none\",\n bf16=True,\n\n temperature=0.9,\n top_p=0.95,\n epsilon=0.2, # PPO clip\n beta=0.04, # KL penalty vs reference\n)\n\ntrainer = GRPOTrainer(\n model=model,\n processing_class=tokenizer,\n reward_funcs=[format_reward_fn, env_reward_fn],\n args=training_args,\n train_dataset=train_ds,\n)\n\nprint(\"GRPO trainer ready. starting training...\")\ntrainer.train()\nprint(\"training done.\")"
97
+ },
98
+ {
99
+ "cell_type": "markdown",
100
+ "metadata": {},
101
+ "source": "## 7. Plot training curves\n\nWe plot:\n- mean group reward per step\n- mean per-reward-function score (so you can see format-reward saturate first\n and env-reward keep climbing)\n- KL vs reference model\n- mean completion length\n\nThese are the curves judges will look at."
102
+ },
103
+ {
104
+ "cell_type": "code",
105
+ "execution_count": null,
106
+ "metadata": {},
107
+ "outputs": [],
108
+ "source": "import json\nimport matplotlib.pyplot as plt\nfrom pathlib import Path\n\nlog = trainer.state.log_history\nprint(f\"log entries: {len(log)} | sample keys:\")\nprint(set().union(*(r.keys() for r in log[:20])) if log else \"(empty)\")\n\n\ndef series(key):\n xs, ys = [], []\n for entry in log:\n if key in entry and \"step\" in entry:\n xs.append(entry[\"step\"])\n ys.append(entry[key])\n return xs, ys\n\n\nfig, axes = plt.subplots(2, 2, figsize=(13, 8))\n\nxs, ys = series(\"reward\")\naxes[0, 0].plot(xs, ys, color=\"#1f77b4\")\naxes[0, 0].set_title(\"mean group reward\")\naxes[0, 0].set_xlabel(\"training step\")\naxes[0, 0].set_ylabel(\"reward\")\naxes[0, 0].grid(alpha=0.3)\n\n# per-reward-fn scores (TRL emits e.g. \"rewards/format_reward_fn\")\nfmt_xs, fmt_ys = series(\"rewards/format_reward_fn\")\nenv_xs, env_ys = series(\"rewards/env_reward_fn\")\nif not fmt_ys:\n fmt_xs, fmt_ys = series(\"rewards/format_reward_fn/mean\")\n env_xs, env_ys = series(\"rewards/env_reward_fn/mean\")\naxes[0, 1].plot(fmt_xs, fmt_ys, label=\"format reward\", color=\"#2ca02c\")\naxes[0, 1].plot(env_xs, env_ys, label=\"env reward\", color=\"#d62728\")\naxes[0, 1].set_title(\"per-reward-function score\")\naxes[0, 1].set_xlabel(\"training step\")\naxes[0, 1].set_ylabel(\"reward\")\naxes[0, 1].legend()\naxes[0, 1].grid(alpha=0.3)\n\nxs, ys = series(\"kl\")\naxes[1, 0].plot(xs, ys, color=\"#9467bd\")\naxes[1, 0].set_title(\"KL(model || reference)\")\naxes[1, 0].set_xlabel(\"training step\")\naxes[1, 0].set_ylabel(\"kl\")\naxes[1, 0].grid(alpha=0.3)\n\nxs, ys = series(\"completion_length\") or series(\"completions/mean_length\")\naxes[1, 1].plot(xs, ys, color=\"#ff7f0e\")\naxes[1, 1].set_title(\"mean completion length (tokens)\")\naxes[1, 1].set_xlabel(\"training step\")\naxes[1, 1].set_ylabel(\"tokens\")\naxes[1, 1].grid(alpha=0.3)\n\nfig.tight_layout()\nout_dir = Path(\"/content/grpo-claims\")\nout_dir.mkdir(parents=True, exist_ok=True)\nfig.savefig(out_dir / \"grpo_training.png\", dpi=120)\nplt.show()\n\nwith (out_dir / \"training_log.json\").open(\"w\") as fh:\n json.dump(log, fh, indent=2, default=str)\nprint(\"saved:\", out_dir / \"grpo_training.png\")\nprint(\"saved:\", out_dir / \"training_log.json\")"
109
+ },
110
+ {
111
+ "cell_type": "markdown",
112
+ "metadata": {},
113
+ "source": "## 8. Before / after rollout demo\n\nRoll out the trained adapter and a \"no-LoRA\" baseline on the same case and\ncompare environment reward + the actual generated trajectory."
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": null,
118
+ "metadata": {},
119
+ "outputs": [],
120
+ "source": "from peft import PeftModel\nimport statistics\n\nFastLanguageModel.for_inference(model)\n\n\ndef generate(prompt_text: str, max_new_tokens: int = 200) -> str:\n inputs = tokenizer(prompt_text, return_tensors=\"pt\").to(model.device)\n out = model.generate(\n **inputs,\n max_new_tokens=max_new_tokens,\n do_sample=True,\n temperature=0.7,\n top_p=0.9,\n pad_token_id=tokenizer.pad_token_id,\n )\n return tokenizer.decode(out[0, inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n\n\ndef rollout_with_adapter(scenario_index: int, *, with_adapter: bool) -> tuple[str, float]:\n if with_adapter:\n model.enable_adapter_layers()\n else:\n model.disable_adapter_layers()\n text = generate(make_prompt(scenario_index))\n env_r, _, _ = replay(parse_actions(text), scenario_index)\n return text, env_r\n\n\nfor sidx in range(len(CASE_LIBRARY)):\n case = CASE_LIBRARY[sidx]\n print(\"=\" * 72)\n print(f\"case [{sidx}] {case.claim_id} - {case.claim_type} (${case.claim_amount:,.0f})\")\n\n base_text, base_r = rollout_with_adapter(sidx, with_adapter=False)\n print(f\"\\n[BASE no LoRA] env reward = {base_r:+.2f}\")\n print(\"---\")\n print(base_text.strip())\n\n trained_text, trained_r = rollout_with_adapter(sidx, with_adapter=True)\n print(f\"\\n[TRAINED LoRA] env reward = {trained_r:+.2f} delta = {trained_r-base_r:+.2f}\")\n print(\"---\")\n print(trained_text.strip())\n print()\n\n# always re-enable adapter at the end\nmodel.enable_adapter_layers()"
121
+ },
122
+ {
123
+ "cell_type": "markdown",
124
+ "metadata": {},
125
+ "source": "## 9. Save the LoRA adapter\n\nWe save the LoRA adapter (small) and a tiny summary JSON. Optionally push to\nthe Hub - judges can then load it with one line."
126
+ },
127
+ {
128
+ "cell_type": "code",
129
+ "execution_count": null,
130
+ "metadata": {},
131
+ "outputs": [],
132
+ "source": "from pathlib import Path\nimport json\n\nADAPTER_DIR = Path(\"/content/grpo-claims/lora-adapter\")\nmodel.save_pretrained(str(ADAPTER_DIR))\ntokenizer.save_pretrained(str(ADAPTER_DIR))\nprint(\"saved LoRA adapter to:\", ADAPTER_DIR)\n\nsummary = {\n \"base_model\": \"unsloth/Qwen2.5-1.5B-Instruct\",\n \"adapter_method\": \"LoRA r=16, alpha=32\",\n \"trainer\": \"trl.GRPOTrainer\",\n \"num_generations\": 4,\n \"max_steps\": NUM_GRPO_STEPS,\n \"reward_functions\": [\"format_reward_fn\", \"env_reward_fn\"],\n \"env\": \"ClaimSense (https://huggingface.co/spaces/akhiilll/claims-env)\",\n}\nwith open(\"/content/grpo-claims/run_summary.json\", \"w\") as fh:\n json.dump(summary, fh, indent=2)\nprint(json.dumps(summary, indent=2))\n\n\n# OPTIONAL: push the adapter to your namespace.\n# Replace MODEL_REPO with something like \"akhiilll/claims-grpo-qwen2.5-1.5b\".\n#\n# from huggingface_hub import notebook_login\n# notebook_login()\n#\n# MODEL_REPO = \"akhiilll/claims-grpo-qwen2.5-1.5b\"\n# model.push_to_hub(MODEL_REPO)\n# tokenizer.push_to_hub(MODEL_REPO)\nprint(\"(uncomment the push_to_hub block above to publish the adapter)\")"
133
+ },
134
+ {
135
+ "cell_type": "markdown",
136
+ "metadata": {},
137
+ "source": "## Recap\n\nWhat this notebook did:\n\n1. Cloned the OpenEnv-compliant `akhiilll/claims-env` Space into Colab so the\n adjudication gym runs in-process and is deterministic per case.\n2. Loaded `unsloth/Qwen2.5-1.5B-Instruct` 4-bit, attached LoRA r=16 adapters\n (~12-15M trainable params).\n3. Built a prompt dataset where every row is pinned to one of the 8 curated\n cases via `scenario_index`.\n4. Trained for `NUM_GRPO_STEPS` GRPO updates with **two independent reward\n functions** (format + env-replay) - this is the multi-reward, anti-hack\n pattern the hackathon guide explicitly recommends.\n5. Plotted reward / KL / completion-length curves and saved them to disk.\n6. Did a per-case before-vs-after rollout demo so reviewers can see the\n trained adapter's behaviour change.\n7. Saved the LoRA adapter (with an optional `push_to_hub`).\n\n### Links\n- **Environment Space:** https://huggingface.co/spaces/akhiilll/claims-env\n- **Live API:** https://akhiilll-claims-env.hf.space\n- **Repo README:** https://huggingface.co/spaces/akhiilll/claims-env/blob/main/README.md"
138
+ }
139
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  }