Radianis commited on
Commit
ac62897
·
1 Parent(s): ff99487

Add notebook-based Easy and Ablation runners

Browse files
Files changed (4) hide show
  1. README.md +12 -4
  2. _demo_runtime.py +0 -1441
  3. app.py +1222 -242
  4. requirements.txt +5 -4
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: LBW Guard Direct Runner
3
  emoji: 🚀
4
  colorFrom: green
5
  colorTo: blue
@@ -8,6 +8,7 @@ python_version: "3.12"
8
  app_file: app.py
9
  suggested_hardware: t4-medium
10
  models:
 
11
  - Qwen/Qwen2.5-0.5B
12
  datasets:
13
  - Salesforce/wikitext
@@ -22,10 +23,17 @@ Copyright (c) Qluon Inc. All rights reserved.
22
 
23
  Provided for Learn-By-Wire Guard evaluation and customer testing under the applicable Qluon license terms.
24
 
25
- # LBW Guard Direct Runner
26
 
27
- This Space runs a compact AdamW vs `lbw_guard` WikiText LoRA smoke test directly on Hugging Face Spaces.
28
 
29
- Use GPU hardware for meaningful runtime. CPU can load the app, but model training may be slow or fail on memory.
 
 
 
 
 
 
 
30
 
31
  The app writes run artifacts to the Space working directory. Add persistent storage if you need outputs to survive Space restarts.
 
1
  ---
2
+ title: LBW Guard Colab Tests
3
  emoji: 🚀
4
  colorFrom: green
5
  colorTo: blue
 
8
  app_file: app.py
9
  suggested_hardware: t4-medium
10
  models:
11
+ - TinyLlama/TinyLlama_v1.1
12
  - Qwen/Qwen2.5-0.5B
13
  datasets:
14
  - Salesforce/wikitext
 
23
 
24
  Provided for Learn-By-Wire Guard evaluation and customer testing under the applicable Qluon license terms.
25
 
26
+ # LBW Guard Colab Tests
27
 
28
+ This private Space runs notebook-faithful Hugging Face versions of:
29
 
30
+ - `LBW_Guard_Easy_Test_COLAB.ipynb`
31
+ - `LBW_Guard_Ablation_Test_COLAB.ipynb`
32
+
33
+ It installs `lbw-guard` from PyPI and does not vendor the local `lbw/` source folder.
34
+
35
+ Paper: https://arxiv.org/abs/2605.19008
36
+
37
+ Use GPU hardware for meaningful runtime. CPU can load the app, but training is intentionally capped to tiny smoke settings.
38
 
39
  The app writes run artifacts to the Space working directory. Add persistent storage if you need outputs to survive Space restarts.
