Radianis commited on
Commit
9e047c8
·
0 Parent(s):

Add LBW Guard direct runner Space

Browse files
Files changed (5) hide show
  1. .gitignore +7 -0
  2. README.md +31 -0
  3. _demo_runtime.py +1441 -0
  4. app.py +329 -0
  5. requirements.txt +6 -0
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.py[cod]
3
+ .DS_Store
4
+ .env
5
+ runs/
6
+ .hf_cache/
7
+ .wandb/
README.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: LBW Guard Direct Runner
3
+ emoji: 🚀
4
+ colorFrom: green
5
+ colorTo: blue
6
+ sdk: gradio
7
+ python_version: "3.10"
8
+ app_file: app.py
9
+ suggested_hardware: t4-medium
10
+ models:
11
+ - Qwen/Qwen2.5-0.5B
12
+ datasets:
13
+ - Salesforce/wikitext
14
+ tags:
15
+ - optimizer
16
+ - training
17
+ - gradio
18
+ - gpu
19
+ ---
20
+
21
+ Copyright (c) Qluon Inc. All rights reserved.
22
+
23
+ Provided for Learn-By-Wire Guard evaluation and customer testing under the applicable Qluon license terms.
24
+
25
+ # LBW Guard Direct Runner
26
+
27
+ This Space runs a compact AdamW vs `lbw_guard` WikiText LoRA smoke test directly on Hugging Face Spaces.
28
+
29
+ Use GPU hardware for meaningful runtime. CPU can load the app, but model training may be slow or fail on memory.
30
+
31
+ The app writes run artifacts to the Space working directory. Add persistent storage if you need outputs to survive Space restarts.
_demo_runtime.py ADDED
@@ -0,0 +1,1441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Standalone customer demo runtime decoupled from the internal benchmark harness."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import importlib.util
7
+ import json
8
+ import math
9
+ import os
10
+ import random
11
+ import shutil
12
+ import statistics
13
+ import subprocess
14
+ import sys
15
+ import tempfile
16
+ import time
17
+ import warnings
18
+ from array import array
19
+ from collections import Counter, deque
20
+ from dataclasses import dataclass, field
21
+ from pathlib import Path
22
+ from typing import Any, Dict, List, Optional, Sequence
23
+
24
+ AUTOMATION_DIR = Path(__file__).resolve().parent
25
+ TEST_ROOT = AUTOMATION_DIR.parent
26
+ LBW_ROOT = TEST_ROOT.parent
27
+
28
+ _hf_home = os.environ.setdefault("HF_HOME", str((LBW_ROOT / ".hf_cache").resolve()))
29
+ os.environ.setdefault("HF_DATASETS_CACHE", str((Path(_hf_home) / "datasets").resolve()))
30
+ os.environ.setdefault("TRANSFORMERS_CACHE", str((Path(_hf_home) / "transformers").resolve()))
31
+ # Prevent background safetensors conversion threads from calling the HF
32
+ # conversion Space during demo loads. This keeps repo-hosted .bin models
33
+ # usable without noisy thread crashes when the service/network misbehaves.
34
+ os.environ.setdefault("DISABLE_SAFETENSORS_CONVERSION", "1")
35
+
36
+ _wandb_home = LBW_ROOT / ".wandb"
37
+ os.environ.setdefault("WANDB_DIR", str(_wandb_home.resolve()))
38
+ os.environ.setdefault("WANDB_CACHE_DIR", str((_wandb_home / "cache").resolve()))
39
+ os.environ.setdefault("WANDB_CONFIG_DIR", str((_wandb_home / "config").resolve()))
40
+
41
+ import torch
42
+ from datasets import load_dataset
43
+ from peft import LoraConfig, TaskType, get_peft_model
44
+ from transformers import AutoModelForCausalLM, AutoTokenizer, get_cosine_schedule_with_warmup
45
+
46
+ try:
47
+ import wandb
48
+ except Exception:
49
+ wandb = None
50
+
51
+ try:
52
+ from lbw import Guard
53
+ except Exception:
54
+ Guard = None
55
+
56
+
57
+ @dataclass
58
+ class BenchmarkConfig:
59
+ model_name: str = "Qwen/Qwen2.5-3B"
60
+ device: str = "cuda"
61
+ enable_lora: bool = True
62
+ lora_r: int = 16
63
+ lora_alpha: int = 64
64
+ lora_dropout: float = 0.05
65
+ lora_target_modules: List[str] = field(
66
+ default_factory=lambda: [
67
+ "q_proj",
68
+ "k_proj",
69
+ "v_proj",
70
+ "o_proj",
71
+ "gate_proj",
72
+ "up_proj",
73
+ "down_proj",
74
+ ]
75
+ )
76
+ seq_len: int = 256
77
+ batch_size: int = 2
78
+ grad_accum: int = 2
79
+ max_steps: int = 100
80
+ warmup_steps: int = 50
81
+ eval_every: int = 50
82
+ eval_batches: int = 50
83
+ schedule_mode: str = "all_cosine"
84
+ max_chars: int = 4_000_000
85
+ eval_chars: int = 1_000_000
86
+ full_wikitext_train: bool = False
87
+ full_wikitext_eval: bool = False
88
+ full_validation_ppl: bool = False
89
+ lr: float = 5e-4
90
+ weight_decay: float = 0.01
91
+ betas: tuple[float, float] = (0.9, 0.999)
92
+ lbw_stats_freq: int = 50
93
+ lbw_stress_th: float = 1.1 #1.1
94
+ lbw_spike_th: float = 1.5 #1.5
95
+ lbw_rec_fast: float = 0.05 #0.01
96
+ lbw_ema_decay: float = 0.95
97
+ use_wandb: bool = False
98
+ use_lbwgov: bool = False
99
+ print_all_metrics: bool = False
100
+ lbwgov_experiment_name: str = "LBW-Customer-Demo"
101
+ output_dir: str = str((AUTOMATION_DIR / "demo_outputs").resolve())
102
+ enable_benchmarks: bool = False
103
+ use_lm_eval: bool = False
104
+ lm_eval_ppl: bool = False
105
+ lm_eval_ppl_task: str = "wikitext_103_raw"
106
+ lm_eval_ppl_limit: Optional[float] = None
107
+ lm_eval_acc: bool = False
108
+ lm_eval_acc_tasks: str = "mmlu,arc_challenge"
109
+ lm_eval_acc_limit: Optional[float] = None
110
+ lm_eval_mmlu_limit: Optional[float] = None
111
+ lm_eval_arc_challenge_limit: Optional[float] = None
112
+ lm_eval_mmlu_fewshot: int = 5
113
+ lm_eval_arc_challenge_fewshot: int = 25
114
+ lm_eval_batch_size: str = "1"
115
+
116
+
117
+ @dataclass
118
+ class ChunkedTokens:
119
+ input_ids: torch.Tensor
120
+ labels: torch.Tensor
121
+ split: str = ""
122
+ char_count: int = 0
123
+ cap_chars: Optional[int] = None
124
+
125
+
126
+ def _demo_log(config: Optional["BenchmarkConfig"], message: str) -> None:
127
+ if config is not None and not bool(getattr(config, "print_all_metrics", False)):
128
+ return
129
+ print(f"[DemoRuntime] {message}", flush=True)
130
+
131
+
132
+ def _safe_float(value: Any) -> Optional[float]:
133
+ if value is None:
134
+ return None
135
+ try:
136
+ out = float(value)
137
+ except Exception:
138
+ return None
139
+ if math.isnan(out) or math.isinf(out):
140
+ return None
141
+ return out
142
+
143
+
144
+ @dataclass
145
+ class GovernanceMetricConfig:
146
+ ema_decay: float = 0.95
147
+ short_window: int = 20
148
+ long_window: int = 100
149
+ intervention_eps: float = 1e-3
150
+ stable_ratio_low: float = 0.90
151
+ stable_ratio_high: float = 1.10
152
+ unstable_ratio_high: float = 1.35
153
+ stagnation_ratio_low: float = 0.70
154
+ oscillation_flip_high: float = 0.30
155
+
156
+
157
+ class GovernanceMetricsTracker:
158
+ def __init__(self, cfg: Optional[GovernanceMetricConfig] = None):
159
+ self.cfg = cfg or GovernanceMetricConfig()
160
+ self.grad_rms_history = deque(maxlen=self.cfg.long_window)
161
+ self.ratio_history = deque(maxlen=self.cfg.long_window)
162
+ self.regime_history = deque(maxlen=self.cfg.long_window)
163
+ self.flip_rate_history = deque(maxlen=self.cfg.long_window)
164
+ self.loss_ema: Optional[float] = None
165
+ self.grad_norm_ema: Optional[float] = None
166
+ self.grad_rms_ema: Optional[float] = None
167
+ self.prev_loss: Optional[float] = None
168
+ self.prev_prev_loss: Optional[float] = None
169
+ self.prev_regime: Optional[str] = None
170
+ self.prev_grad_sign_summary: Optional[tuple[int, int, int]] = None
171
+ self.total_logged_steps = 0
172
+ self.intervention_count = 0
173
+ self.regime_switch_count = 0
174
+ self.stress_entries = 0
175
+ self.total_control_energy = 0.0
176
+ self.max_control_energy = 0.0
177
+ self.open_recovery_start_step: Optional[int] = None
178
+ self.completed_recovery_latencies: List[int] = []
179
+ self.best_eval_loss: Optional[float] = None
180
+ self.best_eval_perplexity: Optional[float] = None
181
+
182
+ def _safe_float(self, value: Any, default: float = 0.0) -> float:
183
+ try:
184
+ out = float(value)
185
+ if math.isfinite(out):
186
+ return out
187
+ except Exception:
188
+ pass
189
+ return float(default)
190
+
191
+ def _ema_update(self, old: Optional[float], new: Optional[float]) -> Optional[float]:
192
+ if new is None:
193
+ return old
194
+ if old is None:
195
+ return float(new)
196
+ d = self.cfg.ema_decay
197
+ return float(d * old + (1.0 - d) * float(new))
198
+
199
+ def _std(self, values) -> float:
200
+ vals = [float(v) for v in values if v is not None and math.isfinite(float(v))]
201
+ if len(vals) < 2:
202
+ return 0.0
203
+ return float(statistics.pstdev(vals))
204
+
205
+ def _mean(self, values) -> float:
206
+ vals = [float(v) for v in values if v is not None and math.isfinite(float(v))]
207
+ if not vals:
208
+ return 0.0
209
+ return float(sum(vals) / len(vals))
210
+
211
+ def _compute_grad_sign_flip_rate(self, params) -> float:
212
+ pos = 0
213
+ neg = 0
214
+ zero = 0
215
+ for param in params:
216
+ if param.grad is None:
217
+ continue
218
+ grad = param.grad.detach()
219
+ if grad.numel() == 0:
220
+ continue
221
+ signs = torch.sign(grad)
222
+ pos += int((signs > 0).sum().item())
223
+ neg += int((signs < 0).sum().item())
224
+ zero += int((signs == 0).sum().item())
225
+ summary = (pos, neg, zero)
226
+ if self.prev_grad_sign_summary is None:
227
+ self.prev_grad_sign_summary = summary
228
+ return 0.0
229
+ prev_pos, prev_neg, _ = self.prev_grad_sign_summary
230
+ prev_total = max(prev_pos + prev_neg, 1)
231
+ cur_total = max(pos + neg, 1)
232
+ flip_rate = abs((pos / cur_total) - (prev_pos / prev_total))
233
+ self.prev_grad_sign_summary = summary
234
+ return float(flip_rate)
235
+
236
+ def classify_regime(self, *, ratio: float, flip_rate: float, loss_velocity: float, scale: float, stress_mode: str) -> str:
237
+ ratio = self._safe_float(ratio, 1.0)
238
+ flip_rate = self._safe_float(flip_rate, 0.0)
239
+ loss_velocity = self._safe_float(loss_velocity, 0.0)
240
+ scale = self._safe_float(scale, 1.0)
241
+ stress_mode = str(stress_mode or "unknown").lower()
242
+ if "stress" in stress_mode or ratio >= self.cfg.unstable_ratio_high:
243
+ return "unstable"
244
+ if flip_rate >= self.cfg.oscillation_flip_high:
245
+ return "oscillatory"
246
+ if ratio <= self.cfg.stagnation_ratio_low and abs(loss_velocity) < 1e-4:
247
+ return "stagnation"
248
+ if (self.cfg.stable_ratio_low <= ratio <= self.cfg.stable_ratio_high) and abs(scale - 1.0) <= 0.05:
249
+ return "stable"
250
+ return "transitional"
251
+
252
+ def update_step(
253
+ self,
254
+ *,
255
+ step: int,
256
+ trainable_params,
257
+ loss_val: float,
258
+ grad_norm: float,
259
+ grad_rms: float,
260
+ ema_grad_rms: float,
261
+ ratio: float,
262
+ scale: float,
263
+ stress_mode: str,
264
+ current_lr: float,
265
+ ) -> Dict[str, float]:
266
+ self.total_logged_steps += 1
267
+ loss_val = self._safe_float(loss_val)
268
+ grad_norm = self._safe_float(grad_norm)
269
+ grad_rms = self._safe_float(grad_rms)
270
+ ema_grad_rms = self._safe_float(ema_grad_rms)
271
+ ratio = self._safe_float(ratio, 1.0)
272
+ scale = self._safe_float(scale, 1.0)
273
+ current_lr = self._safe_float(current_lr)
274
+
275
+ self.loss_ema = self._ema_update(self.loss_ema, loss_val)
276
+ self.grad_norm_ema = self._ema_update(self.grad_norm_ema, grad_norm)
277
+ self.grad_rms_ema = self._ema_update(self.grad_rms_ema, grad_rms)
278
+
279
+ loss_velocity = 0.0 if self.prev_loss is None else (loss_val - self.prev_loss)
280
+ loss_acceleration = 0.0 if self.prev_loss is None or self.prev_prev_loss is None else (
281
+ loss_val - 2.0 * self.prev_loss + self.prev_prev_loss
282
+ )
283
+ flip_rate = self._compute_grad_sign_flip_rate(trainable_params)
284
+ grad_deviation = 0.0
285
+ if ema_grad_rms > 0:
286
+ grad_deviation = (grad_rms - ema_grad_rms) / max(ema_grad_rms, 1e-12)
287
+ control_energy = abs(scale - 1.0)
288
+ intervention_flag = 1.0 if control_energy > self.cfg.intervention_eps else 0.0
289
+ if intervention_flag > 0:
290
+ self.intervention_count += 1
291
+ self.total_control_energy += control_energy
292
+ self.max_control_energy = max(self.max_control_energy, control_energy)
293
+
294
+ regime = self.classify_regime(
295
+ ratio=ratio,
296
+ flip_rate=flip_rate,
297
+ loss_velocity=loss_velocity,
298
+ scale=scale,
299
+ stress_mode=stress_mode,
300
+ )
301
+ if self.prev_regime is not None and regime != self.prev_regime:
302
+ self.regime_switch_count += 1
303
+ if regime in {"unstable", "oscillatory"} and self.open_recovery_start_step is None:
304
+ self.open_recovery_start_step = step
305
+ self.stress_entries += 1
306
+ if regime == "stable" and self.open_recovery_start_step is not None:
307
+ self.completed_recovery_latencies.append(step - self.open_recovery_start_step)
308
+ self.open_recovery_start_step = None
309
+
310
+ self.grad_rms_history.append(grad_rms)
311
+ self.ratio_history.append(ratio)
312
+ self.regime_history.append(regime)
313
+ self.flip_rate_history.append(flip_rate)
314
+
315
+ short_grad_std = self._std(list(self.grad_rms_history)[-self.cfg.short_window :])
316
+ long_grad_std = self._std(self.grad_rms_history)
317
+ grad_variance_reduction = 0.0
318
+ if long_grad_std > 1e-12:
319
+ grad_variance_reduction = 1.0 - (short_grad_std / long_grad_std)
320
+
321
+ out = {
322
+ "obs/grad_direction_change_rate": flip_rate,
323
+ "obs/loss_velocity": loss_velocity,
324
+ "obs/loss_acceleration": loss_acceleration,
325
+ "obs/update_magnitude_proxy": scale * current_lr,
326
+ "state/grad_ratio": ratio,
327
+ "state/grad_deviation_score": grad_deviation,
328
+ "state/regime_stable": 1.0 if regime == "stable" else 0.0,
329
+ "state/regime_unstable": 1.0 if regime == "unstable" else 0.0,
330
+ "state/regime_oscillatory": 1.0 if regime == "oscillatory" else 0.0,
331
+ "state/regime_stagnation": 1.0 if regime == "stagnation" else 0.0,
332
+ "state/regime_transitional": 1.0 if regime == "transitional" else 0.0,
333
+ "control/action_strength": control_energy,
334
+ "control/intervention_flag": intervention_flag,
335
+ "loop/intervention_rate": self.intervention_count / max(self.total_logged_steps, 1),
336
+ "loop/regime_switch_count": float(self.regime_switch_count),
337
+ "loop/avg_control_energy": self.total_control_energy / max(self.total_logged_steps, 1),
338
+ "loop/max_control_energy": self.max_control_energy,
339
+ "effect/grad_variance_reduction": grad_variance_reduction,
340
+ "effect/recovery_latency_mean_steps": self._mean(self.completed_recovery_latencies),
341
+ "effect/recovery_events": float(len(self.completed_recovery_latencies)),
342
+ }
343
+
344
+ self.prev_prev_loss = self.prev_loss
345
+ self.prev_loss = loss_val
346
+ self.prev_regime = regime
347
+ return out
348
+
349
+ def update_eval(
350
+ self,
351
+ *,
352
+ eval_loss: Optional[float] = None,
353
+ eval_perplexity: Optional[float] = None,
354
+ avg_tps_wall: Optional[float] = None,
355
+ ) -> Dict[str, float]:
356
+ out: Dict[str, float] = {}
357
+ if eval_loss is not None:
358
+ if self.best_eval_loss is None or eval_loss < self.best_eval_loss:
359
+ self.best_eval_loss = float(eval_loss)
360
+ out["effect/best_eval_loss"] = float(self.best_eval_loss)
361
+ out["effect/eval_loss_gap_to_best"] = float(eval_loss - self.best_eval_loss)
362
+ if eval_perplexity is not None:
363
+ if self.best_eval_perplexity is None or eval_perplexity < self.best_eval_perplexity:
364
+ self.best_eval_perplexity = float(eval_perplexity)
365
+ out["effect/best_eval_perplexity"] = float(self.best_eval_perplexity)
366
+ out["effect/eval_perplexity_gap_to_best"] = float(eval_perplexity - self.best_eval_perplexity)
367
+ if avg_tps_wall is not None:
368
+ out["effect/efficiency_wall_tps"] = float(avg_tps_wall)
369
+ return out
370
+
371
+ def snapshot(self) -> Dict[str, Any]:
372
+ return {
373
+ "total_logged_steps": self.total_logged_steps,
374
+ "intervention_count": self.intervention_count,
375
+ "regime_switch_count": self.regime_switch_count,
376
+ "stress_entries": self.stress_entries,
377
+ "avg_control_energy": self.total_control_energy / max(self.total_logged_steps, 1),
378
+ "max_control_energy": self.max_control_energy,
379
+ "completed_recovery_latencies": list(self.completed_recovery_latencies),
380
+ "recent_regimes": list(self.regime_history),
381
+ "best_eval_loss": self.best_eval_loss,
382
+ "best_eval_perplexity": self.best_eval_perplexity,
383
+ }
384
+
385
+
386
+ def _wants_cuda(device: Optional[str] = None) -> bool:
387
+ return str(device or "").strip().lower().startswith("cuda")
388
+
389
+
390
+ def set_seed(seed: int, device: Optional[str] = None):
391
+ random.seed(seed)
392
+ torch.manual_seed(seed)
393
+ if _wants_cuda(device) and torch.cuda.is_available():
394
+ torch.cuda.manual_seed_all(seed)
395
+
396
+
397
+ def normalize_optimizer_name(name: str) -> str:
398
+ aliases = {
399
+ "guard": "lbw_guard",
400
+ "lbw": "lbw_guard",
401
+ "lbw-guard": "lbw_guard",
402
+ "adam": "adamw",
403
+ }
404
+ key = str(name or "").strip().lower()
405
+ return aliases.get(key, key)
406
+
407
+
408
+ def check_optimizer_support(name: str, device: Optional[str] = None) -> tuple[bool, str]:
409
+ normalized = normalize_optimizer_name(name)
410
+ if normalized not in {"adamw", "lbw_guard"}:
411
+ return False, "Standalone customer demo runtime supports only adamw and lbw_guard."
412
+ if normalized == "lbw_guard" and Guard is None:
413
+ return False, "LBW_Guard package not found. Install the standard LBW_Guard package in the active Python environment."
414
+ if normalized == "lbw_guard" and _wants_cuda(device) and torch.cuda.is_available() and int(torch.cuda.device_count()) > 1:
415
+ return False, "lbw_guard supports at most 1 visible GPU. Restrict CUDA_VISIBLE_DEVICES to one GPU."
416
+ return True, ""
417
+
418
+
419
+ def _hf_offline_mode() -> bool:
420
+ return (
421
+ os.environ.get("HF_HUB_OFFLINE", "").lower() in {"1", "true", "yes"}
422
+ or os.environ.get("TRANSFORMERS_OFFLINE", "").lower() in {"1", "true", "yes"}
423
+ )
424
+
425
+
426
+ def _hf_pretrained_kwargs() -> Dict[str, Any]:
427
+ kwargs: Dict[str, Any] = {"trust_remote_code": True}
428
+ if _hf_offline_mode():
429
+ kwargs["local_files_only"] = True
430
+ return kwargs
431
+
432
+
433
+ def _resolve_model_dtype(device: torch.device):
434
+ return torch.bfloat16 if device.type == "cuda" else torch.float32
435
+
436
+
437
+ def _resolve_model_device_map(device: torch.device):
438
+ if device.type != "cuda":
439
+ return None
440
+ return {"": (device.index if device.index is not None else 0)}
441
+
442
+
443
+ def _load_tokenizer_and_model(model_name: str, device: torch.device):
444
+ hf_kwargs = _hf_pretrained_kwargs()
445
+ tokenizer = AutoTokenizer.from_pretrained(model_name, **hf_kwargs)
446
+ if tokenizer.pad_token is None:
447
+ tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token
448
+ model_kwargs: Dict[str, Any] = {
449
+ "torch_dtype": _resolve_model_dtype(device),
450
+ **hf_kwargs,
451
+ }
452
+ device_map = _resolve_model_device_map(device)
453
+ if device_map is not None:
454
+ model_kwargs["device_map"] = device_map
455
+ model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
456
+ if device_map is None:
457
+ model.to(device)
458
+ return tokenizer, model
459
+
460
+
461
+ def build_wikitext_chunks(
462
+ tokenizer,
463
+ seq_len: int,
464
+ max_chars: Optional[int],
465
+ split: str,
466
+ *,
467
+ config: Optional[BenchmarkConfig] = None,
468
+ ) -> ChunkedTokens:
469
+ cap = None if max_chars is None else int(max_chars)
470
+ _demo_log(
471
+ config,
472
+ f"Preparing WikiText split='{split}'"
473
+ + (f" with char cap {cap:,}" if cap is not None else " with full split"),
474
+ )
475
+ ds = load_dataset("wikitext", "wikitext-103-raw-v1", split=split)
476
+ token_buf = array("I")
477
+ chars_used = 0
478
+ first_piece = True
479
+ rows_used = 0
480
+ next_report_chars = 500_000 if cap is None else max(250_000, cap // 4)
481
+ for row in ds:
482
+ text = str(row.get("text", "") or "")
483
+ if not text.strip():
484
+ continue
485
+ piece = text if first_piece else (" " + text)
486
+ if cap is not None:
487
+ remain = cap - chars_used
488
+ if remain <= 0:
489
+ break
490
+ if len(piece) > remain:
491
+ piece = piece[:remain]
492
+ chars_used += len(piece)
493
+ first_piece = False
494
+ rows_used += 1
495
+ ids_piece = tokenizer(piece, add_special_tokens=False)["input_ids"]
496
+ if ids_piece:
497
+ token_buf.extend(ids_piece)
498
+ if config is not None and bool(getattr(config, "print_all_metrics", False)) and chars_used >= next_report_chars:
499
+ target = f"/{cap:,}" if cap is not None else ""
500
+ _demo_log(config, f"Tokenizing split='{split}': {chars_used:,}{target} chars")
501
+ next_report_chars += 500_000 if cap is None else max(250_000, cap // 4)
502
+ if cap is not None and chars_used >= cap:
503
+ break
504
+ if len(token_buf) == 0:
505
+ raise RuntimeError(f"No tokens built for split '{split}'.")
506
+ ids = torch.tensor(token_buf, dtype=torch.long)
507
+ n = ids.numel() // seq_len
508
+ if n <= 0:
509
+ raise RuntimeError(f"Not enough tokens for seq_len {seq_len}. Increase max_chars.")
510
+ ids = ids[: n * seq_len].view(n, seq_len).contiguous()
511
+ _demo_log(
512
+ config,
513
+ f"Prepared split='{split}': {chars_used:,} chars across {rows_used:,} rows -> {ids.size(0):,} sequences of len {seq_len}",
514
+ )
515
+ return ChunkedTokens(input_ids=ids, labels=ids, split=split, char_count=int(chars_used), cap_chars=cap)
516
+
517
+
518
+ def batch_iter(chunks: ChunkedTokens, batch_size: int, device: torch.device):
519
+ x, y = chunks.input_ids, chunks.labels
520
+ i, n = 0, x.size(0)
521
+ while True:
522
+ if i + batch_size > n:
523
+ i = 0
524
+ yield (
525
+ x[i : i + batch_size].to(device, non_blocking=True),
526
+ y[i : i + batch_size].to(device, non_blocking=True),
527
+ )
528
+ i += batch_size
529
+
530
+
531
+ def evaluate_perplexity(model, eval_chunks: ChunkedTokens, config: BenchmarkConfig, device: torch.device, *, full_pass: bool = False):
532
+ model.eval()
533
+ total_nll = 0.0
534
+ total_tokens = 0
535
+ with torch.no_grad():
536
+ if full_pass:
537
+ x, y = eval_chunks.input_ids, eval_chunks.labels
538
+ n = x.size(0)
539
+ for i in range(0, n, config.batch_size):
540
+ ex = x[i : i + config.batch_size].to(device, non_blocking=True)
541
+ ey = y[i : i + config.batch_size].to(device, non_blocking=True)
542
+ with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=(device.type == "cuda")):
543
+ out = model(input_ids=ex, labels=ey)
544
+ tok = int(ey[:, 1:].numel())
545
+ if tok > 0 and math.isfinite(float(out.loss.item())):
546
+ total_nll += float(out.loss.item()) * float(tok)
547
+ total_tokens += tok
548
+ else:
549
+ eval_iter = batch_iter(eval_chunks, config.batch_size, device)
550
+ for _ in range(max(1, int(config.eval_batches))):
551
+ ex, ey = next(eval_iter)
552
+ with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=(device.type == "cuda")):
553
+ out = model(input_ids=ex, labels=ey)
554
+ tok = int(ey[:, 1:].numel())
555
+ if tok > 0 and math.isfinite(float(out.loss.item())):
556
+ total_nll += float(out.loss.item()) * float(tok)
557
+ total_tokens += tok
558
+ if total_tokens <= 0:
559
+ raise RuntimeError("Validation produced no batches.")
560
+ avg_eval_loss = float(total_nll / float(total_tokens))
561
+ return avg_eval_loss, math.exp(avg_eval_loss)
562
+
563
+
564
+ def _compute_grad_rms(params) -> float:
565
+ sq_sum = 0.0
566
+ count = 0
567
+ for param in params:
568
+ if param.grad is None:
569
+ continue
570
+ grad = param.grad.detach()
571
+ sq_sum += float(torch.sum(grad.float() * grad.float()).item())
572
+ count += int(grad.numel())
573
+ if count <= 0:
574
+ return 0.0
575
+ return float(math.sqrt(sq_sum / float(count)))
576
+
577
+
578
+ def get_optimizer(name: str, model, config: BenchmarkConfig):
579
+ params = [p for p in model.parameters() if p.requires_grad]
580
+ normalized = normalize_optimizer_name(name)
581
+ if normalized == "adamw":
582
+ return torch.optim.AdamW(
583
+ params,
584
+ lr=config.lr,
585
+ betas=config.betas,
586
+ weight_decay=config.weight_decay,
587
+ )
588
+ if normalized == "lbw_guard":
589
+ if Guard is None:
590
+ raise RuntimeError("LBW Guard package not available for lbw_guard.")
591
+ return Guard(
592
+ params,
593
+ lr=config.lr,
594
+ betas=config.betas,
595
+ weight_decay=config.weight_decay,
596
+ mode="eval",
597
+ auto_enabled=True,
598
+ stats_freq=int(config.lbw_stats_freq),
599
+ stress_threshold=config.lbw_stress_th,
600
+ spike_threshold=config.lbw_spike_th,
601
+ recovery_fast=config.lbw_rec_fast,
602
+ ema_decay=config.lbw_ema_decay,
603
+ use_max_rms=True,
604
+ )
605
+ raise ValueError(f"Unsupported optimizer for standalone demo runtime: {name}")
606
+
607
+
608
+ class _SchedulerProxyOptimizer(torch.optim.Optimizer):
609
+ def __init__(self, param_groups: List[Dict[str, Any]]):
610
+ proxy_groups = []
611
+ for group in list(param_groups or []):
612
+ proxy_groups.append({"params": list(group.get("params", []) or []), "lr": float(group.get("lr", 0.0))})
613
+ super().__init__(proxy_groups, defaults={})
614
+ self.param_groups = param_groups
615
+
616
+ def step(self, closure=None):
617
+ del closure
618
+ return None
619
+
620
+
621
+ def _build_scheduler_proxy_for_optimizer_like(opt: Any) -> Optional[torch.optim.Optimizer]:
622
+ param_groups = getattr(opt, "param_groups", None)
623
+ if not isinstance(param_groups, list) or not param_groups:
624
+ return None
625
+ try:
626
+ return _SchedulerProxyOptimizer(param_groups)
627
+ except Exception:
628
+ return None
629
+
630
+
631
+ def _pick_lm_eval_perplexity_metrics(task_result: Dict[str, Any]) -> Dict[str, float]:
632
+ out: Dict[str, float] = {}
633
+ if not isinstance(task_result, dict):
634
+ return out
635
+ key_map = (
636
+ ("word_perplexity,none", "word_perplexity"),
637
+ ("word_perplexity", "word_perplexity"),
638
+ ("perplexity,none", "perplexity"),
639
+ ("perplexity", "perplexity"),
640
+ ("bits_per_byte,none", "bits_per_byte"),
641
+ ("bits_per_byte", "bits_per_byte"),
642
+ )
643
+ for src, dst in key_map:
644
+ value = task_result.get(src, None)
645
+ if isinstance(value, (int, float)):
646
+ out[dst] = float(value)
647
+ return out
648
+
649
+
650
+ def _pick_lm_eval_accuracy_metrics(task_result: Dict[str, Any]) -> Dict[str, float]:
651
+ out: Dict[str, float] = {}
652
+ if not isinstance(task_result, dict):
653
+ return out
654
+ key_map = (
655
+ ("acc_norm,none", "acc_norm"),
656
+ ("acc_norm", "acc_norm"),
657
+ ("acc,none", "acc"),
658
+ ("acc", "acc"),
659
+ ("exact_match,none", "exact_match"),
660
+ ("exact_match", "exact_match"),
661
+ )
662
+ for src, dst in key_map:
663
+ value = task_result.get(src, None)
664
+ if isinstance(value, (int, float)):
665
+ out[dst] = float(value)
666
+ return out
667
+
668
+
669
+ def _normalize_lm_eval_task_name(task_name: str) -> str:
670
+ task = str(task_name or "").strip()
671
+ if not task:
672
+ return task
673
+ normalized = task.lower().replace("-", "_")
674
+ aliases = {
675
+ "wikitext_103_raw": "wikitext_103_raw",
676
+ "wikitext103_raw": "wikitext_103_raw",
677
+ "wikitext103raw": "wikitext_103_raw",
678
+ "wikitext_103": "wikitext_103_raw",
679
+ "wikitext103": "wikitext_103_raw",
680
+ "wikitext_103_raw_v1": "wikitext_103_raw",
681
+ "wikitext": "wikitext",
682
+ "paloma_wikitext_103": "paloma_wikitext_103",
683
+ "mmlu": "mmlu",
684
+ "hendrycks_test": "mmlu",
685
+ "arc": "arc_challenge",
686
+ "arc_challenge": "arc_challenge",
687
+ "arcchallenge": "arc_challenge",
688
+ }
689
+ return aliases.get(normalized, normalized)
690
+
691
+
692
+ def _parse_lm_eval_task_list(raw_tasks: Any) -> List[str]:
693
+ if raw_tasks is None:
694
+ return []
695
+ if isinstance(raw_tasks, str):
696
+ items = [part.strip() for part in raw_tasks.split(",")]
697
+ else:
698
+ items = [str(part).strip() for part in raw_tasks]
699
+ out: List[str] = []
700
+ seen = set()
701
+ for item in items:
702
+ if not item:
703
+ continue
704
+ normalized = _normalize_lm_eval_task_name(item)
705
+ if normalized and normalized not in seen:
706
+ seen.add(normalized)
707
+ out.append(normalized)
708
+ return out
709
+
710
+
711
+ def _lm_eval_num_fewshot_for_task(config: BenchmarkConfig, task_name: str) -> int:
712
+ normalized = _normalize_lm_eval_task_name(task_name)
713
+ if normalized == "mmlu":
714
+ return int(getattr(config, "lm_eval_mmlu_fewshot", 5))
715
+ if normalized == "arc_challenge":
716
+ return int(getattr(config, "lm_eval_arc_challenge_fewshot", 25))
717
+ return 0
718
+
719
+
720
+ def _lm_eval_limit_for_task(config: BenchmarkConfig, task_name: str) -> Optional[float]:
721
+ normalized = _normalize_lm_eval_task_name(task_name)
722
+ legacy_limit = getattr(config, "lm_eval_acc_limit", None)
723
+ if normalized == "mmlu":
724
+ value = getattr(config, "lm_eval_mmlu_limit", legacy_limit)
725
+ elif normalized == "arc_challenge":
726
+ value = getattr(config, "lm_eval_arc_challenge_limit", None)
727
+ else:
728
+ value = legacy_limit
729
+ if value is None:
730
+ return None
731
+ return float(value)
732
+
733
+
734
+ def _load_lm_eval_results(output_path: Path) -> Dict[str, Any]:
735
+ candidates: List[Path] = []
736
+ if output_path.is_file():
737
+ candidates = [output_path]
738
+ elif output_path.is_dir():
739
+ candidates = sorted([p for p in output_path.rglob("*.json") if p.is_file()], key=lambda p: p.stat().st_mtime, reverse=True)
740
+ for cand in candidates:
741
+ try:
742
+ payload = json.loads(cand.read_text())
743
+ if isinstance(payload, dict) and isinstance(payload.get("results", None), dict):
744
+ return payload["results"]
745
+ except Exception:
746
+ continue
747
+ raise RuntimeError(f"Unable to parse lm_eval output from {output_path}")
748
+
749
+
750
+ def _find_lm_eval_task_result(raw_results: Dict[str, Any], task_name: str) -> Dict[str, Any]:
751
+ task_key = task_name if task_name in raw_results else next(
752
+ (k for k in raw_results if k == task_name or k.startswith(task_name)),
753
+ None,
754
+ )
755
+ if task_key is None:
756
+ raise RuntimeError(f"lm_eval returned no results for '{task_name}'.")
757
+ task_result = raw_results.get(task_key, {})
758
+ if not isinstance(task_result, dict):
759
+ raise RuntimeError(f"lm_eval returned malformed results for '{task_name}'.")
760
+ return task_result
761
+
762
+
763
+ def _resolve_lm_eval_command() -> List[str]:
764
+ lm_eval_bin = shutil.which("lm_eval")
765
+ if lm_eval_bin:
766
+ return [lm_eval_bin, "run"]
767
+ if importlib.util.find_spec("lm_eval") is not None:
768
+ return [sys.executable, "-m", "lm_eval", "run"]
769
+ raise RuntimeError("lm_eval not found. Install EleutherAI lm-evaluation-harness in your venv.")
770
+
771
+
772
+ def _resolve_lm_eval_include_path() -> Optional[str]:
773
+ paths: List[str] = []
774
+ for local_tasks in (AUTOMATION_DIR / "lm_eval_tasks", TEST_ROOT / "lm_eval_tasks"):
775
+ if local_tasks.exists():
776
+ paths.append(str(local_tasks))
777
+ env_paths = str(os.environ.get("LM_EVAL_INCLUDE_PATH", "") or "").strip()
778
+ if env_paths:
779
+ for path in env_paths.split(":"):
780
+ path = path.strip()
781
+ if path:
782
+ paths.append(path)
783
+ return ":".join(paths) if paths else None
784
+
785
+
786
+ def _run_lm_eval_with_retry(cmd: List[str], batch_size_value: str) -> None:
787
+ try:
788
+ subprocess.run(cmd, check=True)
789
+ return
790
+ except subprocess.CalledProcessError as exc:
791
+ bs = str(batch_size_value or "").strip().lower()
792
+ is_auto = bs.startswith("auto")
793
+ if (not is_auto) or ("out of memory" not in str(exc).lower() and "oom" not in str(exc).lower()):
794
+ raise
795
+ retry_cmd = list(cmd)
796
+ idx = retry_cmd.index("--batch_size")
797
+ retry_cmd[idx + 1] = "1"
798
+ subprocess.run(retry_cmd, check=True)
799
+
800
+
801
+ def _prepare_adapter_dir(*, model=None, tokenizer=None) -> tuple[Path, tempfile.TemporaryDirectory]:
802
+ if model is None or tokenizer is None:
803
+ raise RuntimeError("model and tokenizer are required for lm_eval PPL.")
804
+ tmp_ctx = tempfile.TemporaryDirectory(prefix="lbw_demo_lmeval_")
805
+ out_dir = Path(tmp_ctx.name) / "peft_adapter"
806
+ model.save_pretrained(str(out_dir))
807
+ tokenizer.save_pretrained(str(out_dir))
808
+ return out_dir, tmp_ctx
809
+
810
+
811
+ def _run_lm_eval_tasks_with_adapter(
812
+ adapter_path: Path,
813
+ *,
814
+ config: BenchmarkConfig,
815
+ device: torch.device,
816
+ tasks: Sequence[str],
817
+ limit: Optional[float],
818
+ output_name: str,
819
+ num_fewshot: int = 0,
820
+ ) -> Dict[str, Any]:
821
+ lm_eval_cmd = _resolve_lm_eval_command()
822
+ include_path = _resolve_lm_eval_include_path()
823
+ normalized_tasks = [_normalize_lm_eval_task_name(task) for task in tasks if str(task).strip()]
824
+ if not normalized_tasks:
825
+ raise RuntimeError("No lm_eval tasks were provided.")
826
+ out_dir = adapter_path.parent / output_name
827
+ out_dir.mkdir(parents=True, exist_ok=True)
828
+ lm_eval_dtype = "bfloat16" if device.type == "cuda" else "float32"
829
+ model_args = [
830
+ f"pretrained={config.model_name}",
831
+ f"peft={adapter_path}",
832
+ f"dtype={lm_eval_dtype}",
833
+ "trust_remote_code=True",
834
+ ]
835
+ batch_size_value = str(getattr(config, "lm_eval_batch_size", "1"))
836
+ cmd = [
837
+ *lm_eval_cmd,
838
+ "--model",
839
+ "hf",
840
+ "--model_args",
841
+ ",".join(model_args),
842
+ "--tasks",
843
+ ",".join(normalized_tasks),
844
+ "--num_fewshot",
845
+ str(int(num_fewshot)),
846
+ "--batch_size",
847
+ batch_size_value,
848
+ "--device",
849
+ ("cuda" if device.type == "cuda" else "cpu"),
850
+ "--output_path",
851
+ str(out_dir),
852
+ ]
853
+ if include_path is not None:
854
+ cmd.extend(["--include_path", include_path])
855
+ if limit is not None:
856
+ cmd.extend(["--limit", str(float(limit))])
857
+ _run_lm_eval_with_retry(cmd, batch_size_value)
858
+ return _load_lm_eval_results(out_dir)
859
+
860
+
861
+ def _summarize_lm_eval_status(
862
+ requested: Dict[str, bool],
863
+ statuses: Dict[str, str],
864
+ errors: Dict[str, Optional[str]],
865
+ ) -> tuple[str, Optional[str]]:
866
+ active = [name for name, enabled in requested.items() if enabled]
867
+ if not active:
868
+ return "disabled", None
869
+ active_statuses = [str(statuses.get(name, "disabled")) for name in active]
870
+ joined_errors = "; ".join(
871
+ f"{name}: {errors[name]}"
872
+ for name in active
873
+ if str(errors.get(name) or "").strip()
874
+ ) or None
875
+ if all(status == "ok" for status in active_statuses):
876
+ return "ok", None
877
+ if any(status == "ok" for status in active_statuses):
878
+ return "partial", joined_errors
879
+ return "skipped", joined_errors
880
+
881
+
882
+ def run_lm_eval_suite(model, tokenizer, config: BenchmarkConfig, device: torch.device) -> Dict[str, Any]:
883
+ requested = {
884
+ "ppl": bool(getattr(config, "lm_eval_ppl", False)),
885
+ "acc": bool(getattr(config, "lm_eval_acc", False)),
886
+ }
887
+ statuses: Dict[str, str] = {
888
+ "ppl": "disabled",
889
+ "acc": "disabled",
890
+ }
891
+ errors: Dict[str, Optional[str]] = {
892
+ "ppl": None,
893
+ "acc": None,
894
+ }
895
+ out: Dict[str, Any] = {}
896
+ adapter_path, tmp_ctx = _prepare_adapter_dir(model=model, tokenizer=tokenizer)
897
+ try:
898
+ if requested["ppl"]:
899
+ statuses["ppl"] = "requested"
900
+ ppl_task = _normalize_lm_eval_task_name(getattr(config, "lm_eval_ppl_task", "wikitext_103_raw"))
901
+ try:
902
+ raw_results = _run_lm_eval_tasks_with_adapter(
903
+ adapter_path,
904
+ config=config,
905
+ device=device,
906
+ tasks=[ppl_task],
907
+ limit=getattr(config, "lm_eval_ppl_limit", None),
908
+ output_name="lm_eval_ppl_out",
909
+ num_fewshot=0,
910
+ )
911
+ ppl_metrics = _pick_lm_eval_perplexity_metrics(_find_lm_eval_task_result(raw_results, ppl_task))
912
+ if not ppl_metrics:
913
+ raise RuntimeError(f"No perplexity-like metrics found for '{ppl_task}'.")
914
+ if "word_perplexity" in ppl_metrics:
915
+ out["lm_eval/final_word_perplexity"] = float(ppl_metrics["word_perplexity"])
916
+ out["final_eval/perplexity_lm_eval"] = float(ppl_metrics["word_perplexity"])
917
+ if "perplexity" in ppl_metrics:
918
+ out["lm_eval/final_perplexity"] = float(ppl_metrics["perplexity"])
919
+ out.setdefault("final_eval/perplexity_lm_eval", float(ppl_metrics["perplexity"]))
920
+ if "bits_per_byte" in ppl_metrics:
921
+ out["lm_eval/final_bits_per_byte"] = float(ppl_metrics["bits_per_byte"])
922
+ statuses["ppl"] = "ok"
923
+ except Exception as exc:
924
+ statuses["ppl"] = "skipped"
925
+ errors["ppl"] = str(exc).strip() or type(exc).__name__
926
+
927
+ if requested["acc"]:
928
+ statuses["acc"] = "requested"
929
+ acc_tasks = _parse_lm_eval_task_list(getattr(config, "lm_eval_acc_tasks", "mmlu,arc_challenge"))
930
+ if not acc_tasks:
931
+ statuses["acc"] = "skipped"
932
+ errors["acc"] = "No lm_eval accuracy tasks configured."
933
+ else:
934
+ try:
935
+ acc_out: Dict[str, Any] = {}
936
+ for task_name in acc_tasks:
937
+ raw_results = _run_lm_eval_tasks_with_adapter(
938
+ adapter_path,
939
+ config=config,
940
+ device=device,
941
+ tasks=[task_name],
942
+ limit=_lm_eval_limit_for_task(config, task_name),
943
+ output_name=f"lm_eval_acc_{task_name}_out",
944
+ num_fewshot=_lm_eval_num_fewshot_for_task(config, task_name),
945
+ )
946
+ metrics = _pick_lm_eval_accuracy_metrics(_find_lm_eval_task_result(raw_results, task_name))
947
+ if task_name == "mmlu":
948
+ value = metrics.get("acc")
949
+ if value is None:
950
+ raise RuntimeError("No `acc` metric found for `mmlu`.")
951
+ acc_out["lm_eval/final_mmlu_acc"] = float(value)
952
+ acc_out["final_eval/mmlu_acc_lm_eval"] = float(value)
953
+ elif task_name == "arc_challenge":
954
+ value = metrics.get("acc_norm")
955
+ if value is None:
956
+ value = metrics.get("acc")
957
+ if value is None:
958
+ value = metrics.get("exact_match")
959
+ if value is None:
960
+ raise RuntimeError("No accuracy-like metric found for `arc_challenge`.")
961
+ acc_out["lm_eval/final_arc_challenge_acc"] = float(value)
962
+ acc_out["final_eval/arc_challenge_acc_lm_eval"] = float(value)
963
+ out.update(acc_out)
964
+ statuses["acc"] = "ok"
965
+ except Exception as exc:
966
+ statuses["acc"] = "skipped"
967
+ errors["acc"] = str(exc).strip() or type(exc).__name__
968
+
969
+ overall_status, overall_error = _summarize_lm_eval_status(requested, statuses, errors)
970
+ out.update(
971
+ {
972
+ "lm_eval_status": overall_status,
973
+ "lm_eval_error": overall_error,
974
+ "lm_eval_ppl_status": statuses["ppl"],
975
+ "lm_eval_ppl_error": errors["ppl"],
976
+ "lm_eval_acc_status": statuses["acc"],
977
+ "lm_eval_acc_error": errors["acc"],
978
+ }
979
+ )
980
+ return out
981
+ finally:
982
+ tmp_ctx.cleanup()
983
+
984
+
985
+ def _current_learning_rate(opt: Any, scheduler: Optional[Any], config: BenchmarkConfig) -> float:
986
+ if scheduler is not None:
987
+ try:
988
+ return float(scheduler.get_last_lr()[0])
989
+ except Exception:
990
+ pass
991
+ try:
992
+ return float(opt.param_groups[0]["lr"])
993
+ except Exception:
994
+ return float(config.lr)
995
+
996
+
997
+ def _optimizer_group_learning_rate(opt: Any, config: BenchmarkConfig) -> float:
998
+ try:
999
+ return float(opt.param_groups[0]["lr"])
1000
+ except Exception:
1001
+ return float(config.lr)
1002
+
1003
+
1004
+ def _build_scheduler(opt: Any, optimizer_name: str, config: BenchmarkConfig):
1005
+ schedule_mode = str(getattr(config, "schedule_mode", "all_cosine") or "all_cosine").strip().lower()
1006
+ if schedule_mode not in {"native", "all_cosine", "all_constant"}:
1007
+ schedule_mode = "all_cosine"
1008
+ if isinstance(opt, torch.optim.Optimizer):
1009
+ target = opt
1010
+ else:
1011
+ target = _build_scheduler_proxy_for_optimizer_like(opt)
1012
+ if target is None:
1013
+ return None
1014
+ if schedule_mode == "all_constant":
1015
+ return torch.optim.lr_scheduler.LambdaLR(target, lr_lambda=lambda step: 1.0)
1016
+ if schedule_mode == "all_cosine":
1017
+ return get_cosine_schedule_with_warmup(target, num_warmup_steps=config.warmup_steps, num_training_steps=config.max_steps)
1018
+ if optimizer_name == "lbw_guard":
1019
+ return torch.optim.lr_scheduler.LambdaLR(target, lr_lambda=lambda step: 1.0)
1020
+ return get_cosine_schedule_with_warmup(target, num_warmup_steps=config.warmup_steps, num_training_steps=config.max_steps)
1021
+
1022
+
1023
+ def _init_wandb_if_enabled(config: BenchmarkConfig, *, group_name: Optional[str], run_name: Optional[str]):
1024
+ if not bool(getattr(config, "use_wandb", False)):
1025
+ return None
1026
+ if wandb is None:
1027
+ print("[W&B] Disabled (wandb not installed)")
1028
+ return None
1029
+ try:
1030
+ wandb.init(
1031
+ project="LBW-Customer-Demo",
1032
+ group=group_name,
1033
+ name=run_name,
1034
+ config=config.__dict__,
1035
+ reinit=True,
1036
+ settings=wandb.Settings(start_method="thread"),
1037
+ )
1038
+ return wandb
1039
+ except Exception as exc:
1040
+ print(f"[W&B] Disabled (wandb.init failed: {exc})")
1041
+ return None
1042
+
1043
+
1044
+ def train_one_run(
1045
+ optimizer_name: str,
1046
+ config: BenchmarkConfig,
1047
+ *,
1048
+ group_name: Optional[str] = None,
1049
+ run_name: Optional[str] = None,
1050
+ shared_pre_bench_results=None,
1051
+ shared_bench_dataset_bundle=None,
1052
+ ) -> Dict[str, Any]:
1053
+ del shared_pre_bench_results, shared_bench_dataset_bundle
1054
+
1055
+ normalized = normalize_optimizer_name(optimizer_name)
1056
+ ok, reason = check_optimizer_support(normalized, device=config.device)
1057
+ if not ok:
1058
+ raise RuntimeError(f"{normalized}: {reason}")
1059
+
1060
+ device = torch.device(config.device)
1061
+ if device.type != "cuda":
1062
+ warnings.filterwarnings(
1063
+ "ignore",
1064
+ message="CUDA initialization: The NVIDIA driver on your system is too old.*",
1065
+ category=UserWarning,
1066
+ )
1067
+ wb = _init_wandb_if_enabled(config, group_name=group_name, run_name=run_name or normalized)
1068
+
1069
+ _demo_log(config, f"Loading model and tokenizer: {config.model_name} on {device}")
1070
+ tokenizer, model = _load_tokenizer_and_model(config.model_name, device)
1071
+ _demo_log(config, "Model load complete")
1072
+ train_cap = None if config.full_wikitext_train else config.max_chars
1073
+ eval_cap = None if config.full_wikitext_eval else config.eval_chars
1074
+ train_chunks = build_wikitext_chunks(tokenizer, config.seq_len, train_cap, "train", config=config)
1075
+ eval_chunks = build_wikitext_chunks(tokenizer, config.seq_len, eval_cap, "validation", config=config)
1076
+ train_iter = batch_iter(train_chunks, config.batch_size, device)
1077
+ train_sequence_count = int(train_chunks.input_ids.size(0))
1078
+ train_token_count = int(train_chunks.input_ids.numel())
1079
+ eval_sequence_count = int(eval_chunks.input_ids.size(0))
1080
+ eval_token_count = int(eval_chunks.input_ids.numel())
1081
+ sequences_per_optimizer_step = max(int(config.batch_size * config.grad_accum), 1)
1082
+ tokens_per_optimizer_step = max(int(config.batch_size * config.seq_len * config.grad_accum), 1)
1083
+ steps_per_train_pass = int(math.ceil(train_sequence_count / float(sequences_per_optimizer_step)))
1084
+
1085
+ if bool(getattr(config, "enable_lora", True)):
1086
+ _demo_log(config, "Attaching LoRA adapters")
1087
+ lora_cfg = LoraConfig(
1088
+ r=config.lora_r,
1089
+ lora_alpha=config.lora_alpha,
1090
+ lora_dropout=config.lora_dropout,
1091
+ target_modules=config.lora_target_modules,
1092
+ task_type=TaskType.CAUSAL_LM,
1093
+ bias="none",
1094
+ )
1095
+ model = get_peft_model(model, lora_cfg)
1096
+ else:
1097
+ _demo_log(config, "Training without LoRA adapters")
1098
+ model.train()
1099
+ trainable_params = [p for p in model.parameters() if p.requires_grad]
1100
+ if not trainable_params:
1101
+ raise RuntimeError("No trainable parameters found for training.")
1102
+
1103
+ _demo_log(config, f"Creating optimizer: {normalized}")
1104
+ opt = get_optimizer(normalized, model, config)
1105
+ scheduler = _build_scheduler(opt, normalized, config)
1106
+ governance_tracker = GovernanceMetricsTracker()
1107
+ _demo_log(config, f"Starting training for {config.max_steps} optimizer steps")
1108
+
1109
+ train_start = time.time()
1110
+ step_wall_start = train_start
1111
+ step_compute_start = train_start
1112
+ train_losses: List[float] = []
1113
+ pure_tps_history: List[float] = []
1114
+ wall_tps_history: List[float] = []
1115
+ pure_step_time_history: List[float] = []
1116
+ wall_step_time_history: List[float] = []
1117
+ runtime_snapshot: Dict[str, Any] = {
1118
+ "stress_mode": "none",
1119
+ "scale": 1.0,
1120
+ "ratio": 1.0,
1121
+ "grad_rms": 0.0,
1122
+ "scheduled_lr_used": float(config.lr),
1123
+ "scheduled_lr_next": float(config.lr),
1124
+ "effective_lr_main_used": float(config.lr),
1125
+ "effective_lr_weight_decay_used": float(config.lr),
1126
+ "train_sequences": train_sequence_count,
1127
+ "train_tokens": train_token_count,
1128
+ "train_chars": int(train_chunks.char_count),
1129
+ "train_cap_chars": train_chunks.cap_chars,
1130
+ "eval_sequences": eval_sequence_count,
1131
+ "eval_tokens": eval_token_count,
1132
+ "eval_chars": int(eval_chunks.char_count),
1133
+ "eval_cap_chars": eval_chunks.cap_chars,
1134
+ "sequences_per_optimizer_step": sequences_per_optimizer_step,
1135
+ "tokens_per_optimizer_step": tokens_per_optimizer_step,
1136
+ "steps_per_train_pass": steps_per_train_pass,
1137
+ "epochs_completed": 0.0,
1138
+ }
1139
+
1140
+ global_step = 0
1141
+ accumulation_step = 0
1142
+
1143
+ while global_step < config.max_steps:
1144
+ xb, yb = next(train_iter)
1145
+ with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=(device.type == "cuda")):
1146
+ outputs = model(input_ids=xb, labels=yb)
1147
+ loss = outputs.loss / config.grad_accum
1148
+ loss.backward()
1149
+ accumulation_step += 1
1150
+ if accumulation_step % config.grad_accum != 0:
1151
+ continue
1152
+
1153
+ step_number = global_step + 1
1154
+ grad_norm = torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
1155
+ grad_norm_value = grad_norm.detach().item() if torch.is_tensor(grad_norm) else float(grad_norm)
1156
+ grad_rms = 0.0 if normalized == "lbw_guard" else _compute_grad_rms(trainable_params)
1157
+ loss_val = float(loss.item() * config.grad_accum)
1158
+ scheduled_lr_used = _optimizer_group_learning_rate(opt, config)
1159
+
1160
+ if normalized == "lbw_guard":
1161
+ opt.step()
1162
+ else:
1163
+ opt.step()
1164
+ if scheduler is not None:
1165
+ scheduler.step()
1166
+ opt.zero_grad()
1167
+
1168
+ compute_end = time.time()
1169
+ pure_step_time = max(compute_end - step_compute_start, 1e-12)
1170
+ tokens_per_step = int(config.batch_size * config.seq_len * config.grad_accum)
1171
+ pure_tps = tokens_per_step / pure_step_time
1172
+
1173
+ scheduled_lr_next = _current_learning_rate(opt, scheduler, config)
1174
+ scale = 1.0
1175
+ ratio = 1.0
1176
+ ema_grad_rms = grad_rms
1177
+ stress_mode = "none"
1178
+ edition = normalized
1179
+ effective_lr_main_used = scheduled_lr_used
1180
+ effective_lr_weight_decay_used = scheduled_lr_used
1181
+ if normalized == "lbw_guard":
1182
+ lbw_state = dict(getattr(opt, "state", {}).get("lbw", {}) or {})
1183
+ scale = float(lbw_state.get("scale", lbw_state.get("lbw_scale", 1.0)))
1184
+ ratio = float(lbw_state.get("ratio", 1.0))
1185
+ grad_rms = float(lbw_state.get("grad_rms", grad_rms))
1186
+ ema_grad_rms = grad_rms / ratio if ratio > 0 else grad_rms
1187
+ stress_mode = str(lbw_state.get("stress_mode", "unknown"))
1188
+ edition = str(lbw_state.get("edition", lbw_state.get("mode", normalized)))
1189
+ effective_lr_main_used = scheduled_lr_used * scale
1190
+ effective_lr_weight_decay_used = scheduled_lr_used * scale
1191
+
1192
+ derived_gov_metrics = governance_tracker.update_step(
1193
+ step=global_step,
1194
+ trainable_params=trainable_params,
1195
+ loss_val=loss_val,
1196
+ grad_norm=grad_norm_value,
1197
+ grad_rms=grad_rms,
1198
+ ema_grad_rms=ema_grad_rms,
1199
+ ratio=ratio,
1200
+ scale=scale,
1201
+ stress_mode=stress_mode,
1202
+ current_lr=scheduled_lr_used,
1203
+ )
1204
+
1205
+ train_losses.append(loss_val)
1206
+ pure_tps_history.append(pure_tps)
1207
+ pure_step_time_history.append(pure_step_time)
1208
+ epochs_completed = (step_number * sequences_per_optimizer_step) / float(train_sequence_count)
1209
+
1210
+ eval_log: Dict[str, float] = {}
1211
+ progress_every = max(
1212
+ 1,
1213
+ min(
1214
+ int(config.eval_every),
1215
+ 5 if int(config.max_steps) <= 50 else 10,
1216
+ ),
1217
+ )
1218
+ if step_number % config.eval_every == 0:
1219
+ avg_eval_loss, perp = evaluate_perplexity(model, eval_chunks, config, device)
1220
+ eval_log = {
1221
+ "eval/loss": avg_eval_loss,
1222
+ "eval/perplexity": perp,
1223
+ }
1224
+ eval_log.update(
1225
+ governance_tracker.update_eval(
1226
+ eval_loss=avg_eval_loss,
1227
+ eval_perplexity=perp,
1228
+ avg_tps_wall=(wall_tps_history[-1] if wall_tps_history else None),
1229
+ )
1230
+ )
1231
+ _demo_log(
1232
+ config,
1233
+ f"step {step_number}/{config.max_steps}: loss={loss_val:.4f}, "
1234
+ f"sampled_eval_loss={avg_eval_loss:.4f}, sampled_eval_ppl={perp:.4f}, "
1235
+ f"scale={scale:.4f}, ratio={ratio:.4f}",
1236
+ )
1237
+ model.train()
1238
+ elif bool(getattr(config, "print_all_metrics", False)) and (
1239
+ step_number == 1
1240
+ or step_number == config.max_steps
1241
+ or step_number % progress_every == 0
1242
+ ):
1243
+ _demo_log(
1244
+ config,
1245
+ f"step {step_number}/{config.max_steps}: loss={loss_val:.4f}, scale={scale:.4f}, ratio={ratio:.4f}",
1246
+ )
1247
+
1248
+ wall_end = time.time()
1249
+ wall_step_time = max(wall_end - step_wall_start, 1e-12)
1250
+ wall_tps = tokens_per_step / wall_step_time
1251
+ wall_tps_history.append(wall_tps)
1252
+ wall_step_time_history.append(wall_step_time)
1253
+
1254
+ train_log = {
1255
+ "train/loss": loss_val,
1256
+ "train/grad_norm": grad_norm_value,
1257
+ "train/tokens_per_sec_pure": pure_tps,
1258
+ "train/tokens_per_sec_wall": wall_tps,
1259
+ "train/step_time_pure_sec": pure_step_time,
1260
+ "train/step_time_wall_sec": wall_step_time,
1261
+ "train/lr": scheduled_lr_used,
1262
+ "train/lr_used": scheduled_lr_used,
1263
+ "train/lr_next": scheduled_lr_next,
1264
+ "train/effective_lr_main": effective_lr_main_used,
1265
+ "train/effective_lr_weight_decay": effective_lr_weight_decay_used,
1266
+ "train/steps_per_train_pass": float(steps_per_train_pass),
1267
+ "train/epochs_completed": float(epochs_completed),
1268
+ "lbw/scale": scale,
1269
+ "lbw/ratio": ratio,
1270
+ "lbw/grad_rms": grad_rms,
1271
+ "lbw/ema_grad_rms": ema_grad_rms,
1272
+ "lbw/stress_mode": stress_mode,
1273
+ "lbw/edition": edition,
1274
+ }
1275
+ train_log.update(derived_gov_metrics)
1276
+ if wb is not None:
1277
+ wb.log({**train_log, **eval_log}, step=step_number)
1278
+
1279
+ runtime_snapshot = {
1280
+ "stress_mode": stress_mode,
1281
+ "scale": scale,
1282
+ "ratio": ratio,
1283
+ "grad_rms": grad_rms,
1284
+ "scheduled_lr_used": scheduled_lr_used,
1285
+ "scheduled_lr_next": scheduled_lr_next,
1286
+ "effective_lr_main_used": effective_lr_main_used,
1287
+ "effective_lr_weight_decay_used": effective_lr_weight_decay_used,
1288
+ "train_sequences": train_sequence_count,
1289
+ "train_tokens": train_token_count,
1290
+ "train_chars": int(train_chunks.char_count),
1291
+ "train_cap_chars": train_chunks.cap_chars,
1292
+ "eval_sequences": eval_sequence_count,
1293
+ "eval_tokens": eval_token_count,
1294
+ "eval_chars": int(eval_chunks.char_count),
1295
+ "eval_cap_chars": eval_chunks.cap_chars,
1296
+ "sequences_per_optimizer_step": sequences_per_optimizer_step,
1297
+ "tokens_per_optimizer_step": tokens_per_optimizer_step,
1298
+ "steps_per_train_pass": steps_per_train_pass,
1299
+ "epochs_completed": float(epochs_completed),
1300
+ }
1301
+
1302
+ global_step += 1
1303
+ step_wall_start = time.time()
1304
+ step_compute_start = step_wall_start
1305
+
1306
+ training_wall_time = max(time.time() - train_start, 1e-12)
1307
+ final_eval_is_full = bool(config.full_validation_ppl)
1308
+ if final_eval_is_full:
1309
+ final_eval_scope = "full_wikitext" if eval_chunks.cap_chars is None else "full_loaded_subset"
1310
+ final_eval_scope_text = (
1311
+ "over the full WikiText validation split"
1312
+ if eval_chunks.cap_chars is None
1313
+ else f"over the full loaded validation subset ({int(eval_chunks.char_count):,} chars; --eval-chars cap)"
1314
+ )
1315
+ else:
1316
+ final_eval_scope = "sampled"
1317
+ final_eval_scope_text = f"over {int(config.eval_batches)} sampled batches"
1318
+ _demo_log(
1319
+ config,
1320
+ "Running final validation PPL " + final_eval_scope_text,
1321
+ )
1322
+ final_eval_start = time.time()
1323
+ final_eval_loss, final_eval_perp = evaluate_perplexity(
1324
+ model,
1325
+ eval_chunks,
1326
+ config,
1327
+ device,
1328
+ full_pass=final_eval_is_full,
1329
+ )
1330
+ final_eval_time_sec = max(time.time() - final_eval_start, 0.0)
1331
+ final_eval_perp_lm_eval = None
1332
+ final_eval_mmlu_acc_lm_eval = None
1333
+ final_eval_arc_challenge_acc_lm_eval = None
1334
+ lm_eval_status = "disabled"
1335
+ lm_eval_error = None
1336
+ lm_eval_ppl_status = "disabled"
1337
+ lm_eval_ppl_error = None
1338
+ lm_eval_acc_status = "disabled"
1339
+ lm_eval_acc_error = None
1340
+ lm_eval_time_sec = 0.0
1341
+ if bool(config.use_lm_eval) and (bool(config.lm_eval_ppl) or bool(getattr(config, "lm_eval_acc", False))):
1342
+ try:
1343
+ lm_eval_start = time.time()
1344
+ final_lm_eval_metrics = run_lm_eval_suite(model, tokenizer, config, device)
1345
+ lm_eval_time_sec = max(time.time() - lm_eval_start, 0.0)
1346
+ final_eval_perp_lm_eval = _safe_float(final_lm_eval_metrics.get("final_eval/perplexity_lm_eval"))
1347
+ final_eval_mmlu_acc_lm_eval = _safe_float(final_lm_eval_metrics.get("final_eval/mmlu_acc_lm_eval"))
1348
+ final_eval_arc_challenge_acc_lm_eval = _safe_float(
1349
+ final_lm_eval_metrics.get("final_eval/arc_challenge_acc_lm_eval")
1350
+ )
1351
+ lm_eval_status = str(final_lm_eval_metrics.get("lm_eval_status") or "ok")
1352
+ lm_eval_error = str(final_lm_eval_metrics.get("lm_eval_error") or "").strip() or None
1353
+ lm_eval_ppl_status = str(final_lm_eval_metrics.get("lm_eval_ppl_status") or "disabled")
1354
+ lm_eval_ppl_error = str(final_lm_eval_metrics.get("lm_eval_ppl_error") or "").strip() or None
1355
+ lm_eval_acc_status = str(final_lm_eval_metrics.get("lm_eval_acc_status") or "disabled")
1356
+ lm_eval_acc_error = str(final_lm_eval_metrics.get("lm_eval_acc_error") or "").strip() or None
1357
+ if lm_eval_status in {"skipped", "partial"} and lm_eval_error:
1358
+ print(f"[DemoRuntime] lm_eval issues: {lm_eval_error}")
1359
+ except Exception as exc:
1360
+ lm_eval_time_sec = max(time.time() - lm_eval_start, 0.0)
1361
+ lm_eval_status = "skipped"
1362
+ lm_eval_error = str(exc).strip() or type(exc).__name__
1363
+ print(f"[DemoRuntime] lm_eval skipped: {lm_eval_error}")
1364
+
1365
+ wall_time = max(time.time() - train_start, 1e-12)
1366
+ post_training_benchmark_time_sec = max(wall_time - training_wall_time, 0.0)
1367
+ avg_tps_wall = float(sum(wall_tps_history) / len(wall_tps_history)) if wall_tps_history else 0.0
1368
+ final_effect_metrics = governance_tracker.update_eval(
1369
+ eval_loss=final_eval_loss,
1370
+ eval_perplexity=final_eval_perp,
1371
+ avg_tps_wall=avg_tps_wall,
1372
+ )
1373
+ governance_snapshot = governance_tracker.snapshot()
1374
+ _demo_log(
1375
+ config,
1376
+ f"Finished: "
1377
+ f"{'final_full_eval_loss' if final_eval_is_full else 'final_eval_loss'}={final_eval_loss:.4f}, "
1378
+ f"{'final_full_eval_ppl' if final_eval_is_full else 'final_eval_ppl'}={final_eval_perp:.4f}, "
1379
+ f"wall_time={wall_time:.1f}s",
1380
+ )
1381
+
1382
+ if wb is not None:
1383
+ wb.log(
1384
+ {
1385
+ "final/eval_loss": final_eval_loss,
1386
+ "final/eval_perplexity": final_eval_perp,
1387
+ **final_effect_metrics,
1388
+ },
1389
+ step=config.max_steps,
1390
+ )
1391
+ wb.finish()
1392
+
1393
+ return {
1394
+ "optimizer": normalized,
1395
+ "group_name": group_name,
1396
+ "run_name": run_name,
1397
+ "model_name": config.model_name,
1398
+ "final_eval_loss": float(final_eval_loss),
1399
+ "final_eval_perp": float(final_eval_perp),
1400
+ "final_eval_perp_lm_eval": final_eval_perp_lm_eval,
1401
+ "final_eval_mmlu_acc_lm_eval": final_eval_mmlu_acc_lm_eval,
1402
+ "final_eval_arc_challenge_acc_lm_eval": final_eval_arc_challenge_acc_lm_eval,
1403
+ "lm_eval_status": lm_eval_status,
1404
+ "lm_eval_error": lm_eval_error,
1405
+ "lm_eval_ppl_status": lm_eval_ppl_status,
1406
+ "lm_eval_ppl_error": lm_eval_ppl_error,
1407
+ "lm_eval_acc_status": lm_eval_acc_status,
1408
+ "lm_eval_acc_error": lm_eval_acc_error,
1409
+ "avg_tokens_per_sec_pure": float(sum(pure_tps_history) / len(pure_tps_history)) if pure_tps_history else 0.0,
1410
+ "avg_tokens_per_sec_wall": avg_tps_wall,
1411
+ "avg_step_time_pure_sec": float(sum(pure_step_time_history) / len(pure_step_time_history)) if pure_step_time_history else 0.0,
1412
+ "avg_step_time_wall_sec": float(sum(wall_step_time_history) / len(wall_step_time_history)) if wall_step_time_history else 0.0,
1413
+ "training_wall_time_sec": float(training_wall_time),
1414
+ "final_eval_time_sec": float(final_eval_time_sec),
1415
+ "lm_eval_time_sec": float(lm_eval_time_sec),
1416
+ "post_training_benchmark_time_sec": float(post_training_benchmark_time_sec),
1417
+ "wall_time_sec": float(wall_time),
1418
+ "train_sequence_count": int(train_sequence_count),
1419
+ "train_token_count": int(train_token_count),
1420
+ "train_char_count": int(train_chunks.char_count),
1421
+ "train_cap_chars": train_chunks.cap_chars,
1422
+ "eval_sequence_count": int(eval_sequence_count),
1423
+ "eval_token_count": int(eval_token_count),
1424
+ "eval_char_count": int(eval_chunks.char_count),
1425
+ "eval_cap_chars": eval_chunks.cap_chars,
1426
+ "full_wikitext_train": bool(config.full_wikitext_train),
1427
+ "full_wikitext_eval": bool(config.full_wikitext_eval),
1428
+ "full_validation_ppl": bool(config.full_validation_ppl),
1429
+ "final_eval_full_pass": bool(final_eval_is_full),
1430
+ "final_eval_scope": final_eval_scope,
1431
+ "sequences_per_optimizer_step": int(sequences_per_optimizer_step),
1432
+ "tokens_per_optimizer_step": int(tokens_per_optimizer_step),
1433
+ "steps_per_train_pass": int(steps_per_train_pass),
1434
+ "epochs_completed": float((global_step * sequences_per_optimizer_step) / float(train_sequence_count)),
1435
+ "runtime_snapshot": runtime_snapshot,
1436
+ "governance_snapshot": governance_snapshot,
1437
+ "final_effect_metrics": final_effect_metrics,
1438
+ "train_loss_last": (float(train_losses[-1]) if train_losses else None),
1439
+ "schedule_mode": config.schedule_mode,
1440
+ "max_steps": int(config.max_steps),
1441
+ }
app.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import csv
4
+ import gc
5
+ import io
6
+ import json
7
+ import os
8
+ import time
9
+ import traceback
10
+ from contextlib import redirect_stdout
11
+ from pathlib import Path
12
+ from typing import Any
13
+
14
+
15
+ ROOT = Path(__file__).resolve().parent
16
+ os.environ.setdefault("HF_HOME", str((ROOT / ".hf_cache").resolve()))
17
+ os.environ.setdefault("HF_DATASETS_CACHE", str((ROOT / ".hf_cache" / "datasets").resolve()))
18
+ os.environ.setdefault("TRANSFORMERS_CACHE", str((ROOT / ".hf_cache" / "transformers").resolve()))
19
+ os.environ.setdefault("WANDB_DIR", str((ROOT / ".wandb").resolve()))
20
+ os.environ.setdefault("WANDB_CACHE_DIR", str((ROOT / ".wandb" / "cache").resolve()))
21
+ os.environ.setdefault("WANDB_CONFIG_DIR", str((ROOT / ".wandb" / "config").resolve()))
22
+ os.environ.setdefault("DISABLE_SAFETENSORS_CONVERSION", "1")
23
+
24
+ import gradio as gr
25
+ import torch
26
+
27
+ import _demo_runtime as runtime
28
+
29
+
30
+ RUNS_DIR = ROOT / "runs"
31
+
32
+
33
+ def _device_default() -> str:
34
+ return "cuda" if torch.cuda.is_available() else "cpu"
35
+
36
+
37
+ def _safe_float(value: Any) -> float | None:
38
+ if value is None:
39
+ return None
40
+ try:
41
+ return float(value)
42
+ except Exception:
43
+ return None
44
+
45
+
46
+ def _build_config(
47
+ *,
48
+ model_name: str,
49
+ steps: int,
50
+ lr: float,
51
+ seq_len: int,
52
+ train_chars: int,
53
+ eval_chars: int,
54
+ eval_batches: int,
55
+ batch_size: int,
56
+ grad_accum: int,
57
+ seed: int,
58
+ device: str,
59
+ ) -> runtime.BenchmarkConfig:
60
+ config = runtime.BenchmarkConfig()
61
+ config.model_name = str(model_name).strip() or "Qwen/Qwen2.5-0.5B"
62
+ config.device = str(device or _device_default())
63
+ config.max_steps = int(steps)
64
+ config.lr = float(lr)
65
+ config.seq_len = int(seq_len)
66
+ config.batch_size = int(batch_size)
67
+ config.grad_accum = int(grad_accum)
68
+ config.warmup_steps = min(5, max(0, int(steps) // 5))
69
+ config.eval_every = max(1, min(int(steps), 10))
70
+ config.eval_batches = int(eval_batches)
71
+ config.max_chars = int(train_chars)
72
+ config.eval_chars = int(eval_chars)
73
+ config.full_wikitext_train = False
74
+ config.full_wikitext_eval = False
75
+ config.full_validation_ppl = False
76
+ config.schedule_mode = "all_cosine"
77
+ config.lora_r = 8
78
+ config.lora_alpha = 32
79
+ config.lora_dropout = 0.05
80
+ config.lbw_stats_freq = 5
81
+ config.lbw_stress_th = 1.1
82
+ config.lbw_spike_th = 1.5
83
+ config.lbw_rec_fast = 0.01
84
+ config.lbw_ema_decay = 0.95
85
+ config.use_wandb = False
86
+ config.use_lbwgov = False
87
+ config.print_all_metrics = True
88
+ config.output_dir = str((RUNS_DIR / f"run_{int(time.time())}").resolve())
89
+ config.use_lm_eval = False
90
+ config.lm_eval_ppl = False
91
+ config.lm_eval_acc = False
92
+ runtime.set_seed(int(seed), device=config.device)
93
+ return config
94
+
95
+
96
+ def _result_row(result: dict[str, Any]) -> dict[str, Any]:
97
+ runtime_snapshot = dict(result.get("runtime_snapshot") or {})
98
+ governance_snapshot = dict(result.get("governance_snapshot") or {})
99
+ return {
100
+ "optimizer": result.get("optimizer"),
101
+ "final_eval_perplexity": _safe_float(result.get("final_eval_perp")),
102
+ "final_eval_loss": _safe_float(result.get("final_eval_loss")),
103
+ "tokens_per_sec_wall": _safe_float(result.get("avg_tokens_per_sec_wall")),
104
+ "training_wall_time_sec": _safe_float(result.get("training_wall_time_sec")),
105
+ "wall_time_sec": _safe_float(result.get("wall_time_sec")),
106
+ "scale": _safe_float(runtime_snapshot.get("scale")),
107
+ "ratio": _safe_float(runtime_snapshot.get("ratio")),
108
+ "stress_mode": runtime_snapshot.get("stress_mode"),
109
+ "intervention_count": governance_snapshot.get("intervention_count"),
110
+ "regime_switch_count": governance_snapshot.get("regime_switch_count"),
111
+ }
112
+
113
+
114
+ def _gain_rows(rows: list[dict[str, Any]]) -> list[dict[str, Any]]:
115
+ by_optimizer = {str(row.get("optimizer")): row for row in rows}
116
+ baseline = by_optimizer.get("adamw")
117
+ if baseline is None:
118
+ return []
119
+ gains = []
120
+ for row in rows:
121
+ if row.get("optimizer") == "adamw":
122
+ continue
123
+ baseline_ppl = _safe_float(baseline.get("final_eval_perplexity"))
124
+ candidate_ppl = _safe_float(row.get("final_eval_perplexity"))
125
+ baseline_tps = _safe_float(baseline.get("tokens_per_sec_wall"))
126
+ candidate_tps = _safe_float(row.get("tokens_per_sec_wall"))
127
+ ppl_gain = None if baseline_ppl is None or candidate_ppl is None else baseline_ppl - candidate_ppl
128
+ speedup = None if baseline_tps in (None, 0.0) or candidate_tps is None else candidate_tps / baseline_tps
129
+ gains.append(
130
+ {
131
+ "optimizer": row.get("optimizer"),
132
+ "eval_perplexity_gain_vs_adamw": ppl_gain,
133
+ "eval_perplexity_pct_gain_vs_adamw": (
134
+ None if baseline_ppl in (None, 0.0) or candidate_ppl is None else (baseline_ppl - candidate_ppl) / baseline_ppl
135
+ ),
136
+ "wall_tokens_per_sec_speedup_vs_adamw": speedup,
137
+ }
138
+ )
139
+ return gains
140
+
141
+
142
+ def _write_csv(path: Path, rows: list[dict[str, Any]]) -> None:
143
+ if not rows:
144
+ path.write_text("", encoding="utf-8")
145
+ return
146
+ with path.open("w", encoding="utf-8", newline="") as handle:
147
+ writer = csv.DictWriter(handle, fieldnames=list(rows[0].keys()))
148
+ writer.writeheader()
149
+ writer.writerows(rows)
150
+
151
+
152
+ def run_demo(
153
+ model_name: str,
154
+ steps: int,
155
+ lr: float,
156
+ seq_len: int,
157
+ train_chars: int,
158
+ eval_chars: int,
159
+ eval_batches: int,
160
+ batch_size: int,
161
+ grad_accum: int,
162
+ seed: int,
163
+ run_lbw_guard: bool,
164
+ ) -> tuple[str, str | None, str | None]:
165
+ if not run_lbw_guard:
166
+ optimizers = ["adamw"]
167
+ else:
168
+ optimizers = ["adamw", "lbw_guard"]
169
+ device = _device_default()
170
+ config = _build_config(
171
+ model_name=model_name,
172
+ steps=steps,
173
+ lr=lr,
174
+ seq_len=seq_len,
175
+ train_chars=train_chars,
176
+ eval_chars=eval_chars,
177
+ eval_batches=eval_batches,
178
+ batch_size=batch_size,
179
+ grad_accum=grad_accum,
180
+ seed=seed,
181
+ device=device,
182
+ )
183
+ run_dir = Path(config.output_dir)
184
+ run_dir.mkdir(parents=True, exist_ok=True)
185
+
186
+ log_buffer = io.StringIO()
187
+ try:
188
+ results = []
189
+ with redirect_stdout(log_buffer):
190
+ for optimizer_name in optimizers:
191
+ normalized = runtime.normalize_optimizer_name(optimizer_name)
192
+ ok, reason = runtime.check_optimizer_support(normalized, device=config.device)
193
+ if not ok:
194
+ raise RuntimeError(f"{normalized}: {reason}")
195
+ runtime.set_seed(int(seed), device=config.device)
196
+ run_config = runtime.BenchmarkConfig(**config.__dict__)
197
+ run_name = f"{normalized}_{int(time.time())}"
198
+ result = runtime.train_one_run(
199
+ normalized,
200
+ run_config,
201
+ group_name="LBW-Guard-HF-Direct-Runner",
202
+ run_name=run_name,
203
+ )
204
+ result["optimizer"] = normalized
205
+ results.append(result)
206
+ gc.collect()
207
+ if torch.cuda.is_available():
208
+ torch.cuda.empty_cache()
209
+
210
+ rows = [_result_row(result) for result in results]
211
+ gains = _gain_rows(rows)
212
+ payload = {
213
+ "config": {
214
+ "model_name": model_name,
215
+ "device": device,
216
+ "steps": int(steps),
217
+ "lr": float(lr),
218
+ "seq_len": int(seq_len),
219
+ "train_chars": int(train_chars),
220
+ "eval_chars": int(eval_chars),
221
+ "eval_batches": int(eval_batches),
222
+ "batch_size": int(batch_size),
223
+ "grad_accum": int(grad_accum),
224
+ "seed": int(seed),
225
+ },
226
+ "results": results,
227
+ "rows": rows,
228
+ "gains": gains,
229
+ }
230
+ json_path = run_dir / "lbw_guard_direct_runner_results.json"
231
+ csv_path = run_dir / "lbw_guard_direct_runner_metrics.csv"
232
+ gains_path = run_dir / "lbw_guard_direct_runner_gains.csv"
233
+ json_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
234
+ _write_csv(csv_path, rows)
235
+ _write_csv(gains_path, gains)
236
+
237
+ summary = [
238
+ f"Device: `{device}`",
239
+ "",
240
+ "## Metrics",
241
+ "",
242
+ "| Optimizer | Final Eval PPL | Final Eval Loss | Wall Tokens/s | Wall Time (s) | Scale | Ratio | Stress Mode |",
243
+ "| --- | --- | --- | --- | --- | --- | --- | --- |",
244
+ ]
245
+ for row in rows:
246
+ summary.append(
247
+ "| {optimizer} | {ppl:.4f} | {loss:.4f} | {tps:.2f} | {wall:.2f} | {scale:.4f} | {ratio:.4f} | {stress} |".format(
248
+ optimizer=row.get("optimizer"),
249
+ ppl=float(row.get("final_eval_perplexity") or 0.0),
250
+ loss=float(row.get("final_eval_loss") or 0.0),
251
+ tps=float(row.get("tokens_per_sec_wall") or 0.0),
252
+ wall=float(row.get("wall_time_sec") or 0.0),
253
+ scale=float(row.get("scale") or 0.0),
254
+ ratio=float(row.get("ratio") or 0.0),
255
+ stress=row.get("stress_mode") or "-",
256
+ )
257
+ )
258
+ if gains:
259
+ summary.extend(["", "## Gains vs AdamW", ""])
260
+ for gain in gains:
261
+ pct = _safe_float(gain.get("eval_perplexity_pct_gain_vs_adamw"))
262
+ speedup = _safe_float(gain.get("wall_tokens_per_sec_speedup_vs_adamw"))
263
+ summary.append(
264
+ f"- `{gain.get('optimizer')}` PPL gain: `{_safe_float(gain.get('eval_perplexity_gain_vs_adamw'))}`, "
265
+ f"PPL pct gain: `{pct * 100.0:.2f}%`" if pct is not None else f"- `{gain.get('optimizer')}` PPL pct gain unavailable."
266
+ )
267
+ if speedup is not None:
268
+ summary.append(f"- `{gain.get('optimizer')}` wall tokens/s speedup: `{speedup:.3f}x`.")
269
+ summary.extend(["", "## Runtime Log", "", "```text", log_buffer.getvalue()[-8000:], "```"])
270
+ return "\n".join(summary), str(json_path), str(csv_path)
271
+ except Exception:
272
+ error_text = traceback.format_exc()
273
+ error_path = run_dir / "error.txt"
274
+ error_path.write_text(error_text + "\n\n" + log_buffer.getvalue(), encoding="utf-8")
275
+ return f"Run failed.\n\n```text\n{error_text}\n```", str(error_path), None
276
+
277
+
278
+ INTRO = """
279
+ # LBW Guard Direct Runner
280
+
281
+ Run a compact AdamW vs `lbw_guard` LoRA smoke test directly inside this Hugging Face Space.
282
+
283
+ Use GPU hardware for real runs. CPU mode is best treated as an import/build check.
284
+ """
285
+
286
+
287
+ with gr.Blocks(title="LBW Guard Direct Runner") as demo:
288
+ gr.Markdown(INTRO)
289
+ with gr.Row():
290
+ model_name = gr.Textbox(value="Qwen/Qwen2.5-0.5B", label="Model")
291
+ run_lbw_guard = gr.Checkbox(value=True, label="Run LBW Guard comparison")
292
+ with gr.Row():
293
+ steps = gr.Slider(1, 100, value=5, step=1, label="Optimizer steps")
294
+ lr = gr.Number(value=5e-4, label="Learning rate")
295
+ seed = gr.Number(value=42, precision=0, label="Seed")
296
+ with gr.Row():
297
+ seq_len = gr.Dropdown([64, 128, 256], value=64, label="Sequence length")
298
+ batch_size = gr.Slider(1, 4, value=1, step=1, label="Batch size")
299
+ grad_accum = gr.Slider(1, 8, value=2, step=1, label="Gradient accumulation")
300
+ with gr.Row():
301
+ train_chars = gr.Slider(10_000, 500_000, value=50_000, step=10_000, label="Train char cap")
302
+ eval_chars = gr.Slider(5_000, 200_000, value=20_000, step=5_000, label="Eval char cap")
303
+ eval_batches = gr.Slider(1, 20, value=4, step=1, label="Eval batches")
304
+ run_button = gr.Button("Run Direct Smoke Test", variant="primary")
305
+ summary = gr.Markdown()
306
+ json_file = gr.File(label="Raw JSON")
307
+ metrics_file = gr.File(label="Metrics CSV")
308
+
309
+ run_button.click(
310
+ fn=run_demo,
311
+ inputs=[
312
+ model_name,
313
+ steps,
314
+ lr,
315
+ seq_len,
316
+ train_chars,
317
+ eval_chars,
318
+ eval_batches,
319
+ batch_size,
320
+ grad_accum,
321
+ seed,
322
+ run_lbw_guard,
323
+ ],
324
+ outputs=[summary, json_file, metrics_file],
325
+ )
326
+
327
+
328
+ if __name__ == "__main__":
329
+ demo.queue(default_concurrency_limit=1).launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ datasets
4
+ peft
5
+ accelerate
6
+ LBW-Guard