OGrohit commited on
Commit
7a0f038
Β·
verified Β·
1 Parent(s): f191fd4

For Judges To Train And Test Script

Browse files
Files changed (1) hide show
  1. LogTriageEnv_Training.ipynb +352 -0
LogTriageEnv_Training.ipynb ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# LogTriageEnv: Training LLM Agents to Triage Production Incidents\n",
8
+ "\n",
9
+ "**Meta Γ— PyTorch Γ— Scaler OpenEnv Grand Finale 2026**\n",
10
+ "\n",
11
+ "This notebook trains an LLM agent with GRPO to identify root causes in cascading production failures.\n",
12
+ "\n",
13
+ "## Quick Info\n",
14
+ "- **GPU:** T4+ required (15GB+ VRAM)\n",
15
+ "- **Time:** 10-15 minutes\n",
16
+ "- **Model:** Auto-selects 32B→7B→3B based on VRAM\n",
17
+ "- **Output:** Trained model + reward curves"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "markdown",
22
+ "metadata": {},
23
+ "source": [
24
+ "## Step 1: Check GPU"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": null,
30
+ "metadata": {},
31
+ "outputs": [],
32
+ "source": [
33
+ "!nvidia-smi"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": null,
39
+ "metadata": {},
40
+ "outputs": [],
41
+ "source": [
42
+ "import torch\n",
43
+ "\n",
44
+ "print(\"[GPU CHECK]\")\n",
45
+ "if torch.cuda.is_available():\n",
46
+ " vram_gb = torch.cuda.get_device_properties(0).total_memory / 1e9\n",
47
+ " print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
48
+ " print(f\"VRAM: {vram_gb:.1f} GB\")\n",
49
+ " VRAM_GB = vram_gb\n",
50
+ "else:\n",
51
+ " print(\"No GPU found\")\n",
52
+ " VRAM_GB = 0"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "markdown",
57
+ "metadata": {},
58
+ "source": [
59
+ "## Step 2: Install Dependencies"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "code",
64
+ "execution_count": null,
65
+ "metadata": {},
66
+ "outputs": [],
67
+ "source": "print(\"Installing dependencies in correct order...\")\nprint(\"Step 1: Upgrade pip\")\n!pip install -q -U pip\nprint(\"Step 2: Install Unsloth FIRST (critical for patching)\")\n!pip install -q unsloth\nprint(\"Step 3: Install PyTorch\")\n!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121\nprint(\"Step 4: Install remaining packages\")\n!pip install -q bitsandbytes peft trl transformers datasets accelerate matplotlib requests huggingface_hub\nprint(\"βœ“ All dependencies installed successfully\")"
68
+ },
69
+ {
70
+ "cell_type": "markdown",
71
+ "metadata": {},
72
+ "source": [
73
+ "## Step 3: Optional - HuggingFace Login\n",
74
+ "\n",
75
+ "Skip this if you just want local training. Uncomment to push to Hub."
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "code",
80
+ "execution_count": null,
81
+ "metadata": {},
82
+ "outputs": [],
83
+ "source": [
84
+ "# Optional: Uncomment to login\n",
85
+ "# from huggingface_hub import login\n",
86
+ "# login()\n",
87
+ "\n",
88
+ "print(\"HF login: SKIPPED (model will save locally)\")"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "markdown",
93
+ "metadata": {},
94
+ "source": [
95
+ "## Step 4: Clone Repository"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "code",
100
+ "execution_count": null,
101
+ "metadata": {},
102
+ "outputs": [],
103
+ "source": "import os\n\nif not os.path.exists('logtriage-env'):\n !git clone https://github.com/rohitdecodes/logtriage-env.git\n os.chdir('logtriage-env')\nelse:\n os.chdir('logtriage-env')\n\nprint(f\"Working dir: {os.getcwd()}\")"
104
+ },
105
+ {
106
+ "cell_type": "markdown",
107
+ "metadata": {},
108
+ "source": [
109
+ "## Step 5: The Problem\n",
110
+ "\n",
111
+ "### Scenario: Production Incident at 2 AM\n",
112
+ "\n",
113
+ "Six services firing alerts:\n",
114
+ "```\n",
115
+ "api-gateway β†’ ERROR: timeout (most visible)\n",
116
+ "auth-service β†’ WARN: connection pool exhausted\n",
117
+ "user-db β†’ ERROR: slow query\n",
118
+ "payment-db β†’ [no logs yet] (ROOT CAUSE - 3 hops upstream)\n",
119
+ "```\n",
120
+ "\n",
121
+ "**Question:** Which service to page first?\n",
122
+ "\n",
123
+ "**Naive Answer:** api-gateway ❌\n",
124
+ "\n",
125
+ "**Correct Answer:** payment-db βœ…\n",
126
+ "\n",
127
+ "### Why It's Hard\n",
128
+ "- Root cause **never logs first**\n",
129
+ "- Symptoms cascade before causes appear\n",
130
+ "- Agent must reason **backward** through dependencies\n",
131
+ "- LLaMA 3.3 70B baseline: only 0.65 accuracy\n",
132
+ "\n",
133
+ "### How We Train\n",
134
+ "GRPO with dense reward shaping forces causal reasoning:\n",
135
+ "- +0.3 for correct root cause\n",
136
+ "- +0.3 for correct escalation\n",
137
+ "- +0.3 for correct fix\n",
138
+ "- **0 for wrong combinations**"
139
+ ]
140
+ },
141
+ {
142
+ "cell_type": "markdown",
143
+ "metadata": {},
144
+ "source": [
145
+ "## Step 6: Intelligent Model Selection"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "code",
150
+ "execution_count": null,
151
+ "metadata": {},
152
+ "outputs": [],
153
+ "source": [
154
+ "print(\"[MODEL SELECTION]\")\n",
155
+ "\n",
156
+ "if VRAM_GB >= 24:\n",
157
+ " model_id = \"Qwen/Qwen2.5-32B-Instruct\"\n",
158
+ " model_size = \"32B (BEST)\"\n",
159
+ " improvement = \"+0.12 to +0.15\"\n",
160
+ " print(f\"βœ“ {VRAM_GB:.1f} GB VRAM\")\n",
161
+ " print(f\"βœ“ Selected: {model_size}\")\nelif VRAM_GB >= 10:\n",
162
+ " model_id = \"Qwen/Qwen2.5-7B-Instruct\"\n",
163
+ " model_size = \"7B (GOOD)\"\n",
164
+ " improvement = \"+0.04 to +0.06\"\n",
165
+ " print(f\"βœ“ {VRAM_GB:.1f} GB VRAM\")\n",
166
+ " print(f\"βœ“ Selected: {model_size}\")\nelse:\n",
167
+ " model_id = \"Qwen/Qwen2.5-3B-Instruct\"\n",
168
+ " model_size = \"3B (FALLBACK)\"\n",
169
+ " improvement = \"+0.015\"\n",
170
+ " print(f\"⚠ {VRAM_GB:.1f} GB VRAM (limited)\")\n",
171
+ " print(f\"⚠ Selected: {model_size}\")\n",
172
+ "\nprint()\nprint(f\"Model: {model_id}\")\nprint(f\"Expected cascading_failure improvement: {improvement}\")"
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "markdown",
177
+ "metadata": {},
178
+ "source": [
179
+ "## Step 7: Launch Training\n",
180
+ "\n",
181
+ "⏱️ This takes ~10-15 minutes"
182
+ ]
183
+ },
184
+ {
185
+ "cell_type": "code",
186
+ "execution_count": null,
187
+ "metadata": {},
188
+ "outputs": [],
189
+ "source": [
190
+ "import subprocess\n",
191
+ "\n",
192
+ "print(\"\\n\" + \"=\"*60)\n",
193
+ "print(\"[START] LogTriageEnv Training\")\n",
194
+ "print(\"=\"*60)\n",
195
+ "print(f\"Model: {model_id}\")\n",
196
+ "print(f\"Episodes: 30 per task (90 total)\")\n",
197
+ "print(f\"Algorithm: GRPO + 4-bit Unsloth\")\n",
198
+ "print(\"=\"*60)\nprint()\n",
199
+ "\n",
200
+ "cmd = [\n",
201
+ " \"python\", \"train.py\",\n",
202
+ " \"--model\", model_id,\n",
203
+ " \"--task\", \"all\",\n",
204
+ " \"--episodes\", \"30\",\n",
205
+ " \"--load_in_4bit\",\n",
206
+ " \"--grpo_max_steps\", \"10\",\n",
207
+ " \"--env_url\", \"https://ogrohit-logtriage-env.hf.space\"\n",
208
+ "]\n",
209
+ "\n",
210
+ "try:\n",
211
+ " subprocess.run(cmd, check=True)\n",
212
+ " print(\"\\n\" + \"=\"*60)\n",
213
+ " print(\"βœ“ TRAINING COMPLETE\")\n",
214
+ " print(\"=\"*60)\nexcept Exception as e:\n",
215
+ " print(f\"Error: {e}\")"
216
+ ]
217
+ },
218
+ {
219
+ "cell_type": "markdown",
220
+ "metadata": {},
221
+ "source": [
222
+ "## Step 8: Analyze Results"
223
+ ]
224
+ },
225
+ {
226
+ "cell_type": "code",
227
+ "execution_count": null,
228
+ "metadata": {},
229
+ "outputs": [],
230
+ "source": [
231
+ "import json\n",
232
+ "import os\n",
233
+ "\n",
234
+ "print(\"\\n\" + \"=\"*60)\n",
235
+ "print(\"RESULTS\")\n",
236
+ "print(\"=\"*60)\nprint()\n",
237
+ "\n",
238
+ "tasks = [\"single_crash\", \"cascading_failure\", \"silent_degradation\"]\n",
239
+ "\n",
240
+ "for task in tasks:\n",
241
+ " checkpoint_file = f\"./phase2_checkpoints/{task}_ep25.json\"\n",
242
+ " \n",
243
+ " if os.path.exists(checkpoint_file):\n",
244
+ " with open(checkpoint_file, 'r') as f:\n",
245
+ " data = json.load(f)\n",
246
+ " \n",
247
+ " rewards = [ep.get('reward', 0) for ep in data.get('episodes', [])]\n",
248
+ " \n",
249
+ " if rewards:\n",
250
+ " first_10 = sum(rewards[:10]) / 10\n",
251
+ " last_10 = sum(rewards[-10:]) / 10\n",
252
+ " improvement = last_10 - first_10\n",
253
+ " \n",
254
+ " symbol = \"βœ“\" if improvement > 0 else \"↓\"\n",
255
+ " task_name = task.replace(\"_\", \" \").title()\n",
256
+ " \n",
257
+ " print(f\"{symbol} {task_name}\")\n",
258
+ " print(f\" First 10 avg: {first_10:+.3f}\")\n",
259
+ " print(f\" Last 10 avg: {last_10:+.3f}\")\n",
260
+ " print(f\" Improvement: {improvement:+.3f}\")\n",
261
+ " print()\n",
262
+ "\nprint(\"=\"*60)\nprint(\"βœ“ Key metric: Cascading Failure improvement\")\nprint(\" (Shows genuine multi-hop causal learning)\")"
263
+ ]
264
+ },
265
+ {
266
+ "cell_type": "markdown",
267
+ "metadata": {},
268
+ "source": [
269
+ "## Step 9: Visualize Reward Curves"
270
+ ]
271
+ },
272
+ {
273
+ "cell_type": "code",
274
+ "execution_count": null,
275
+ "metadata": {},
276
+ "outputs": [],
277
+ "source": [
278
+ "import os\n",
279
+ "\n",
280
+ "if os.path.exists(\"merge_curves.py\"):\n",
281
+ " !python merge_curves.py\n",
282
+ " print(\"βœ“ Curves generated\")\nelse:\n",
283
+ " print(\"[INFO] merge_curves.py not found\")"
284
+ ]
285
+ },
286
+ {
287
+ "cell_type": "code",
288
+ "execution_count": null,
289
+ "metadata": {},
290
+ "outputs": [],
291
+ "source": [
292
+ "import matplotlib.pyplot as plt\n",
293
+ "from PIL import Image\n",
294
+ "import os\n",
295
+ "\n",
296
+ "if os.path.exists(\"reward_curve.png\"):\n",
297
+ " img = Image.open(\"reward_curve.png\")\n",
298
+ " plt.figure(figsize=(14, 8))\n",
299
+ " plt.imshow(img)\n",
300
+ " plt.axis('off')\n",
301
+ " plt.title(\"Training Reward Curves\", fontsize=14, fontweight='bold')\n",
302
+ " plt.tight_layout()\n",
303
+ " plt.show()\nelse:\n",
304
+ " print(\"reward_curve.png not found\")"
305
+ ]
306
+ },
307
+ {
308
+ "cell_type": "markdown",
309
+ "metadata": {},
310
+ "source": [
311
+ "## Step 10: Download Outputs (Colab)"
312
+ ]
313
+ },
314
+ {
315
+ "cell_type": "code",
316
+ "execution_count": null,
317
+ "metadata": {},
318
+ "outputs": [],
319
+ "source": [
320
+ "import os\n",
321
+ "\n",
322
+ "try:\n",
323
+ " from google.colab import files\n",
324
+ " \n",
325
+ " if os.path.exists(\"reward_curve.png\"):\n",
326
+ " print(\"Downloading reward_curve.png...\")\n",
327
+ " files.download(\"reward_curve.png\")\n",
328
+ " print(\"βœ“ Download started\")\nexcept ImportError:\n",
329
+ " print(\"[INFO] Not in Colab. Files saved locally:\")\n",
330
+ " !ls -lh reward_curve.png logtriage-trained/ 2>/dev/null || echo \"Check ./logtriage-trained/\""
331
+ ]
332
+ },
333
+ {
334
+ "cell_type": "markdown",
335
+ "metadata": {},
336
+ "source": "## Summary\n\n### What You Just Did\n1. βœ“ Auto-selected best model for your GPU\n2. βœ“ Trained on 3 incident types (90 episodes total)\n3. βœ“ Generated reward curves\n4. βœ“ Produced trained agent ready for deployment\n\n### Outputs\n- `./logtriage-trained/` - Trained model\n- `reward_curve.png` - Learning curves\n- `./phase2_checkpoints/` - Episode data\n\n### Next Steps\n1. **Push to Hub:** `huggingface-cli login` then uncomment `--push_to_hub`\n2. **Use Locally:** Load from `./logtriage-trained/`\n3. **Deploy:** Integrate into on-call system\n\n### Resources\n- Environment: https://huggingface.co/spaces/OGrohit/logtriage-env\n- GitHub: https://github.com/rohitdecodes/logtriage-env\n- Blog: https://github.com/rohitdecodes/logtriage-env/blob/main/BLOG_POST.md"
337
+ }
338
+ ],
339
+ "metadata": {
340
+ "kernelspec": {
341
+ "display_name": "Python 3",
342
+ "language": "python",
343
+ "name": "python3"
344
+ },
345
+ "language_info": {
346
+ "name": "python",
347
+ "version": "3.10.0"
348
+ }
349
+ },
350
+ "nbformat": 4,
351
+ "nbformat_minor": 5
352
+ }