_demo_runtime.py DELETED
@@ -1,1441 +0,0 @@
1
- #!/usr/bin/env python3
2
- """Standalone customer demo runtime decoupled from the internal benchmark harness."""
3
-
4
- from __future__ import annotations
5
-
6
- import importlib.util
7
- import json
8
- import math
9
- import os
10
- import random
11
- import shutil
12
- import statistics
13
- import subprocess
14
- import sys
15
- import tempfile
16
- import time
17
- import warnings
18
- from array import array
19
- from collections import Counter, deque
20
- from dataclasses import dataclass, field
21
- from pathlib import Path
22
- from typing import Any, Dict, List, Optional, Sequence
23
-
24
- AUTOMATION_DIR = Path(__file__).resolve().parent
25
- TEST_ROOT = AUTOMATION_DIR.parent
26
- LBW_ROOT = TEST_ROOT.parent
27
-
28
- _hf_home = os.environ.setdefault("HF_HOME", str((LBW_ROOT / ".hf_cache").resolve()))
29
- os.environ.setdefault("HF_DATASETS_CACHE", str((Path(_hf_home) / "datasets").resolve()))
30
- os.environ.setdefault("TRANSFORMERS_CACHE", str((Path(_hf_home) / "transformers").resolve()))
31
- # Prevent background safetensors conversion threads from calling the HF
32
- # conversion Space during demo loads. This keeps repo-hosted .bin models
33
- # usable without noisy thread crashes when the service/network misbehaves.
34
- os.environ.setdefault("DISABLE_SAFETENSORS_CONVERSION", "1")
35
-
36
- _wandb_home = LBW_ROOT / ".wandb"
37
- os.environ.setdefault("WANDB_DIR", str(_wandb_home.resolve()))
38
- os.environ.setdefault("WANDB_CACHE_DIR", str((_wandb_home / "cache").resolve()))
39
- os.environ.setdefault("WANDB_CONFIG_DIR", str((_wandb_home / "config").resolve()))
40
-
41
- import torch
42
- from datasets import load_dataset
43
- from peft import LoraConfig, TaskType, get_peft_model
44
- from transformers import AutoModelForCausalLM, AutoTokenizer, get_cosine_schedule_with_warmup
45
-
46
- try:
47
- import wandb
48
- except Exception:
49
- wandb = None
50
-
51
- try:
52
- from lbw import Guard
53
- except Exception:
54
- Guard = None
55
-
56
-
57
- @dataclass
58
- class BenchmarkConfig:
59
- model_name: str = "Qwen/Qwen2.5-3B"
60
- device: str = "cuda"
61
- enable_lora: bool = True
62
- lora_r: int = 16
63
- lora_alpha: int = 64
64
- lora_dropout: float = 0.05
65
- lora_target_modules: List[str] = field(
66
- default_factory=lambda: [
67
- "q_proj",
68
- "k_proj",
69
- "v_proj",
70
- "o_proj",
71
- "gate_proj",
72
- "up_proj",
73
- "down_proj",
74
- ]
75
- )
76
- seq_len: int = 256
77
- batch_size: int = 2
78
- grad_accum: int = 2
79
- max_steps: int = 100
80
- warmup_steps: int = 50
81
- eval_every: int = 50
82
- eval_batches: int = 50
83
- schedule_mode: str = "all_cosine"
84
- max_chars: int = 4_000_000
85
- eval_chars: int = 1_000_000
86
- full_wikitext_train: bool = False
87
- full_wikitext_eval: bool = False
88
- full_validation_ppl: bool = False
89
- lr: float = 5e-4
90
- weight_decay: float = 0.01
91
- betas: tuple[float, float] = (0.9, 0.999)
92
- lbw_stats_freq: int = 50
93
- lbw_stress_th: float = 1.1 #1.1
94
- lbw_spike_th: float = 1.5 #1.5
95
- lbw_rec_fast: float = 0.05 #0.01
96
- lbw_ema_decay: float = 0.95
97
- use_wandb: bool = False
98
- use_lbwgov: bool = False
99
- print_all_metrics: bool = False
100
- lbwgov_experiment_name: str = "LBW-Customer-Demo"
101
- output_dir: str = str((AUTOMATION_DIR / "demo_outputs").resolve())
102
- enable_benchmarks: bool = False
103
- use_lm_eval: bool = False
104
- lm_eval_ppl: bool = False
105
- lm_eval_ppl_task: str = "wikitext_103_raw"
106
- lm_eval_ppl_limit: Optional[float] = None
107
- lm_eval_acc: bool = False
108
- lm_eval_acc_tasks: str = "mmlu,arc_challenge"
109
- lm_eval_acc_limit: Optional[float] = None
110
- lm_eval_mmlu_limit: Optional[float] = None
111
- lm_eval_arc_challenge_limit: Optional[float] = None
112
- lm_eval_mmlu_fewshot: int = 5
113
- lm_eval_arc_challenge_fewshot: int = 25
114
- lm_eval_batch_size: str = "1"
115
-
116
-
117
- @dataclass
118
- class ChunkedTokens:
119
- input_ids: torch.Tensor
120
- labels: torch.Tensor
121
- split: str = ""
122
- char_count: int = 0
123
- cap_chars: Optional[int] = None
124
-
125
-
126
- def _demo_log(config: Optional["BenchmarkConfig"], message: str) -> None:
127
- if config is not None and not bool(getattr(config, "print_all_metrics", False)):
128
- return
129
- print(f"[DemoRuntime] {message}", flush=True)
130
-
131
-
132
- def _safe_float(value: Any) -> Optional[float]:
133
- if value is None:
134
- return None
135
- try:
136
- out = float(value)
137
- except Exception:
138
- return None
139
- if math.isnan(out) or math.isinf(out):
140
- return None
141
- return out
142
-
143
-
144
- @dataclass
145
- class GovernanceMetricConfig:
146
- ema_decay: float = 0.95
147
- short_window: int = 20
148
- long_window: int = 100
149
- intervention_eps: float = 1e-3
150
- stable_ratio_low: float = 0.90
151
- stable_ratio_high: float = 1.10
152
- unstable_ratio_high: float = 1.35
153
- stagnation_ratio_low: float = 0.70
154
- oscillation_flip_high: float = 0.30
155
-
156
-
157
- class GovernanceMetricsTracker:
158
- def __init__(self, cfg: Optional[GovernanceMetricConfig] = None):
159
- self.cfg = cfg or GovernanceMetricConfig()
160
- self.grad_rms_history = deque(maxlen=self.cfg.long_window)
161
- self.ratio_history = deque(maxlen=self.cfg.long_window)
162
- self.regime_history = deque(maxlen=self.cfg.long_window)
163
- self.flip_rate_history = deque(maxlen=self.cfg.long_window)
164
- self.loss_ema: Optional[float] = None
165
- self.grad_norm_ema: Optional[float] = None
166
- self.grad_rms_ema: Optional[float] = None
167
- self.prev_loss: Optional[float] = None
168
- self.prev_prev_loss: Optional[float] = None
169
- self.prev_regime: Optional[str] = None
170
- self.prev_grad_sign_summary: Optional[tuple[int, int, int]] = None
171
- self.total_logged_steps = 0
172
- self.intervention_count = 0
173
- self.regime_switch_count = 0
174
- self.stress_entries = 0
175
- self.total_control_energy = 0.0
176
- self.max_control_energy = 0.0
177
- self.open_recovery_start_step: Optional[int] = None
178
- self.completed_recovery_latencies: List[int] = []
179
- self.best_eval_loss: Optional[float] = None
180
- self.best_eval_perplexity: Optional[float] = None
181
-
182
- def _safe_float(self, value: Any, default: float = 0.0) -> float:
183
- try:
184
- out = float(value)
185
- if math.isfinite(out):
186
- return out
187
- except Exception:
188
- pass
189
- return float(default)
190
-
191
- def _ema_update(self, old: Optional[float], new: Optional[float]) -> Optional[float]:
192
- if new is None:
193
- return old
194
- if old is None:
195
- return float(new)
196
- d = self.cfg.ema_decay
197
- return float(d * old + (1.0 - d) * float(new))
198
-
199
- def _std(self, values) -> float:
200
- vals = [float(v) for v in values if v is not None and math.isfinite(float(v))]
201
- if len(vals) < 2:
202
- return 0.0
203
- return float(statistics.pstdev(vals))
204
-
205
- def _mean(self, values) -> float:
206
- vals = [float(v) for v in values if v is not None and math.isfinite(float(v))]
207
- if not vals:
208
- return 0.0
209
- return float(sum(vals) / len(vals))
210
-
211
- def _compute_grad_sign_flip_rate(self, params) -> float:
212
- pos = 0
213
- neg = 0
214
- zero = 0
215
- for param in params:
216
- if param.grad is None:
217
- continue
218
- grad = param.grad.detach()
219
- if grad.numel() == 0:
220
- continue
221
- signs = torch.sign(grad)
222
- pos += int((signs > 0).sum().item())
223
- neg += int((signs < 0).sum().item())
224
- zero += int((signs == 0).sum().item())
225
- summary = (pos, neg, zero)
226
- if self.prev_grad_sign_summary is None:
227
- self.prev_grad_sign_summary = summary
228
- return 0.0
229
- prev_pos, prev_neg, _ = self.prev_grad_sign_summary
230
- prev_total = max(prev_pos + prev_neg, 1)
231
- cur_total = max(pos + neg, 1)
232
- flip_rate = abs((pos / cur_total) - (prev_pos / prev_total))
233
- self.prev_grad_sign_summary = summary
234
- return float(flip_rate)
235
-
236
- def classify_regime(self, *, ratio: float, flip_rate: float, loss_velocity: float, scale: float, stress_mode: str) -> str:
237
- ratio = self._safe_float(ratio, 1.0)
238
- flip_rate = self._safe_float(flip_rate, 0.0)
239
- loss_velocity = self._safe_float(loss_velocity, 0.0)
240
- scale = self._safe_float(scale, 1.0)
241
- stress_mode = str(stress_mode or "unknown").lower()
242
- if "stress" in stress_mode or ratio >= self.cfg.unstable_ratio_high:
243
- return "unstable"
244
- if flip_rate >= self.cfg.oscillation_flip_high:
245
- return "oscillatory"
246
- if ratio <= self.cfg.stagnation_ratio_low and abs(loss_velocity) < 1e-4:
247
- return "stagnation"
248
- if (self.cfg.stable_ratio_low <= ratio <= self.cfg.stable_ratio_high) and abs(scale - 1.0) <= 0.05:
249
- return "stable"
250
- return "transitional"
251
-
252
- def update_step(
253
- self,
254
- *,
255
- step: int,
256
- trainable_params,
257
- loss_val: float,
258
- grad_norm: float,
259
- grad_rms: float,
260
- ema_grad_rms: float,
261
- ratio: float,
262
- scale: float,
263
- stress_mode: str,
264
- current_lr: float,
265
- ) -> Dict[str, float]:
266
- self.total_logged_steps += 1
267
- loss_val = self._safe_float(loss_val)
268
- grad_norm = self._safe_float(grad_norm)
269
- grad_rms = self._safe_float(grad_rms)
270
- ema_grad_rms = self._safe_float(ema_grad_rms)
271
- ratio = self._safe_float(ratio, 1.0)
272
- scale = self._safe_float(scale, 1.0)
273
- current_lr = self._safe_float(current_lr)
274
-
275
- self.loss_ema = self._ema_update(self.loss_ema, loss_val)
276
- self.grad_norm_ema = self._ema_update(self.grad_norm_ema, grad_norm)
277
- self.grad_rms_ema = self._ema_update(self.grad_rms_ema, grad_rms)
278
-
279
- loss_velocity = 0.0 if self.prev_loss is None else (loss_val - self.prev_loss)
280
- loss_acceleration = 0.0 if self.prev_loss is None or self.prev_prev_loss is None else (
281
- loss_val - 2.0 * self.prev_loss + self.prev_prev_loss
282
- )
283
- flip_rate = self._compute_grad_sign_flip_rate(trainable_params)
284
- grad_deviation = 0.0
285
- if ema_grad_rms > 0:
286
- grad_deviation = (grad_rms - ema_grad_rms) / max(ema_grad_rms, 1e-12)
287
- control_energy = abs(scale - 1.0)
288
- intervention_flag = 1.0 if control_energy > self.cfg.intervention_eps else 0.0
289
- if intervention_flag > 0:
290
- self.intervention_count += 1
291
- self.total_control_energy += control_energy
292
- self.max_control_energy = max(self.max_control_energy, control_energy)
293
-
294
- regime = self.classify_regime(
295
- ratio=ratio,
296
- flip_rate=flip_rate,
297
- loss_velocity=loss_velocity,
298
- scale=scale,
299
- stress_mode=stress_mode,
300
- )
301
- if self.prev_regime is not None and regime != self.prev_regime:
302
- self.regime_switch_count += 1
303
- if regime in {"unstable", "oscillatory"} and self.open_recovery_start_step is None:
304
- self.open_recovery_start_step = step
305
- self.stress_entries += 1
306
- if regime == "stable" and self.open_recovery_start_step is not None:
307
- self.completed_recovery_latencies.append(step - self.open_recovery_start_step)
308
- self.open_recovery_start_step = None
309
-
310
- self.grad_rms_history.append(grad_rms)
311
- self.ratio_history.append(ratio)
312
- self.regime_history.append(regime)
313
- self.flip_rate_history.append(flip_rate)
314
-
315
- short_grad_std = self._std(list(self.grad_rms_history)[-self.cfg.short_window :])
316
- long_grad_std = self._std(self.grad_rms_history)
317
- grad_variance_reduction = 0.0
318
- if long_grad_std > 1e-12:
319
- grad_variance_reduction = 1.0 - (short_grad_std / long_grad_std)
320
-
321
- out = {
322
- "obs/grad_direction_change_rate": flip_rate,
323
- "obs/loss_velocity": loss_velocity,
324
- "obs/loss_acceleration": loss_acceleration,
325
- "obs/update_magnitude_proxy": scale * current_lr,
326
- "state/grad_ratio": ratio,
327
- "state/grad_deviation_score": grad_deviation,
328
- "state/regime_stable": 1.0 if regime == "stable" else 0.0,
329
- "state/regime_unstable": 1.0 if regime == "unstable" else 0.0,
330
- "state/regime_oscillatory": 1.0 if regime == "oscillatory" else 0.0,
331
- "state/regime_stagnation": 1.0 if regime == "stagnation" else 0.0,
332
- "state/regime_transitional": 1.0 if regime == "transitional" else 0.0,
333
- "control/action_strength": control_energy,
334
- "control/intervention_flag": intervention_flag,
335
- "loop/intervention_rate": self.intervention_count / max(self.total_logged_steps, 1),
336
- "loop/regime_switch_count": float(self.regime_switch_count),
337
- "loop/avg_control_energy": self.total_control_energy / max(self.total_logged_steps, 1),
338
- "loop/max_control_energy": self.max_control_energy,
339
- "effect/grad_variance_reduction": grad_variance_reduction,
340
- "effect/recovery_latency_mean_steps": self._mean(self.completed_recovery_latencies),
341
- "effect/recovery_events": float(len(self.completed_recovery_latencies)),
342
- }
343
-
344
- self.prev_prev_loss = self.prev_loss
345
- self.prev_loss = loss_val
346
- self.prev_regime = regime
347
- return out
348
-
349
- def update_eval(
350
- self,
351
- *,
352
- eval_loss: Optional[float] = None,
353
- eval_perplexity: Optional[float] = None,
354
- avg_tps_wall: Optional[float] = None,
355
- ) -> Dict[str, float]:
356
- out: Dict[str, float] = {}
357
- if eval_loss is not None:
358
- if self.best_eval_loss is None or eval_loss < self.best_eval_loss:
359
- self.best_eval_loss = float(eval_loss)
360
- out["effect/best_eval_loss"] = float(self.best_eval_loss)
361
- out["effect/eval_loss_gap_to_best"] = float(eval_loss - self.best_eval_loss)
362
- if eval_perplexity is not None:
363
- if self.best_eval_perplexity is None or eval_perplexity < self.best_eval_perplexity:
364
- self.best_eval_perplexity = float(eval_perplexity)
365
- out["effect/best_eval_perplexity"] = float(self.best_eval_perplexity)
366
- out["effect/eval_perplexity_gap_to_best"] = float(eval_perplexity - self.best_eval_perplexity)
367
- if avg_tps_wall is not None:
368
- out["effect/efficiency_wall_tps"] = float(avg_tps_wall)
369
- return out
370
-
371
- def snapshot(self) -> Dict[str, Any]:
372
- return {
373
- "total_logged_steps": self.total_logged_steps,
374
- "intervention_count": self.intervention_count,
375
- "regime_switch_count": self.regime_switch_count,
376
- "stress_entries": self.stress_entries,
377
- "avg_control_energy": self.total_control_energy / max(self.total_logged_steps, 1),
378
- "max_control_energy": self.max_control_energy,
379
- "completed_recovery_latencies": list(self.completed_recovery_latencies),
380
- "recent_regimes": list(self.regime_history),
381
- "best_eval_loss": self.best_eval_loss,
382
- "best_eval_perplexity": self.best_eval_perplexity,
383
- }
384
-
385
-
386
- def _wants_cuda(device: Optional[str] = None) -> bool:
387
- return str(device or "").strip().lower().startswith("cuda")
388
-
389
-
390
- def set_seed(seed: int, device: Optional[str] = None):
391
- random.seed(seed)
392
- torch.manual_seed(seed)
393
- if _wants_cuda(device) and torch.cuda.is_available():
394
- torch.cuda.manual_seed_all(seed)
395
-
396
-
397
- def normalize_optimizer_name(name: str) -> str:
398
- aliases = {
399
- "guard": "lbw_guard",
400
- "lbw": "lbw_guard",
401
- "lbw-guard": "lbw_guard",
402
- "adam": "adamw",
403
- }
404
- key = str(name or "").strip().lower()
405
- return aliases.get(key, key)
406
-
407
-
408
- def check_optimizer_support(name: str, device: Optional[str] = None) -> tuple[bool, str]:
409
- normalized = normalize_optimizer_name(name)
410
- if normalized not in {"adamw", "lbw_guard"}:
411
- return False, "Standalone customer demo runtime supports only adamw and lbw_guard."
412
- if normalized == "lbw_guard" and Guard is None:
413
- return False, "LBW_Guard package not found. Install the standard LBW_Guard package in the active Python environment."
414
- if normalized == "lbw_guard" and _wants_cuda(device) and torch.cuda.is_available() and int(torch.cuda.device_count()) > 1:
415
- return False, "lbw_guard supports at most 1 visible GPU. Restrict CUDA_VISIBLE_DEVICES to one GPU."
416
- return True, ""
417
-
418
-
419
- def _hf_offline_mode() -> bool:
420
- return (
421
- os.environ.get("HF_HUB_OFFLINE", "").lower() in {"1", "true", "yes"}
422
- or os.environ.get("TRANSFORMERS_OFFLINE", "").lower() in {"1", "true", "yes"}
423
- )
424
-
425
-
426
- def _hf_pretrained_kwargs() -> Dict[str, Any]:
427
- kwargs: Dict[str, Any] = {"trust_remote_code": True}
428
- if _hf_offline_mode():
429
- kwargs["local_files_only"] = True
430
- return kwargs
431
-
432
-
433
- def _resolve_model_dtype(device: torch.device):
434
- return torch.bfloat16 if device.type == "cuda" else torch.float32
435
-
436
-
437
- def _resolve_model_device_map(device: torch.device):
438
- if device.type != "cuda":
439
- return None
440
- return {"": (device.index if device.index is not None else 0)}
441
-
442
-
443
- def _load_tokenizer_and_model(model_name: str, device: torch.device):
444
- hf_kwargs = _hf_pretrained_kwargs()
445
- tokenizer = AutoTokenizer.from_pretrained(model_name, **hf_kwargs)
446
- if tokenizer.pad_token is None:
447
- tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token
448
- model_kwargs: Dict[str, Any] = {
449
- "torch_dtype": _resolve_model_dtype(device),
450
- **hf_kwargs,
451
- }
452
- device_map = _resolve_model_device_map(device)
453
- if device_map is not None:
454
- model_kwargs["device_map"] = device_map
455
- model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
456
- if device_map is None:
457
- model.to(device)
458
- return tokenizer, model
459
-
460
-
461
- def build_wikitext_chunks(
462
- tokenizer,
463
- seq_len: int,
464
- max_chars: Optional[int],
465
- split: str,
466
- *,
467
- config: Optional[BenchmarkConfig] = None,
468
- ) -> ChunkedTokens:
469
- cap = None if max_chars is None else int(max_chars)
470
- _demo_log(
471
- config,
472
- f"Preparing WikiText split='{split}'"
473
- + (f" with char cap {cap:,}" if cap is not None else " with full split"),
474
- )
475
- ds = load_dataset("wikitext", "wikitext-103-raw-v1", split=split)
476
- token_buf = array("I")
477
- chars_used = 0
478
- first_piece = True
479
- rows_used = 0
480
- next_report_chars = 500_000 if cap is None else max(250_000, cap // 4)
481
- for row in ds:
482
- text = str(row.get("text", "") or "")
483
- if not text.strip():
484
- continue
485
- piece = text if first_piece else (" " + text)
486
- if cap is not None:
487
- remain = cap - chars_used
488
- if remain <= 0:
489
- break
490
- if len(piece) > remain:
491
- piece = piece[:remain]
492
- chars_used += len(piece)
493
- first_piece = False
494
- rows_used += 1
495
- ids_piece = tokenizer(piece, add_special_tokens=False)["input_ids"]
496
- if ids_piece:
497
- token_buf.extend(ids_piece)
498
- if config is not None and bool(getattr(config, "print_all_metrics", False)) and chars_used >= next_report_chars:
499
- target = f"/{cap:,}" if cap is not None else ""
500
- _demo_log(config, f"Tokenizing split='{split}': {chars_used:,}{target} chars")
501
- next_report_chars += 500_000 if cap is None else max(250_000, cap // 4)
502
- if cap is not None and chars_used >= cap:
503
- break
504
- if len(token_buf) == 0:
505
- raise RuntimeError(f"No tokens built for split '{split}'.")
506
- ids = torch.tensor(token_buf, dtype=torch.long)
507
- n = ids.numel() // seq_len
508
- if n <= 0:
509
- raise RuntimeError(f"Not enough tokens for seq_len {seq_len}. Increase max_chars.")
510
- ids = ids[: n * seq_len].view(n, seq_len).contiguous()
511
- _demo_log(
512
- config,
513
- f"Prepared split='{split}': {chars_used:,} chars across {rows_used:,} rows -> {ids.size(0):,} sequences of len {seq_len}",
514
- )
515
- return ChunkedTokens(input_ids=ids, labels=ids, split=split, char_count=int(chars_used), cap_chars=cap)
516
-
517
-
518
- def batch_iter(chunks: ChunkedTokens, batch_size: int, device: torch.device):
519
- x, y = chunks.input_ids, chunks.labels
520
- i, n = 0, x.size(0)
521
- while True:
522
- if i + batch_size > n:
523
- i = 0
524
- yield (
525
- x[i : i + batch_size].to(device, non_blocking=True),
526
- y[i : i + batch_size].to(device, non_blocking=True),
527
- )
528
- i += batch_size
529
-
530
-
531
- def evaluate_perplexity(model, eval_chunks: ChunkedTokens, config: BenchmarkConfig, device: torch.device, *, full_pass: bool = False):
532
- model.eval()
533
- total_nll = 0.0
534
- total_tokens = 0
535
- with torch.no_grad():
536
- if full_pass:
537
- x, y = eval_chunks.input_ids, eval_chunks.labels
538
- n = x.size(0)
539
- for i in range(0, n, config.batch_size):
540
- ex = x[i : i + config.batch_size].to(device, non_blocking=True)
541
- ey = y[i : i + config.batch_size].to(device, non_blocking=True)
542
- with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=(device.type == "cuda")):
543
- out = model(input_ids=ex, labels=ey)
544
- tok = int(ey[:, 1:].numel())
545
- if tok > 0 and math.isfinite(float(out.loss.item())):
546
- total_nll += float(out.loss.item()) * float(tok)
547
- total_tokens += tok
548
- else:
549
- eval_iter = batch_iter(eval_chunks, config.batch_size, device)
550
- for _ in range(max(1, int(config.eval_batches))):
551
- ex, ey = next(eval_iter)
552
- with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=(device.type == "cuda")):
553
- out = model(input_ids=ex, labels=ey)
554
- tok = int(ey[:, 1:].numel())
555
- if tok > 0 and math.isfinite(float(out.loss.item())):
556
- total_nll += float(out.loss.item()) * float(tok)
557
- total_tokens += tok
558
- if total_tokens <= 0:
559
- raise RuntimeError("Validation produced no batches.")
560
- avg_eval_loss = float(total_nll / float(total_tokens))
561
- return avg_eval_loss, math.exp(avg_eval_loss)
562
-
563
-
564
- def _compute_grad_rms(params) -> float:
565
- sq_sum = 0.0
566
- count = 0
567
- for param in params:
568
- if param.grad is None:
569
- continue
570
- grad = param.grad.detach()
571
- sq_sum += float(torch.sum(grad.float() * grad.float()).item())
572
- count += int(grad.numel())
573
- if count <= 0:
574
- return 0.0
575
- return float(math.sqrt(sq_sum / float(count)))
576
-
577
-
578
- def get_optimizer(name: str, model, config: BenchmarkConfig):
579
- params = [p for p in model.parameters() if p.requires_grad]
580
- normalized = normalize_optimizer_name(name)
581
- if normalized == "adamw":
582
- return torch.optim.AdamW(
583
- params,
584
- lr=config.lr,
585
- betas=config.betas,
586
- weight_decay=config.weight_decay,
587
- )
588
- if normalized == "lbw_guard":
589
- if Guard is None:
590
- raise RuntimeError("LBW Guard package not available for lbw_guard.")
591
- return Guard(
592
- params,
593
- lr=config.lr,
594
- betas=config.betas,
595
- weight_decay=config.weight_decay,
596
- mode="eval",
597
- auto_enabled=True,
598
- stats_freq=int(config.lbw_stats_freq),
599
- stress_threshold=config.lbw_stress_th,
600
- spike_threshold=config.lbw_spike_th,
601
- recovery_fast=config.lbw_rec_fast,
602
- ema_decay=config.lbw_ema_decay,
603
- use_max_rms=True,
604
- )
605
- raise ValueError(f"Unsupported optimizer for standalone demo runtime: {name}")
606
-
607
-
608
- class _SchedulerProxyOptimizer(torch.optim.Optimizer):
609
- def __init__(self, param_groups: List[Dict[str, Any]]):
610
- proxy_groups = []
611
- for group in list(param_groups or []):
612
- proxy_groups.append({"params": list(group.get("params", []) or []), "lr": float(group.get("lr", 0.0))})
613
- super().__init__(proxy_groups, defaults={})
614
- self.param_groups = param_groups
615
-
616
- def step(self, closure=None):
617
- del closure
618
- return None
619
-
620
-
621
- def _build_scheduler_proxy_for_optimizer_like(opt: Any) -> Optional[torch.optim.Optimizer]:
622
- param_groups = getattr(opt, "param_groups", None)
623
- if not isinstance(param_groups, list) or not param_groups:
624
- return None
625
- try:
626
- return _SchedulerProxyOptimizer(param_groups)
627
- except Exception:
628
- return None
629
-
630
-
631
- def _pick_lm_eval_perplexity_metrics(task_result: Dict[str, Any]) -> Dict[str, float]:
632
- out: Dict[str, float] = {}
633
- if not isinstance(task_result, dict):
634
- return out
635
- key_map = (
636
- ("word_perplexity,none", "word_perplexity"),
637
- ("word_perplexity", "word_perplexity"),
638
- ("perplexity,none", "perplexity"),
639
- ("perplexity", "perplexity"),
640
- ("bits_per_byte,none", "bits_per_byte"),
641
- ("bits_per_byte", "bits_per_byte"),
642
- )
643
- for src, dst in key_map:
644
- value = task_result.get(src, None)
645
- if isinstance(value, (int, float)):
646
- out[dst] = float(value)
647
- return out
648
-
649
-
650
- def _pick_lm_eval_accuracy_metrics(task_result: Dict[str, Any]) -> Dict[str, float]:
651
- out: Dict[str, float] = {}
652
- if not isinstance(task_result, dict):
653
- return out
654
- key_map = (
655
- ("acc_norm,none", "acc_norm"),
656
- ("acc_norm", "acc_norm"),
657
- ("acc,none", "acc"),
658
- ("acc", "acc"),
659
- ("exact_match,none", "exact_match"),
660
- ("exact_match", "exact_match"),
661
- )
662
- for src, dst in key_map:
663
- value = task_result.get(src, None)
664
- if isinstance(value, (int, float)):
665
- out[dst] = float(value)
666
- return out
667
-
668
-
669
- def _normalize_lm_eval_task_name(task_name: str) -> str:
670
- task = str(task_name or "").strip()
671
- if not task:
672
- return task
673
- normalized = task.lower().replace("-", "_")
674
- aliases = {
675
- "wikitext_103_raw": "wikitext_103_raw",
676
- "wikitext103_raw": "wikitext_103_raw",
677
- "wikitext103raw": "wikitext_103_raw",
678
- "wikitext_103": "wikitext_103_raw",
679
- "wikitext103": "wikitext_103_raw",
680
- "wikitext_103_raw_v1": "wikitext_103_raw",
681
- "wikitext": "wikitext",
682
- "paloma_wikitext_103": "paloma_wikitext_103",
683
- "mmlu": "mmlu",
684
- "hendrycks_test": "mmlu",
685
- "arc": "arc_challenge",
686
- "arc_challenge": "arc_challenge",
687
- "arcchallenge": "arc_challenge",
688
- }
689
- return aliases.get(normalized, normalized)
690
-
691
-
692
- def _parse_lm_eval_task_list(raw_tasks: Any) -> List[str]:
693
- if raw_tasks is None:
694
- return []
695
- if isinstance(raw_tasks, str):
696
- items = [part.strip() for part in raw_tasks.split(",")]
697
- else:
698
- items = [str(part).strip() for part in raw_tasks]
699
- out: List[str] = []
700
- seen = set()
701
- for item in items:
702
- if not item:
703
- continue
704
- normalized = _normalize_lm_eval_task_name(item)
705
- if normalized and normalized not in seen:
706
- seen.add(normalized)
707
- out.append(normalized)
708
- return out
709
-
710
-
711
- def _lm_eval_num_fewshot_for_task(config: BenchmarkConfig, task_name: str) -> int:
712
- normalized = _normalize_lm_eval_task_name(task_name)
713
- if normalized == "mmlu":
714
- return int(getattr(config, "lm_eval_mmlu_fewshot", 5))
715
- if normalized == "arc_challenge":
716
- return int(getattr(config, "lm_eval_arc_challenge_fewshot", 25))
717
- return 0
718
-
719
-
720
- def _lm_eval_limit_for_task(config: BenchmarkConfig, task_name: str) -> Optional[float]:
721
- normalized = _normalize_lm_eval_task_name(task_name)
722
- legacy_limit = getattr(config, "lm_eval_acc_limit", None)
723
- if normalized == "mmlu":
724
- value = getattr(config, "lm_eval_mmlu_limit", legacy_limit)
725
- elif normalized == "arc_challenge":
726
- value = getattr(config, "lm_eval_arc_challenge_limit", None)
727
- else:
728
- value = legacy_limit
729
- if value is None:
730
- return None
731
- return float(value)
732
-
733
-
734
- def _load_lm_eval_results(output_path: Path) -> Dict[str, Any]:
735
- candidates: List[Path] = []
736
- if output_path.is_file():
737
- candidates = [output_path]
738
- elif output_path.is_dir():
739
- candidates = sorted([p for p in output_path.rglob("*.json") if p.is_file()], key=lambda p: p.stat().st_mtime, reverse=True)
740
- for cand in candidates:
741
- try:
742
- payload = json.loads(cand.read_text())
743
- if isinstance(payload, dict) and isinstance(payload.get("results", None), dict):
744
- return payload["results"]
745
- except Exception:
746
- continue
747
- raise RuntimeError(f"Unable to parse lm_eval output from {output_path}")
748
-
749
-
750
- def _find_lm_eval_task_result(raw_results: Dict[str, Any], task_name: str) -> Dict[str, Any]:
751
- task_key = task_name if task_name in raw_results else next(
752
- (k for k in raw_results if k == task_name or k.startswith(task_name)),
753
- None,
754
- )
755
- if task_key is None:
756
- raise RuntimeError(f"lm_eval returned no results for '{task_name}'.")
757
- task_result = raw_results.get(task_key, {})
758
- if not isinstance(task_result, dict):
759
- raise RuntimeError(f"lm_eval returned malformed results for '{task_name}'.")
760
- return task_result
761
-
762
-
763
- def _resolve_lm_eval_command() -> List[str]:
764
- lm_eval_bin = shutil.which("lm_eval")
765
- if lm_eval_bin:
766
- return [lm_eval_bin, "run"]
767
- if importlib.util.find_spec("lm_eval") is not None:
768
- return [sys.executable, "-m", "lm_eval", "run"]
769
- raise RuntimeError("lm_eval not found. Install EleutherAI lm-evaluation-harness in your venv.")
770
-
771
-
772
- def _resolve_lm_eval_include_path() -> Optional[str]:
773
- paths: List[str] = []
774
- for local_tasks in (AUTOMATION_DIR / "lm_eval_tasks", TEST_ROOT / "lm_eval_tasks"):
775
- if local_tasks.exists():
776
- paths.append(str(local_tasks))
777
- env_paths = str(os.environ.get("LM_EVAL_INCLUDE_PATH", "") or "").strip()
778
- if env_paths:
779
- for path in env_paths.split(":"):
780
- path = path.strip()
781
- if path:
782
- paths.append(path)
783
- return ":".join(paths) if paths else None
784
-
785
-
786
- def _run_lm_eval_with_retry(cmd: List[str], batch_size_value: str) -> None:
787
- try:
788
- subprocess.run(cmd, check=True)
789
- return
790
- except subprocess.CalledProcessError as exc:
791
- bs = str(batch_size_value or "").strip().lower()
792
- is_auto = bs.startswith("auto")
793
- if (not is_auto) or ("out of memory" not in str(exc).lower() and "oom" not in str(exc).lower()):
794
- raise
795
- retry_cmd = list(cmd)
796
- idx = retry_cmd.index("--batch_size")
797
- retry_cmd[idx + 1] = "1"
798
- subprocess.run(retry_cmd, check=True)
799
-
800
-
801
- def _prepare_adapter_dir(*, model=None, tokenizer=None) -> tuple[Path, tempfile.TemporaryDirectory]:
802
- if model is None or tokenizer is None:
803
- raise RuntimeError("model and tokenizer are required for lm_eval PPL.")
804
- tmp_ctx = tempfile.TemporaryDirectory(prefix="lbw_demo_lmeval_")
805
- out_dir = Path(tmp_ctx.name) / "peft_adapter"
806
- model.save_pretrained(str(out_dir))
807
- tokenizer.save_pretrained(str(out_dir))
808
- return out_dir, tmp_ctx
809
-
810
-
811
- def _run_lm_eval_tasks_with_adapter(
812
- adapter_path: Path,
813
- *,
814
- config: BenchmarkConfig,
815
- device: torch.device,
816
- tasks: Sequence[str],
817
- limit: Optional[float],
818
- output_name: str,
819
- num_fewshot: int = 0,
820
- ) -> Dict[str, Any]:
821
- lm_eval_cmd = _resolve_lm_eval_command()
822
- include_path = _resolve_lm_eval_include_path()
823
- normalized_tasks = [_normalize_lm_eval_task_name(task) for task in tasks if str(task).strip()]
824
- if not normalized_tasks:
825
- raise RuntimeError("No lm_eval tasks were provided.")
826
- out_dir = adapter_path.parent / output_name
827
- out_dir.mkdir(parents=True, exist_ok=True)
828
- lm_eval_dtype = "bfloat16" if device.type == "cuda" else "float32"
829
- model_args = [
830
- f"pretrained={config.model_name}",
831
- f"peft={adapter_path}",
832
- f"dtype={lm_eval_dtype}",
833
- "trust_remote_code=True",
834
- ]
835
- batch_size_value = str(getattr(config, "lm_eval_batch_size", "1"))
836
- cmd = [
837
- *lm_eval_cmd,
838
- "--model",
839
- "hf",
840
- "--model_args",
841
- ",".join(model_args),
842
- "--tasks",
843
- ",".join(normalized_tasks),
844
- "--num_fewshot",
845
- str(int(num_fewshot)),
846
- "--batch_size",
847
- batch_size_value,
848
- "--device",
849
- ("cuda" if device.type == "cuda" else "cpu"),
850
- "--output_path",
851
- str(out_dir),
852
- ]
853
- if include_path is not None:
854
- cmd.extend(["--include_path", include_path])
855
- if limit is not None:
856
- cmd.extend(["--limit", str(float(limit))])
857
- _run_lm_eval_with_retry(cmd, batch_size_value)
858
- return _load_lm_eval_results(out_dir)
859
-
860
-
861
- def _summarize_lm_eval_status(
862
- requested: Dict[str, bool],
863
- statuses: Dict[str, str],
864
- errors: Dict[str, Optional[str]],
865
- ) -> tuple[str, Optional[str]]:
866
- active = [name for name, enabled in requested.items() if enabled]
867
- if not active:
868
- return "disabled", None
869
- active_statuses = [str(statuses.get(name, "disabled")) for name in active]
870
- joined_errors = "; ".join(
871
- f"{name}: {errors[name]}"
872
- for name in active
873
- if str(errors.get(name) or "").strip()
874
- ) or None
875
- if all(status == "ok" for status in active_statuses):
876
- return "ok", None
877
- if any(status == "ok" for status in active_statuses):
878
- return "partial", joined_errors
879
- return "skipped", joined_errors
880
-
881
-
882
- def run_lm_eval_suite(model, tokenizer, config: BenchmarkConfig, device: torch.device) -> Dict[str, Any]:
883
- requested = {
884
- "ppl": bool(getattr(config, "lm_eval_ppl", False)),
885
- "acc": bool(getattr(config, "lm_eval_acc", False)),
886
- }
887
- statuses: Dict[str, str] = {
888
- "ppl": "disabled",
889
- "acc": "disabled",
890
- }
891
- errors: Dict[str, Optional[str]] = {
892
- "ppl": None,
893
- "acc": None,
894
- }
895
- out: Dict[str, Any] = {}
896
- adapter_path, tmp_ctx = _prepare_adapter_dir(model=model, tokenizer=tokenizer)
897
- try:
898
- if requested["ppl"]:
899
- statuses["ppl"] = "requested"
900
- ppl_task = _normalize_lm_eval_task_name(getattr(config, "lm_eval_ppl_task", "wikitext_103_raw"))
901
- try:
902
- raw_results = _run_lm_eval_tasks_with_adapter(
903
- adapter_path,
904
- config=config,
905
- device=device,
906
- tasks=[ppl_task],
907
- limit=getattr(config, "lm_eval_ppl_limit", None),
908
- output_name="lm_eval_ppl_out",
909
- num_fewshot=0,
910
- )
911
- ppl_metrics = _pick_lm_eval_perplexity_metrics(_find_lm_eval_task_result(raw_results, ppl_task))
912
- if not ppl_metrics:
913
- raise RuntimeError(f"No perplexity-like metrics found for '{ppl_task}'.")
914
- if "word_perplexity" in ppl_metrics:
915
- out["lm_eval/final_word_perplexity"] = float(ppl_metrics["word_perplexity"])
916
- out["final_eval/perplexity_lm_eval"] = float(ppl_metrics["word_perplexity"])
917
- if "perplexity" in ppl_metrics:
918
- out["lm_eval/final_perplexity"] = float(ppl_metrics["perplexity"])
919
- out.setdefault("final_eval/perplexity_lm_eval", float(ppl_metrics["perplexity"]))
920
- if "bits_per_byte" in ppl_metrics:
921
- out["lm_eval/final_bits_per_byte"] = float(ppl_metrics["bits_per_byte"])
922
- statuses["ppl"] = "ok"
923
- except Exception as exc:
924
- statuses["ppl"] = "skipped"
925
- errors["ppl"] = str(exc).strip() or type(exc).__name__
926
-
927
- if requested["acc"]:
928
- statuses["acc"] = "requested"
929
- acc_tasks = _parse_lm_eval_task_list(getattr(config, "lm_eval_acc_tasks", "mmlu,arc_challenge"))
930
- if not acc_tasks:
931
- statuses["acc"] = "skipped"
932
- errors["acc"] = "No lm_eval accuracy tasks configured."
933
- else:
934
- try:
935
- acc_out: Dict[str, Any] = {}
936
- for task_name in acc_tasks:
937
- raw_results = _run_lm_eval_tasks_with_adapter(
938
- adapter_path,
939
- config=config,
940
- device=device,
941
- tasks=[task_name],
942
- limit=_lm_eval_limit_for_task(config, task_name),
943
- output_name=f"lm_eval_acc_{task_name}_out",
944
- num_fewshot=_lm_eval_num_fewshot_for_task(config, task_name),
945
- )
946
- metrics = _pick_lm_eval_accuracy_metrics(_find_lm_eval_task_result(raw_results, task_name))
947
- if task_name == "mmlu":
948
- value = metrics.get("acc")
949
- if value is None:
950
- raise RuntimeError("No `acc` metric found for `mmlu`.")
951
- acc_out["lm_eval/final_mmlu_acc"] = float(value)
952
- acc_out["final_eval/mmlu_acc_lm_eval"] = float(value)
953
- elif task_name == "arc_challenge":
954
- value = metrics.get("acc_norm")
955
- if value is None:
956
- value = metrics.get("acc")
957
- if value is None:
958
- value = metrics.get("exact_match")
959
- if value is None:
960
- raise RuntimeError("No accuracy-like metric found for `arc_challenge`.")
961
- acc_out["lm_eval/final_arc_challenge_acc"] = float(value)
962
- acc_out["final_eval/arc_challenge_acc_lm_eval"] = float(value)
963
- out.update(acc_out)
964
- statuses["acc"] = "ok"
965
- except Exception as exc:
966
- statuses["acc"] = "skipped"
967
- errors["acc"] = str(exc).strip() or type(exc).__name__
968
-
969
- overall_status, overall_error = _summarize_lm_eval_status(requested, statuses, errors)
970
- out.update(
971
- {
972
- "lm_eval_status": overall_status,
973
- "lm_eval_error": overall_error,
974
- "lm_eval_ppl_status": statuses["ppl"],
975
- "lm_eval_ppl_error": errors["ppl"],
976
- "lm_eval_acc_status": statuses["acc"],
977
- "lm_eval_acc_error": errors["acc"],
978
- }
979
- )
980
- return out
981
- finally:
982
- tmp_ctx.cleanup()
983
-
984
-
985
- def _current_learning_rate(opt: Any, scheduler: Optional[Any], config: BenchmarkConfig) -> float:
986
- if scheduler is not None:
987
- try:
988
- return float(scheduler.get_last_lr()[0])
989
- except Exception:
990
- pass
991
- try:
992
- return float(opt.param_groups[0]["lr"])
993
- except Exception:
994
- return float(config.lr)
995
-
996
-
997
- def _optimizer_group_learning_rate(opt: Any, config: BenchmarkConfig) -> float:
998
- try:
999
- return float(opt.param_groups[0]["lr"])
1000
- except Exception:
1001
- return float(config.lr)
1002
-
1003
-
1004
- def _build_scheduler(opt: Any, optimizer_name: str, config: BenchmarkConfig):
1005
- schedule_mode = str(getattr(config, "schedule_mode", "all_cosine") or "all_cosine").strip().lower()
1006
- if schedule_mode not in {"native", "all_cosine", "all_constant"}:
1007
- schedule_mode = "all_cosine"
1008
- if isinstance(opt, torch.optim.Optimizer):
1009
- target = opt
1010
- else:
1011
- target = _build_scheduler_proxy_for_optimizer_like(opt)
1012
- if target is None:
1013
- return None
1014
- if schedule_mode == "all_constant":
1015
- return torch.optim.lr_scheduler.LambdaLR(target, lr_lambda=lambda step: 1.0)
1016
- if schedule_mode == "all_cosine":
1017
- return get_cosine_schedule_with_warmup(target, num_warmup_steps=config.warmup_steps, num_training_steps=config.max_steps)
1018
- if optimizer_name == "lbw_guard":
1019
- return torch.optim.lr_scheduler.LambdaLR(target, lr_lambda=lambda step: 1.0)
1020
- return get_cosine_schedule_with_warmup(target, num_warmup_steps=config.warmup_steps, num_training_steps=config.max_steps)
1021
-
1022
-
1023
- def _init_wandb_if_enabled(config: BenchmarkConfig, *, group_name: Optional[str], run_name: Optional[str]):
1024
- if not bool(getattr(config, "use_wandb", False)):
1025
- return None
1026
- if wandb is None:
1027
- print("[W&B] Disabled (wandb not installed)")
1028
- return None
1029
- try:
1030
- wandb.init(
1031
- project="LBW-Customer-Demo",
1032
- group=group_name,
1033
- name=run_name,
1034
- config=config.__dict__,
1035
- reinit=True,
1036
- settings=wandb.Settings(start_method="thread"),
1037
- )
1038
- return wandb
1039
- except Exception as exc:
1040
- print(f"[W&B] Disabled (wandb.init failed: {exc})")
1041
- return None
1042
-
1043
-
1044
- def train_one_run(
1045
- optimizer_name: str,
1046
- config: BenchmarkConfig,
1047
- *,
1048
- group_name: Optional[str] = None,
1049
- run_name: Optional[str] = None,
1050
- shared_pre_bench_results=None,
1051
- shared_bench_dataset_bundle=None,
1052
- ) -> Dict[str, Any]:
1053
- del shared_pre_bench_results, shared_bench_dataset_bundle
1054
-
1055
- normalized = normalize_optimizer_name(optimizer_name)
1056
- ok, reason = check_optimizer_support(normalized, device=config.device)
1057
- if not ok:
1058
- raise RuntimeError(f"{normalized}: {reason}")
1059
-
1060
- device = torch.device(config.device)
1061
- if device.type != "cuda":
1062
- warnings.filterwarnings(
1063
- "ignore",
1064
- message="CUDA initialization: The NVIDIA driver on your system is too old.*",
1065
- category=UserWarning,
1066
- )
1067
- wb = _init_wandb_if_enabled(config, group_name=group_name, run_name=run_name or normalized)
1068
-
1069
- _demo_log(config, f"Loading model and tokenizer: {config.model_name} on {device}")
1070
- tokenizer, model = _load_tokenizer_and_model(config.model_name, device)
1071
- _demo_log(config, "Model load complete")
1072
- train_cap = None if config.full_wikitext_train else config.max_chars
1073
- eval_cap = None if config.full_wikitext_eval else config.eval_chars
1074
- train_chunks = build_wikitext_chunks(tokenizer, config.seq_len, train_cap, "train", config=config)
1075
- eval_chunks = build_wikitext_chunks(tokenizer, config.seq_len, eval_cap, "validation", config=config)
1076
- train_iter = batch_iter(train_chunks, config.batch_size, device)
1077
- train_sequence_count = int(train_chunks.input_ids.size(0))
1078
- train_token_count = int(train_chunks.input_ids.numel())
1079
- eval_sequence_count = int(eval_chunks.input_ids.size(0))
1080
- eval_token_count = int(eval_chunks.input_ids.numel())
1081
- sequences_per_optimizer_step = max(int(config.batch_size * config.grad_accum), 1)
1082
- tokens_per_optimizer_step = max(int(config.batch_size * config.seq_len * config.grad_accum), 1)
1083
- steps_per_train_pass = int(math.ceil(train_sequence_count / float(sequences_per_optimizer_step)))
1084
-
1085
- if bool(getattr(config, "enable_lora", True)):
1086
- _demo_log(config, "Attaching LoRA adapters")
1087
- lora_cfg = LoraConfig(
1088
- r=config.lora_r,
1089
- lora_alpha=config.lora_alpha,
1090
- lora_dropout=config.lora_dropout,
1091
- target_modules=config.lora_target_modules,
1092
- task_type=TaskType.CAUSAL_LM,
1093
- bias="none",
1094
- )
1095
- model = get_peft_model(model, lora_cfg)
1096
- else:
1097
- _demo_log(config, "Training without LoRA adapters")
1098
- model.train()
1099
- trainable_params = [p for p in model.parameters() if p.requires_grad]
1100
- if not trainable_params:
1101
- raise RuntimeError("No trainable parameters found for training.")
1102
-
1103
- _demo_log(config, f"Creating optimizer: {normalized}")
1104
- opt = get_optimizer(normalized, model, config)
1105
- scheduler = _build_scheduler(opt, normalized, config)
1106
- governance_tracker = GovernanceMetricsTracker()
1107
- _demo_log(config, f"Starting training for {config.max_steps} optimizer steps")
1108
-
1109
- train_start = time.time()
1110
- step_wall_start = train_start
1111
- step_compute_start = train_start
1112
- train_losses: List[float] = []
1113
- pure_tps_history: List[float] = []
1114
- wall_tps_history: List[float] = []
1115
- pure_step_time_history: List[float] = []
1116
- wall_step_time_history: List[float] = []
1117
- runtime_snapshot: Dict[str, Any] = {
1118
- "stress_mode": "none",
1119
- "scale": 1.0,
1120
- "ratio": 1.0,
1121
- "grad_rms": 0.0,
1122
- "scheduled_lr_used": float(config.lr),
1123
- "scheduled_lr_next": float(config.lr),
1124
- "effective_lr_main_used": float(config.lr),
1125
- "effective_lr_weight_decay_used": float(config.lr),
1126
- "train_sequences": train_sequence_count,
1127
- "train_tokens": train_token_count,
1128
- "train_chars": int(train_chunks.char_count),
1129
- "train_cap_chars": train_chunks.cap_chars,
1130
- "eval_sequences": eval_sequence_count,
1131
- "eval_tokens": eval_token_count,
1132
- "eval_chars": int(eval_chunks.char_count),
1133
- "eval_cap_chars": eval_chunks.cap_chars,
1134
- "sequences_per_optimizer_step": sequences_per_optimizer_step,
1135
- "tokens_per_optimizer_step": tokens_per_optimizer_step,
1136
- "steps_per_train_pass": steps_per_train_pass,
1137
- "epochs_completed": 0.0,
1138
- }
1139
-
1140
- global_step = 0
1141
- accumulation_step = 0
1142
-
1143
- while global_step < config.max_steps:
1144
- xb, yb = next(train_iter)
1145
- with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=(device.type == "cuda")):
1146
- outputs = model(input_ids=xb, labels=yb)
1147
- loss = outputs.loss / config.grad_accum
1148
- loss.backward()
1149
- accumulation_step += 1
1150
- if accumulation_step % config.grad_accum != 0:
1151
- continue
1152
-
1153
- step_number = global_step + 1
1154
- grad_norm = torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
1155
- grad_norm_value = grad_norm.detach().item() if torch.is_tensor(grad_norm) else float(grad_norm)
1156
- grad_rms = 0.0 if normalized == "lbw_guard" else _compute_grad_rms(trainable_params)
1157
- loss_val = float(loss.item() * config.grad_accum)
1158
- scheduled_lr_used = _optimizer_group_learning_rate(opt, config)
1159
-
1160
- if normalized == "lbw_guard":
1161
- opt.step()
1162
- else:
1163
- opt.step()
1164
- if scheduler is not None:
1165
- scheduler.step()
1166
- opt.zero_grad()
1167
-
1168
- compute_end = time.time()
1169
- pure_step_time = max(compute_end - step_compute_start, 1e-12)
1170
- tokens_per_step = int(config.batch_size * config.seq_len * config.grad_accum)
1171
- pure_tps = tokens_per_step / pure_step_time
1172
-
1173
- scheduled_lr_next = _current_learning_rate(opt, scheduler, config)
1174
- scale = 1.0
1175
- ratio = 1.0
1176
- ema_grad_rms = grad_rms
1177
- stress_mode = "none"
1178
- edition = normalized
1179
- effective_lr_main_used = scheduled_lr_used
1180
- effective_lr_weight_decay_used = scheduled_lr_used
1181
- if normalized == "lbw_guard":
1182
- lbw_state = dict(getattr(opt, "state", {}).get("lbw", {}) or {})
1183
- scale = float(lbw_state.get("scale", lbw_state.get("lbw_scale", 1.0)))
1184
- ratio = float(lbw_state.get("ratio", 1.0))
1185
- grad_rms = float(lbw_state.get("grad_rms", grad_rms))
1186
- ema_grad_rms = grad_rms / ratio if ratio > 0 else grad_rms
1187
- stress_mode = str(lbw_state.get("stress_mode", "unknown"))
1188
- edition = str(lbw_state.get("edition", lbw_state.get("mode", normalized)))
1189
- effective_lr_main_used = scheduled_lr_used * scale
1190
- effective_lr_weight_decay_used = scheduled_lr_used * scale
1191
-
1192
- derived_gov_metrics = governance_tracker.update_step(
1193
- step=global_step,
1194
- trainable_params=trainable_params,
1195
- loss_val=loss_val,
1196
- grad_norm=grad_norm_value,
1197
- grad_rms=grad_rms,
1198
- ema_grad_rms=ema_grad_rms,
1199
- ratio=ratio,
1200
- scale=scale,
1201
- stress_mode=stress_mode,
1202
- current_lr=scheduled_lr_used,
1203
- )
1204
-
1205
- train_losses.append(loss_val)
1206
- pure_tps_history.append(pure_tps)
1207
- pure_step_time_history.append(pure_step_time)
1208
- epochs_completed = (step_number * sequences_per_optimizer_step) / float(train_sequence_count)
1209
-
1210
- eval_log: Dict[str, float] = {}
1211
- progress_every = max(
1212
- 1,
1213
- min(
1214
- int(config.eval_every),
1215
- 5 if int(config.max_steps) <= 50 else 10,
1216
- ),
1217
- )
1218
- if step_number % config.eval_every == 0:
1219
- avg_eval_loss, perp = evaluate_perplexity(model, eval_chunks, config, device)
1220
- eval_log = {
1221
- "eval/loss": avg_eval_loss,
1222
- "eval/perplexity": perp,
1223
- }
1224
- eval_log.update(
1225
- governance_tracker.update_eval(
1226
- eval_loss=avg_eval_loss,
1227
- eval_perplexity=perp,
1228
- avg_tps_wall=(wall_tps_history[-1] if wall_tps_history else None),
1229
- )
1230
- )
1231
- _demo_log(
1232
- config,
1233
- f"step {step_number}/{config.max_steps}: loss={loss_val:.4f}, "
1234
- f"sampled_eval_loss={avg_eval_loss:.4f}, sampled_eval_ppl={perp:.4f}, "
1235
- f"scale={scale:.4f}, ratio={ratio:.4f}",
1236
- )
1237
- model.train()
1238
- elif bool(getattr(config, "print_all_metrics", False)) and (
1239
- step_number == 1
1240
- or step_number == config.max_steps
1241
- or step_number % progress_every == 0
1242
- ):
1243
- _demo_log(
1244
- config,
1245
- f"step {step_number}/{config.max_steps}: loss={loss_val:.4f}, scale={scale:.4f}, ratio={ratio:.4f}",
1246
- )
1247
-
1248
- wall_end = time.time()
1249
- wall_step_time = max(wall_end - step_wall_start, 1e-12)
1250
- wall_tps = tokens_per_step / wall_step_time
1251
- wall_tps_history.append(wall_tps)
1252
- wall_step_time_history.append(wall_step_time)
1253
-
1254
- train_log = {
1255
- "train/loss": loss_val,
1256
- "train/grad_norm": grad_norm_value,
1257
- "train/tokens_per_sec_pure": pure_tps,
1258
- "train/tokens_per_sec_wall": wall_tps,
1259
- "train/step_time_pure_sec": pure_step_time,
1260
- "train/step_time_wall_sec": wall_step_time,
1261
- "train/lr": scheduled_lr_used,
1262
- "train/lr_used": scheduled_lr_used,
1263
- "train/lr_next": scheduled_lr_next,
1264
- "train/effective_lr_main": effective_lr_main_used,
1265
- "train/effective_lr_weight_decay": effective_lr_weight_decay_used,
1266
- "train/steps_per_train_pass": float(steps_per_train_pass),
1267
- "train/epochs_completed": float(epochs_completed),
1268
- "lbw/scale": scale,
1269
- "lbw/ratio": ratio,
1270
- "lbw/grad_rms": grad_rms,
1271
- "lbw/ema_grad_rms": ema_grad_rms,
1272
- "lbw/stress_mode": stress_mode,
1273
- "lbw/edition": edition,
1274
- }
1275
- train_log.update(derived_gov_metrics)
1276
- if wb is not None:
1277
- wb.log({**train_log, **eval_log}, step=step_number)
1278
-
1279
- runtime_snapshot = {
1280
- "stress_mode": stress_mode,
1281
- "scale": scale,
1282
- "ratio": ratio,
1283
- "grad_rms": grad_rms,
1284
- "scheduled_lr_used": scheduled_lr_used,
1285
- "scheduled_lr_next": scheduled_lr_next,
1286
- "effective_lr_main_used": effective_lr_main_used,
1287
- "effective_lr_weight_decay_used": effective_lr_weight_decay_used,
1288
- "train_sequences": train_sequence_count,
1289
- "train_tokens": train_token_count,
1290
- "train_chars": int(train_chunks.char_count),
1291
- "train_cap_chars": train_chunks.cap_chars,
1292
- "eval_sequences": eval_sequence_count,
1293
- "eval_tokens": eval_token_count,
1294
- "eval_chars": int(eval_chunks.char_count),
1295
- "eval_cap_chars": eval_chunks.cap_chars,
1296
- "sequences_per_optimizer_step": sequences_per_optimizer_step,
1297
- "tokens_per_optimizer_step": tokens_per_optimizer_step,
1298
- "steps_per_train_pass": steps_per_train_pass,
1299
- "epochs_completed": float(epochs_completed),
1300
- }
1301
-
1302
- global_step += 1
1303
- step_wall_start = time.time()
1304
- step_compute_start = step_wall_start
1305
-
1306
- training_wall_time = max(time.time() - train_start, 1e-12)
1307
- final_eval_is_full = bool(config.full_validation_ppl)
1308
- if final_eval_is_full:
1309
- final_eval_scope = "full_wikitext" if eval_chunks.cap_chars is None else "full_loaded_subset"
1310
- final_eval_scope_text = (
1311
- "over the full WikiText validation split"
1312
- if eval_chunks.cap_chars is None
1313
- else f"over the full loaded validation subset ({int(eval_chunks.char_count):,} chars; --eval-chars cap)"
1314
- )
1315
- else:
1316
- final_eval_scope = "sampled"
1317
- final_eval_scope_text = f"over {int(config.eval_batches)} sampled batches"
1318
- _demo_log(
1319
- config,
1320
- "Running final validation PPL " + final_eval_scope_text,
1321
- )
1322
- final_eval_start = time.time()
1323
- final_eval_loss, final_eval_perp = evaluate_perplexity(
1324
- model,
1325
- eval_chunks,
1326
- config,
1327
- device,
1328
- full_pass=final_eval_is_full,
1329
- )
1330
- final_eval_time_sec = max(time.time() - final_eval_start, 0.0)
1331
- final_eval_perp_lm_eval = None
1332
- final_eval_mmlu_acc_lm_eval = None
1333
- final_eval_arc_challenge_acc_lm_eval = None
1334
- lm_eval_status = "disabled"
1335
- lm_eval_error = None
1336
- lm_eval_ppl_status = "disabled"
1337
- lm_eval_ppl_error = None
1338
- lm_eval_acc_status = "disabled"
1339
- lm_eval_acc_error = None
1340
- lm_eval_time_sec = 0.0
1341
- if bool(config.use_lm_eval) and (bool(config.lm_eval_ppl) or bool(getattr(config, "lm_eval_acc", False))):
1342
- try:
1343
- lm_eval_start = time.time()
1344
- final_lm_eval_metrics = run_lm_eval_suite(model, tokenizer, config, device)
1345
- lm_eval_time_sec = max(time.time() - lm_eval_start, 0.0)
1346
- final_eval_perp_lm_eval = _safe_float(final_lm_eval_metrics.get("final_eval/perplexity_lm_eval"))
1347
- final_eval_mmlu_acc_lm_eval = _safe_float(final_lm_eval_metrics.get("final_eval/mmlu_acc_lm_eval"))
1348
- final_eval_arc_challenge_acc_lm_eval = _safe_float(
1349
- final_lm_eval_metrics.get("final_eval/arc_challenge_acc_lm_eval")
1350
- )
1351
- lm_eval_status = str(final_lm_eval_metrics.get("lm_eval_status") or "ok")
1352
- lm_eval_error = str(final_lm_eval_metrics.get("lm_eval_error") or "").strip() or None
1353
- lm_eval_ppl_status = str(final_lm_eval_metrics.get("lm_eval_ppl_status") or "disabled")
1354
- lm_eval_ppl_error = str(final_lm_eval_metrics.get("lm_eval_ppl_error") or "").strip() or None
1355
- lm_eval_acc_status = str(final_lm_eval_metrics.get("lm_eval_acc_status") or "disabled")
1356
- lm_eval_acc_error = str(final_lm_eval_metrics.get("lm_eval_acc_error") or "").strip() or None
1357
- if lm_eval_status in {"skipped", "partial"} and lm_eval_error:
1358
- print(f"[DemoRuntime] lm_eval issues: {lm_eval_error}")
1359
- except Exception as exc:
1360
- lm_eval_time_sec = max(time.time() - lm_eval_start, 0.0)
1361
- lm_eval_status = "skipped"
1362
- lm_eval_error = str(exc).strip() or type(exc).__name__
1363
- print(f"[DemoRuntime] lm_eval skipped: {lm_eval_error}")
1364
-
1365
- wall_time = max(time.time() - train_start, 1e-12)
1366
- post_training_benchmark_time_sec = max(wall_time - training_wall_time, 0.0)
1367
- avg_tps_wall = float(sum(wall_tps_history) / len(wall_tps_history)) if wall_tps_history else 0.0
1368
- final_effect_metrics = governance_tracker.update_eval(
1369
- eval_loss=final_eval_loss,
1370
- eval_perplexity=final_eval_perp,
1371
- avg_tps_wall=avg_tps_wall,
1372
- )
1373
- governance_snapshot = governance_tracker.snapshot()
1374
- _demo_log(
1375
- config,
1376
- f"Finished: "
1377
- f"{'final_full_eval_loss' if final_eval_is_full else 'final_eval_loss'}={final_eval_loss:.4f}, "
1378
- f"{'final_full_eval_ppl' if final_eval_is_full else 'final_eval_ppl'}={final_eval_perp:.4f}, "
1379
- f"wall_time={wall_time:.1f}s",
1380
- )
1381
-
1382
- if wb is not None:
1383
- wb.log(
1384
- {
1385
- "final/eval_loss": final_eval_loss,
1386
- "final/eval_perplexity": final_eval_perp,
1387
- **final_effect_metrics,
1388
- },
1389
- step=config.max_steps,
1390
- )
1391
- wb.finish()
1392
-
1393
- return {
1394
- "optimizer": normalized,
1395
- "group_name": group_name,
1396
- "run_name": run_name,
1397
- "model_name": config.model_name,
1398
- "final_eval_loss": float(final_eval_loss),
1399
- "final_eval_perp": float(final_eval_perp),
1400
- "final_eval_perp_lm_eval": final_eval_perp_lm_eval,
1401
- "final_eval_mmlu_acc_lm_eval": final_eval_mmlu_acc_lm_eval,
1402
- "final_eval_arc_challenge_acc_lm_eval": final_eval_arc_challenge_acc_lm_eval,
1403
- "lm_eval_status": lm_eval_status,
1404
- "lm_eval_error": lm_eval_error,
1405
- "lm_eval_ppl_status": lm_eval_ppl_status,
1406
- "lm_eval_ppl_error": lm_eval_ppl_error,
1407
- "lm_eval_acc_status": lm_eval_acc_status,
1408
- "lm_eval_acc_error": lm_eval_acc_error,
1409
- "avg_tokens_per_sec_pure": float(sum(pure_tps_history) / len(pure_tps_history)) if pure_tps_history else 0.0,
1410
- "avg_tokens_per_sec_wall": avg_tps_wall,
1411
- "avg_step_time_pure_sec": float(sum(pure_step_time_history) / len(pure_step_time_history)) if pure_step_time_history else 0.0,
1412
- "avg_step_time_wall_sec": float(sum(wall_step_time_history) / len(wall_step_time_history)) if wall_step_time_history else 0.0,
1413
- "training_wall_time_sec": float(training_wall_time),
1414
- "final_eval_time_sec": float(final_eval_time_sec),
1415
- "lm_eval_time_sec": float(lm_eval_time_sec),
1416
- "post_training_benchmark_time_sec": float(post_training_benchmark_time_sec),
1417
- "wall_time_sec": float(wall_time),
1418
- "train_sequence_count": int(train_sequence_count),
1419
- "train_token_count": int(train_token_count),
1420
- "train_char_count": int(train_chunks.char_count),
1421
- "train_cap_chars": train_chunks.cap_chars,
1422
- "eval_sequence_count": int(eval_sequence_count),
1423
- "eval_token_count": int(eval_token_count),
1424
- "eval_char_count": int(eval_chunks.char_count),
1425
- "eval_cap_chars": eval_chunks.cap_chars,
1426
- "full_wikitext_train": bool(config.full_wikitext_train),
1427
- "full_wikitext_eval": bool(config.full_wikitext_eval),
1428
- "full_validation_ppl": bool(config.full_validation_ppl),
1429
- "final_eval_full_pass": bool(final_eval_is_full),
1430
- "final_eval_scope": final_eval_scope,
1431
- "sequences_per_optimizer_step": int(sequences_per_optimizer_step),
1432
- "tokens_per_optimizer_step": int(tokens_per_optimizer_step),
1433
- "steps_per_train_pass": int(steps_per_train_pass),
1434
- "epochs_completed": float((global_step * sequences_per_optimizer_step) / float(train_sequence_count)),
1435
- "runtime_snapshot": runtime_snapshot,
1436
- "governance_snapshot": governance_snapshot,
1437
- "final_effect_metrics": final_effect_metrics,
1438
- "train_loss_last": (float(train_losses[-1]) if train_losses else None),
1439
- "schedule_mode": config.schedule_mode,
1440
- "max_steps": int(config.max_steps),
1441
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -2,12 +2,12 @@ from __future__ import annotations
2
 
3
  import csv
4
  import gc
5
- import io
6
  import json
 
7
  import os
 
8
  import time
9
  import traceback
10
- from contextlib import redirect_stdout
11
  from pathlib import Path
12
  from typing import Any
13
 
@@ -23,8 +23,17 @@ os.environ.setdefault("DISABLE_SAFETENSORS_CONVERSION", "1")
23
 
24
  import gradio as gr
25
  import torch
 
 
 
26
 
27
- import _demo_runtime as runtime
 
 
 
 
 
 
28
 
29
 
30
  RUNS_DIR = ROOT / "runs"
@@ -34,81 +43,369 @@ def _device_default() -> str:
34
  return "cuda" if torch.cuda.is_available() else "cpu"
35
 
36
 
 
 
 
 
 
 
 
37
  def _safe_float(value: Any) -> float | None:
38
  if value is None:
39
  return None
40
  try:
41
- return float(value)
42
  except Exception:
43
  return None
 
 
 
 
44
 
 
 
 
45
 
46
- def _build_config(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  *,
48
  model_name: str,
49
- steps: int,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  lr: float,
51
- seq_len: int,
52
- train_chars: int,
53
- eval_chars: int,
54
- eval_batches: int,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  batch_size: int,
56
- grad_accum: int,
57
- seed: int,
58
- device: str,
59
- ) -> runtime.BenchmarkConfig:
60
- config = runtime.BenchmarkConfig()
61
- config.model_name = str(model_name).strip() or "Qwen/Qwen2.5-0.5B"
62
- config.device = str(device or _device_default())
63
- config.max_steps = int(steps)
64
- config.lr = float(lr)
65
- config.seq_len = int(seq_len)
66
- config.batch_size = int(batch_size)
67
- config.grad_accum = int(grad_accum)
68
- config.warmup_steps = min(5, max(0, int(steps) // 5))
69
- config.eval_every = max(1, min(int(steps), 10))
70
- config.eval_batches = int(eval_batches)
71
- config.max_chars = int(train_chars)
72
- config.eval_chars = int(eval_chars)
73
- config.full_wikitext_train = False
74
- config.full_wikitext_eval = False
75
- config.full_validation_ppl = False
76
- config.schedule_mode = "all_cosine"
77
- config.lora_r = 8
78
- config.lora_alpha = 32
79
- config.lora_dropout = 0.05
80
- config.lbw_stats_freq = 5
81
- config.lbw_stress_th = 1.1
82
- config.lbw_spike_th = 1.5
83
- config.lbw_rec_fast = 0.01
84
- config.lbw_ema_decay = 0.95
85
- config.use_wandb = False
86
- config.use_lbwgov = False
87
- config.print_all_metrics = True
88
- config.output_dir = str((RUNS_DIR / f"run_{int(time.time())}").resolve())
89
- config.use_lm_eval = False
90
- config.lm_eval_ppl = False
91
- config.lm_eval_acc = False
92
- runtime.set_seed(int(seed), device=config.device)
93
- return config
94
-
95
-
96
- def _result_row(result: dict[str, Any]) -> dict[str, Any]:
97
- runtime_snapshot = dict(result.get("runtime_snapshot") or {})
98
- governance_snapshot = dict(result.get("governance_snapshot") or {})
99
  return {
100
- "optimizer": result.get("optimizer"),
101
- "final_eval_perplexity": _safe_float(result.get("final_eval_perp")),
102
- "final_eval_loss": _safe_float(result.get("final_eval_loss")),
103
- "tokens_per_sec_wall": _safe_float(result.get("avg_tokens_per_sec_wall")),
104
- "training_wall_time_sec": _safe_float(result.get("training_wall_time_sec")),
105
- "wall_time_sec": _safe_float(result.get("wall_time_sec")),
106
- "scale": _safe_float(runtime_snapshot.get("scale")),
107
- "ratio": _safe_float(runtime_snapshot.get("ratio")),
108
- "stress_mode": runtime_snapshot.get("stress_mode"),
109
- "intervention_count": governance_snapshot.get("intervention_count"),
110
- "regime_switch_count": governance_snapshot.get("regime_switch_count"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  }
 
 
 
 
 
112
 
113
 
114
  def _gain_rows(rows: list[dict[str, Any]]) -> list[dict[str, Any]]:
@@ -116,24 +413,30 @@ def _gain_rows(rows: list[dict[str, Any]]) -> list[dict[str, Any]]:
116
  baseline = by_optimizer.get("adamw")
117
  if baseline is None:
118
  return []
119
- gains = []
 
 
120
  for row in rows:
121
  if row.get("optimizer") == "adamw":
122
  continue
123
- baseline_ppl = _safe_float(baseline.get("final_eval_perplexity"))
124
- candidate_ppl = _safe_float(row.get("final_eval_perplexity"))
125
- baseline_tps = _safe_float(baseline.get("tokens_per_sec_wall"))
126
- candidate_tps = _safe_float(row.get("tokens_per_sec_wall"))
127
- ppl_gain = None if baseline_ppl is None or candidate_ppl is None else baseline_ppl - candidate_ppl
128
- speedup = None if baseline_tps in (None, 0.0) or candidate_tps is None else candidate_tps / baseline_tps
129
  gains.append(
130
  {
131
  "optimizer": row.get("optimizer"),
132
- "eval_perplexity_gain_vs_adamw": ppl_gain,
 
 
133
  "eval_perplexity_pct_gain_vs_adamw": (
134
- None if baseline_ppl in (None, 0.0) or candidate_ppl is None else (baseline_ppl - candidate_ppl) / baseline_ppl
 
 
 
 
 
 
 
135
  ),
136
- "wall_tokens_per_sec_speedup_vs_adamw": speedup,
137
  }
138
  )
139
  return gains
@@ -149,220 +452,897 @@ def _write_csv(path: Path, rows: list[dict[str, Any]]) -> None:
149
  writer.writerows(rows)
150
 
151
 
152
- def run_demo(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  model_name: str,
154
- steps: int,
155
- lr: float,
 
 
156
  seq_len: int,
 
157
  train_chars: int,
158
  eval_chars: int,
159
- eval_batches: int,
160
- batch_size: int,
161
- grad_accum: int,
 
162
  seed: int,
163
- run_lbw_guard: bool,
164
- ) -> Any:
165
- if not run_lbw_guard:
166
- optimizers = ["adamw"]
167
- else:
168
- optimizers = ["adamw", "lbw_guard"]
169
- device = _device_default()
170
- if device == "cpu" and int(steps) > 3:
171
- yield (
172
- "This Space is currently running on `cpu-basic`. "
173
- "For CPU smoke checks, use `1-3` steps. For larger runs, switch the Space hardware to GPU first.",
174
- None,
175
- None,
176
- )
177
- return
178
- if device == "cpu" and run_lbw_guard and int(steps) > 1:
179
- yield (
180
- "This Space is currently running on `cpu-basic`. "
181
- "An AdamW + LBW comparison runs two full model passes, so CPU mode is capped at `1` step when comparison is enabled.",
182
- None,
183
- None,
184
- )
185
- return
186
-
187
- config = _build_config(
188
- model_name=model_name,
189
- steps=steps,
190
- lr=lr,
191
- seq_len=seq_len,
192
- train_chars=train_chars,
193
- eval_chars=eval_chars,
194
- eval_batches=eval_batches,
195
- batch_size=batch_size,
196
- grad_accum=grad_accum,
197
- seed=seed,
198
- device=device,
199
- )
200
- run_dir = Path(config.output_dir)
201
  run_dir.mkdir(parents=True, exist_ok=True)
 
 
 
202
 
203
- log_buffer = io.StringIO()
204
  try:
205
- results = []
206
- yield (
207
- f"Starting run on `{device}` with `{int(steps)}` optimizer step(s) for `{', '.join(optimizers)}`.\n\n"
208
- "The first run may spend time downloading the model and WikiText dataset.",
209
- None,
210
- None,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  )
212
- with redirect_stdout(log_buffer):
213
- for optimizer_name in optimizers:
214
- normalized = runtime.normalize_optimizer_name(optimizer_name)
215
- ok, reason = runtime.check_optimizer_support(normalized, device=config.device)
216
- if not ok:
217
- raise RuntimeError(f"{normalized}: {reason}")
218
- yield (
219
- f"Running `{normalized}` on `{device}`...\n\n"
220
- "Progress inside the optimizer loop is written to the Space logs and will appear here when this phase completes.",
221
- None,
222
- None,
223
- )
224
- runtime.set_seed(int(seed), device=config.device)
225
- run_config = runtime.BenchmarkConfig(**config.__dict__)
226
- run_name = f"{normalized}_{int(time.time())}"
227
- result = runtime.train_one_run(
228
- normalized,
229
- run_config,
230
- group_name="LBW-Guard-HF-Direct-Runner",
231
- run_name=run_name,
232
- )
233
- result["optimizer"] = normalized
234
- results.append(result)
235
- partial_rows = [_result_row(item) for item in results]
236
- next_message = "Preparing the next phase..." if len(results) < len(optimizers) else "Preparing final metrics..."
237
- yield (
238
- f"Completed `{normalized}`.\n\n"
239
- f"Finished phases: `{', '.join(str(row.get('optimizer')) for row in partial_rows)}`\n\n"
240
- f"{next_message}",
241
- None,
242
- None,
243
- )
244
- gc.collect()
245
- if torch.cuda.is_available():
246
- torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
- rows = [_result_row(result) for result in results]
249
  gains = _gain_rows(rows)
250
  payload = {
 
251
  "config": {
252
- "model_name": model_name,
253
- "device": device,
254
- "steps": int(steps),
255
- "lr": float(lr),
256
- "seq_len": int(seq_len),
257
- "train_chars": int(train_chars),
258
- "eval_chars": int(eval_chars),
259
  "eval_batches": int(eval_batches),
 
260
  "batch_size": int(batch_size),
261
- "grad_accum": int(grad_accum),
262
- "seed": int(seed),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  },
264
- "results": results,
265
- "rows": rows,
266
  "gains": gains,
 
267
  }
268
- json_path = run_dir / "lbw_guard_direct_runner_results.json"
269
- csv_path = run_dir / "lbw_guard_direct_runner_metrics.csv"
270
- gains_path = run_dir / "lbw_guard_direct_runner_gains.csv"
271
  json_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
272
  _write_csv(csv_path, rows)
273
  _write_csv(gains_path, gains)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
 
275
- summary = [
276
- f"Device: `{device}`",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
  "",
278
  "## Metrics",
279
  "",
280
- "| Optimizer | Final Eval PPL | Final Eval Loss | Wall Tokens/s | Wall Time (s) | Scale | Ratio | Stress Mode |",
281
  "| --- | --- | --- | --- | --- | --- | --- | --- |",
282
  ]
 
 
283
  for row in rows:
284
  summary.append(
285
- "| {optimizer} | {ppl:.4f} | {loss:.4f} | {tps:.2f} | {wall:.2f} | {scale:.4f} | {ratio:.4f} | {stress} |".format(
 
286
  optimizer=row.get("optimizer"),
287
- ppl=float(row.get("final_eval_perplexity") or 0.0),
288
- loss=float(row.get("final_eval_loss") or 0.0),
289
- tps=float(row.get("tokens_per_sec_wall") or 0.0),
290
- wall=float(row.get("wall_time_sec") or 0.0),
291
- scale=float(row.get("scale") or 0.0),
292
- ratio=float(row.get("ratio") or 0.0),
293
  stress=row.get("stress_mode") or "-",
294
  )
295
  )
296
- if gains:
297
- summary.extend(["", "## Gains vs AdamW", ""])
298
- for gain in gains:
299
- pct = _safe_float(gain.get("eval_perplexity_pct_gain_vs_adamw"))
300
- speedup = _safe_float(gain.get("wall_tokens_per_sec_speedup_vs_adamw"))
301
- summary.append(
302
- f"- `{gain.get('optimizer')}` PPL gain: `{_safe_float(gain.get('eval_perplexity_gain_vs_adamw'))}`, "
303
- f"PPL pct gain: `{pct * 100.0:.2f}%`" if pct is not None else f"- `{gain.get('optimizer')}` PPL pct gain unavailable."
304
- )
305
- if speedup is not None:
306
- summary.append(f"- `{gain.get('optimizer')}` wall tokens/s speedup: `{speedup:.3f}x`.")
307
- summary.extend(["", "## Runtime Log", "", "```text", log_buffer.getvalue()[-8000:], "```"])
308
- yield "\n".join(summary), str(json_path), str(csv_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
  except Exception:
310
  error_text = traceback.format_exc()
311
  error_path = run_dir / "error.txt"
312
- error_path.write_text(error_text + "\n\n" + log_buffer.getvalue(), encoding="utf-8")
313
- yield f"Run failed.\n\n```text\n{error_text}\n```", str(error_path), None
314
 
315
 
316
  INTRO = """
317
- # LBW Guard Direct Runner
318
 
319
- Run a compact AdamW vs `lbw_guard` LoRA smoke test directly inside this Hugging Face Space.
320
 
321
- Use GPU hardware for real runs. CPU mode is best treated as an import/build check.
 
322
 
323
- If the Space says `cpu-basic`, keep smoke tests to `1` step or change hardware to a GPU before running larger jobs.
324
  """
325
 
326
 
327
- with gr.Blocks(title="LBW Guard Direct Runner") as demo:
328
  gr.Markdown(INTRO)
329
- with gr.Row():
330
- model_name = gr.Textbox(value="Qwen/Qwen2.5-0.5B", label="Model")
331
- run_lbw_guard = gr.Checkbox(value=True, label="Run LBW Guard comparison")
332
- with gr.Row():
333
- steps = gr.Slider(1, 20, value=1, step=1, label="Optimizer steps")
334
- lr = gr.Number(value=5e-4, label="Learning rate")
335
- seed = gr.Number(value=42, precision=0, label="Seed")
336
- with gr.Row():
337
- seq_len = gr.Dropdown([64, 128, 256], value=64, label="Sequence length")
338
- batch_size = gr.Slider(1, 4, value=1, step=1, label="Batch size")
339
- grad_accum = gr.Slider(1, 8, value=2, step=1, label="Gradient accumulation")
340
- with gr.Row():
341
- train_chars = gr.Slider(10_000, 500_000, value=50_000, step=10_000, label="Train char cap")
342
- eval_chars = gr.Slider(5_000, 200_000, value=20_000, step=5_000, label="Eval char cap")
343
- eval_batches = gr.Slider(1, 20, value=4, step=1, label="Eval batches")
344
- run_button = gr.Button("Run Direct Smoke Test", variant="primary")
345
- summary = gr.Markdown()
346
- json_file = gr.File(label="Raw JSON")
347
- metrics_file = gr.File(label="Metrics CSV")
348
-
349
- run_button.click(
350
- fn=run_demo,
351
- inputs=[
352
- model_name,
353
- steps,
354
- lr,
355
- seq_len,
356
- train_chars,
357
- eval_chars,
358
- eval_batches,
359
- batch_size,
360
- grad_accum,
361
- seed,
362
- run_lbw_guard,
363
- ],
364
- outputs=[summary, json_file, metrics_file],
365
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
 
367
 
368
  if __name__ == "__main__":
 
2
 
3
  import csv
4
  import gc
 
5
  import json
6
+ import math
7
  import os
8
+ import random
9
  import time
10
  import traceback
 
11
  from pathlib import Path
12
  from typing import Any
13
 
 
23
 
24
  import gradio as gr
25
  import torch
26
+ from datasets import load_dataset
27
+ from peft import LoraConfig, TaskType, get_peft_model
28
+ from transformers import AutoModelForCausalLM, AutoTokenizer
29
 
30
+ try:
31
+ import lbw
32
+ except Exception as exc: # pragma: no cover - shown in the Space UI.
33
+ lbw = None
34
+ LBW_IMPORT_ERROR = exc
35
+ else:
36
+ LBW_IMPORT_ERROR = None
37
 
38
 
39
  RUNS_DIR = ROOT / "runs"
 
43
  return "cuda" if torch.cuda.is_available() else "cpu"
44
 
45
 
46
+ def _set_seed(seed: int) -> None:
47
+ random.seed(seed)
48
+ torch.manual_seed(seed)
49
+ if torch.cuda.is_available():
50
+ torch.cuda.manual_seed_all(seed)
51
+
52
+
53
  def _safe_float(value: Any) -> float | None:
54
  if value is None:
55
  return None
56
  try:
57
+ out = float(value)
58
  except Exception:
59
  return None
60
+ if not math.isfinite(out):
61
+ return None
62
+ return out
63
+
64
 
65
+ def _fmt_float(value: Any, digits: int = 4) -> str:
66
+ number = _safe_float(value)
67
+ return "-" if number is None else f"{number:.{digits}f}"
68
 
69
+
70
+ def _append_log(logs: list[str], message: str) -> None:
71
+ logs.append(message)
72
+ print(message, flush=True)
73
+
74
+
75
+ def _build_wikitext_chunks(
76
+ tokenizer,
77
+ *,
78
+ split: str,
79
+ max_chars: int | None,
80
+ seq_len: int,
81
+ logs: list[str],
82
+ ) -> dict[str, Any]:
83
+ cap = None if max_chars is None else int(max_chars)
84
+ _append_log(
85
+ logs,
86
+ f"Preparing WikiText split={split!r}" + (f" with char cap {cap:,}" if cap is not None else " with full split"),
87
+ )
88
+ ds = load_dataset("wikitext", "wikitext-103-raw-v1", split=split)
89
+ pieces: list[str] = []
90
+ chars_used = 0
91
+ rows_used = 0
92
+ first_piece = True
93
+ for row in ds:
94
+ text = str(row.get("text", "") or "")
95
+ if not text.strip():
96
+ continue
97
+ piece = text if first_piece else " " + text
98
+ if cap is not None:
99
+ remain = cap - chars_used
100
+ if remain <= 0:
101
+ break
102
+ if len(piece) > remain:
103
+ piece = piece[:remain]
104
+ pieces.append(piece)
105
+ chars_used += len(piece)
106
+ rows_used += 1
107
+ first_piece = False
108
+ if cap is not None and chars_used >= cap:
109
+ break
110
+
111
+ token_ids = tokenizer("".join(pieces), add_special_tokens=False)["input_ids"]
112
+ ids = torch.tensor(token_ids, dtype=torch.long)
113
+ sequence_count = ids.numel() // int(seq_len)
114
+ if sequence_count <= 0:
115
+ raise RuntimeError("Not enough tokens. Increase the train/eval char cap or reduce sequence length.")
116
+ ids = ids[: sequence_count * int(seq_len)].view(sequence_count, int(seq_len)).contiguous()
117
+ _append_log(
118
+ logs,
119
+ f"Prepared split={split!r}: {chars_used:,} chars across {rows_used:,} rows -> {ids.size(0):,} sequences",
120
+ )
121
+ return {"input_ids": ids, "chars": chars_used, "rows": rows_used, "cap": cap}
122
+
123
+
124
+ def _batch_iter(chunks: dict[str, Any], *, batch_size: int, device: torch.device):
125
+ ids = chunks["input_ids"]
126
+ i = 0
127
+ while True:
128
+ if i + int(batch_size) > ids.size(0):
129
+ i = 0
130
+ batch = ids[i : i + int(batch_size)].to(device, non_blocking=True)
131
+ i += int(batch_size)
132
+ yield batch
133
+
134
+
135
+ def _load_lora_model(
136
  *,
137
  model_name: str,
138
+ device: torch.device,
139
+ lora_r: int,
140
+ lora_alpha: int,
141
+ lora_dropout: float,
142
+ ):
143
+ dtype = torch.float16 if device.type == "cuda" else torch.float32
144
+ model = AutoModelForCausalLM.from_pretrained(
145
+ model_name,
146
+ torch_dtype=dtype,
147
+ low_cpu_mem_usage=True,
148
+ )
149
+ if getattr(model.config, "use_cache", None) is not None:
150
+ model.config.use_cache = False
151
+ model.to(device)
152
+ lora_cfg = LoraConfig(
153
+ r=int(lora_r),
154
+ lora_alpha=int(lora_alpha),
155
+ lora_dropout=float(lora_dropout),
156
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
157
+ task_type=TaskType.CAUSAL_LM,
158
+ bias="none",
159
+ )
160
+ return get_peft_model(model, lora_cfg)
161
+
162
+
163
+ def _make_optimizer(
164
+ name: str,
165
+ model,
166
+ *,
167
  lr: float,
168
+ betas: tuple[float, float],
169
+ weight_decay: float,
170
+ lbw_stats_freq: int,
171
+ lbw_stress_th: float,
172
+ lbw_spike_th: float,
173
+ lbw_rec_fast: float,
174
+ lbw_ema_decay: float,
175
+ ):
176
+ params = [param for param in model.parameters() if param.requires_grad]
177
+ if name == "adamw":
178
+ return torch.optim.AdamW(params, lr=float(lr), betas=betas, weight_decay=float(weight_decay))
179
+ if name == "lbw_guard":
180
+ if lbw is None:
181
+ raise RuntimeError(f"LBW Guard package import failed: {LBW_IMPORT_ERROR}")
182
+ return lbw.Guard(
183
+ params,
184
+ lr=float(lr),
185
+ betas=betas,
186
+ weight_decay=float(weight_decay),
187
+ mode="eval",
188
+ auto_enabled=True,
189
+ stats_freq=int(lbw_stats_freq),
190
+ stress_threshold=float(lbw_stress_th),
191
+ spike_threshold=float(lbw_spike_th),
192
+ recovery_fast=float(lbw_rec_fast),
193
+ ema_decay=float(lbw_ema_decay),
194
+ use_max_rms=True,
195
+ )
196
+ raise ValueError(f"Unknown optimizer: {name}")
197
+
198
+
199
+ @torch.no_grad()
200
+ def _evaluate_ppl(
201
+ model,
202
+ eval_chunks: dict[str, Any],
203
+ *,
204
  batch_size: int,
205
+ eval_batches: int,
206
+ device: torch.device,
207
+ full_pass: bool,
208
+ ) -> tuple[float, float]:
209
+ model.eval()
210
+ ids = eval_chunks["input_ids"]
211
+ max_sequences = ids.size(0) if full_pass else min(ids.size(0), int(eval_batches) * int(batch_size))
212
+ losses: list[float] = []
213
+ for start in range(0, max_sequences, int(batch_size)):
214
+ xb = ids[start : start + int(batch_size)].to(device, non_blocking=True)
215
+ with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=(device.type == "cuda")):
216
+ loss = model(input_ids=xb, labels=xb).loss
217
+ losses.append(float(loss.detach().cpu()))
218
+ avg_loss = sum(losses) / max(len(losses), 1)
219
+ return avg_loss, math.exp(min(avg_loss, 20.0))
220
+
221
+
222
+ def _optimizer_state(opt) -> dict[str, Any]:
223
+ state = dict(getattr(opt, "state", {}).get("lbw", {}) or {})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  return {
225
+ "scale": float(state.get("scale", state.get("lbw_scale", 1.0))),
226
+ "ratio": float(state.get("ratio", 1.0)),
227
+ "stress_mode": str(state.get("stress_mode", "none")),
228
+ }
229
+
230
+
231
+ def _status_markdown(
232
+ *,
233
+ device_name: str,
234
+ rows: list[dict[str, Any]],
235
+ logs: list[str],
236
+ phase: str,
237
+ ) -> str:
238
+ summary = [
239
+ f"Device: `{device_name}`",
240
+ "",
241
+ f"Status: {phase}",
242
+ "",
243
+ "## Results",
244
+ "",
245
+ "| Optimizer | Final Eval PPL | Final Eval Loss | Scope | Scale | Ratio | Stress Mode | Wall Time (s) |",
246
+ "| --- | --- | --- | --- | --- | --- | --- | --- |",
247
+ ]
248
+ if rows:
249
+ for row in rows:
250
+ summary.append(
251
+ "| {optimizer} | {ppl} | {loss} | {scope} | {scale} | {ratio} | {stress} | {wall} |".format(
252
+ optimizer=row.get("optimizer"),
253
+ ppl=_fmt_float(row.get("final_eval_ppl")),
254
+ loss=_fmt_float(row.get("final_eval_loss")),
255
+ scope=row.get("final_eval_scope") or "-",
256
+ scale=_fmt_float(row.get("scale")),
257
+ ratio=_fmt_float(row.get("ratio")),
258
+ stress=row.get("stress_mode") or "-",
259
+ wall=_fmt_float(row.get("wall_time_sec"), digits=2),
260
+ )
261
+ )
262
+ else:
263
+ summary.append("| - | - | - | - | - | - | - | - |")
264
+
265
+ gains = _gain_rows(rows)
266
+ if gains:
267
+ summary.extend(["", "## LBW vs AdamW", ""])
268
+ for gain in gains:
269
+ pct = _safe_float(gain.get("eval_perplexity_pct_gain_vs_adamw"))
270
+ wall_speedup = _safe_float(gain.get("wall_time_speedup_vs_adamw"))
271
+ summary.append(
272
+ f"- `{gain.get('optimizer')}` PPL gain vs AdamW: `{_fmt_float(gain.get('eval_perplexity_gain_vs_adamw'))}`"
273
+ + (f" (`{pct * 100.0:.2f}%`)." if pct is not None else ".")
274
+ )
275
+ if wall_speedup is not None:
276
+ summary.append(f"- `{gain.get('optimizer')}` wall-time speedup vs AdamW: `{wall_speedup:.3f}x`.")
277
+
278
+ summary.extend(["", "## Runtime Log", "", "```text", "\n".join(logs[-80:]), "```"])
279
+ return "\n".join(summary)
280
+
281
+
282
+ def _run_one_optimizer_events(
283
+ *,
284
+ optimizer_name: str,
285
+ model_name: str,
286
+ train_chunks: dict[str, Any],
287
+ eval_chunks: dict[str, Any],
288
+ device: torch.device,
289
+ seed: int,
290
+ max_steps: int,
291
+ eval_every: int,
292
+ eval_batches: int,
293
+ seq_len: int,
294
+ batch_size: int,
295
+ lr: float,
296
+ betas: tuple[float, float],
297
+ weight_decay: float,
298
+ full_validation_ppl: bool,
299
+ lora_r: int,
300
+ lora_alpha: int,
301
+ lora_dropout: float,
302
+ lbw_stats_freq: int,
303
+ lbw_stress_th: float,
304
+ lbw_spike_th: float,
305
+ lbw_rec_fast: float,
306
+ lbw_ema_decay: float,
307
+ logs: list[str],
308
+ ):
309
+ _set_seed(int(seed))
310
+ _append_log(logs, f"Loading {model_name} with LoRA for {optimizer_name}.")
311
+ model = _load_lora_model(
312
+ model_name=model_name,
313
+ device=device,
314
+ lora_r=lora_r,
315
+ lora_alpha=lora_alpha,
316
+ lora_dropout=lora_dropout,
317
+ )
318
+ model.train()
319
+ opt = _make_optimizer(
320
+ optimizer_name,
321
+ model,
322
+ lr=lr,
323
+ betas=betas,
324
+ weight_decay=weight_decay,
325
+ lbw_stats_freq=lbw_stats_freq,
326
+ lbw_stress_th=lbw_stress_th,
327
+ lbw_spike_th=lbw_spike_th,
328
+ lbw_rec_fast=lbw_rec_fast,
329
+ lbw_ema_decay=lbw_ema_decay,
330
+ )
331
+ train_batches = _batch_iter(train_chunks, batch_size=batch_size, device=device)
332
+ start_time = time.time()
333
+ last_loss = None
334
+ last_eval_loss = None
335
+ last_eval_ppl = None
336
+ state = _optimizer_state(opt)
337
+ trainable_params = [param for param in model.parameters() if param.requires_grad]
338
+
339
+ for step in range(1, int(max_steps) + 1):
340
+ xb = next(train_batches)
341
+ with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=(device.type == "cuda")):
342
+ loss = model(input_ids=xb, labels=xb).loss
343
+ loss.backward()
344
+ torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
345
+ opt.step()
346
+ opt.zero_grad(set_to_none=True)
347
+ last_loss = float(loss.detach().cpu())
348
+ state = _optimizer_state(opt)
349
+
350
+ if step == 1 or step == int(max_steps) or step % int(eval_every) == 0:
351
+ last_eval_loss, last_eval_ppl = _evaluate_ppl(
352
+ model,
353
+ eval_chunks,
354
+ batch_size=batch_size,
355
+ eval_batches=eval_batches,
356
+ device=device,
357
+ full_pass=False,
358
+ )
359
+ message = (
360
+ f"{optimizer_name} step {step}/{int(max_steps)}: "
361
+ f"loss={last_loss:.4f}, sampled_eval_ppl={last_eval_ppl:.4f}, "
362
+ f"scale={state['scale']:.4f}, ratio={state['ratio']:.4f}"
363
+ )
364
+ _append_log(logs, message)
365
+ yield {"type": "progress", "message": message}
366
+ model.train()
367
+
368
+ final_full_pass = bool(full_validation_ppl)
369
+ if final_full_pass and eval_chunks["cap"] is None:
370
+ final_scope = "full_wikitext"
371
+ elif final_full_pass:
372
+ final_scope = "full_loaded_subset"
373
+ else:
374
+ final_scope = "sampled"
375
+ _append_log(logs, f"Running final {final_scope} validation PPL for {optimizer_name}.")
376
+ final_loss, final_ppl = _evaluate_ppl(
377
+ model,
378
+ eval_chunks,
379
+ batch_size=batch_size,
380
+ eval_batches=eval_batches,
381
+ device=device,
382
+ full_pass=final_full_pass,
383
+ )
384
+ state = _optimizer_state(opt)
385
+ wall_time = time.time() - start_time
386
+ result = {
387
+ "optimizer": optimizer_name,
388
+ "final_eval_ppl": final_ppl,
389
+ "final_eval_loss": final_loss,
390
+ "final_eval_scope": final_scope,
391
+ "train_chars": train_chunks["chars"],
392
+ "eval_chars": eval_chunks["chars"],
393
+ "train_sequences": int(train_chunks["input_ids"].size(0)),
394
+ "eval_sequences": int(eval_chunks["input_ids"].size(0)),
395
+ "tokens_per_step": int(batch_size) * int(seq_len),
396
+ "last_train_loss": last_loss,
397
+ "last_sampled_eval_loss": last_eval_loss,
398
+ "last_sampled_eval_ppl": last_eval_ppl,
399
+ "scale": state["scale"],
400
+ "ratio": state["ratio"],
401
+ "stress_mode": state["stress_mode"],
402
+ "wall_time_sec": wall_time,
403
  }
404
+ del model, opt
405
+ gc.collect()
406
+ if device.type == "cuda":
407
+ torch.cuda.empty_cache()
408
+ yield {"type": "result", "result": result}
409
 
410
 
411
  def _gain_rows(rows: list[dict[str, Any]]) -> list[dict[str, Any]]:
 
413
  baseline = by_optimizer.get("adamw")
414
  if baseline is None:
415
  return []
416
+ baseline_ppl = _safe_float(baseline.get("final_eval_ppl"))
417
+ baseline_wall = _safe_float(baseline.get("wall_time_sec"))
418
+ gains: list[dict[str, Any]] = []
419
  for row in rows:
420
  if row.get("optimizer") == "adamw":
421
  continue
422
+ candidate_ppl = _safe_float(row.get("final_eval_ppl"))
423
+ candidate_wall = _safe_float(row.get("wall_time_sec"))
 
 
 
 
424
  gains.append(
425
  {
426
  "optimizer": row.get("optimizer"),
427
+ "eval_perplexity_gain_vs_adamw": (
428
+ None if baseline_ppl is None or candidate_ppl is None else baseline_ppl - candidate_ppl
429
+ ),
430
  "eval_perplexity_pct_gain_vs_adamw": (
431
+ None
432
+ if baseline_ppl in (None, 0.0) or candidate_ppl is None
433
+ else (baseline_ppl - candidate_ppl) / baseline_ppl
434
+ ),
435
+ "wall_time_speedup_vs_adamw": (
436
+ None
437
+ if baseline_wall in (None, 0.0) or candidate_wall in (None, 0.0)
438
+ else baseline_wall / candidate_wall
439
  ),
 
440
  }
441
  )
442
  return gains
 
452
  writer.writerows(rows)
453
 
454
 
455
+ def _set_lr(opt, value: float) -> None:
456
+ for group in getattr(opt, "param_groups", []) or []:
457
+ group["lr"] = float(value)
458
+
459
+
460
+ def _scheduled_lr(cfg: dict[str, Any], step: int) -> float:
461
+ base_lr = float(cfg["lr"])
462
+ warmup = max(int(cfg.get("warmup_steps", 0)), 0)
463
+ max_steps = max(int(cfg["max_steps"]), 1)
464
+ if warmup > 0 and int(step) <= warmup:
465
+ return base_lr * float(step) / float(warmup)
466
+ mode = str(cfg.get("schedule_mode", "constant")).strip().lower()
467
+ if mode == "cosine":
468
+ progress = (int(step) - warmup) / max(max_steps - warmup, 1)
469
+ progress = min(max(progress, 0.0), 1.0)
470
+ return base_lr * 0.5 * (1.0 + math.cos(math.pi * progress))
471
+ return base_lr
472
+
473
+
474
+ def _parse_float_sweep(text: str, default: list[float]) -> list[float]:
475
+ raw = str(text or "").replace("\n", ",").replace(";", ",").split(",")
476
+ values: list[float] = []
477
+ for item in raw:
478
+ item = item.strip()
479
+ if not item:
480
+ continue
481
+ values.append(float(item))
482
+ return values or list(default)
483
+
484
+
485
+ def _parse_int_sweep(text: str, default: list[int]) -> list[int]:
486
+ return [int(value) for value in _parse_float_sweep(text, [float(item) for item in default])]
487
+
488
+
489
+ def run_easy_test(
490
  model_name: str,
491
+ run_lbw_guard: bool,
492
+ max_steps: int,
493
+ eval_every: int,
494
+ eval_batches: int,
495
  seq_len: int,
496
+ batch_size: int,
497
  train_chars: int,
498
  eval_chars: int,
499
+ full_wikitext_train: bool,
500
+ full_wikitext_eval: bool,
501
+ full_validation_ppl: bool,
502
+ lr: float,
503
  seed: int,
504
+ ):
505
+ logs: list[str] = []
506
+ rows: list[dict[str, Any]] = []
507
+ run_dir = RUNS_DIR / f"easy_test_{int(time.time())}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508
  run_dir.mkdir(parents=True, exist_ok=True)
509
+ device_name = _device_default()
510
+ device = torch.device(device_name)
511
+ optimizers = ["adamw", "lbw_guard"] if bool(run_lbw_guard) else ["adamw"]
512
 
 
513
  try:
514
+ if device.type == "cpu" and (
515
+ int(max_steps) > 1
516
+ or int(train_chars) > 20_000
517
+ or int(eval_chars) > 8_000
518
+ or bool(full_wikitext_train)
519
+ or bool(full_wikitext_eval)
520
+ or bool(full_validation_ppl)
521
+ ):
522
+ yield (
523
+ "This Space is currently on `cpu-basic`. CPU mode is capped to 1 step, 20k train chars, "
524
+ "8k eval chars, and sampled validation. Switch the Space hardware to GPU for the Easy Test defaults.",
525
+ None,
526
+ None,
527
+ None,
528
+ )
529
+ return
530
+ if device.type == "cuda" and bool(run_lbw_guard) and torch.cuda.device_count() > 1:
531
+ yield (
532
+ "LBW Guard should run with one visible GPU. Set the Space to single-GPU hardware or restrict CUDA_VISIBLE_DEVICES.",
533
+ None,
534
+ None,
535
+ None,
536
+ )
537
+ return
538
+
539
+ _append_log(logs, f"Device: {device_name}")
540
+ if device.type == "cuda":
541
+ _append_log(logs, f"GPU: {torch.cuda.get_device_name(0)}")
542
+ _append_log(logs, f"Optimizers: {', '.join(optimizers)}")
543
+ yield _status_markdown(device_name=device_name, rows=rows, logs=logs, phase="Loading tokenizer"), None, None, None
544
+
545
+ _set_seed(int(seed))
546
+ resolved_model = str(model_name).strip() or "TinyLlama/TinyLlama_v1.1"
547
+ tokenizer = AutoTokenizer.from_pretrained(resolved_model, use_fast=True)
548
+ if tokenizer.pad_token is None:
549
+ tokenizer.pad_token = tokenizer.eos_token
550
+
551
+ train_cap = None if bool(full_wikitext_train) else int(train_chars)
552
+ eval_cap = None if bool(full_wikitext_eval) else int(eval_chars)
553
+ train_chunks = _build_wikitext_chunks(
554
+ tokenizer,
555
+ split="train",
556
+ max_chars=train_cap,
557
+ seq_len=int(seq_len),
558
+ logs=logs,
559
  )
560
+ yield _status_markdown(device_name=device_name, rows=rows, logs=logs, phase="Prepared train split"), None, None, None
561
+ eval_chunks = _build_wikitext_chunks(
562
+ tokenizer,
563
+ split="validation",
564
+ max_chars=eval_cap,
565
+ seq_len=int(seq_len),
566
+ logs=logs,
567
+ )
568
+ yield _status_markdown(device_name=device_name, rows=rows, logs=logs, phase="Prepared validation split"), None, None, None
569
+
570
+ for optimizer_name in optimizers:
571
+ _append_log(logs, f"=== {optimizer_name} ===")
572
+ yield _status_markdown(
573
+ device_name=device_name,
574
+ rows=rows,
575
+ logs=logs,
576
+ phase=f"Running {optimizer_name}",
577
+ ), None, None, None
578
+ for event in _run_one_optimizer_events(
579
+ optimizer_name=optimizer_name,
580
+ model_name=resolved_model,
581
+ train_chunks=train_chunks,
582
+ eval_chunks=eval_chunks,
583
+ device=device,
584
+ seed=int(seed),
585
+ max_steps=int(max_steps),
586
+ eval_every=max(1, int(eval_every)),
587
+ eval_batches=int(eval_batches),
588
+ seq_len=int(seq_len),
589
+ batch_size=int(batch_size),
590
+ lr=float(lr),
591
+ betas=(0.9, 0.999),
592
+ weight_decay=0.01,
593
+ full_validation_ppl=bool(full_validation_ppl),
594
+ lora_r=8,
595
+ lora_alpha=16,
596
+ lora_dropout=0.05,
597
+ lbw_stats_freq=10,
598
+ lbw_stress_th=1.1,
599
+ lbw_spike_th=1.5,
600
+ lbw_rec_fast=0.01,
601
+ lbw_ema_decay=0.95,
602
+ logs=logs,
603
+ ):
604
+ if event.get("type") == "result":
605
+ rows.append(event["result"])
606
+ yield _status_markdown(
607
+ device_name=device_name,
608
+ rows=rows,
609
+ logs=logs,
610
+ phase=f"Running {optimizer_name}",
611
+ ), None, None, None
612
 
 
613
  gains = _gain_rows(rows)
614
  payload = {
615
+ "source": "LBW_Guard_Easy_Test_COLAB.ipynb",
616
  "config": {
617
+ "model_name": resolved_model,
618
+ "device": device_name,
619
+ "optimizers": optimizers,
620
+ "seed": int(seed),
621
+ "max_steps": int(max_steps),
622
+ "eval_every": int(eval_every),
 
623
  "eval_batches": int(eval_batches),
624
+ "seq_len": int(seq_len),
625
  "batch_size": int(batch_size),
626
+ "max_chars": train_cap,
627
+ "eval_chars": eval_cap,
628
+ "full_wikitext_train": bool(full_wikitext_train),
629
+ "full_wikitext_eval": bool(full_wikitext_eval),
630
+ "full_validation_ppl": bool(full_validation_ppl),
631
+ "lr": float(lr),
632
+ "betas": [0.9, 0.999],
633
+ "weight_decay": 0.01,
634
+ "lora_r": 8,
635
+ "lora_alpha": 16,
636
+ "lora_dropout": 0.05,
637
+ "lbw_stats_freq": 10,
638
+ "lbw_stress_th": 1.1,
639
+ "lbw_spike_th": 1.5,
640
+ "lbw_rec_fast": 0.01,
641
+ "lbw_ema_decay": 0.95,
642
  },
643
+ "results": rows,
 
644
  "gains": gains,
645
+ "logs": logs,
646
  }
647
+ json_path = run_dir / "lbw_guard_easy_test_results.json"
648
+ csv_path = run_dir / "lbw_guard_easy_test_results.csv"
649
+ gains_path = run_dir / "lbw_guard_easy_test_gains.csv"
650
  json_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
651
  _write_csv(csv_path, rows)
652
  _write_csv(gains_path, gains)
653
+ _append_log(logs, f"Wrote {csv_path}")
654
+ yield (
655
+ _status_markdown(device_name=device_name, rows=rows, logs=logs, phase="Complete"),
656
+ str(json_path),
657
+ str(csv_path),
658
+ str(gains_path),
659
+ )
660
+ except Exception:
661
+ error_text = traceback.format_exc()
662
+ error_path = run_dir / "error.txt"
663
+ error_path.write_text(error_text + "\n\n" + "\n".join(logs), encoding="utf-8")
664
+ yield f"Run failed.\n\n```text\n{error_text}\n```", str(error_path), None, None
665
+
666
+
667
+ def _make_ablation_scenario(slug: str, label: str, note: str, base_config: dict[str, Any], overrides=None):
668
+ cfg = dict(base_config)
669
+ if overrides:
670
+ cfg.update(overrides)
671
+ return {
672
+ "slug": slug,
673
+ "label": label,
674
+ "note": note,
675
+ "config": cfg,
676
+ }
677
+
678
+
679
+ def _build_ablation_scenarios(
680
+ *,
681
+ selected_ablations: list[str],
682
+ base_config: dict[str, Any],
683
+ lr_sweep: list[float],
684
+ step_sweep: list[int],
685
+ lora_r_sweep: list[int],
686
+ ) -> list[dict[str, Any]]:
687
+ selected = {str(item).strip().lower() for item in selected_ablations if str(item).strip()}
688
+ if not selected:
689
+ selected = {"optimizer"}
690
+ scenarios: list[dict[str, Any]] = []
691
+
692
+ if "optimizer" in selected:
693
+ scenarios.append(
694
+ _make_ablation_scenario(
695
+ "optimizer-adamw-vs-lbw-guard",
696
+ "Optimizer: AdamW vs lbw_guard",
697
+ "Direct optimizer comparison with the base config.",
698
+ base_config,
699
+ )
700
+ )
701
+
702
+ if "lr" in selected:
703
+ for lr in lr_sweep:
704
+ scenarios.append(
705
+ _make_ablation_scenario(
706
+ f"lr-{lr:g}",
707
+ f"Learning Rate: {lr:g}",
708
+ "Learning-rate sensitivity check.",
709
+ base_config,
710
+ {"lr": float(lr)},
711
+ )
712
+ )
713
+
714
+ if "schedule" in selected:
715
+ for mode in ["constant", "cosine"]:
716
+ scenarios.append(
717
+ _make_ablation_scenario(
718
+ f"schedule-{mode}",
719
+ f"Schedule: {mode}",
720
+ "Scheduler-shape sensitivity check.",
721
+ base_config,
722
+ {"schedule_mode": mode},
723
+ )
724
+ )
725
+
726
+ if "steps" in selected:
727
+ for steps in step_sweep:
728
+ scenarios.append(
729
+ _make_ablation_scenario(
730
+ f"steps-{steps}",
731
+ f"Steps: {steps}",
732
+ "Training-length sensitivity check.",
733
+ base_config,
734
+ {"max_steps": int(steps), "eval_every": max(1, int(steps) // 4)},
735
+ )
736
+ )
737
 
738
+ if "data" in selected:
739
+ for item in [
740
+ {"max_chars": 20_000, "eval_chars": 8_000, "label": "small-data"},
741
+ {"max_chars": 80_000, "eval_chars": 20_000, "label": "larger-data"},
742
+ ]:
743
+ scenarios.append(
744
+ _make_ablation_scenario(
745
+ item["label"],
746
+ f"Data Slice: {item['label']}",
747
+ "WikiText slice-size sensitivity check.",
748
+ base_config,
749
+ {"max_chars": int(item["max_chars"]), "eval_chars": int(item["eval_chars"])},
750
+ )
751
+ )
752
+
753
+ if "lora" in selected:
754
+ for rank in lora_r_sweep:
755
+ scenarios.append(
756
+ _make_ablation_scenario(
757
+ f"lora-r{rank}",
758
+ f"LoRA Rank: {rank}",
759
+ "Adapter-capacity sensitivity check.",
760
+ base_config,
761
+ {"lora_r": int(rank), "lora_alpha": int(rank) * 2},
762
+ )
763
+ )
764
+
765
+ if not scenarios:
766
+ raise ValueError("No scenarios selected. Choose optimizer, lr, schedule, steps, data, or lora.")
767
+ return scenarios
768
+
769
+
770
+ def _ablation_status_markdown(
771
+ *,
772
+ device_name: str,
773
+ rows: list[dict[str, Any]],
774
+ logs: list[str],
775
+ phase: str,
776
+ plan: list[dict[str, Any]],
777
+ ) -> str:
778
+ summary = [
779
+ f"Device: `{device_name}`",
780
+ "",
781
+ f"Status: {phase}",
782
+ "",
783
+ "## Plan",
784
+ "",
785
+ "| Scenario | Steps | LR | Schedule | Train Chars | Eval Chars | LoRA r |",
786
+ "| --- | --- | --- | --- | --- | --- | --- |",
787
+ ]
788
+ for item in plan:
789
+ cfg = item["config"]
790
+ summary.append(
791
+ "| {label} | {steps} | {lr:g} | {schedule} | {train_chars} | {eval_chars} | {lora_r} |".format(
792
+ label=item["label"],
793
+ steps=int(cfg["max_steps"]),
794
+ lr=float(cfg["lr"]),
795
+ schedule=cfg["schedule_mode"],
796
+ train_chars="FULL" if cfg["full_wikitext_train"] else int(cfg["max_chars"]),
797
+ eval_chars="FULL" if cfg["full_wikitext_eval"] else int(cfg["eval_chars"]),
798
+ lora_r=int(cfg["lora_r"]),
799
+ )
800
+ )
801
+
802
+ summary.extend(
803
+ [
804
  "",
805
  "## Metrics",
806
  "",
807
+ "| Scenario | Optimizer | Final Eval PPL | Final Eval Loss | Tokens/s | Scale | Ratio | Stress Mode |",
808
  "| --- | --- | --- | --- | --- | --- | --- | --- |",
809
  ]
810
+ )
811
+ if rows:
812
  for row in rows:
813
  summary.append(
814
+ "| {scenario} | {optimizer} | {ppl} | {loss} | {tps} | {scale} | {ratio} | {stress} |".format(
815
+ scenario=row.get("scenario"),
816
  optimizer=row.get("optimizer"),
817
+ ppl=_fmt_float(row.get("final_eval_ppl")),
818
+ loss=_fmt_float(row.get("final_eval_loss")),
819
+ tps=_fmt_float(row.get("tokens_per_sec_wall"), digits=2),
820
+ scale=_fmt_float(row.get("scale")),
821
+ ratio=_fmt_float(row.get("ratio")),
 
822
  stress=row.get("stress_mode") or "-",
823
  )
824
  )
825
+ else:
826
+ summary.append("| - | - | - | - | - | - | - | - |")
827
+
828
+ gains = _build_ablation_gain_rows(rows)
829
+ if gains:
830
+ summary.extend(["", "## LBW vs AdamW", ""])
831
+ for gain in gains:
832
+ summary.append(
833
+ f"- `{gain.get('scenario')}`: `{gain.get('optimizer')}` "
834
+ f"PPL gain `{_fmt_float(gain.get('ppl_gain_pct_vs_adamw'))}%`, "
835
+ f"loss gain `{_fmt_float(gain.get('loss_gain_pct_vs_adamw'))}%`, "
836
+ f"speed gain `{_fmt_float(gain.get('speed_gain_pct_vs_adamw'))}%`."
837
+ )
838
+
839
+ summary.extend(["", "## Runtime Log", "", "```text", "\n".join(logs[-100:]), "```"])
840
+ return "\n".join(summary)
841
+
842
+
843
+ def _run_ablation_optimizer_events(
844
+ *,
845
+ scenario_item: dict[str, Any],
846
+ optimizer_name: str,
847
+ model_name: str,
848
+ train_chunks: dict[str, Any],
849
+ eval_chunks: dict[str, Any],
850
+ device: torch.device,
851
+ logs: list[str],
852
+ ):
853
+ cfg = scenario_item["config"]
854
+ _set_seed(int(cfg["seed"]))
855
+ _append_log(logs, f"Loading {model_name} with LoRA for {scenario_item['slug']} / {optimizer_name}.")
856
+ model = _load_lora_model(
857
+ model_name=model_name,
858
+ device=device,
859
+ lora_r=int(cfg["lora_r"]),
860
+ lora_alpha=int(cfg["lora_alpha"]),
861
+ lora_dropout=float(cfg["lora_dropout"]),
862
+ )
863
+ model.train()
864
+ opt = _make_optimizer(
865
+ optimizer_name,
866
+ model,
867
+ lr=float(cfg["lr"]),
868
+ betas=tuple(cfg["betas"]),
869
+ weight_decay=float(cfg["weight_decay"]),
870
+ lbw_stats_freq=int(cfg["lbw_stats_freq"]),
871
+ lbw_stress_th=float(cfg["lbw_stress_th"]),
872
+ lbw_spike_th=float(cfg["lbw_spike_th"]),
873
+ lbw_rec_fast=float(cfg["lbw_rec_fast"]),
874
+ lbw_ema_decay=float(cfg["lbw_ema_decay"]),
875
+ )
876
+ train_batches = _batch_iter(train_chunks, batch_size=int(cfg["batch_size"]), device=device)
877
+ trainable_params = [param for param in model.parameters() if param.requires_grad]
878
+ start_time = time.time()
879
+ losses: list[float] = []
880
+ eval_loss = None
881
+ eval_ppl = None
882
+ last_lr = float(cfg["lr"])
883
+ state = _optimizer_state(opt)
884
+
885
+ for step in range(1, int(cfg["max_steps"]) + 1):
886
+ last_lr = _scheduled_lr(cfg, step)
887
+ _set_lr(opt, last_lr)
888
+ xb = next(train_batches)
889
+ with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=(device.type == "cuda")):
890
+ loss = model(input_ids=xb, labels=xb).loss
891
+ loss.backward()
892
+ torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
893
+ opt.step()
894
+ opt.zero_grad(set_to_none=True)
895
+ loss_value = float(loss.detach().cpu())
896
+ losses.append(loss_value)
897
+
898
+ if step == 1 or step == int(cfg["max_steps"]) or step % int(cfg["eval_every"]) == 0:
899
+ eval_loss, eval_ppl = _evaluate_ppl(
900
+ model,
901
+ eval_chunks,
902
+ batch_size=int(cfg["batch_size"]),
903
+ eval_batches=int(cfg["eval_batches"]),
904
+ device=device,
905
+ full_pass=False,
906
+ )
907
+ state = _optimizer_state(opt)
908
+ message = (
909
+ f"[{scenario_item['slug']}] {optimizer_name} step {step}/{cfg['max_steps']}: "
910
+ f"loss={loss_value:.4f}, sampled_eval_ppl={eval_ppl:.4f}, "
911
+ f"lr={last_lr:.2e}, scale={state['scale']:.4f}, ratio={state['ratio']:.4f}"
912
+ )
913
+ _append_log(logs, message)
914
+ yield {"type": "progress", "message": message}
915
+ model.train()
916
+
917
+ final_full_pass = bool(cfg["full_validation_ppl"])
918
+ if final_full_pass and eval_chunks["cap"] is None:
919
+ final_scope = "full_wikitext"
920
+ elif final_full_pass:
921
+ final_scope = "full_loaded_subset"
922
+ else:
923
+ final_scope = "sampled"
924
+ _append_log(logs, f"Running final {final_scope} validation PPL for {scenario_item['slug']} / {optimizer_name}.")
925
+ final_loss, final_ppl = _evaluate_ppl(
926
+ model,
927
+ eval_chunks,
928
+ batch_size=int(cfg["batch_size"]),
929
+ eval_batches=int(cfg["eval_batches"]),
930
+ device=device,
931
+ full_pass=final_full_pass,
932
+ )
933
+ state = _optimizer_state(opt)
934
+ wall_time = max(time.time() - start_time, 1e-9)
935
+ trained_tokens = int(cfg["max_steps"]) * int(cfg["batch_size"]) * int(cfg["seq_len"])
936
+ result = {
937
+ "scenario_slug": scenario_item["slug"],
938
+ "scenario": scenario_item["label"],
939
+ "optimizer": optimizer_name,
940
+ "final_eval_ppl": final_ppl,
941
+ "final_eval_loss": final_loss,
942
+ "train_loss_last": losses[-1] if losses else None,
943
+ "last_sampled_eval_loss": eval_loss,
944
+ "last_sampled_eval_ppl": eval_ppl,
945
+ "final_eval_scope": final_scope,
946
+ "max_steps": int(cfg["max_steps"]),
947
+ "lr": float(cfg["lr"]),
948
+ "scheduled_lr_last": float(last_lr),
949
+ "schedule_mode": str(cfg["schedule_mode"]),
950
+ "batch_size": int(cfg["batch_size"]),
951
+ "seq_len": int(cfg["seq_len"]),
952
+ "lora_r": int(cfg["lora_r"]),
953
+ "train_chars": int(train_chunks["chars"]),
954
+ "eval_chars": int(eval_chunks["chars"]),
955
+ "train_sequences": int(train_chunks["input_ids"].size(0)),
956
+ "eval_sequences": int(eval_chunks["input_ids"].size(0)),
957
+ "scale": state["scale"],
958
+ "ratio": state["ratio"],
959
+ "stress_mode": state["stress_mode"],
960
+ "wall_time_sec": wall_time,
961
+ "tokens_per_sec_wall": trained_tokens / wall_time,
962
+ }
963
+
964
+ del model, opt
965
+ gc.collect()
966
+ if device.type == "cuda":
967
+ torch.cuda.empty_cache()
968
+ yield {"type": "result", "result": result}
969
+
970
+
971
+ def _build_ablation_gain_rows(metrics: list[dict[str, Any]]) -> list[dict[str, Any]]:
972
+ grouped: dict[str, list[dict[str, Any]]] = {}
973
+ for row in metrics:
974
+ grouped.setdefault(str(row.get("scenario_slug")), []).append(row)
975
+ gain_rows: list[dict[str, Any]] = []
976
+ for scenario_slug, rows in grouped.items():
977
+ baseline = next((row for row in rows if row.get("optimizer") == "adamw"), None)
978
+ if baseline is None:
979
+ continue
980
+ baseline_ppl = _safe_float(baseline.get("final_eval_ppl"))
981
+ baseline_loss = _safe_float(baseline.get("final_eval_loss"))
982
+ baseline_tps = _safe_float(baseline.get("tokens_per_sec_wall"))
983
+ for row in rows:
984
+ if row.get("optimizer") == "adamw":
985
+ continue
986
+ candidate_ppl = _safe_float(row.get("final_eval_ppl"))
987
+ candidate_loss = _safe_float(row.get("final_eval_loss"))
988
+ candidate_tps = _safe_float(row.get("tokens_per_sec_wall"))
989
+ gain_rows.append(
990
+ {
991
+ "scenario_slug": scenario_slug,
992
+ "scenario": row.get("scenario"),
993
+ "optimizer": row.get("optimizer"),
994
+ "adamw_final_eval_ppl": baseline_ppl,
995
+ "optimizer_final_eval_ppl": candidate_ppl,
996
+ "ppl_gain_pct_vs_adamw": (
997
+ None
998
+ if baseline_ppl in (None, 0.0) or candidate_ppl is None
999
+ else (baseline_ppl - candidate_ppl) / baseline_ppl * 100.0
1000
+ ),
1001
+ "loss_gain_pct_vs_adamw": (
1002
+ None
1003
+ if baseline_loss in (None, 0.0) or candidate_loss is None
1004
+ else (baseline_loss - candidate_loss) / baseline_loss * 100.0
1005
+ ),
1006
+ "speed_gain_pct_vs_adamw": (
1007
+ None
1008
+ if baseline_tps in (None, 0.0) or candidate_tps is None
1009
+ else (candidate_tps - baseline_tps) / baseline_tps * 100.0
1010
+ ),
1011
+ "adamw_tokens_per_sec_wall": baseline_tps,
1012
+ "optimizer_tokens_per_sec_wall": candidate_tps,
1013
+ "lbw_scale": row.get("scale"),
1014
+ "lbw_ratio": row.get("ratio"),
1015
+ "lbw_stress_mode": row.get("stress_mode"),
1016
+ }
1017
+ )
1018
+ return gain_rows
1019
+
1020
+
1021
+ def run_ablation_test(
1022
+ model_name: str,
1023
+ selected_ablations: list[str],
1024
+ run_lbw_guard: bool,
1025
+ max_steps: int,
1026
+ eval_every: int,
1027
+ eval_batches: int,
1028
+ seq_len: int,
1029
+ batch_size: int,
1030
+ train_chars: int,
1031
+ eval_chars: int,
1032
+ full_wikitext_train: bool,
1033
+ full_wikitext_eval: bool,
1034
+ full_validation_ppl: bool,
1035
+ lr: float,
1036
+ schedule_mode: str,
1037
+ warmup_steps: int,
1038
+ seed: int,
1039
+ lr_sweep_text: str,
1040
+ step_sweep_text: str,
1041
+ lora_r_sweep_text: str,
1042
+ ):
1043
+ logs: list[str] = []
1044
+ rows: list[dict[str, Any]] = []
1045
+ run_dir = RUNS_DIR / f"ablation_test_{int(time.time())}"
1046
+ run_dir.mkdir(parents=True, exist_ok=True)
1047
+ device_name = _device_default()
1048
+ device = torch.device(device_name)
1049
+ optimizers = ["adamw", "lbw_guard"] if bool(run_lbw_guard) else ["adamw"]
1050
+
1051
+ try:
1052
+ base_config = {
1053
+ "seed": int(seed),
1054
+ "max_steps": int(max_steps),
1055
+ "eval_every": max(1, int(eval_every)),
1056
+ "eval_batches": int(eval_batches),
1057
+ "seq_len": int(seq_len),
1058
+ "batch_size": int(batch_size),
1059
+ "max_chars": int(train_chars),
1060
+ "eval_chars": int(eval_chars),
1061
+ "full_wikitext_train": bool(full_wikitext_train),
1062
+ "full_wikitext_eval": bool(full_wikitext_eval),
1063
+ "full_validation_ppl": bool(full_validation_ppl),
1064
+ "lr": float(lr),
1065
+ "betas": (0.9, 0.999),
1066
+ "weight_decay": 0.01,
1067
+ "warmup_steps": int(warmup_steps),
1068
+ "schedule_mode": str(schedule_mode or "constant").strip().lower(),
1069
+ "lora_r": 8,
1070
+ "lora_alpha": 16,
1071
+ "lora_dropout": 0.05,
1072
+ "lbw_stats_freq": 10,
1073
+ "lbw_stress_th": 1.1,
1074
+ "lbw_spike_th": 1.5,
1075
+ "lbw_rec_fast": 0.01,
1076
+ "lbw_ema_decay": 0.95,
1077
+ }
1078
+ lr_sweep = _parse_float_sweep(lr_sweep_text, [1e-3, 5e-4])
1079
+ step_sweep = _parse_int_sweep(step_sweep_text, [100, 200])
1080
+ lora_r_sweep = _parse_int_sweep(lora_r_sweep_text, [4, 8, 16])
1081
+ scenarios = _build_ablation_scenarios(
1082
+ selected_ablations=list(selected_ablations or ["optimizer"]),
1083
+ base_config=base_config,
1084
+ lr_sweep=lr_sweep,
1085
+ step_sweep=step_sweep,
1086
+ lora_r_sweep=lora_r_sweep,
1087
+ )
1088
+
1089
+ if device.type == "cpu" and (
1090
+ len(scenarios) > 1
1091
+ or int(max_steps) > 1
1092
+ or int(train_chars) > 20_000
1093
+ or int(eval_chars) > 8_000
1094
+ or bool(full_wikitext_train)
1095
+ or bool(full_wikitext_eval)
1096
+ or bool(full_validation_ppl)
1097
+ ):
1098
+ yield (
1099
+ "This Space is currently on `cpu-basic`. CPU ablation mode is capped to one optimizer scenario, "
1100
+ "1 step, 20k train chars, 8k eval chars, and sampled validation. Switch the Space hardware to GPU for ablations.",
1101
+ None,
1102
+ None,
1103
+ None,
1104
+ )
1105
+ return
1106
+ if device.type == "cuda" and bool(run_lbw_guard) and torch.cuda.device_count() > 1:
1107
+ yield (
1108
+ "LBW Guard should run with one visible GPU. Set the Space to single-GPU hardware or restrict CUDA_VISIBLE_DEVICES.",
1109
+ None,
1110
+ None,
1111
+ None,
1112
+ )
1113
+ return
1114
+
1115
+ resolved_model = str(model_name).strip() or "Qwen/Qwen2.5-0.5B"
1116
+ _append_log(logs, f"Device: {device_name}")
1117
+ if device.type == "cuda":
1118
+ _append_log(logs, f"GPU: {torch.cuda.get_device_name(0)}")
1119
+ _append_log(logs, f"Selected ablations: {', '.join(selected_ablations or ['optimizer'])}")
1120
+ _append_log(logs, f"Optimizers: {', '.join(optimizers)}")
1121
+ yield _ablation_status_markdown(
1122
+ device_name=device_name,
1123
+ rows=rows,
1124
+ logs=logs,
1125
+ phase="Loading tokenizer",
1126
+ plan=scenarios,
1127
+ ), None, None, None
1128
+
1129
+ tokenizer = AutoTokenizer.from_pretrained(resolved_model, use_fast=True)
1130
+ if tokenizer.pad_token is None:
1131
+ tokenizer.pad_token = tokenizer.eos_token
1132
+ data_cache: dict[tuple[int, int | None, int | None], dict[str, dict[str, Any]]] = {}
1133
+
1134
+ for scenario_item in scenarios:
1135
+ cfg = scenario_item["config"]
1136
+ train_cap = None if cfg["full_wikitext_train"] else int(cfg["max_chars"])
1137
+ eval_cap = None if cfg["full_wikitext_eval"] else int(cfg["eval_chars"])
1138
+ cache_key = (int(cfg["seq_len"]), train_cap, eval_cap)
1139
+ if cache_key not in data_cache:
1140
+ data_cache[cache_key] = {
1141
+ "train": _build_wikitext_chunks(
1142
+ tokenizer,
1143
+ split="train",
1144
+ max_chars=train_cap,
1145
+ seq_len=int(cfg["seq_len"]),
1146
+ logs=logs,
1147
+ ),
1148
+ "eval": _build_wikitext_chunks(
1149
+ tokenizer,
1150
+ split="validation",
1151
+ max_chars=eval_cap,
1152
+ seq_len=int(cfg["seq_len"]),
1153
+ logs=logs,
1154
+ ),
1155
+ }
1156
+
1157
+ _append_log(logs, f"=== Scenario: {scenario_item['label']} ===")
1158
+ for optimizer_name in optimizers:
1159
+ _append_log(logs, f"--- {optimizer_name} ---")
1160
+ yield _ablation_status_markdown(
1161
+ device_name=device_name,
1162
+ rows=rows,
1163
+ logs=logs,
1164
+ phase=f"Running {scenario_item['label']} / {optimizer_name}",
1165
+ plan=scenarios,
1166
+ ), None, None, None
1167
+ for event in _run_ablation_optimizer_events(
1168
+ scenario_item=scenario_item,
1169
+ optimizer_name=optimizer_name,
1170
+ model_name=resolved_model,
1171
+ train_chunks=data_cache[cache_key]["train"],
1172
+ eval_chunks=data_cache[cache_key]["eval"],
1173
+ device=device,
1174
+ logs=logs,
1175
+ ):
1176
+ if event.get("type") == "result":
1177
+ rows.append(event["result"])
1178
+ yield _ablation_status_markdown(
1179
+ device_name=device_name,
1180
+ rows=rows,
1181
+ logs=logs,
1182
+ phase=f"Running {scenario_item['label']} / {optimizer_name}",
1183
+ plan=scenarios,
1184
+ ), None, None, None
1185
+
1186
+ gains = _build_ablation_gain_rows(rows)
1187
+ payload = {
1188
+ "source": "LBW_Guard_Ablation_Test_COLAB.ipynb",
1189
+ "model_name": resolved_model,
1190
+ "device": device_name,
1191
+ "optimizers": optimizers,
1192
+ "selected_ablations": list(selected_ablations or ["optimizer"]),
1193
+ "base_config": base_config,
1194
+ "scenarios": scenarios,
1195
+ "results": rows,
1196
+ "gains": gains,
1197
+ "logs": logs,
1198
+ }
1199
+ json_path = run_dir / "lbw_guard_ablation_results.json"
1200
+ metrics_path = run_dir / "lbw_guard_ablation_metrics.csv"
1201
+ gains_path = run_dir / "lbw_guard_ablation_gains.csv"
1202
+ json_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
1203
+ _write_csv(metrics_path, rows)
1204
+ _write_csv(gains_path, gains)
1205
+ _append_log(logs, f"Wrote {metrics_path}")
1206
+ _append_log(logs, f"Wrote {gains_path}")
1207
+ yield (
1208
+ _ablation_status_markdown(device_name=device_name, rows=rows, logs=logs, phase="Complete", plan=scenarios),
1209
+ str(json_path),
1210
+ str(metrics_path),
1211
+ str(gains_path),
1212
+ )
1213
  except Exception:
1214
  error_text = traceback.format_exc()
1215
  error_path = run_dir / "error.txt"
1216
+ error_path.write_text(error_text + "\n\n" + "\n".join(logs), encoding="utf-8")
1217
+ yield f"Run failed.\n\n```text\n{error_text}\n```", str(error_path), None, None
1218
 
1219
 
1220
  INTRO = """
1221
+ # LBW Guard Colab Tests
1222
 
1223
+ Runs notebook-faithful Hugging Face Space versions of:
1224
 
1225
+ - `LBW_Guard_Easy_Test_COLAB.ipynb`
1226
+ - `LBW_Guard_Ablation_Test_COLAB.ipynb`
1227
 
1228
+ Current hardware is detected at run time. GPU is recommended for the default Easy Test.
1229
  """
1230
 
1231
 
1232
+ with gr.Blocks(title="LBW Guard Colab Tests") as demo:
1233
  gr.Markdown(INTRO)
1234
+ with gr.Tabs():
1235
+ with gr.Tab("Easy Test"):
1236
+ with gr.Row():
1237
+ easy_model_name = gr.Textbox(value="TinyLlama/TinyLlama_v1.1", label="Model")
1238
+ easy_run_lbw_guard = gr.Checkbox(value=True, label="Run LBW Guard comparison")
1239
+ with gr.Row():
1240
+ easy_max_steps = gr.Slider(1, 1000, value=5, step=1, label="Optimizer steps")
1241
+ easy_eval_every = gr.Slider(1, 200, value=5, step=1, label="Eval every")
1242
+ easy_eval_batches = gr.Slider(1, 128, value=8, step=1, label="Eval batches")
1243
+ with gr.Row():
1244
+ easy_seq_len = gr.Dropdown([64, 128, 256, 512], value=64, label="Sequence length")
1245
+ easy_batch_size = gr.Slider(1, 8, value=1, step=1, label="Batch size")
1246
+ easy_lr = gr.Number(value=5e-4, label="Learning rate")
1247
+ with gr.Row():
1248
+ easy_train_chars = gr.Slider(5_000, 2_000_000, value=20_000, step=5_000, label="Train char cap")
1249
+ easy_eval_chars = gr.Slider(1_000, 500_000, value=8_000, step=1_000, label="Eval char cap")
1250
+ easy_seed = gr.Number(value=42, precision=0, label="Seed")
1251
+ with gr.Row():
1252
+ easy_full_wikitext_train = gr.Checkbox(value=False, label="Full WikiText train")
1253
+ easy_full_wikitext_eval = gr.Checkbox(value=False, label="Full WikiText eval")
1254
+ easy_full_validation_ppl = gr.Checkbox(value=False, label="Full validation PPL")
1255
+ easy_run_button = gr.Button("Run Easy Test", variant="primary")
1256
+ easy_summary = gr.Markdown()
1257
+ easy_json_file = gr.File(label="Raw JSON")
1258
+ easy_results_file = gr.File(label="Results CSV")
1259
+ easy_gains_file = gr.File(label="Gains CSV")
1260
+
1261
+ easy_run_button.click(
1262
+ fn=run_easy_test,
1263
+ inputs=[
1264
+ easy_model_name,
1265
+ easy_run_lbw_guard,
1266
+ easy_max_steps,
1267
+ easy_eval_every,
1268
+ easy_eval_batches,
1269
+ easy_seq_len,
1270
+ easy_batch_size,
1271
+ easy_train_chars,
1272
+ easy_eval_chars,
1273
+ easy_full_wikitext_train,
1274
+ easy_full_wikitext_eval,
1275
+ easy_full_validation_ppl,
1276
+ easy_lr,
1277
+ easy_seed,
1278
+ ],
1279
+ outputs=[easy_summary, easy_json_file, easy_results_file, easy_gains_file],
1280
+ )
1281
+
1282
+ with gr.Tab("Ablation Test"):
1283
+ with gr.Row():
1284
+ ablation_model_name = gr.Textbox(value="Qwen/Qwen2.5-0.5B", label="Model")
1285
+ ablation_run_lbw_guard = gr.Checkbox(value=True, label="Run LBW Guard comparison")
1286
+ selected_ablations = gr.CheckboxGroup(
1287
+ choices=["optimizer", "lr", "schedule", "steps", "data", "lora"],
1288
+ value=["optimizer"],
1289
+ label="Ablations",
1290
+ )
1291
+ with gr.Row():
1292
+ ablation_max_steps = gr.Slider(1, 1000, value=200, step=1, label="Base optimizer steps")
1293
+ ablation_eval_every = gr.Slider(1, 200, value=50, step=1, label="Eval every")
1294
+ ablation_eval_batches = gr.Slider(1, 128, value=8, step=1, label="Eval batches")
1295
+ with gr.Row():
1296
+ ablation_seq_len = gr.Dropdown([64, 128, 256, 512], value=64, label="Sequence length")
1297
+ ablation_batch_size = gr.Slider(1, 8, value=1, step=1, label="Batch size")
1298
+ ablation_lr = gr.Number(value=5e-4, label="Base learning rate")
1299
+ with gr.Row():
1300
+ ablation_train_chars = gr.Slider(5_000, 2_000_000, value=20_000, step=5_000, label="Train char cap")
1301
+ ablation_eval_chars = gr.Slider(1_000, 500_000, value=8_000, step=1_000, label="Eval char cap")
1302
+ ablation_seed = gr.Number(value=42, precision=0, label="Seed")
1303
+ with gr.Row():
1304
+ ablation_schedule_mode = gr.Dropdown(["constant", "cosine"], value="constant", label="Base schedule")
1305
+ ablation_warmup_steps = gr.Slider(0, 100, value=10, step=1, label="Warmup steps")
1306
+ with gr.Row():
1307
+ ablation_full_wikitext_train = gr.Checkbox(value=False, label="Full WikiText train")
1308
+ ablation_full_wikitext_eval = gr.Checkbox(value=False, label="Full WikiText eval")
1309
+ ablation_full_validation_ppl = gr.Checkbox(value=False, label="Full validation PPL")
1310
+ with gr.Row():
1311
+ lr_sweep_text = gr.Textbox(value="1e-3, 5e-4", label="LR sweep")
1312
+ step_sweep_text = gr.Textbox(value="100, 200", label="Step sweep")
1313
+ lora_r_sweep_text = gr.Textbox(value="4, 8, 16", label="LoRA r sweep")
1314
+ ablation_run_button = gr.Button("Run Ablation Test", variant="primary")
1315
+ ablation_summary = gr.Markdown()
1316
+ ablation_json_file = gr.File(label="Raw JSON")
1317
+ ablation_metrics_file = gr.File(label="Metrics CSV")
1318
+ ablation_gains_file = gr.File(label="Gains CSV")
1319
+
1320
+ ablation_run_button.click(
1321
+ fn=run_ablation_test,
1322
+ inputs=[
1323
+ ablation_model_name,
1324
+ selected_ablations,
1325
+ ablation_run_lbw_guard,
1326
+ ablation_max_steps,
1327
+ ablation_eval_every,
1328
+ ablation_eval_batches,
1329
+ ablation_seq_len,
1330
+ ablation_batch_size,
1331
+ ablation_train_chars,
1332
+ ablation_eval_chars,
1333
+ ablation_full_wikitext_train,
1334
+ ablation_full_wikitext_eval,
1335
+ ablation_full_validation_ppl,
1336
+ ablation_lr,
1337
+ ablation_schedule_mode,
1338
+ ablation_warmup_steps,
1339
+ ablation_seed,
1340
+ lr_sweep_text,
1341
+ step_sweep_text,
1342
+ lora_r_sweep_text,
1343
+ ],
1344
+ outputs=[ablation_summary, ablation_json_file, ablation_metrics_file, ablation_gains_file],
1345
+ )
1346
 
1347
 
1348
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -1,6 +1,7 @@
1
  torch
2
- transformers
3
- datasets
4
- peft
5
- accelerate
 
6
  lbw-guard==1.1.3
 
1
  torch
2
+ transformers>=4.45
3
+ datasets>=2.20
4
+ peft>=0.12
5
+ accelerate>=0.33
6
+ sentencepiece
7
  lbw-guard==1.1.3