anugrahhu commited on
Commit
30adf48
·
verified ·
1 Parent(s): 8f805e2

fix: switch trainer Space to vanilla GRPO path

Browse files
Files changed (1) hide show
  1. space/training/app.py +747 -673
space/training/app.py CHANGED
@@ -1,673 +1,747 @@
1
- """FastAPI control panel for the CERNenv trainer Space.
2
-
3
- Endpoints:
4
- GET / → status page (HTML)
5
- GET /status → JSON status of the current training run
6
- GET /metrics → JSON snapshot of reward / success rate
7
- GET /logs → tail of the training log
8
- POST /train → start (or restart) a training run
9
- GET /health → liveness probe
10
-
11
- Designed to run on a Hugging Face Space with `sdk: docker`. Heavy training
12
- work runs in a background thread so the HTTP server stays responsive.
13
- """
14
-
15
- from __future__ import annotations
16
-
17
- import json
18
- import logging
19
- import os
20
- import subprocess
21
- import sys
22
- import threading
23
- import time
24
- from datetime import datetime, timezone
25
- from pathlib import Path
26
- from typing import Any, Dict, Optional
27
-
28
- from fastapi import FastAPI, HTTPException
29
- from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, PlainTextResponse
30
- from fastapi.staticfiles import StaticFiles
31
-
32
-
33
- logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
34
- logger = logging.getLogger(__name__)
35
-
36
-
37
- def _resolve_repo_root() -> Path:
38
- env_root = os.environ.get("CERNENV_ROOT")
39
- candidates = []
40
- if env_root:
41
- candidates.append(Path(env_root))
42
- candidates.extend([
43
- Path("/home/user/app"),
44
- Path(__file__).resolve().parent.parent.parent,
45
- ])
46
- for p in candidates:
47
- try:
48
- if p.exists():
49
- return p.resolve()
50
- except OSError:
51
- continue
52
- return candidates[-1].resolve()
53
-
54
-
55
- REPO_ROOT = _resolve_repo_root()
56
- LOG_DIR = REPO_ROOT / "training" / "runs"
57
- try:
58
- LOG_DIR.mkdir(parents=True, exist_ok=True)
59
- except OSError as exc: # pragma: no cover - read-only filesystem fallback
60
- logger.warning("could not create %s (%s); using /tmp", LOG_DIR, exc)
61
- LOG_DIR = Path("/tmp/cernenv-runs")
62
- LOG_DIR.mkdir(parents=True, exist_ok=True)
63
- LOG_FILE = LOG_DIR / "training.log"
64
- EVIDENCE_DIR = REPO_ROOT / "evidence"
65
- try:
66
- EVIDENCE_DIR.mkdir(parents=True, exist_ok=True)
67
- except OSError: # pragma: no cover
68
- EVIDENCE_DIR = Path("/tmp/cernenv-evidence")
69
- EVIDENCE_DIR.mkdir(parents=True, exist_ok=True)
70
- METRICS_FILE = EVIDENCE_DIR / "before_after_metrics.json"
71
-
72
-
73
- def _env(name: str, default: str) -> str:
74
- return os.environ.get(name, default)
75
-
76
-
77
- def _detect_gpus() -> int:
78
- try:
79
- import torch # type: ignore
80
- if torch.cuda.is_available():
81
- return torch.cuda.device_count()
82
- except Exception:
83
- pass
84
- try:
85
- out = subprocess.run(
86
- ["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"],
87
- capture_output=True, text=True, timeout=5,
88
- )
89
- return len([l for l in out.stdout.splitlines() if l.strip()])
90
- except Exception:
91
- return 0
92
-
93
-
94
- _NUM_GPUS = _detect_gpus()
95
-
96
-
97
- CONFIG = {
98
- "model_name": _env("MODEL_NAME", "unsloth/Qwen2.5-3B-Instruct"),
99
- "difficulty": _env("DIFFICULTY", "easy"),
100
- "curriculum": _env("CURRICULUM", "1") == "1",
101
- "curriculum_promote": float(_env("CURRICULUM_PROMOTE", "0.55")),
102
- "curriculum_demote": float(_env("CURRICULUM_DEMOTE", "0.10")),
103
- "total_episodes": int(_env("TOTAL_EPISODES", "1500")),
104
- "max_steps": int(_env("MAX_STEPS", "18")),
105
- "num_generations": int(_env("NUM_GENERATIONS", "8")),
106
- "checkpoint_eval_steps": int(_env("CHECKPOINT_EVAL_STEPS", "25")),
107
- "checkpoint_eval_episodes": int(_env("CHECKPOINT_EVAL_EPISODES", "8")),
108
- "eval_episodes": int(_env("EVAL_EPISODES", "32")),
109
- "output_dir": _env("OUTPUT_DIR", "runs/unsloth-grpo"),
110
- "evidence_dir": _env("EVIDENCE_DIR", "evidence"),
111
- "num_gpus": int(_env("NUM_GPUS", str(_NUM_GPUS or 1))),
112
- "hf_username": _env("HF_USERNAME", "anugrah55"),
113
- "push_repo": _env(
114
- "PUSH_REPO",
115
- f"{_env('HF_USERNAME', 'anugrah55')}/cernenv-grpo-qwen2.5-3b",
116
- ),
117
- "autostart": _env("AUTOSTART", "0") == "1",
118
- }
119
-
120
-
121
- # ── Run state ────────────────────────────────────────────────────────────
122
-
123
-
124
- class RunState:
125
- def __init__(self) -> None:
126
- self.lock = threading.Lock()
127
- self.thread: Optional[threading.Thread] = None
128
- self.process: Optional[subprocess.Popen] = None
129
- self.status: str = "idle" # idle | running | finished | failed
130
- self.started_at: Optional[str] = None
131
- self.finished_at: Optional[str] = None
132
- self.last_error: Optional[str] = None
133
- self.last_config: Dict[str, Any] = {}
134
-
135
- def to_dict(self) -> Dict[str, Any]:
136
- with self.lock:
137
- return {
138
- "status": self.status,
139
- "started_at": self.started_at,
140
- "finished_at": self.finished_at,
141
- "last_error": self.last_error,
142
- "last_config": self.last_config,
143
- }
144
-
145
-
146
- STATE = RunState()
147
-
148
-
149
- # ── Training pipeline ────────────────────────────────────────────────────
150
-
151
-
152
- def _stream_subprocess(cmd: list[str], log_handle) -> int:
153
- log_handle.write(f"\n$ {' '.join(cmd)}\n")
154
- log_handle.flush()
155
- proc = subprocess.Popen(
156
- cmd,
157
- cwd=str(REPO_ROOT),
158
- stdout=subprocess.PIPE,
159
- stderr=subprocess.STDOUT,
160
- bufsize=1,
161
- universal_newlines=True,
162
- env={**os.environ, "PYTHONPATH": str(REPO_ROOT)},
163
- )
164
- STATE.process = proc
165
- assert proc.stdout is not None
166
- for line in proc.stdout:
167
- log_handle.write(line)
168
- log_handle.flush()
169
- rc = proc.wait()
170
- log_handle.write(f"[exit code {rc}]\n")
171
- log_handle.flush()
172
- STATE.process = None
173
- return rc
174
-
175
-
176
- def _build_training_cmd(config: Dict[str, Any]) -> list[str]:
177
- """Compose the training launcher (single-GPU python or multi-GPU accelerate)."""
178
- base = [
179
- "-m", "training.training_unsloth",
180
- "--model_name", config["model_name"],
181
- "--difficulty", config["difficulty"],
182
- "--total_episodes", str(config["total_episodes"]),
183
- "--max_steps", str(config["max_steps"]),
184
- "--num_generations", str(config["num_generations"]),
185
- "--checkpoint_eval_steps", str(config["checkpoint_eval_steps"]),
186
- "--checkpoint_eval_episodes", str(config["checkpoint_eval_episodes"]),
187
- "--output_dir", config["output_dir"],
188
- "--evidence_dir", config["evidence_dir"],
189
- ]
190
- if config.get("curriculum"):
191
- base.extend([
192
- "--curriculum",
193
- "--curriculum_promote", str(config["curriculum_promote"]),
194
- "--curriculum_demote", str(config["curriculum_demote"]),
195
- ])
196
- n = max(int(config.get("num_gpus", 1)), 1)
197
- if n > 1:
198
- return ["accelerate", "launch", "--num_processes", str(n), "--mixed_precision", "bf16"] + base
199
- return [sys.executable] + base
200
-
201
-
202
- def _push_evidence_to_hub(*, evidence_dir: Path, repo_id: str, log) -> None:
203
- """Upload the entire evidence/ directory to the model repo."""
204
- token = os.environ.get("HF_TOKEN")
205
- if not token:
206
- log.write("\n[skip] HF_TOKEN not set — evidence not pushed\n")
207
- log.flush()
208
- return
209
- try:
210
- from huggingface_hub import HfApi
211
- api = HfApi(token=token)
212
- api.upload_folder(
213
- folder_path=str(evidence_dir),
214
- repo_id=repo_id,
215
- repo_type="model",
216
- path_in_repo="evidence",
217
- commit_message="Upload CERNenv training evidence (curves, evals, plots)",
218
- )
219
- log.write(f"\n[ok] uploaded evidence/ → https://huggingface.co/{repo_id}/tree/main/evidence\n")
220
- log.flush()
221
- except Exception as exc:
222
- log.write(f"\n[warn] evidence push failed: {exc}\n")
223
- log.flush()
224
-
225
-
226
- def _training_pipeline(config: Dict[str, Any]) -> None:
227
- started = datetime.now(timezone.utc).isoformat()
228
- with STATE.lock:
229
- STATE.status = "running"
230
- STATE.started_at = started
231
- STATE.finished_at = None
232
- STATE.last_error = None
233
- STATE.last_config = dict(config)
234
-
235
- evidence_dir = Path(config["evidence_dir"]).resolve()
236
- evidence_dir.mkdir(parents=True, exist_ok=True)
237
-
238
- LOG_FILE.parent.mkdir(parents=True, exist_ok=True)
239
- with open(LOG_FILE, "a") as log:
240
- log.write(f"\n=== Training started {started} ===\n")
241
- log.write(json.dumps(config, indent=2) + "\n")
242
- log.flush()
243
- try:
244
- output_dir = config["output_dir"]
245
- difficulty = config["difficulty"]
246
- max_steps = str(config["max_steps"])
247
- eval_episodes = str(config["eval_episodes"])
248
- model_name = config["model_name"]
249
- push_repo = config["push_repo"]
250
- evidence_str = config["evidence_dir"]
251
- pre_jsonl = f"{evidence_str}/pre_eval.jsonl"
252
- post_jsonl = f"{evidence_str}/post_eval.jsonl"
253
-
254
- log.write("\n--- baseline sanity check (random / heuristic / oracle) ---\n")
255
- log.flush()
256
- for agent in ("random", "heuristic", "oracle"):
257
- _stream_subprocess(
258
- [
259
- sys.executable, "-m", "scripts.run_agent",
260
- "--agent", agent, "--difficulty", difficulty,
261
- "--episodes", "3", "--quiet",
262
- ],
263
- log,
264
- )
265
-
266
- log.write(f"\n--- pre-train evaluation ({eval_episodes} eps) ---\n")
267
- log.flush()
268
- rc = _stream_subprocess(
269
- [
270
- sys.executable, "-m", "training.evaluate",
271
- "--model_name", model_name,
272
- "--difficulty", difficulty,
273
- "--episodes", eval_episodes,
274
- "--max_steps", max_steps,
275
- "--tag", "pre_train",
276
- "--out", pre_jsonl,
277
- ],
278
- log,
279
- )
280
- if rc != 0:
281
- # don't abort — we still want training + post-eval evidence.
282
- log.write(f"\n[warn] pre-train eval failed (rc={rc}); continuing without baseline\n")
283
- log.flush()
284
-
285
- log.write(f"\n--- GRPO training ({config['num_gpus']} GPU process(es)) ---\n")
286
- log.flush()
287
- rc = _stream_subprocess(_build_training_cmd(config), log)
288
- if rc != 0:
289
- raise RuntimeError(f"training failed (rc={rc})")
290
-
291
- # ── LoRA save-and-reload smoke test ─────────────────────
292
- # Hackathon FAQ Q9: "Do not upcast a 4-bit model to 16-bit
293
- # and then merge the LoRA weights naively" — the canonical
294
- # cause of a broken push. Before we burn time on the full
295
- # post-train evaluation (32 eps), do a 2-episode cold-load
296
- # rollout against the saved adapters. If that fails, abort
297
- # immediately so we surface a save problem, not a 30-min
298
- # eval timeout.
299
- log.write(
300
- f"\n--- adapter save/reload smoke test "
301
- f"(loading {output_dir} cold-start, 2 eps) ---\n"
302
- )
303
- log.flush()
304
- rc = _stream_subprocess(
305
- [
306
- sys.executable, "-m", "training.evaluate",
307
- "--model_name", model_name,
308
- "--adapter_dir", output_dir,
309
- "--difficulty", difficulty,
310
- "--episodes", "2",
311
- "--max_steps", max_steps,
312
- "--tag", "smoke",
313
- "--out", f"{evidence_str}/smoke_eval.jsonl",
314
- ],
315
- log,
316
- )
317
- if rc != 0:
318
- raise RuntimeError(
319
- f"adapter smoke test failed (rc={rc}); refusing to push "
320
- f"unloadable adapters to the Hub. Inspect {output_dir} and "
321
- "verify adapter_config.json + adapter_model.safetensors exist."
322
- )
323
-
324
- log.write(f"\n--- post-train evaluation ({eval_episodes} eps) ---\n")
325
- log.flush()
326
- rc = _stream_subprocess(
327
- [
328
- sys.executable, "-m", "training.evaluate",
329
- "--model_name", model_name,
330
- "--adapter_dir", output_dir,
331
- "--difficulty", difficulty,
332
- "--episodes", eval_episodes,
333
- "--max_steps", max_steps,
334
- "--tag", "post_train",
335
- "--out", post_jsonl,
336
- ],
337
- log,
338
- )
339
- if rc != 0:
340
- log.write(f"\n[warn] post-train eval failed (rc={rc}); evidence will be partial\n")
341
- log.flush()
342
-
343
- log.write("\n--- evidence: before/after summary, distribution, trajectories ---\n")
344
- log.flush()
345
- try:
346
- from training.evidence import (
347
- EvidencePaths,
348
- render_before_after,
349
- render_sample_trajectories,
350
- render_training_curve,
351
- render_reward_components,
352
- render_checkpoint_progression,
353
- )
354
- paths = EvidencePaths(root=Path(evidence_str))
355
- paths.ensure()
356
- metrics = render_before_after(
357
- pre_jsonl=Path(pre_jsonl),
358
- post_jsonl=Path(post_jsonl),
359
- summary_png=paths.before_after_summary_png,
360
- distribution_png=paths.reward_distribution_png,
361
- metrics_json=paths.before_after_metrics_json,
362
- )
363
- render_sample_trajectories(
364
- pre_jsonl=Path(pre_jsonl),
365
- post_jsonl=Path(post_jsonl),
366
- md_path=paths.sample_trajectories_md,
367
- )
368
- render_training_curve(paths.training_log_csv, paths.training_curve_png)
369
- render_reward_components(
370
- paths.reward_components_csv, paths.reward_components_png,
371
- )
372
- render_checkpoint_progression(
373
- paths.checkpoint_evals_csv, paths.checkpoint_progression_png,
374
- )
375
- log.write(json.dumps(metrics, indent=2) + "\n")
376
- log.flush()
377
- except Exception as exc:
378
- log.write(f"[warn] evidence rendering failed: {exc}\n")
379
- log.flush()
380
-
381
- if os.environ.get("HF_TOKEN"):
382
- log.write("\n--- push adapters to Hub ---\n")
383
- log.flush()
384
- _stream_subprocess(
385
- [
386
- sys.executable, "-m", "scripts.push_to_hub", "model",
387
- "--adapter_dir", output_dir,
388
- "--repo_id", push_repo,
389
- "--base_model", model_name,
390
- ],
391
- log,
392
- )
393
- _push_evidence_to_hub(
394
- evidence_dir=evidence_dir,
395
- repo_id=push_repo,
396
- log=log,
397
- )
398
- else:
399
- log.write("\n[skip] HF_TOKEN not set — not pushing to Hub\n")
400
- log.flush()
401
-
402
- with STATE.lock:
403
- STATE.status = "finished"
404
- except Exception as exc:
405
- logger.exception("training pipeline failed")
406
- with STATE.lock:
407
- STATE.status = "failed"
408
- STATE.last_error = str(exc)
409
- finally:
410
- finished = datetime.now(timezone.utc).isoformat()
411
- log.write(f"\n=== Training ended {finished} ===\n")
412
- log.flush()
413
- with STATE.lock:
414
- STATE.finished_at = finished
415
-
416
-
417
- def _start_training(config: Dict[str, Any]) -> None:
418
- with STATE.lock:
419
- if STATE.status == "running":
420
- raise RuntimeError("a training run is already in progress")
421
- STATE.thread = threading.Thread(
422
- target=_training_pipeline,
423
- args=(config,),
424
- name="cernenv-trainer",
425
- daemon=True,
426
- )
427
- STATE.thread.start()
428
-
429
-
430
- # ── FastAPI app ──────────────────────────────────────────────────────────
431
-
432
-
433
- app = FastAPI(title="CERNenv Trainer", version="0.1.0")
434
-
435
-
436
- _HTML = """\
437
- <!doctype html>
438
- <html lang=en>
439
- <head>
440
- <meta charset=utf-8>
441
- <title>CERNenv Trainer</title>
442
- <style>
443
- body { font-family: ui-sans-serif, system-ui, sans-serif; margin: 2rem auto;
444
- max-width: 1000px; color:#111; padding: 0 1rem; line-height:1.5 }
445
- h1 { margin-bottom: 0 }
446
- h2 { margin-top: 2rem; border-bottom:1px solid #eee; padding-bottom:.25rem }
447
- .muted { color:#666 }
448
- pre { background:#0e1116; color:#e6edf3; padding:1rem; border-radius:6px;
449
- overflow-x:auto; max-height:40vh; font-size:.85em }
450
- button { font-size:1rem; padding:.6rem 1rem; border-radius:6px; border:1px solid #888;
451
- background:#fff; cursor:pointer; margin-right:.4rem }
452
- .pill { display:inline-block; padding:.1rem .55rem; border-radius:999px;
453
- background:#eef; color:#225; font-size:.85em }
454
- .ok { background:#dfd; color:#272 }
455
- .fail { background:#fdd; color:#822 }
456
- .run { background:#fdf6d8; color:#774 }
457
- table { border-collapse:collapse; margin:.5rem 0 }
458
- td, th { padding:.25rem .8rem .25rem 0; vertical-align: top; text-align:left }
459
- th { color:#444; font-weight:600 }
460
- .grid { display:grid; grid-template-columns:1fr 1fr; gap:1rem }
461
- .card { border:1px solid #e5e7eb; border-radius:8px; padding:.75rem; background:#fafafa }
462
- .card img { max-width:100%; border-radius:4px }
463
- .delta-pos { color:#15803d; font-weight:600 }
464
- .delta-neg { color:#b91c1c; font-weight:600 }
465
- code { background:#f4f4f4; padding:.05rem .35rem; border-radius:4px }
466
- a { color:#1d4ed8 }
467
- </style>
468
- </head>
469
- <body>
470
- <h1>⚛️ CERNenv Trainer</h1>
471
- <p class=muted>GRPO + Unsloth + LoRA on the CERNenv LHC discovery environment. Multi-GPU on Hugging Face Spaces.</p>
472
-
473
- <h2>Run status</h2>
474
- <p>Status: <span id=status class=pill>?</span></p>
475
- <table id=meta></table>
476
- <p>
477
- <button onclick="startRun()">▶ Start training</button>
478
- <button onclick="refresh()">↻ Refresh</button>
479
- <a href="/evidence" target=_blank><button>📁 Evidence index</button></a>
480
- <a href="/docs" target=_blank><button>🛠 API</button></a>
481
- </p>
482
-
483
- <h2>Training-progress evidence</h2>
484
- <p class=muted>Auto-updated as training runs. All artifacts are also saved to <code>evidence/</code> and pushed to the model repo on the Hub.</p>
485
- <div class=grid>
486
- <div class=card><b>Per-step training curve</b><br>
487
- <img id=curve src="/evidence/training_curve.png" onerror="this.style.display='none'">
488
- <div id=curve_missing class=muted style="display:none">(not yet — waiting for first GRPO step)</div>
489
- </div>
490
- <div class=card><b>Reward components (terminal vs shaping)</b><br>
491
- <img id=components src="/evidence/reward_components.png" onerror="this.style.display='none'">
492
- <div id=components_missing class=muted style="display:none">(populated after a few rollouts — watches verifier hacks)</div>
493
- </div>
494
- <div class=card><b>Mid-training checkpoint progression</b><br>
495
- <img id=ckpt src="/evidence/checkpoint_progression.png" onerror="this.style.display='none'">
496
- <div id=ckpt_missing class=muted style="display:none">(not yet — waiting for first checkpoint eval)</div>
497
- </div>
498
- <div class=card><b>Before vs after summary</b><br>
499
- <img id=summary src="/evidence/before_after_summary.png" onerror="this.style.display='none'">
500
- <div id=summary_missing class=muted style="display:none">(generated after post-train eval)</div>
501
- </div>
502
- <div class=card><b>Reward distribution: pre vs post</b><br>
503
- <img id=dist src="/evidence/reward_distribution.png" onerror="this.style.display='none'">
504
- <div id=dist_missing class=muted style="display:none">(generated after post-train eval)</div>
505
- </div>
506
- </div>
507
-
508
- <h2>Before / after metrics</h2>
509
- <table id=metrics_table>
510
- <tr><th>metric</th><th>pre</th><th>post</th><th>Δ</th></tr>
511
- </table>
512
-
513
- <h2>Live logs (tail)</h2>
514
- <pre id=logs>loading…</pre>
515
-
516
- <script>
517
- function fmt(v) {
518
- if (v == null) return '–';
519
- if (typeof v === 'number') return v.toFixed(3);
520
- return v;
521
- }
522
- function fmtDelta(d) {
523
- if (d == null || isNaN(d)) return '–';
524
- const sign = d >= 0 ? '+' : '';
525
- const cls = d >= 0 ? 'delta-pos' : 'delta-neg';
526
- return `<span class="${cls}">${sign}${d.toFixed(3)}</span>`;
527
- }
528
-
529
- async function refresh() {
530
- // status
531
- const s = await fetch('/status').then(r => r.json());
532
- const pill = document.getElementById('status');
533
- pill.textContent = s.status;
534
- pill.className = 'pill ' + ({idle:'',running:'run',finished:'ok',failed:'fail'}[s.status] || '');
535
-
536
- const meta = document.getElementById('meta');
537
- meta.innerHTML = '';
538
- const obj = {
539
- started_at: s.started_at, finished_at: s.finished_at, error: s.last_error,
540
- ...(s.last_config || {}),
541
- };
542
- for (const [k, v] of Object.entries(obj)) {
543
- if (v == null || v === '') continue;
544
- const tr = document.createElement('tr');
545
- tr.innerHTML = `<td><b>${k}</b></td><td><code>${v}</code></td>`;
546
- meta.appendChild(tr);
547
- }
548
-
549
- // metrics
550
- const m = await fetch('/metrics').then(r => r.json()).catch(() => ({pre:null, post:null}));
551
- const tbody = document.getElementById('metrics_table');
552
- tbody.innerHTML = '<tr><th>metric</th><th>pre</th><th>post</th><th>Δ</th></tr>';
553
- const fields = ['mean_reward', 'success_rate', 'mass_acc', 'channel_acc', 'median_reward'];
554
- for (const f of fields) {
555
- const pre = m.pre && m.pre[f];
556
- const post = m.post && m.post[f];
557
- const delta = m.delta && m.delta[f];
558
- const tr = document.createElement('tr');
559
- tr.innerHTML = `<td><code>${f}</code></td><td>${fmt(pre)}</td><td>${fmt(post)}</td><td>${fmtDelta(delta)}</td>`;
560
- tbody.appendChild(tr);
561
- }
562
-
563
- // bust caches on plots
564
- const bust = '?t=' + Date.now();
565
- for (const [imgId, missingId] of [
566
- ['curve', 'curve_missing'],
567
- ['components', 'components_missing'],
568
- ['ckpt', 'ckpt_missing'],
569
- ['summary', 'summary_missing'],
570
- ['dist', 'dist_missing'],
571
- ]) {
572
- const img = document.getElementById(imgId);
573
- const miss = document.getElementById(missingId);
574
- const baseSrc = img.getAttribute('src').split('?')[0];
575
- const probe = new Image();
576
- probe.onload = () => { img.src = baseSrc + bust; img.style.display=''; miss.style.display='none'; };
577
- probe.onerror = () => { img.style.display='none'; miss.style.display=''; };
578
- probe.src = baseSrc + bust;
579
- }
580
-
581
- const logs = await fetch('/logs?tail=200').then(r => r.text());
582
- document.getElementById('logs').textContent = logs || '(no logs yet)';
583
- }
584
- async function startRun() {
585
- const r = await fetch('/train', {method:'POST'});
586
- if (!r.ok) alert((await r.json()).detail || 'failed');
587
- setTimeout(refresh, 500);
588
- }
589
- refresh();
590
- setInterval(refresh, 5000);
591
- </script>
592
- </body>
593
- </html>
594
- """
595
-
596
-
597
- @app.get("/", response_class=HTMLResponse)
598
- def index() -> HTMLResponse:
599
- return HTMLResponse(_HTML)
600
-
601
-
602
- @app.get("/health")
603
- def health() -> Dict[str, str]:
604
- return {"status": "ok"}
605
-
606
-
607
- @app.get("/status")
608
- def status() -> JSONResponse:
609
- return JSONResponse(STATE.to_dict())
610
-
611
-
612
- @app.get("/metrics")
613
- def metrics() -> JSONResponse:
614
- if METRICS_FILE.exists():
615
- try:
616
- return JSONResponse(json.loads(METRICS_FILE.read_text()))
617
- except Exception:
618
- return JSONResponse({"error": "metrics file unreadable"}, status_code=500)
619
- return JSONResponse({"pre": None, "post": None, "delta": None})
620
-
621
-
622
- @app.get("/evidence")
623
- def evidence_index() -> JSONResponse:
624
- """List every evidence artifact currently on disk."""
625
- files = []
626
- if EVIDENCE_DIR.exists():
627
- for p in sorted(EVIDENCE_DIR.iterdir()):
628
- if p.is_file():
629
- files.append({
630
- "name": p.name,
631
- "size": p.stat().st_size,
632
- "url": f"/evidence/{p.name}",
633
- })
634
- return JSONResponse({"dir": str(EVIDENCE_DIR), "files": files})
635
-
636
-
637
- @app.get("/evidence/{name}")
638
- def evidence_file(name: str):
639
- """Serve a single evidence artifact (PNG/CSV/JSON/MD) by filename."""
640
- if "/" in name or ".." in name:
641
- raise HTTPException(status_code=400, detail="invalid name")
642
- target = EVIDENCE_DIR / name
643
- if not target.exists() or not target.is_file():
644
- raise HTTPException(status_code=404, detail=f"{name} not found")
645
- return FileResponse(target)
646
-
647
-
648
- @app.get("/logs", response_class=PlainTextResponse)
649
- def logs(tail: int = 400) -> PlainTextResponse:
650
- if not LOG_FILE.exists():
651
- return PlainTextResponse("")
652
- text = LOG_FILE.read_text()
653
- lines = text.splitlines()
654
- return PlainTextResponse("\n".join(lines[-max(tail, 1):]))
655
-
656
-
657
- @app.post("/train")
658
- def train() -> JSONResponse:
659
- try:
660
- _start_training(dict(CONFIG))
661
- except RuntimeError as exc:
662
- raise HTTPException(status_code=409, detail=str(exc))
663
- return JSONResponse({"status": "started", "config": CONFIG})
664
-
665
-
666
- @app.on_event("startup")
667
- def _maybe_autostart() -> None:
668
- if CONFIG["autostart"]:
669
- try:
670
- _start_training(dict(CONFIG))
671
- logger.info("autostarted training run")
672
- except RuntimeError as exc:
673
- logger.warning("autostart skipped: %s", exc)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI control panel for the CERNenv trainer Space.
2
+
3
+ Endpoints:
4
+ GET / → status page (HTML)
5
+ GET /status → JSON status of the current training run
6
+ GET /metrics → JSON snapshot of reward / success rate
7
+ GET /logs → tail of the training log
8
+ POST /train → start (or restart) a training run
9
+ GET /health → liveness probe
10
+
11
+ Designed to run on a Hugging Face Space with `sdk: docker`. Heavy training
12
+ work runs in a background thread so the HTTP server stays responsive.
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import json
18
+ import logging
19
+ import os
20
+ import subprocess
21
+ import sys
22
+ import threading
23
+ import time
24
+ from datetime import datetime, timezone
25
+ from pathlib import Path
26
+ from typing import Any, Dict, Optional
27
+
28
+ from fastapi import FastAPI, HTTPException
29
+ from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, PlainTextResponse
30
+ from fastapi.staticfiles import StaticFiles
31
+
32
+
33
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ def _resolve_repo_root() -> Path:
38
+ env_root = os.environ.get("CERNENV_ROOT")
39
+ candidates = []
40
+ if env_root:
41
+ candidates.append(Path(env_root))
42
+ candidates.extend([
43
+ Path("/home/user/app"),
44
+ Path(__file__).resolve().parent.parent.parent,
45
+ ])
46
+ for p in candidates:
47
+ try:
48
+ if p.exists():
49
+ return p.resolve()
50
+ except OSError:
51
+ continue
52
+ return candidates[-1].resolve()
53
+
54
+
55
+ REPO_ROOT = _resolve_repo_root()
56
+ LOG_DIR = REPO_ROOT / "training" / "runs"
57
+ try:
58
+ LOG_DIR.mkdir(parents=True, exist_ok=True)
59
+ except OSError as exc: # pragma: no cover - read-only filesystem fallback
60
+ logger.warning("could not create %s (%s); using /tmp", LOG_DIR, exc)
61
+ LOG_DIR = Path("/tmp/cernenv-runs")
62
+ LOG_DIR.mkdir(parents=True, exist_ok=True)
63
+ LOG_FILE = LOG_DIR / "training.log"
64
+ EVIDENCE_DIR = REPO_ROOT / "evidence"
65
+ try:
66
+ EVIDENCE_DIR.mkdir(parents=True, exist_ok=True)
67
+ except OSError: # pragma: no cover
68
+ EVIDENCE_DIR = Path("/tmp/cernenv-evidence")
69
+ EVIDENCE_DIR.mkdir(parents=True, exist_ok=True)
70
+ METRICS_FILE = EVIDENCE_DIR / "before_after_metrics.json"
71
+
72
+
73
+ def _env(name: str, default: str) -> str:
74
+ return os.environ.get(name, default)
75
+
76
+
77
+ def _detect_gpus() -> int:
78
+ try:
79
+ import torch # type: ignore
80
+ if torch.cuda.is_available():
81
+ return torch.cuda.device_count()
82
+ except Exception:
83
+ pass
84
+ try:
85
+ out = subprocess.run(
86
+ ["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"],
87
+ capture_output=True, text=True, timeout=5,
88
+ )
89
+ return len([l for l in out.stdout.splitlines() if l.strip()])
90
+ except Exception:
91
+ return 0
92
+
93
+
94
+ _NUM_GPUS = _detect_gpus()
95
+
96
+
97
+ CONFIG = {
98
+ "training_backend": _env("TRAINING_BACKEND", "vanilla"),
99
+ "model_name": _env("MODEL_NAME", "HuggingFaceTB/SmolLM2-360M-Instruct"),
100
+ "difficulty": _env("DIFFICULTY", "easy"),
101
+ "curriculum": _env("CURRICULUM", "0") == "1",
102
+ "curriculum_promote": float(_env("CURRICULUM_PROMOTE", "0.55")),
103
+ "curriculum_demote": float(_env("CURRICULUM_DEMOTE", "0.10")),
104
+ "total_episodes": int(_env("TOTAL_EPISODES", "120")),
105
+ "max_steps": int(_env("MAX_STEPS", "12")),
106
+ "num_generations": int(_env("NUM_GENERATIONS", "4")),
107
+ "checkpoint_eval_steps": int(_env("CHECKPOINT_EVAL_STEPS", "25")),
108
+ "checkpoint_eval_episodes": int(_env("CHECKPOINT_EVAL_EPISODES", "8")),
109
+ "eval_episodes": int(_env("EVAL_EPISODES", "8")),
110
+ "output_dir": _env("OUTPUT_DIR", "runs/vanilla-grpo"),
111
+ "evidence_dir": _env("EVIDENCE_DIR", "evidence"),
112
+ "num_gpus": int(_env("NUM_GPUS", "1")),
113
+ "hf_username": _env("HF_USERNAME", "anugrahhu"),
114
+ "push_repo": _env(
115
+ "PUSH_REPO",
116
+ f"{_env('HF_USERNAME', 'anugrahhu')}/cernenv-grpo-smollm2-360m",
117
+ ),
118
+ "autostart": _env("AUTOSTART", "0") == "1",
119
+ }
120
+
121
+
122
+ # ── Run state ────────────────────────────────────────────────────────────
123
+
124
+
125
+ class RunState:
126
+ def __init__(self) -> None:
127
+ self.lock = threading.Lock()
128
+ self.thread: Optional[threading.Thread] = None
129
+ self.process: Optional[subprocess.Popen] = None
130
+ self.status: str = "idle" # idle | running | finished | failed
131
+ self.started_at: Optional[str] = None
132
+ self.finished_at: Optional[str] = None
133
+ self.last_error: Optional[str] = None
134
+ self.last_config: Dict[str, Any] = {}
135
+
136
+ def to_dict(self) -> Dict[str, Any]:
137
+ with self.lock:
138
+ return {
139
+ "status": self.status,
140
+ "started_at": self.started_at,
141
+ "finished_at": self.finished_at,
142
+ "last_error": self.last_error,
143
+ "last_config": self.last_config,
144
+ }
145
+
146
+
147
+ STATE = RunState()
148
+
149
+
150
+ # ── Training pipeline ────────────────────────────────────────────────────
151
+
152
+
153
+ def _stream_subprocess(cmd: list[str], log_handle) -> int:
154
+ log_handle.write(f"\n$ {' '.join(cmd)}\n")
155
+ log_handle.flush()
156
+ proc = subprocess.Popen(
157
+ cmd,
158
+ cwd=str(REPO_ROOT),
159
+ stdout=subprocess.PIPE,
160
+ stderr=subprocess.STDOUT,
161
+ bufsize=1,
162
+ universal_newlines=True,
163
+ env={**os.environ, "PYTHONPATH": str(REPO_ROOT)},
164
+ )
165
+ STATE.process = proc
166
+ assert proc.stdout is not None
167
+ for line in proc.stdout:
168
+ log_handle.write(line)
169
+ log_handle.flush()
170
+ rc = proc.wait()
171
+ log_handle.write(f"[exit code {rc}]\n")
172
+ log_handle.flush()
173
+ STATE.process = None
174
+ return rc
175
+
176
+
177
+ def _build_training_cmd(config: Dict[str, Any]) -> list[str]:
178
+ """Compose the selected training launcher."""
179
+ backend = str(config.get("training_backend", "vanilla")).lower()
180
+ if backend == "vanilla":
181
+ python_bin = "/usr/local/bin/python" if Path("/usr/local/bin/python").exists() else sys.executable
182
+ return [
183
+ python_bin, "-m", "training.training_script",
184
+ "--model_name", config["model_name"],
185
+ "--difficulty", config["difficulty"],
186
+ "--total_episodes", str(config["total_episodes"]),
187
+ "--max_steps", str(config["max_steps"]),
188
+ "--num_generations", str(config["num_generations"]),
189
+ "--output_dir", config["output_dir"],
190
+ ]
191
+
192
+ if backend != "unsloth":
193
+ raise ValueError(f"unknown TRAINING_BACKEND={backend!r}")
194
+
195
+ base = [
196
+ "-m", "training.training_unsloth",
197
+ "--model_name", config["model_name"],
198
+ "--difficulty", config["difficulty"],
199
+ "--total_episodes", str(config["total_episodes"]),
200
+ "--max_steps", str(config["max_steps"]),
201
+ "--num_generations", str(config["num_generations"]),
202
+ "--checkpoint_eval_steps", str(config["checkpoint_eval_steps"]),
203
+ "--checkpoint_eval_episodes", str(config["checkpoint_eval_episodes"]),
204
+ "--output_dir", config["output_dir"],
205
+ "--evidence_dir", config["evidence_dir"],
206
+ ]
207
+ if config.get("curriculum"):
208
+ base.extend([
209
+ "--curriculum",
210
+ "--curriculum_promote", str(config["curriculum_promote"]),
211
+ "--curriculum_demote", str(config["curriculum_demote"]),
212
+ ])
213
+ n = max(int(config.get("num_gpus", 1)), 1)
214
+ if n > 1:
215
+ return ["accelerate", "launch", "--num_processes", str(n), "--mixed_precision", "bf16"] + base
216
+ return [sys.executable] + base
217
+
218
+
219
+ def _build_eval_cmd(
220
+ *,
221
+ model_name: str,
222
+ difficulty: str,
223
+ episodes: str,
224
+ max_steps: str,
225
+ tag: str,
226
+ out: str,
227
+ backend: str,
228
+ adapter_dir: Optional[str] = None,
229
+ ) -> list[str]:
230
+ cmd = [
231
+ sys.executable, "-m", "training.evaluate",
232
+ "--model_name", model_name,
233
+ "--difficulty", difficulty,
234
+ "--episodes", episodes,
235
+ "--max_steps", max_steps,
236
+ "--tag", tag,
237
+ "--out", out,
238
+ ]
239
+ if adapter_dir:
240
+ cmd.extend(["--adapter_dir", adapter_dir])
241
+ if backend == "vanilla":
242
+ cmd.append("--no_unsloth")
243
+ return cmd
244
+
245
+
246
+ def _push_model_folder_to_hub(*, output_dir: Path, repo_id: str, base_model: str, log) -> None:
247
+ """Upload a vanilla transformers model directory to the Hub."""
248
+ token = os.environ.get("HF_TOKEN")
249
+ if not token:
250
+ log.write("\n[skip] HF_TOKEN not set — model not pushed\n")
251
+ log.flush()
252
+ return
253
+ try:
254
+ from huggingface_hub import HfApi
255
+ api = HfApi(token=token)
256
+ api.create_repo(repo_id=repo_id, repo_type="model", exist_ok=True)
257
+ api.upload_folder(
258
+ folder_path=str(output_dir),
259
+ repo_id=repo_id,
260
+ repo_type="model",
261
+ commit_message=f"Upload vanilla GRPO model based on {base_model}",
262
+ )
263
+ log.write(f"\n[ok] uploaded model → https://huggingface.co/{repo_id}\n")
264
+ log.flush()
265
+ except Exception as exc:
266
+ log.write(f"\n[warn] model push failed: {exc}\n")
267
+ log.flush()
268
+
269
+
270
+ def _push_evidence_to_hub(*, evidence_dir: Path, repo_id: str, log) -> None:
271
+ """Upload the entire evidence/ directory to the model repo."""
272
+ token = os.environ.get("HF_TOKEN")
273
+ if not token:
274
+ log.write("\n[skip] HF_TOKEN not set — evidence not pushed\n")
275
+ log.flush()
276
+ return
277
+ try:
278
+ from huggingface_hub import HfApi
279
+ api = HfApi(token=token)
280
+ api.upload_folder(
281
+ folder_path=str(evidence_dir),
282
+ repo_id=repo_id,
283
+ repo_type="model",
284
+ path_in_repo="evidence",
285
+ commit_message="Upload CERNenv training evidence (curves, evals, plots)",
286
+ )
287
+ log.write(f"\n[ok] uploaded evidence/ → https://huggingface.co/{repo_id}/tree/main/evidence\n")
288
+ log.flush()
289
+ except Exception as exc:
290
+ log.write(f"\n[warn] evidence push failed: {exc}\n")
291
+ log.flush()
292
+
293
+
294
+ def _training_pipeline(config: Dict[str, Any]) -> None:
295
+ started = datetime.now(timezone.utc).isoformat()
296
+ with STATE.lock:
297
+ STATE.status = "running"
298
+ STATE.started_at = started
299
+ STATE.finished_at = None
300
+ STATE.last_error = None
301
+ STATE.last_config = dict(config)
302
+
303
+ evidence_dir = Path(config["evidence_dir"]).resolve()
304
+ evidence_dir.mkdir(parents=True, exist_ok=True)
305
+
306
+ LOG_FILE.parent.mkdir(parents=True, exist_ok=True)
307
+ with open(LOG_FILE, "a") as log:
308
+ log.write(f"\n=== Training started {started} ===\n")
309
+ log.write(json.dumps(config, indent=2) + "\n")
310
+ log.flush()
311
+ try:
312
+ output_dir = config["output_dir"]
313
+ difficulty = config["difficulty"]
314
+ max_steps = str(config["max_steps"])
315
+ eval_episodes = str(config["eval_episodes"])
316
+ model_name = config["model_name"]
317
+ push_repo = config["push_repo"]
318
+ evidence_str = config["evidence_dir"]
319
+ backend = str(config.get("training_backend", "vanilla")).lower()
320
+ pre_jsonl = f"{evidence_str}/pre_eval.jsonl"
321
+ post_jsonl = f"{evidence_str}/post_eval.jsonl"
322
+
323
+ log.write("\n--- baseline sanity check (random / heuristic / oracle) ---\n")
324
+ log.flush()
325
+ for agent in ("random", "heuristic", "oracle"):
326
+ _stream_subprocess(
327
+ [
328
+ sys.executable, "-m", "scripts.run_agent",
329
+ "--agent", agent, "--difficulty", difficulty,
330
+ "--episodes", "3", "--quiet",
331
+ ],
332
+ log,
333
+ )
334
+
335
+ log.write(f"\n--- pre-train evaluation ({eval_episodes} eps) ---\n")
336
+ log.flush()
337
+ rc = _stream_subprocess(
338
+ _build_eval_cmd(
339
+ model_name=model_name,
340
+ difficulty=difficulty,
341
+ episodes=eval_episodes,
342
+ max_steps=max_steps,
343
+ tag="pre_train",
344
+ out=pre_jsonl,
345
+ backend=backend,
346
+ ),
347
+ log,
348
+ )
349
+ if rc != 0:
350
+ # don't abort — we still want training + post-eval evidence.
351
+ log.write(f"\n[warn] pre-train eval failed (rc={rc}); continuing without baseline\n")
352
+ log.flush()
353
+
354
+ log.write(f"\n--- GRPO training ({backend}, {config['num_gpus']} GPU process(es)) ---\n")
355
+ log.flush()
356
+ rc = _stream_subprocess(_build_training_cmd(config), log)
357
+ if rc != 0:
358
+ raise RuntimeError(f"training failed (rc={rc})")
359
+
360
+ # Cold-load the trained artifact before burning time on post-eval.
361
+ log.write(
362
+ f"\n--- trained artifact smoke test "
363
+ f"(loading {output_dir} cold-start, 2 eps) ---\n"
364
+ )
365
+ log.flush()
366
+ smoke_model = output_dir if backend == "vanilla" else model_name
367
+ smoke_adapter = None if backend == "vanilla" else output_dir
368
+ rc = _stream_subprocess(
369
+ _build_eval_cmd(
370
+ model_name=smoke_model,
371
+ adapter_dir=smoke_adapter,
372
+ difficulty=difficulty,
373
+ episodes="2",
374
+ max_steps=max_steps,
375
+ tag="smoke",
376
+ out=f"{evidence_str}/smoke_eval.jsonl",
377
+ backend=backend,
378
+ ),
379
+ log,
380
+ )
381
+ if rc != 0:
382
+ raise RuntimeError(
383
+ f"trained artifact smoke test failed (rc={rc}); refusing to push "
384
+ f"unloadable output to the Hub. Inspect {output_dir}."
385
+ )
386
+
387
+ log.write(f"\n--- post-train evaluation ({eval_episodes} eps) ---\n")
388
+ log.flush()
389
+ post_model = output_dir if backend == "vanilla" else model_name
390
+ post_adapter = None if backend == "vanilla" else output_dir
391
+ rc = _stream_subprocess(
392
+ _build_eval_cmd(
393
+ model_name=post_model,
394
+ adapter_dir=post_adapter,
395
+ difficulty=difficulty,
396
+ episodes=eval_episodes,
397
+ max_steps=max_steps,
398
+ tag="post_train",
399
+ out=post_jsonl,
400
+ backend=backend,
401
+ ),
402
+ log,
403
+ )
404
+ if rc != 0:
405
+ log.write(f"\n[warn] post-train eval failed (rc={rc}); evidence will be partial\n")
406
+ log.flush()
407
+
408
+ log.write("\n--- evidence: before/after summary, distribution, trajectories ---\n")
409
+ log.flush()
410
+ try:
411
+ from training.evidence import (
412
+ EvidencePaths,
413
+ render_before_after,
414
+ render_sample_trajectories,
415
+ render_training_curve,
416
+ render_reward_components,
417
+ render_checkpoint_progression,
418
+ )
419
+ paths = EvidencePaths(root=Path(evidence_str))
420
+ paths.ensure()
421
+ metrics = render_before_after(
422
+ pre_jsonl=Path(pre_jsonl),
423
+ post_jsonl=Path(post_jsonl),
424
+ summary_png=paths.before_after_summary_png,
425
+ distribution_png=paths.reward_distribution_png,
426
+ metrics_json=paths.before_after_metrics_json,
427
+ )
428
+ render_sample_trajectories(
429
+ pre_jsonl=Path(pre_jsonl),
430
+ post_jsonl=Path(post_jsonl),
431
+ md_path=paths.sample_trajectories_md,
432
+ )
433
+ render_training_curve(paths.training_log_csv, paths.training_curve_png)
434
+ render_reward_components(
435
+ paths.reward_components_csv, paths.reward_components_png,
436
+ )
437
+ render_checkpoint_progression(
438
+ paths.checkpoint_evals_csv, paths.checkpoint_progression_png,
439
+ )
440
+ log.write(json.dumps(metrics, indent=2) + "\n")
441
+ log.flush()
442
+ except Exception as exc:
443
+ log.write(f"[warn] evidence rendering failed: {exc}\n")
444
+ log.flush()
445
+
446
+ if os.environ.get("HF_TOKEN"):
447
+ if backend == "vanilla":
448
+ log.write("\n--- push vanilla model to Hub ---\n")
449
+ log.flush()
450
+ _push_model_folder_to_hub(
451
+ output_dir=Path(output_dir),
452
+ repo_id=push_repo,
453
+ base_model=model_name,
454
+ log=log,
455
+ )
456
+ else:
457
+ log.write("\n--- push adapters to Hub ---\n")
458
+ log.flush()
459
+ _stream_subprocess(
460
+ [
461
+ sys.executable, "-m", "scripts.push_to_hub", "model",
462
+ "--adapter_dir", output_dir,
463
+ "--repo_id", push_repo,
464
+ "--base_model", model_name,
465
+ ],
466
+ log,
467
+ )
468
+ _push_evidence_to_hub(
469
+ evidence_dir=evidence_dir,
470
+ repo_id=push_repo,
471
+ log=log,
472
+ )
473
+ else:
474
+ log.write("\n[skip] HF_TOKEN not set — not pushing to Hub\n")
475
+ log.flush()
476
+ with STATE.lock:
477
+ STATE.status = "finished"
478
+ except Exception as exc:
479
+ logger.exception("training pipeline failed")
480
+ with STATE.lock:
481
+ STATE.status = "failed"
482
+ STATE.last_error = str(exc)
483
+ finally:
484
+ finished = datetime.now(timezone.utc).isoformat()
485
+ log.write(f"\n=== Training ended {finished} ===\n")
486
+ log.flush()
487
+ with STATE.lock:
488
+ STATE.finished_at = finished
489
+
490
+
491
+ def _start_training(config: Dict[str, Any]) -> None:
492
+ with STATE.lock:
493
+ if STATE.status == "running":
494
+ raise RuntimeError("a training run is already in progress")
495
+ STATE.thread = threading.Thread(
496
+ target=_training_pipeline,
497
+ args=(config,),
498
+ name="cernenv-trainer",
499
+ daemon=True,
500
+ )
501
+ STATE.thread.start()
502
+
503
+
504
+ # ── FastAPI app ──────────────────────────────────────────────────────────
505
+
506
+
507
+ app = FastAPI(title="CERNenv Trainer", version="0.1.0")
508
+
509
+
510
+ _HTML = """\
511
+ <!doctype html>
512
+ <html lang=en>
513
+ <head>
514
+ <meta charset=utf-8>
515
+ <title>CERNenv Trainer</title>
516
+ <style>
517
+ body { font-family: ui-sans-serif, system-ui, sans-serif; margin: 2rem auto;
518
+ max-width: 1000px; color:#111; padding: 0 1rem; line-height:1.5 }
519
+ h1 { margin-bottom: 0 }
520
+ h2 { margin-top: 2rem; border-bottom:1px solid #eee; padding-bottom:.25rem }
521
+ .muted { color:#666 }
522
+ pre { background:#0e1116; color:#e6edf3; padding:1rem; border-radius:6px;
523
+ overflow-x:auto; max-height:40vh; font-size:.85em }
524
+ button { font-size:1rem; padding:.6rem 1rem; border-radius:6px; border:1px solid #888;
525
+ background:#fff; cursor:pointer; margin-right:.4rem }
526
+ .pill { display:inline-block; padding:.1rem .55rem; border-radius:999px;
527
+ background:#eef; color:#225; font-size:.85em }
528
+ .ok { background:#dfd; color:#272 }
529
+ .fail { background:#fdd; color:#822 }
530
+ .run { background:#fdf6d8; color:#774 }
531
+ table { border-collapse:collapse; margin:.5rem 0 }
532
+ td, th { padding:.25rem .8rem .25rem 0; vertical-align: top; text-align:left }
533
+ th { color:#444; font-weight:600 }
534
+ .grid { display:grid; grid-template-columns:1fr 1fr; gap:1rem }
535
+ .card { border:1px solid #e5e7eb; border-radius:8px; padding:.75rem; background:#fafafa }
536
+ .card img { max-width:100%; border-radius:4px }
537
+ .delta-pos { color:#15803d; font-weight:600 }
538
+ .delta-neg { color:#b91c1c; font-weight:600 }
539
+ code { background:#f4f4f4; padding:.05rem .35rem; border-radius:4px }
540
+ a { color:#1d4ed8 }
541
+ </style>
542
+ </head>
543
+ <body>
544
+ <h1>⚛️ CERNenv Trainer</h1>
545
+ <p class=muted>GRPO + Unsloth + LoRA on the CERNenv LHC discovery environment. Multi-GPU on Hugging Face Spaces.</p>
546
+
547
+ <h2>Run status</h2>
548
+ <p>Status: <span id=status class=pill>?</span></p>
549
+ <table id=meta></table>
550
+ <p>
551
+ <button onclick="startRun()">▶ Start training</button>
552
+ <button onclick="refresh()">↻ Refresh</button>
553
+ <a href="/evidence" target=_blank><button>📁 Evidence index</button></a>
554
+ <a href="/docs" target=_blank><button>🛠 API</button></a>
555
+ </p>
556
+
557
+ <h2>Training-progress evidence</h2>
558
+ <p class=muted>Auto-updated as training runs. All artifacts are also saved to <code>evidence/</code> and pushed to the model repo on the Hub.</p>
559
+ <div class=grid>
560
+ <div class=card><b>Per-step training curve</b><br>
561
+ <img id=curve src="/evidence/training_curve.png" onerror="this.style.display='none'">
562
+ <div id=curve_missing class=muted style="display:none">(not yet — waiting for first GRPO step)</div>
563
+ </div>
564
+ <div class=card><b>Reward components (terminal vs shaping)</b><br>
565
+ <img id=components src="/evidence/reward_components.png" onerror="this.style.display='none'">
566
+ <div id=components_missing class=muted style="display:none">(populated after a few rollouts — watches verifier hacks)</div>
567
+ </div>
568
+ <div class=card><b>Mid-training checkpoint progression</b><br>
569
+ <img id=ckpt src="/evidence/checkpoint_progression.png" onerror="this.style.display='none'">
570
+ <div id=ckpt_missing class=muted style="display:none">(not yet — waiting for first checkpoint eval)</div>
571
+ </div>
572
+ <div class=card><b>Before vs after summary</b><br>
573
+ <img id=summary src="/evidence/before_after_summary.png" onerror="this.style.display='none'">
574
+ <div id=summary_missing class=muted style="display:none">(generated after post-train eval)</div>
575
+ </div>
576
+ <div class=card><b>Reward distribution: pre vs post</b><br>
577
+ <img id=dist src="/evidence/reward_distribution.png" onerror="this.style.display='none'">
578
+ <div id=dist_missing class=muted style="display:none">(generated after post-train eval)</div>
579
+ </div>
580
+ </div>
581
+
582
+ <h2>Before / after metrics</h2>
583
+ <table id=metrics_table>
584
+ <tr><th>metric</th><th>pre</th><th>post</th><th>Δ</th></tr>
585
+ </table>
586
+
587
+ <h2>Live logs (tail)</h2>
588
+ <pre id=logs>loading…</pre>
589
+
590
+ <script>
591
+ function fmt(v) {
592
+ if (v == null) return '–';
593
+ if (typeof v === 'number') return v.toFixed(3);
594
+ return v;
595
+ }
596
+ function fmtDelta(d) {
597
+ if (d == null || isNaN(d)) return '–';
598
+ const sign = d >= 0 ? '+' : '';
599
+ const cls = d >= 0 ? 'delta-pos' : 'delta-neg';
600
+ return `<span class="${cls}">${sign}${d.toFixed(3)}</span>`;
601
+ }
602
+
603
+ async function refresh() {
604
+ // status
605
+ const s = await fetch('/status').then(r => r.json());
606
+ const pill = document.getElementById('status');
607
+ pill.textContent = s.status;
608
+ pill.className = 'pill ' + ({idle:'',running:'run',finished:'ok',failed:'fail'}[s.status] || '');
609
+
610
+ const meta = document.getElementById('meta');
611
+ meta.innerHTML = '';
612
+ const obj = {
613
+ started_at: s.started_at, finished_at: s.finished_at, error: s.last_error,
614
+ ...(s.last_config || {}),
615
+ };
616
+ for (const [k, v] of Object.entries(obj)) {
617
+ if (v == null || v === '') continue;
618
+ const tr = document.createElement('tr');
619
+ tr.innerHTML = `<td><b>${k}</b></td><td><code>${v}</code></td>`;
620
+ meta.appendChild(tr);
621
+ }
622
+
623
+ // metrics
624
+ const m = await fetch('/metrics').then(r => r.json()).catch(() => ({pre:null, post:null}));
625
+ const tbody = document.getElementById('metrics_table');
626
+ tbody.innerHTML = '<tr><th>metric</th><th>pre</th><th>post</th><th>Δ</th></tr>';
627
+ const fields = ['mean_reward', 'success_rate', 'mass_acc', 'channel_acc', 'median_reward'];
628
+ for (const f of fields) {
629
+ const pre = m.pre && m.pre[f];
630
+ const post = m.post && m.post[f];
631
+ const delta = m.delta && m.delta[f];
632
+ const tr = document.createElement('tr');
633
+ tr.innerHTML = `<td><code>${f}</code></td><td>${fmt(pre)}</td><td>${fmt(post)}</td><td>${fmtDelta(delta)}</td>`;
634
+ tbody.appendChild(tr);
635
+ }
636
+
637
+ // bust caches on plots
638
+ const bust = '?t=' + Date.now();
639
+ for (const [imgId, missingId] of [
640
+ ['curve', 'curve_missing'],
641
+ ['components', 'components_missing'],
642
+ ['ckpt', 'ckpt_missing'],
643
+ ['summary', 'summary_missing'],
644
+ ['dist', 'dist_missing'],
645
+ ]) {
646
+ const img = document.getElementById(imgId);
647
+ const miss = document.getElementById(missingId);
648
+ const baseSrc = img.getAttribute('src').split('?')[0];
649
+ const probe = new Image();
650
+ probe.onload = () => { img.src = baseSrc + bust; img.style.display=''; miss.style.display='none'; };
651
+ probe.onerror = () => { img.style.display='none'; miss.style.display=''; };
652
+ probe.src = baseSrc + bust;
653
+ }
654
+
655
+ const logs = await fetch('/logs?tail=200').then(r => r.text());
656
+ document.getElementById('logs').textContent = logs || '(no logs yet)';
657
+ }
658
+ async function startRun() {
659
+ const r = await fetch('/train', {method:'POST'});
660
+ if (!r.ok) alert((await r.json()).detail || 'failed');
661
+ setTimeout(refresh, 500);
662
+ }
663
+ refresh();
664
+ setInterval(refresh, 5000);
665
+ </script>
666
+ </body>
667
+ </html>
668
+ """
669
+
670
+
671
+ @app.get("/", response_class=HTMLResponse)
672
+ def index() -> HTMLResponse:
673
+ return HTMLResponse(_HTML)
674
+
675
+
676
+ @app.get("/health")
677
+ def health() -> Dict[str, str]:
678
+ return {"status": "ok"}
679
+
680
+
681
+ @app.get("/status")
682
+ def status() -> JSONResponse:
683
+ return JSONResponse(STATE.to_dict())
684
+
685
+
686
+ @app.get("/metrics")
687
+ def metrics() -> JSONResponse:
688
+ if METRICS_FILE.exists():
689
+ try:
690
+ return JSONResponse(json.loads(METRICS_FILE.read_text()))
691
+ except Exception:
692
+ return JSONResponse({"error": "metrics file unreadable"}, status_code=500)
693
+ return JSONResponse({"pre": None, "post": None, "delta": None})
694
+
695
+
696
+ @app.get("/evidence")
697
+ def evidence_index() -> JSONResponse:
698
+ """List every evidence artifact currently on disk."""
699
+ files = []
700
+ if EVIDENCE_DIR.exists():
701
+ for p in sorted(EVIDENCE_DIR.iterdir()):
702
+ if p.is_file():
703
+ files.append({
704
+ "name": p.name,
705
+ "size": p.stat().st_size,
706
+ "url": f"/evidence/{p.name}",
707
+ })
708
+ return JSONResponse({"dir": str(EVIDENCE_DIR), "files": files})
709
+
710
+
711
+ @app.get("/evidence/{name}")
712
+ def evidence_file(name: str):
713
+ """Serve a single evidence artifact (PNG/CSV/JSON/MD) by filename."""
714
+ if "/" in name or ".." in name:
715
+ raise HTTPException(status_code=400, detail="invalid name")
716
+ target = EVIDENCE_DIR / name
717
+ if not target.exists() or not target.is_file():
718
+ raise HTTPException(status_code=404, detail=f"{name} not found")
719
+ return FileResponse(target)
720
+
721
+
722
+ @app.get("/logs", response_class=PlainTextResponse)
723
+ def logs(tail: int = 400) -> PlainTextResponse:
724
+ if not LOG_FILE.exists():
725
+ return PlainTextResponse("")
726
+ text = LOG_FILE.read_text()
727
+ lines = text.splitlines()
728
+ return PlainTextResponse("\n".join(lines[-max(tail, 1):]))
729
+
730
+
731
+ @app.post("/train")
732
+ def train() -> JSONResponse:
733
+ try:
734
+ _start_training(dict(CONFIG))
735
+ except RuntimeError as exc:
736
+ raise HTTPException(status_code=409, detail=str(exc))
737
+ return JSONResponse({"status": "started", "config": CONFIG})
738
+
739
+
740
+ @app.on_event("startup")
741
+ def _maybe_autostart() -> None:
742
+ if CONFIG["autostart"]:
743
+ try:
744
+ _start_training(dict(CONFIG))
745
+ logger.info("autostarted training run")
746
+ except RuntimeError as exc:
747
+ logger.warning("autostart skipped: %s", exc)