from __future__ import annotations import csv import gc import json import math import os import random import time import traceback from pathlib import Path from typing import Any ROOT = Path(__file__).resolve().parent os.environ.setdefault("HF_HOME", str((ROOT / ".hf_cache").resolve())) os.environ.setdefault("HF_DATASETS_CACHE", str((ROOT / ".hf_cache" / "datasets").resolve())) os.environ.setdefault("TRANSFORMERS_CACHE", str((ROOT / ".hf_cache" / "transformers").resolve())) os.environ.setdefault("WANDB_DIR", str((ROOT / ".wandb").resolve())) os.environ.setdefault("WANDB_CACHE_DIR", str((ROOT / ".wandb" / "cache").resolve())) os.environ.setdefault("WANDB_CONFIG_DIR", str((ROOT / ".wandb" / "config").resolve())) os.environ.setdefault("DISABLE_SAFETENSORS_CONVERSION", "1") import gradio as gr import torch from datasets import load_dataset from peft import LoraConfig, TaskType, get_peft_model from transformers import AutoModelForCausalLM, AutoTokenizer try: import lbw except Exception as exc: # pragma: no cover - shown in the Space UI. lbw = None LBW_IMPORT_ERROR = exc else: LBW_IMPORT_ERROR = None RUNS_DIR = ROOT / "runs" def _device_default() -> str: return "cuda" if torch.cuda.is_available() else "cpu" def _set_seed(seed: int) -> None: random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def _safe_float(value: Any) -> float | None: if value is None: return None try: out = float(value) except Exception: return None if not math.isfinite(out): return None return out def _fmt_float(value: Any, digits: int = 4) -> str: number = _safe_float(value) return "-" if number is None else f"{number:.{digits}f}" def _append_log(logs: list[str], message: str) -> None: logs.append(message) print(message, flush=True) def _build_wikitext_chunks( tokenizer, *, split: str, max_chars: int | None, seq_len: int, logs: list[str], ) -> dict[str, Any]: cap = None if max_chars is None else int(max_chars) _append_log( logs, f"Preparing WikiText split={split!r}" + (f" with char cap {cap:,}" if cap is not None else " with full split"), ) ds = load_dataset("wikitext", "wikitext-103-raw-v1", split=split) pieces: list[str] = [] chars_used = 0 rows_used = 0 first_piece = True for row in ds: text = str(row.get("text", "") or "") if not text.strip(): continue piece = text if first_piece else " " + text if cap is not None: remain = cap - chars_used if remain <= 0: break if len(piece) > remain: piece = piece[:remain] pieces.append(piece) chars_used += len(piece) rows_used += 1 first_piece = False if cap is not None and chars_used >= cap: break token_ids = tokenizer("".join(pieces), add_special_tokens=False)["input_ids"] ids = torch.tensor(token_ids, dtype=torch.long) sequence_count = ids.numel() // int(seq_len) if sequence_count <= 0: raise RuntimeError("Not enough tokens. Increase the train/eval char cap or reduce sequence length.") ids = ids[: sequence_count * int(seq_len)].view(sequence_count, int(seq_len)).contiguous() _append_log( logs, f"Prepared split={split!r}: {chars_used:,} chars across {rows_used:,} rows -> {ids.size(0):,} sequences", ) return {"input_ids": ids, "chars": chars_used, "rows": rows_used, "cap": cap} def _batch_iter(chunks: dict[str, Any], *, batch_size: int, device: torch.device): ids = chunks["input_ids"] i = 0 while True: if i + int(batch_size) > ids.size(0): i = 0 batch = ids[i : i + int(batch_size)].to(device, non_blocking=True) i += int(batch_size) yield batch def _load_lora_model( *, model_name: str, device: torch.device, lora_r: int, lora_alpha: int, lora_dropout: float, ): dtype = torch.float16 if device.type == "cuda" else torch.float32 model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=dtype, low_cpu_mem_usage=True, ) if getattr(model.config, "use_cache", None) is not None: model.config.use_cache = False model.to(device) lora_cfg = LoraConfig( r=int(lora_r), lora_alpha=int(lora_alpha), lora_dropout=float(lora_dropout), target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], task_type=TaskType.CAUSAL_LM, bias="none", ) return get_peft_model(model, lora_cfg) def _make_optimizer( name: str, model, *, lr: float, betas: tuple[float, float], weight_decay: float, lbw_stats_freq: int, lbw_stress_th: float, lbw_spike_th: float, lbw_rec_fast: float, lbw_ema_decay: float, ): params = [param for param in model.parameters() if param.requires_grad] if name == "adamw": return torch.optim.AdamW(params, lr=float(lr), betas=betas, weight_decay=float(weight_decay)) if name == "lbw_guard": if lbw is None: raise RuntimeError(f"LBW Guard package import failed: {LBW_IMPORT_ERROR}") return lbw.Guard( params, lr=float(lr), betas=betas, weight_decay=float(weight_decay), mode="eval", auto_enabled=True, stats_freq=int(lbw_stats_freq), stress_threshold=float(lbw_stress_th), spike_threshold=float(lbw_spike_th), recovery_fast=float(lbw_rec_fast), ema_decay=float(lbw_ema_decay), use_max_rms=True, ) raise ValueError(f"Unknown optimizer: {name}") @torch.no_grad() def _evaluate_ppl( model, eval_chunks: dict[str, Any], *, batch_size: int, eval_batches: int, device: torch.device, full_pass: bool, ) -> tuple[float, float]: model.eval() ids = eval_chunks["input_ids"] max_sequences = ids.size(0) if full_pass else min(ids.size(0), int(eval_batches) * int(batch_size)) losses: list[float] = [] for start in range(0, max_sequences, int(batch_size)): xb = ids[start : start + int(batch_size)].to(device, non_blocking=True) with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=(device.type == "cuda")): loss = model(input_ids=xb, labels=xb).loss losses.append(float(loss.detach().cpu())) avg_loss = sum(losses) / max(len(losses), 1) return avg_loss, math.exp(min(avg_loss, 20.0)) def _optimizer_state(opt) -> dict[str, Any]: state = dict(getattr(opt, "state", {}).get("lbw", {}) or {}) return { "scale": float(state.get("scale", state.get("lbw_scale", 1.0))), "ratio": float(state.get("ratio", 1.0)), "stress_mode": str(state.get("stress_mode", "none")), } def _status_markdown( *, device_name: str, rows: list[dict[str, Any]], logs: list[str], phase: str, ) -> str: summary = [ f"Device: `{device_name}`", "", f"Status: {phase}", "", "## Results", "", "| Optimizer | Final Eval PPL | Final Eval Loss | Scope | Scale | Ratio | Stress Mode | Wall Time (s) |", "| --- | --- | --- | --- | --- | --- | --- | --- |", ] if rows: for row in rows: summary.append( "| {optimizer} | {ppl} | {loss} | {scope} | {scale} | {ratio} | {stress} | {wall} |".format( optimizer=row.get("optimizer"), ppl=_fmt_float(row.get("final_eval_ppl")), loss=_fmt_float(row.get("final_eval_loss")), scope=row.get("final_eval_scope") or "-", scale=_fmt_float(row.get("scale")), ratio=_fmt_float(row.get("ratio")), stress=row.get("stress_mode") or "-", wall=_fmt_float(row.get("wall_time_sec"), digits=2), ) ) else: summary.append("| - | - | - | - | - | - | - | - |") gains = _gain_rows(rows) if gains: summary.extend(["", "## LBW vs AdamW", ""]) for gain in gains: pct = _safe_float(gain.get("eval_perplexity_pct_gain_vs_adamw")) wall_speedup = _safe_float(gain.get("wall_time_speedup_vs_adamw")) summary.append( f"- `{gain.get('optimizer')}` PPL gain vs AdamW: `{_fmt_float(gain.get('eval_perplexity_gain_vs_adamw'))}`" + (f" (`{pct * 100.0:.2f}%`)." if pct is not None else ".") ) if wall_speedup is not None: summary.append(f"- `{gain.get('optimizer')}` wall-time speedup vs AdamW: `{wall_speedup:.3f}x`.") summary.extend(["", "## Runtime Log", "", "```text", "\n".join(logs[-80:]), "```"]) return "\n".join(summary) def _run_one_optimizer_events( *, optimizer_name: str, model_name: str, train_chunks: dict[str, Any], eval_chunks: dict[str, Any], device: torch.device, seed: int, max_steps: int, eval_every: int, eval_batches: int, seq_len: int, batch_size: int, lr: float, betas: tuple[float, float], weight_decay: float, full_validation_ppl: bool, lora_r: int, lora_alpha: int, lora_dropout: float, lbw_stats_freq: int, lbw_stress_th: float, lbw_spike_th: float, lbw_rec_fast: float, lbw_ema_decay: float, logs: list[str], ): _set_seed(int(seed)) _append_log(logs, f"Loading {model_name} with LoRA for {optimizer_name}.") model = _load_lora_model( model_name=model_name, device=device, lora_r=lora_r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, ) model.train() opt = _make_optimizer( optimizer_name, model, lr=lr, betas=betas, weight_decay=weight_decay, lbw_stats_freq=lbw_stats_freq, lbw_stress_th=lbw_stress_th, lbw_spike_th=lbw_spike_th, lbw_rec_fast=lbw_rec_fast, lbw_ema_decay=lbw_ema_decay, ) train_batches = _batch_iter(train_chunks, batch_size=batch_size, device=device) start_time = time.time() last_loss = None last_eval_loss = None last_eval_ppl = None state = _optimizer_state(opt) trainable_params = [param for param in model.parameters() if param.requires_grad] for step in range(1, int(max_steps) + 1): xb = next(train_batches) with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=(device.type == "cuda")): loss = model(input_ids=xb, labels=xb).loss loss.backward() torch.nn.utils.clip_grad_norm_(trainable_params, 1.0) opt.step() opt.zero_grad(set_to_none=True) last_loss = float(loss.detach().cpu()) state = _optimizer_state(opt) if step == 1 or step == int(max_steps) or step % int(eval_every) == 0: last_eval_loss, last_eval_ppl = _evaluate_ppl( model, eval_chunks, batch_size=batch_size, eval_batches=eval_batches, device=device, full_pass=False, ) message = ( f"{optimizer_name} step {step}/{int(max_steps)}: " f"loss={last_loss:.4f}, sampled_eval_ppl={last_eval_ppl:.4f}, " f"scale={state['scale']:.4f}, ratio={state['ratio']:.4f}" ) _append_log(logs, message) yield {"type": "progress", "message": message} model.train() final_full_pass = bool(full_validation_ppl) if final_full_pass and eval_chunks["cap"] is None: final_scope = "full_wikitext" elif final_full_pass: final_scope = "full_loaded_subset" else: final_scope = "sampled" _append_log(logs, f"Running final {final_scope} validation PPL for {optimizer_name}.") final_loss, final_ppl = _evaluate_ppl( model, eval_chunks, batch_size=batch_size, eval_batches=eval_batches, device=device, full_pass=final_full_pass, ) state = _optimizer_state(opt) wall_time = time.time() - start_time result = { "optimizer": optimizer_name, "final_eval_ppl": final_ppl, "final_eval_loss": final_loss, "final_eval_scope": final_scope, "train_chars": train_chunks["chars"], "eval_chars": eval_chunks["chars"], "train_sequences": int(train_chunks["input_ids"].size(0)), "eval_sequences": int(eval_chunks["input_ids"].size(0)), "tokens_per_step": int(batch_size) * int(seq_len), "last_train_loss": last_loss, "last_sampled_eval_loss": last_eval_loss, "last_sampled_eval_ppl": last_eval_ppl, "scale": state["scale"], "ratio": state["ratio"], "stress_mode": state["stress_mode"], "wall_time_sec": wall_time, } del model, opt gc.collect() if device.type == "cuda": torch.cuda.empty_cache() yield {"type": "result", "result": result} def _gain_rows(rows: list[dict[str, Any]]) -> list[dict[str, Any]]: by_optimizer = {str(row.get("optimizer")): row for row in rows} baseline = by_optimizer.get("adamw") if baseline is None: return [] baseline_ppl = _safe_float(baseline.get("final_eval_ppl")) baseline_wall = _safe_float(baseline.get("wall_time_sec")) gains: list[dict[str, Any]] = [] for row in rows: if row.get("optimizer") == "adamw": continue candidate_ppl = _safe_float(row.get("final_eval_ppl")) candidate_wall = _safe_float(row.get("wall_time_sec")) gains.append( { "optimizer": row.get("optimizer"), "eval_perplexity_gain_vs_adamw": ( None if baseline_ppl is None or candidate_ppl is None else baseline_ppl - candidate_ppl ), "eval_perplexity_pct_gain_vs_adamw": ( None if baseline_ppl in (None, 0.0) or candidate_ppl is None else (baseline_ppl - candidate_ppl) / baseline_ppl ), "wall_time_speedup_vs_adamw": ( None if baseline_wall in (None, 0.0) or candidate_wall in (None, 0.0) else baseline_wall / candidate_wall ), } ) return gains def _write_csv(path: Path, rows: list[dict[str, Any]]) -> None: if not rows: path.write_text("", encoding="utf-8") return with path.open("w", encoding="utf-8", newline="") as handle: writer = csv.DictWriter(handle, fieldnames=list(rows[0].keys())) writer.writeheader() writer.writerows(rows) def _set_lr(opt, value: float) -> None: for group in getattr(opt, "param_groups", []) or []: group["lr"] = float(value) def _scheduled_lr(cfg: dict[str, Any], step: int) -> float: base_lr = float(cfg["lr"]) warmup = max(int(cfg.get("warmup_steps", 0)), 0) max_steps = max(int(cfg["max_steps"]), 1) if warmup > 0 and int(step) <= warmup: return base_lr * float(step) / float(warmup) mode = str(cfg.get("schedule_mode", "constant")).strip().lower() if mode == "cosine": progress = (int(step) - warmup) / max(max_steps - warmup, 1) progress = min(max(progress, 0.0), 1.0) return base_lr * 0.5 * (1.0 + math.cos(math.pi * progress)) return base_lr def _parse_float_sweep(text: str, default: list[float]) -> list[float]: raw = str(text or "").replace("\n", ",").replace(";", ",").split(",") values: list[float] = [] for item in raw: item = item.strip() if not item: continue values.append(float(item)) return values or list(default) def _parse_int_sweep(text: str, default: list[int]) -> list[int]: return [int(value) for value in _parse_float_sweep(text, [float(item) for item in default])] def run_easy_test( model_name: str, run_lbw_guard: bool, max_steps: int, eval_every: int, eval_batches: int, seq_len: int, batch_size: int, train_chars: int, eval_chars: int, full_wikitext_train: bool, full_wikitext_eval: bool, full_validation_ppl: bool, lr: float, seed: int, ): logs: list[str] = [] rows: list[dict[str, Any]] = [] run_dir = RUNS_DIR / f"easy_test_{int(time.time())}" run_dir.mkdir(parents=True, exist_ok=True) device_name = _device_default() device = torch.device(device_name) optimizers = ["adamw", "lbw_guard"] if bool(run_lbw_guard) else ["adamw"] try: if device.type == "cpu" and ( int(max_steps) > 1 or int(train_chars) > 20_000 or int(eval_chars) > 8_000 or bool(full_wikitext_train) or bool(full_wikitext_eval) or bool(full_validation_ppl) ): yield ( "This Space is currently on `cpu-basic`. CPU mode is capped to 1 step, 20k train chars, " "8k eval chars, and sampled validation. Switch the Space hardware to GPU for the Quick Comparison defaults.", None, None, None, ) return if device.type == "cuda" and bool(run_lbw_guard) and torch.cuda.device_count() > 1: yield ( "LBW Guard should run with one visible GPU. Set the Space to single-GPU hardware or restrict CUDA_VISIBLE_DEVICES.", None, None, None, ) return _append_log(logs, f"Device: {device_name}") if device.type == "cuda": _append_log(logs, f"GPU: {torch.cuda.get_device_name(0)}") _append_log(logs, f"Optimizers: {', '.join(optimizers)}") yield _status_markdown(device_name=device_name, rows=rows, logs=logs, phase="Loading tokenizer"), None, None, None _set_seed(int(seed)) resolved_model = str(model_name).strip() or "TinyLlama/TinyLlama_v1.1" tokenizer = AutoTokenizer.from_pretrained(resolved_model, use_fast=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token train_cap = None if bool(full_wikitext_train) else int(train_chars) eval_cap = None if bool(full_wikitext_eval) else int(eval_chars) train_chunks = _build_wikitext_chunks( tokenizer, split="train", max_chars=train_cap, seq_len=int(seq_len), logs=logs, ) yield _status_markdown(device_name=device_name, rows=rows, logs=logs, phase="Prepared train split"), None, None, None eval_chunks = _build_wikitext_chunks( tokenizer, split="validation", max_chars=eval_cap, seq_len=int(seq_len), logs=logs, ) yield _status_markdown(device_name=device_name, rows=rows, logs=logs, phase="Prepared validation split"), None, None, None for optimizer_name in optimizers: _append_log(logs, f"=== {optimizer_name} ===") yield _status_markdown( device_name=device_name, rows=rows, logs=logs, phase=f"Running {optimizer_name}", ), None, None, None for event in _run_one_optimizer_events( optimizer_name=optimizer_name, model_name=resolved_model, train_chunks=train_chunks, eval_chunks=eval_chunks, device=device, seed=int(seed), max_steps=int(max_steps), eval_every=max(1, int(eval_every)), eval_batches=int(eval_batches), seq_len=int(seq_len), batch_size=int(batch_size), lr=float(lr), betas=(0.9, 0.999), weight_decay=0.01, full_validation_ppl=bool(full_validation_ppl), lora_r=8, lora_alpha=16, lora_dropout=0.05, lbw_stats_freq=10, lbw_stress_th=1.1, lbw_spike_th=1.5, lbw_rec_fast=0.01, lbw_ema_decay=0.95, logs=logs, ): if event.get("type") == "result": rows.append(event["result"]) yield _status_markdown( device_name=device_name, rows=rows, logs=logs, phase=f"Running {optimizer_name}", ), None, None, None gains = _gain_rows(rows) payload = { "source": "HF Quick Comparison Runner", "based_on_colab": "LBW_Guard_Easy_Test_COLAB.ipynb", "config": { "model_name": resolved_model, "device": device_name, "optimizers": optimizers, "seed": int(seed), "max_steps": int(max_steps), "eval_every": int(eval_every), "eval_batches": int(eval_batches), "seq_len": int(seq_len), "batch_size": int(batch_size), "max_chars": train_cap, "eval_chars": eval_cap, "full_wikitext_train": bool(full_wikitext_train), "full_wikitext_eval": bool(full_wikitext_eval), "full_validation_ppl": bool(full_validation_ppl), "lr": float(lr), "betas": [0.9, 0.999], "weight_decay": 0.01, "lora_r": 8, "lora_alpha": 16, "lora_dropout": 0.05, "lbw_stats_freq": 10, "lbw_stress_th": 1.1, "lbw_spike_th": 1.5, "lbw_rec_fast": 0.01, "lbw_ema_decay": 0.95, }, "results": rows, "gains": gains, "logs": logs, } json_path = run_dir / "lbw_guard_hf_quick_comparison_results.json" csv_path = run_dir / "lbw_guard_hf_quick_comparison_results.csv" gains_path = run_dir / "lbw_guard_hf_quick_comparison_gains.csv" json_path.write_text(json.dumps(payload, indent=2), encoding="utf-8") _write_csv(csv_path, rows) _write_csv(gains_path, gains) _append_log(logs, f"Wrote {csv_path}") yield ( _status_markdown(device_name=device_name, rows=rows, logs=logs, phase="Complete"), str(json_path), str(csv_path), str(gains_path), ) except Exception: error_text = traceback.format_exc() error_path = run_dir / "error.txt" error_path.write_text(error_text + "\n\n" + "\n".join(logs), encoding="utf-8") yield f"Run failed.\n\n```text\n{error_text}\n```", str(error_path), None, None def _make_ablation_scenario(slug: str, label: str, note: str, base_config: dict[str, Any], overrides=None): cfg = dict(base_config) if overrides: cfg.update(overrides) return { "slug": slug, "label": label, "note": note, "config": cfg, } def _build_ablation_scenarios( *, selected_ablations: list[str], base_config: dict[str, Any], lr_sweep: list[float], step_sweep: list[int], lora_r_sweep: list[int], ) -> list[dict[str, Any]]: selected = {str(item).strip().lower() for item in selected_ablations if str(item).strip()} if not selected: selected = {"optimizer"} scenarios: list[dict[str, Any]] = [] if "optimizer" in selected: scenarios.append( _make_ablation_scenario( "optimizer-adamw-vs-lbw-guard", "Optimizer: AdamW vs lbw_guard", "Direct optimizer comparison with the base config.", base_config, ) ) if "lr" in selected: for lr in lr_sweep: scenarios.append( _make_ablation_scenario( f"lr-{lr:g}", f"Learning Rate: {lr:g}", "Learning-rate sensitivity check.", base_config, {"lr": float(lr)}, ) ) if "schedule" in selected: for mode in ["constant", "cosine"]: scenarios.append( _make_ablation_scenario( f"schedule-{mode}", f"Schedule: {mode}", "Scheduler-shape sensitivity check.", base_config, {"schedule_mode": mode}, ) ) if "steps" in selected: for steps in step_sweep: scenarios.append( _make_ablation_scenario( f"steps-{steps}", f"Steps: {steps}", "Training-length sensitivity check.", base_config, {"max_steps": int(steps), "eval_every": max(1, int(steps) // 4)}, ) ) if "data" in selected: for item in [ {"max_chars": 20_000, "eval_chars": 8_000, "label": "small-data"}, {"max_chars": 80_000, "eval_chars": 20_000, "label": "larger-data"}, ]: scenarios.append( _make_ablation_scenario( item["label"], f"Data Slice: {item['label']}", "WikiText slice-size sensitivity check.", base_config, {"max_chars": int(item["max_chars"]), "eval_chars": int(item["eval_chars"])}, ) ) if "lora" in selected: for rank in lora_r_sweep: scenarios.append( _make_ablation_scenario( f"lora-r{rank}", f"LoRA Rank: {rank}", "Adapter-capacity sensitivity check.", base_config, {"lora_r": int(rank), "lora_alpha": int(rank) * 2}, ) ) if not scenarios: raise ValueError("No scenarios selected. Choose optimizer, lr, schedule, steps, data, or lora.") return scenarios def _ablation_status_markdown( *, device_name: str, rows: list[dict[str, Any]], logs: list[str], phase: str, plan: list[dict[str, Any]], ) -> str: summary = [ f"Device: `{device_name}`", "", f"Status: {phase}", "", "## Plan", "", "| Scenario | Steps | LR | Schedule | Train Chars | Eval Chars | LoRA r |", "| --- | --- | --- | --- | --- | --- | --- |", ] for item in plan: cfg = item["config"] summary.append( "| {label} | {steps} | {lr:g} | {schedule} | {train_chars} | {eval_chars} | {lora_r} |".format( label=item["label"], steps=int(cfg["max_steps"]), lr=float(cfg["lr"]), schedule=cfg["schedule_mode"], train_chars="FULL" if cfg["full_wikitext_train"] else int(cfg["max_chars"]), eval_chars="FULL" if cfg["full_wikitext_eval"] else int(cfg["eval_chars"]), lora_r=int(cfg["lora_r"]), ) ) summary.extend( [ "", "## Metrics", "", "| Scenario | Optimizer | Final Eval PPL | Final Eval Loss | Tokens/s | Scale | Ratio | Stress Mode |", "| --- | --- | --- | --- | --- | --- | --- | --- |", ] ) if rows: for row in rows: summary.append( "| {scenario} | {optimizer} | {ppl} | {loss} | {tps} | {scale} | {ratio} | {stress} |".format( scenario=row.get("scenario"), optimizer=row.get("optimizer"), ppl=_fmt_float(row.get("final_eval_ppl")), loss=_fmt_float(row.get("final_eval_loss")), tps=_fmt_float(row.get("tokens_per_sec_wall"), digits=2), scale=_fmt_float(row.get("scale")), ratio=_fmt_float(row.get("ratio")), stress=row.get("stress_mode") or "-", ) ) else: summary.append("| - | - | - | - | - | - | - | - |") gains = _build_ablation_gain_rows(rows) if gains: summary.extend(["", "## LBW vs AdamW", ""]) for gain in gains: summary.append( f"- `{gain.get('scenario')}`: `{gain.get('optimizer')}` " f"PPL gain `{_fmt_float(gain.get('ppl_gain_pct_vs_adamw'))}%`, " f"loss gain `{_fmt_float(gain.get('loss_gain_pct_vs_adamw'))}%`, " f"speed gain `{_fmt_float(gain.get('speed_gain_pct_vs_adamw'))}%`." ) summary.extend(["", "## Runtime Log", "", "```text", "\n".join(logs[-100:]), "```"]) return "\n".join(summary) def _run_ablation_optimizer_events( *, scenario_item: dict[str, Any], optimizer_name: str, model_name: str, train_chunks: dict[str, Any], eval_chunks: dict[str, Any], device: torch.device, logs: list[str], ): cfg = scenario_item["config"] _set_seed(int(cfg["seed"])) _append_log(logs, f"Loading {model_name} with LoRA for {scenario_item['slug']} / {optimizer_name}.") model = _load_lora_model( model_name=model_name, device=device, lora_r=int(cfg["lora_r"]), lora_alpha=int(cfg["lora_alpha"]), lora_dropout=float(cfg["lora_dropout"]), ) model.train() opt = _make_optimizer( optimizer_name, model, lr=float(cfg["lr"]), betas=tuple(cfg["betas"]), weight_decay=float(cfg["weight_decay"]), lbw_stats_freq=int(cfg["lbw_stats_freq"]), lbw_stress_th=float(cfg["lbw_stress_th"]), lbw_spike_th=float(cfg["lbw_spike_th"]), lbw_rec_fast=float(cfg["lbw_rec_fast"]), lbw_ema_decay=float(cfg["lbw_ema_decay"]), ) train_batches = _batch_iter(train_chunks, batch_size=int(cfg["batch_size"]), device=device) trainable_params = [param for param in model.parameters() if param.requires_grad] start_time = time.time() losses: list[float] = [] eval_loss = None eval_ppl = None last_lr = float(cfg["lr"]) state = _optimizer_state(opt) for step in range(1, int(cfg["max_steps"]) + 1): last_lr = _scheduled_lr(cfg, step) _set_lr(opt, last_lr) xb = next(train_batches) with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=(device.type == "cuda")): loss = model(input_ids=xb, labels=xb).loss loss.backward() torch.nn.utils.clip_grad_norm_(trainable_params, 1.0) opt.step() opt.zero_grad(set_to_none=True) loss_value = float(loss.detach().cpu()) losses.append(loss_value) if step == 1 or step == int(cfg["max_steps"]) or step % int(cfg["eval_every"]) == 0: eval_loss, eval_ppl = _evaluate_ppl( model, eval_chunks, batch_size=int(cfg["batch_size"]), eval_batches=int(cfg["eval_batches"]), device=device, full_pass=False, ) state = _optimizer_state(opt) message = ( f"[{scenario_item['slug']}] {optimizer_name} step {step}/{cfg['max_steps']}: " f"loss={loss_value:.4f}, sampled_eval_ppl={eval_ppl:.4f}, " f"lr={last_lr:.2e}, scale={state['scale']:.4f}, ratio={state['ratio']:.4f}" ) _append_log(logs, message) yield {"type": "progress", "message": message} model.train() final_full_pass = bool(cfg["full_validation_ppl"]) if final_full_pass and eval_chunks["cap"] is None: final_scope = "full_wikitext" elif final_full_pass: final_scope = "full_loaded_subset" else: final_scope = "sampled" _append_log(logs, f"Running final {final_scope} validation PPL for {scenario_item['slug']} / {optimizer_name}.") final_loss, final_ppl = _evaluate_ppl( model, eval_chunks, batch_size=int(cfg["batch_size"]), eval_batches=int(cfg["eval_batches"]), device=device, full_pass=final_full_pass, ) state = _optimizer_state(opt) wall_time = max(time.time() - start_time, 1e-9) trained_tokens = int(cfg["max_steps"]) * int(cfg["batch_size"]) * int(cfg["seq_len"]) result = { "scenario_slug": scenario_item["slug"], "scenario": scenario_item["label"], "optimizer": optimizer_name, "final_eval_ppl": final_ppl, "final_eval_loss": final_loss, "train_loss_last": losses[-1] if losses else None, "last_sampled_eval_loss": eval_loss, "last_sampled_eval_ppl": eval_ppl, "final_eval_scope": final_scope, "max_steps": int(cfg["max_steps"]), "lr": float(cfg["lr"]), "scheduled_lr_last": float(last_lr), "schedule_mode": str(cfg["schedule_mode"]), "batch_size": int(cfg["batch_size"]), "seq_len": int(cfg["seq_len"]), "lora_r": int(cfg["lora_r"]), "train_chars": int(train_chunks["chars"]), "eval_chars": int(eval_chunks["chars"]), "train_sequences": int(train_chunks["input_ids"].size(0)), "eval_sequences": int(eval_chunks["input_ids"].size(0)), "scale": state["scale"], "ratio": state["ratio"], "stress_mode": state["stress_mode"], "wall_time_sec": wall_time, "tokens_per_sec_wall": trained_tokens / wall_time, } del model, opt gc.collect() if device.type == "cuda": torch.cuda.empty_cache() yield {"type": "result", "result": result} def _build_ablation_gain_rows(metrics: list[dict[str, Any]]) -> list[dict[str, Any]]: grouped: dict[str, list[dict[str, Any]]] = {} for row in metrics: grouped.setdefault(str(row.get("scenario_slug")), []).append(row) gain_rows: list[dict[str, Any]] = [] for scenario_slug, rows in grouped.items(): baseline = next((row for row in rows if row.get("optimizer") == "adamw"), None) if baseline is None: continue baseline_ppl = _safe_float(baseline.get("final_eval_ppl")) baseline_loss = _safe_float(baseline.get("final_eval_loss")) baseline_tps = _safe_float(baseline.get("tokens_per_sec_wall")) for row in rows: if row.get("optimizer") == "adamw": continue candidate_ppl = _safe_float(row.get("final_eval_ppl")) candidate_loss = _safe_float(row.get("final_eval_loss")) candidate_tps = _safe_float(row.get("tokens_per_sec_wall")) gain_rows.append( { "scenario_slug": scenario_slug, "scenario": row.get("scenario"), "optimizer": row.get("optimizer"), "adamw_final_eval_ppl": baseline_ppl, "optimizer_final_eval_ppl": candidate_ppl, "ppl_gain_pct_vs_adamw": ( None if baseline_ppl in (None, 0.0) or candidate_ppl is None else (baseline_ppl - candidate_ppl) / baseline_ppl * 100.0 ), "loss_gain_pct_vs_adamw": ( None if baseline_loss in (None, 0.0) or candidate_loss is None else (baseline_loss - candidate_loss) / baseline_loss * 100.0 ), "speed_gain_pct_vs_adamw": ( None if baseline_tps in (None, 0.0) or candidate_tps is None else (candidate_tps - baseline_tps) / baseline_tps * 100.0 ), "adamw_tokens_per_sec_wall": baseline_tps, "optimizer_tokens_per_sec_wall": candidate_tps, "lbw_scale": row.get("scale"), "lbw_ratio": row.get("ratio"), "lbw_stress_mode": row.get("stress_mode"), } ) return gain_rows def run_ablation_test( model_name: str, selected_ablations: list[str], run_lbw_guard: bool, max_steps: int, eval_every: int, eval_batches: int, seq_len: int, batch_size: int, train_chars: int, eval_chars: int, full_wikitext_train: bool, full_wikitext_eval: bool, full_validation_ppl: bool, lr: float, schedule_mode: str, warmup_steps: int, seed: int, lr_sweep_text: str, step_sweep_text: str, lora_r_sweep_text: str, ): logs: list[str] = [] rows: list[dict[str, Any]] = [] run_dir = RUNS_DIR / f"ablation_test_{int(time.time())}" run_dir.mkdir(parents=True, exist_ok=True) device_name = _device_default() device = torch.device(device_name) optimizers = ["adamw", "lbw_guard"] if bool(run_lbw_guard) else ["adamw"] try: base_config = { "seed": int(seed), "max_steps": int(max_steps), "eval_every": max(1, int(eval_every)), "eval_batches": int(eval_batches), "seq_len": int(seq_len), "batch_size": int(batch_size), "max_chars": int(train_chars), "eval_chars": int(eval_chars), "full_wikitext_train": bool(full_wikitext_train), "full_wikitext_eval": bool(full_wikitext_eval), "full_validation_ppl": bool(full_validation_ppl), "lr": float(lr), "betas": (0.9, 0.999), "weight_decay": 0.01, "warmup_steps": int(warmup_steps), "schedule_mode": str(schedule_mode or "constant").strip().lower(), "lora_r": 8, "lora_alpha": 16, "lora_dropout": 0.05, "lbw_stats_freq": 10, "lbw_stress_th": 1.1, "lbw_spike_th": 1.5, "lbw_rec_fast": 0.01, "lbw_ema_decay": 0.95, } lr_sweep = _parse_float_sweep(lr_sweep_text, [1e-3, 5e-4]) step_sweep = _parse_int_sweep(step_sweep_text, [100, 200]) lora_r_sweep = _parse_int_sweep(lora_r_sweep_text, [4, 8, 16]) scenarios = _build_ablation_scenarios( selected_ablations=list(selected_ablations or ["optimizer"]), base_config=base_config, lr_sweep=lr_sweep, step_sweep=step_sweep, lora_r_sweep=lora_r_sweep, ) if device.type == "cpu" and ( len(scenarios) > 1 or int(max_steps) > 1 or int(train_chars) > 20_000 or int(eval_chars) > 8_000 or bool(full_wikitext_train) or bool(full_wikitext_eval) or bool(full_validation_ppl) ): yield ( "This Space is currently on `cpu-basic`. CPU ablation mode is capped to one optimizer scenario, " "1 step, 20k train chars, 8k eval chars, and sampled validation. Switch the Space hardware to GPU for ablations.", None, None, None, ) return if device.type == "cuda" and bool(run_lbw_guard) and torch.cuda.device_count() > 1: yield ( "LBW Guard should run with one visible GPU. Set the Space to single-GPU hardware or restrict CUDA_VISIBLE_DEVICES.", None, None, None, ) return resolved_model = str(model_name).strip() or "Qwen/Qwen2.5-0.5B" _append_log(logs, f"Device: {device_name}") if device.type == "cuda": _append_log(logs, f"GPU: {torch.cuda.get_device_name(0)}") _append_log(logs, f"Selected ablations: {', '.join(selected_ablations or ['optimizer'])}") _append_log(logs, f"Optimizers: {', '.join(optimizers)}") yield _ablation_status_markdown( device_name=device_name, rows=rows, logs=logs, phase="Loading tokenizer", plan=scenarios, ), None, None, None tokenizer = AutoTokenizer.from_pretrained(resolved_model, use_fast=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token data_cache: dict[tuple[int, int | None, int | None], dict[str, dict[str, Any]]] = {} for scenario_item in scenarios: cfg = scenario_item["config"] train_cap = None if cfg["full_wikitext_train"] else int(cfg["max_chars"]) eval_cap = None if cfg["full_wikitext_eval"] else int(cfg["eval_chars"]) cache_key = (int(cfg["seq_len"]), train_cap, eval_cap) if cache_key not in data_cache: data_cache[cache_key] = { "train": _build_wikitext_chunks( tokenizer, split="train", max_chars=train_cap, seq_len=int(cfg["seq_len"]), logs=logs, ), "eval": _build_wikitext_chunks( tokenizer, split="validation", max_chars=eval_cap, seq_len=int(cfg["seq_len"]), logs=logs, ), } _append_log(logs, f"=== Scenario: {scenario_item['label']} ===") for optimizer_name in optimizers: _append_log(logs, f"--- {optimizer_name} ---") yield _ablation_status_markdown( device_name=device_name, rows=rows, logs=logs, phase=f"Running {scenario_item['label']} / {optimizer_name}", plan=scenarios, ), None, None, None for event in _run_ablation_optimizer_events( scenario_item=scenario_item, optimizer_name=optimizer_name, model_name=resolved_model, train_chunks=data_cache[cache_key]["train"], eval_chunks=data_cache[cache_key]["eval"], device=device, logs=logs, ): if event.get("type") == "result": rows.append(event["result"]) yield _ablation_status_markdown( device_name=device_name, rows=rows, logs=logs, phase=f"Running {scenario_item['label']} / {optimizer_name}", plan=scenarios, ), None, None, None gains = _build_ablation_gain_rows(rows) payload = { "source": "HF Ablation Matrix Runner", "based_on_colab": "LBW_Guard_Ablation_Test_COLAB.ipynb", "model_name": resolved_model, "device": device_name, "optimizers": optimizers, "selected_ablations": list(selected_ablations or ["optimizer"]), "base_config": base_config, "scenarios": scenarios, "results": rows, "gains": gains, "logs": logs, } json_path = run_dir / "lbw_guard_hf_ablation_matrix_results.json" metrics_path = run_dir / "lbw_guard_hf_ablation_matrix_metrics.csv" gains_path = run_dir / "lbw_guard_hf_ablation_matrix_gains.csv" json_path.write_text(json.dumps(payload, indent=2), encoding="utf-8") _write_csv(metrics_path, rows) _write_csv(gains_path, gains) _append_log(logs, f"Wrote {metrics_path}") _append_log(logs, f"Wrote {gains_path}") yield ( _ablation_status_markdown(device_name=device_name, rows=rows, logs=logs, phase="Complete", plan=scenarios), str(json_path), str(metrics_path), str(gains_path), ) except Exception: error_text = traceback.format_exc() error_path = run_dir / "error.txt" error_path.write_text(error_text + "\n\n" + "\n".join(logs), encoding="utf-8") yield f"Run failed.\n\n```text\n{error_text}\n```", str(error_path), None, None INTRO = """ # LBW Guard HF Evaluation Runner This private Space has two HF-specific runners based on the customer Colab notebooks: - **Quick Comparison**: a short AdamW vs `lbw_guard` WikiText LoRA run. - **Ablation Matrix**: a selectable scenario sweep across optimizer, LR, schedule, steps, data, and LoRA rank. Each run writes JSON and CSV artifacts. GPU hardware is recommended; CPU mode is only for tiny smoke checks. """ with gr.Blocks(title="LBW Guard HF Evaluation Runner") as demo: gr.Markdown(INTRO) with gr.Tabs(): with gr.Tab("Quick Comparison"): gr.Markdown( "Runs the HF version of the Easy Test flow: one shared WikiText slice, then AdamW and optional " "`lbw_guard` LoRA training with final perplexity and gain CSV output." ) with gr.Row(): easy_model_name = gr.Textbox(value="TinyLlama/TinyLlama_v1.1", label="Model") easy_run_lbw_guard = gr.Checkbox(value=True, label="Run LBW Guard comparison") with gr.Row(): easy_max_steps = gr.Slider(1, 1000, value=5, step=1, label="Optimizer steps") easy_eval_every = gr.Slider(1, 200, value=5, step=1, label="Eval every") easy_eval_batches = gr.Slider(1, 128, value=8, step=1, label="Eval batches") with gr.Row(): easy_seq_len = gr.Dropdown([64, 128, 256, 512], value=64, label="Sequence length") easy_batch_size = gr.Slider(1, 8, value=1, step=1, label="Batch size") easy_lr = gr.Number(value=5e-4, label="Learning rate") with gr.Row(): easy_train_chars = gr.Slider(5_000, 2_000_000, value=20_000, step=5_000, label="Train char cap") easy_eval_chars = gr.Slider(1_000, 500_000, value=8_000, step=1_000, label="Eval char cap") easy_seed = gr.Number(value=42, precision=0, label="Seed") with gr.Row(): easy_full_wikitext_train = gr.Checkbox(value=False, label="Full WikiText train") easy_full_wikitext_eval = gr.Checkbox(value=False, label="Full WikiText eval") easy_full_validation_ppl = gr.Checkbox(value=False, label="Full validation PPL") easy_run_button = gr.Button("Run Quick Comparison", variant="primary") easy_summary = gr.Markdown() easy_json_file = gr.File(label="Raw JSON") easy_results_file = gr.File(label="Results CSV") easy_gains_file = gr.File(label="Gains CSV") easy_run_button.click( fn=run_easy_test, inputs=[ easy_model_name, easy_run_lbw_guard, easy_max_steps, easy_eval_every, easy_eval_batches, easy_seq_len, easy_batch_size, easy_train_chars, easy_eval_chars, easy_full_wikitext_train, easy_full_wikitext_eval, easy_full_validation_ppl, easy_lr, easy_seed, ], outputs=[easy_summary, easy_json_file, easy_results_file, easy_gains_file], ) with gr.Tab("Ablation Matrix"): gr.Markdown( "Runs the HF version of the ablation flow: build selected scenarios, run AdamW and optional " "`lbw_guard` for each scenario, then export metrics and LBW-vs-AdamW gains." ) with gr.Row(): ablation_model_name = gr.Textbox(value="Qwen/Qwen2.5-0.5B", label="Model") ablation_run_lbw_guard = gr.Checkbox(value=True, label="Run LBW Guard comparison") selected_ablations = gr.CheckboxGroup( choices=["optimizer", "lr", "schedule", "steps", "data", "lora"], value=["optimizer"], label="Ablations", ) with gr.Row(): ablation_max_steps = gr.Slider(1, 1000, value=200, step=1, label="Base optimizer steps") ablation_eval_every = gr.Slider(1, 200, value=50, step=1, label="Eval every") ablation_eval_batches = gr.Slider(1, 128, value=8, step=1, label="Eval batches") with gr.Row(): ablation_seq_len = gr.Dropdown([64, 128, 256, 512], value=64, label="Sequence length") ablation_batch_size = gr.Slider(1, 8, value=1, step=1, label="Batch size") ablation_lr = gr.Number(value=5e-4, label="Base learning rate") with gr.Row(): ablation_train_chars = gr.Slider(5_000, 2_000_000, value=20_000, step=5_000, label="Train char cap") ablation_eval_chars = gr.Slider(1_000, 500_000, value=8_000, step=1_000, label="Eval char cap") ablation_seed = gr.Number(value=42, precision=0, label="Seed") with gr.Row(): ablation_schedule_mode = gr.Dropdown(["constant", "cosine"], value="constant", label="Base schedule") ablation_warmup_steps = gr.Slider(0, 100, value=10, step=1, label="Warmup steps") with gr.Row(): ablation_full_wikitext_train = gr.Checkbox(value=False, label="Full WikiText train") ablation_full_wikitext_eval = gr.Checkbox(value=False, label="Full WikiText eval") ablation_full_validation_ppl = gr.Checkbox(value=False, label="Full validation PPL") with gr.Row(): lr_sweep_text = gr.Textbox(value="1e-3, 5e-4", label="LR sweep") step_sweep_text = gr.Textbox(value="100, 200", label="Step sweep") lora_r_sweep_text = gr.Textbox(value="4, 8, 16", label="LoRA r sweep") ablation_run_button = gr.Button("Run Ablation Matrix", variant="primary") ablation_summary = gr.Markdown() ablation_json_file = gr.File(label="Raw JSON") ablation_metrics_file = gr.File(label="Metrics CSV") ablation_gains_file = gr.File(label="Gains CSV") ablation_run_button.click( fn=run_ablation_test, inputs=[ ablation_model_name, selected_ablations, ablation_run_lbw_guard, ablation_max_steps, ablation_eval_every, ablation_eval_batches, ablation_seq_len, ablation_batch_size, ablation_train_chars, ablation_eval_chars, ablation_full_wikitext_train, ablation_full_wikitext_eval, ablation_full_validation_ppl, ablation_lr, ablation_schedule_mode, ablation_warmup_steps, ablation_seed, lr_sweep_text, step_sweep_text, lora_r_sweep_text, ], outputs=[ablation_summary, ablation_json_file, ablation_metrics_file, ablation_gains_file], ) if __name__ == "__main__": demo.queue(default_concurrency_limit=1).launch()