Spaces:
Running
Running
| from __future__ import annotations | |
| import csv | |
| import gc | |
| import json | |
| import math | |
| import os | |
| import random | |
| import time | |
| import traceback | |
| from pathlib import Path | |
| from typing import Any | |
| ROOT = Path(__file__).resolve().parent | |
| os.environ.setdefault("HF_HOME", str((ROOT / ".hf_cache").resolve())) | |
| os.environ.setdefault("HF_DATASETS_CACHE", str((ROOT / ".hf_cache" / "datasets").resolve())) | |
| os.environ.setdefault("TRANSFORMERS_CACHE", str((ROOT / ".hf_cache" / "transformers").resolve())) | |
| os.environ.setdefault("WANDB_DIR", str((ROOT / ".wandb").resolve())) | |
| os.environ.setdefault("WANDB_CACHE_DIR", str((ROOT / ".wandb" / "cache").resolve())) | |
| os.environ.setdefault("WANDB_CONFIG_DIR", str((ROOT / ".wandb" / "config").resolve())) | |
| os.environ.setdefault("DISABLE_SAFETENSORS_CONVERSION", "1") | |
| import gradio as gr | |
| import torch | |
| from datasets import load_dataset | |
| from peft import LoraConfig, TaskType, get_peft_model | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| try: | |
| import lbw | |
| except Exception as exc: # pragma: no cover - shown in the Space UI. | |
| lbw = None | |
| LBW_IMPORT_ERROR = exc | |
| else: | |
| LBW_IMPORT_ERROR = None | |
| RUNS_DIR = ROOT / "runs" | |
| def _device_default() -> str: | |
| return "cuda" if torch.cuda.is_available() else "cpu" | |
| def _set_seed(seed: int) -> None: | |
| random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| def _safe_float(value: Any) -> float | None: | |
| if value is None: | |
| return None | |
| try: | |
| out = float(value) | |
| except Exception: | |
| return None | |
| if not math.isfinite(out): | |
| return None | |
| return out | |
| def _fmt_float(value: Any, digits: int = 4) -> str: | |
| number = _safe_float(value) | |
| return "-" if number is None else f"{number:.{digits}f}" | |
| def _append_log(logs: list[str], message: str) -> None: | |
| logs.append(message) | |
| print(message, flush=True) | |
| def _build_wikitext_chunks( | |
| tokenizer, | |
| *, | |
| split: str, | |
| max_chars: int | None, | |
| seq_len: int, | |
| logs: list[str], | |
| ) -> dict[str, Any]: | |
| cap = None if max_chars is None else int(max_chars) | |
| _append_log( | |
| logs, | |
| f"Preparing WikiText split={split!r}" + (f" with char cap {cap:,}" if cap is not None else " with full split"), | |
| ) | |
| ds = load_dataset("wikitext", "wikitext-103-raw-v1", split=split) | |
| pieces: list[str] = [] | |
| chars_used = 0 | |
| rows_used = 0 | |
| first_piece = True | |
| for row in ds: | |
| text = str(row.get("text", "") or "") | |
| if not text.strip(): | |
| continue | |
| piece = text if first_piece else " " + text | |
| if cap is not None: | |
| remain = cap - chars_used | |
| if remain <= 0: | |
| break | |
| if len(piece) > remain: | |
| piece = piece[:remain] | |
| pieces.append(piece) | |
| chars_used += len(piece) | |
| rows_used += 1 | |
| first_piece = False | |
| if cap is not None and chars_used >= cap: | |
| break | |
| token_ids = tokenizer("".join(pieces), add_special_tokens=False)["input_ids"] | |
| ids = torch.tensor(token_ids, dtype=torch.long) | |
| sequence_count = ids.numel() // int(seq_len) | |
| if sequence_count <= 0: | |
| raise RuntimeError("Not enough tokens. Increase the train/eval char cap or reduce sequence length.") | |
| ids = ids[: sequence_count * int(seq_len)].view(sequence_count, int(seq_len)).contiguous() | |
| _append_log( | |
| logs, | |
| f"Prepared split={split!r}: {chars_used:,} chars across {rows_used:,} rows -> {ids.size(0):,} sequences", | |
| ) | |
| return {"input_ids": ids, "chars": chars_used, "rows": rows_used, "cap": cap} | |
| def _batch_iter(chunks: dict[str, Any], *, batch_size: int, device: torch.device): | |
| ids = chunks["input_ids"] | |
| i = 0 | |
| while True: | |
| if i + int(batch_size) > ids.size(0): | |
| i = 0 | |
| batch = ids[i : i + int(batch_size)].to(device, non_blocking=True) | |
| i += int(batch_size) | |
| yield batch | |
| def _load_lora_model( | |
| *, | |
| model_name: str, | |
| device: torch.device, | |
| lora_r: int, | |
| lora_alpha: int, | |
| lora_dropout: float, | |
| ): | |
| dtype = torch.float16 if device.type == "cuda" else torch.float32 | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=dtype, | |
| low_cpu_mem_usage=True, | |
| ) | |
| if getattr(model.config, "use_cache", None) is not None: | |
| model.config.use_cache = False | |
| model.to(device) | |
| lora_cfg = LoraConfig( | |
| r=int(lora_r), | |
| lora_alpha=int(lora_alpha), | |
| lora_dropout=float(lora_dropout), | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], | |
| task_type=TaskType.CAUSAL_LM, | |
| bias="none", | |
| ) | |
| return get_peft_model(model, lora_cfg) | |
| def _make_optimizer( | |
| name: str, | |
| model, | |
| *, | |
| lr: float, | |
| betas: tuple[float, float], | |
| weight_decay: float, | |
| lbw_stats_freq: int, | |
| lbw_stress_th: float, | |
| lbw_spike_th: float, | |
| lbw_rec_fast: float, | |
| lbw_ema_decay: float, | |
| ): | |
| params = [param for param in model.parameters() if param.requires_grad] | |
| if name == "adamw": | |
| return torch.optim.AdamW(params, lr=float(lr), betas=betas, weight_decay=float(weight_decay)) | |
| if name == "lbw_guard": | |
| if lbw is None: | |
| raise RuntimeError(f"LBW Guard package import failed: {LBW_IMPORT_ERROR}") | |
| return lbw.Guard( | |
| params, | |
| lr=float(lr), | |
| betas=betas, | |
| weight_decay=float(weight_decay), | |
| mode="eval", | |
| auto_enabled=True, | |
| stats_freq=int(lbw_stats_freq), | |
| stress_threshold=float(lbw_stress_th), | |
| spike_threshold=float(lbw_spike_th), | |
| recovery_fast=float(lbw_rec_fast), | |
| ema_decay=float(lbw_ema_decay), | |
| use_max_rms=True, | |
| ) | |
| raise ValueError(f"Unknown optimizer: {name}") | |
| def _evaluate_ppl( | |
| model, | |
| eval_chunks: dict[str, Any], | |
| *, | |
| batch_size: int, | |
| eval_batches: int, | |
| device: torch.device, | |
| full_pass: bool, | |
| ) -> tuple[float, float]: | |
| model.eval() | |
| ids = eval_chunks["input_ids"] | |
| max_sequences = ids.size(0) if full_pass else min(ids.size(0), int(eval_batches) * int(batch_size)) | |
| losses: list[float] = [] | |
| for start in range(0, max_sequences, int(batch_size)): | |
| xb = ids[start : start + int(batch_size)].to(device, non_blocking=True) | |
| with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=(device.type == "cuda")): | |
| loss = model(input_ids=xb, labels=xb).loss | |
| losses.append(float(loss.detach().cpu())) | |
| avg_loss = sum(losses) / max(len(losses), 1) | |
| return avg_loss, math.exp(min(avg_loss, 20.0)) | |
| def _optimizer_state(opt) -> dict[str, Any]: | |
| state = dict(getattr(opt, "state", {}).get("lbw", {}) or {}) | |
| return { | |
| "scale": float(state.get("scale", state.get("lbw_scale", 1.0))), | |
| "ratio": float(state.get("ratio", 1.0)), | |
| "stress_mode": str(state.get("stress_mode", "none")), | |
| } | |
| def _status_markdown( | |
| *, | |
| device_name: str, | |
| rows: list[dict[str, Any]], | |
| logs: list[str], | |
| phase: str, | |
| ) -> str: | |
| summary = [ | |
| f"Device: `{device_name}`", | |
| "", | |
| f"Status: {phase}", | |
| "", | |
| "## Results", | |
| "", | |
| "| Optimizer | Final Eval PPL | Final Eval Loss | Scope | Scale | Ratio | Stress Mode | Wall Time (s) |", | |
| "| --- | --- | --- | --- | --- | --- | --- | --- |", | |
| ] | |
| if rows: | |
| for row in rows: | |
| summary.append( | |
| "| {optimizer} | {ppl} | {loss} | {scope} | {scale} | {ratio} | {stress} | {wall} |".format( | |
| optimizer=row.get("optimizer"), | |
| ppl=_fmt_float(row.get("final_eval_ppl")), | |
| loss=_fmt_float(row.get("final_eval_loss")), | |
| scope=row.get("final_eval_scope") or "-", | |
| scale=_fmt_float(row.get("scale")), | |
| ratio=_fmt_float(row.get("ratio")), | |
| stress=row.get("stress_mode") or "-", | |
| wall=_fmt_float(row.get("wall_time_sec"), digits=2), | |
| ) | |
| ) | |
| else: | |
| summary.append("| - | - | - | - | - | - | - | - |") | |
| gains = _gain_rows(rows) | |
| if gains: | |
| summary.extend(["", "## LBW vs AdamW", ""]) | |
| for gain in gains: | |
| pct = _safe_float(gain.get("eval_perplexity_pct_gain_vs_adamw")) | |
| wall_speedup = _safe_float(gain.get("wall_time_speedup_vs_adamw")) | |
| summary.append( | |
| f"- `{gain.get('optimizer')}` PPL gain vs AdamW: `{_fmt_float(gain.get('eval_perplexity_gain_vs_adamw'))}`" | |
| + (f" (`{pct * 100.0:.2f}%`)." if pct is not None else ".") | |
| ) | |
| if wall_speedup is not None: | |
| summary.append(f"- `{gain.get('optimizer')}` wall-time speedup vs AdamW: `{wall_speedup:.3f}x`.") | |
| summary.extend(["", "## Runtime Log", "", "```text", "\n".join(logs[-80:]), "```"]) | |
| return "\n".join(summary) | |
| def _run_one_optimizer_events( | |
| *, | |
| optimizer_name: str, | |
| model_name: str, | |
| train_chunks: dict[str, Any], | |
| eval_chunks: dict[str, Any], | |
| device: torch.device, | |
| seed: int, | |
| max_steps: int, | |
| eval_every: int, | |
| eval_batches: int, | |
| seq_len: int, | |
| batch_size: int, | |
| lr: float, | |
| betas: tuple[float, float], | |
| weight_decay: float, | |
| full_validation_ppl: bool, | |
| lora_r: int, | |
| lora_alpha: int, | |
| lora_dropout: float, | |
| lbw_stats_freq: int, | |
| lbw_stress_th: float, | |
| lbw_spike_th: float, | |
| lbw_rec_fast: float, | |
| lbw_ema_decay: float, | |
| logs: list[str], | |
| ): | |
| _set_seed(int(seed)) | |
| _append_log(logs, f"Loading {model_name} with LoRA for {optimizer_name}.") | |
| model = _load_lora_model( | |
| model_name=model_name, | |
| device=device, | |
| lora_r=lora_r, | |
| lora_alpha=lora_alpha, | |
| lora_dropout=lora_dropout, | |
| ) | |
| model.train() | |
| opt = _make_optimizer( | |
| optimizer_name, | |
| model, | |
| lr=lr, | |
| betas=betas, | |
| weight_decay=weight_decay, | |
| lbw_stats_freq=lbw_stats_freq, | |
| lbw_stress_th=lbw_stress_th, | |
| lbw_spike_th=lbw_spike_th, | |
| lbw_rec_fast=lbw_rec_fast, | |
| lbw_ema_decay=lbw_ema_decay, | |
| ) | |
| train_batches = _batch_iter(train_chunks, batch_size=batch_size, device=device) | |
| start_time = time.time() | |
| last_loss = None | |
| last_eval_loss = None | |
| last_eval_ppl = None | |
| state = _optimizer_state(opt) | |
| trainable_params = [param for param in model.parameters() if param.requires_grad] | |
| for step in range(1, int(max_steps) + 1): | |
| xb = next(train_batches) | |
| with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=(device.type == "cuda")): | |
| loss = model(input_ids=xb, labels=xb).loss | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(trainable_params, 1.0) | |
| opt.step() | |
| opt.zero_grad(set_to_none=True) | |
| last_loss = float(loss.detach().cpu()) | |
| state = _optimizer_state(opt) | |
| if step == 1 or step == int(max_steps) or step % int(eval_every) == 0: | |
| last_eval_loss, last_eval_ppl = _evaluate_ppl( | |
| model, | |
| eval_chunks, | |
| batch_size=batch_size, | |
| eval_batches=eval_batches, | |
| device=device, | |
| full_pass=False, | |
| ) | |
| message = ( | |
| f"{optimizer_name} step {step}/{int(max_steps)}: " | |
| f"loss={last_loss:.4f}, sampled_eval_ppl={last_eval_ppl:.4f}, " | |
| f"scale={state['scale']:.4f}, ratio={state['ratio']:.4f}" | |
| ) | |
| _append_log(logs, message) | |
| yield {"type": "progress", "message": message} | |
| model.train() | |
| final_full_pass = bool(full_validation_ppl) | |
| if final_full_pass and eval_chunks["cap"] is None: | |
| final_scope = "full_wikitext" | |
| elif final_full_pass: | |
| final_scope = "full_loaded_subset" | |
| else: | |
| final_scope = "sampled" | |
| _append_log(logs, f"Running final {final_scope} validation PPL for {optimizer_name}.") | |
| final_loss, final_ppl = _evaluate_ppl( | |
| model, | |
| eval_chunks, | |
| batch_size=batch_size, | |
| eval_batches=eval_batches, | |
| device=device, | |
| full_pass=final_full_pass, | |
| ) | |
| state = _optimizer_state(opt) | |
| wall_time = time.time() - start_time | |
| result = { | |
| "optimizer": optimizer_name, | |
| "final_eval_ppl": final_ppl, | |
| "final_eval_loss": final_loss, | |
| "final_eval_scope": final_scope, | |
| "train_chars": train_chunks["chars"], | |
| "eval_chars": eval_chunks["chars"], | |
| "train_sequences": int(train_chunks["input_ids"].size(0)), | |
| "eval_sequences": int(eval_chunks["input_ids"].size(0)), | |
| "tokens_per_step": int(batch_size) * int(seq_len), | |
| "last_train_loss": last_loss, | |
| "last_sampled_eval_loss": last_eval_loss, | |
| "last_sampled_eval_ppl": last_eval_ppl, | |
| "scale": state["scale"], | |
| "ratio": state["ratio"], | |
| "stress_mode": state["stress_mode"], | |
| "wall_time_sec": wall_time, | |
| } | |
| del model, opt | |
| gc.collect() | |
| if device.type == "cuda": | |
| torch.cuda.empty_cache() | |
| yield {"type": "result", "result": result} | |
| def _gain_rows(rows: list[dict[str, Any]]) -> list[dict[str, Any]]: | |
| by_optimizer = {str(row.get("optimizer")): row for row in rows} | |
| baseline = by_optimizer.get("adamw") | |
| if baseline is None: | |
| return [] | |
| baseline_ppl = _safe_float(baseline.get("final_eval_ppl")) | |
| baseline_wall = _safe_float(baseline.get("wall_time_sec")) | |
| gains: list[dict[str, Any]] = [] | |
| for row in rows: | |
| if row.get("optimizer") == "adamw": | |
| continue | |
| candidate_ppl = _safe_float(row.get("final_eval_ppl")) | |
| candidate_wall = _safe_float(row.get("wall_time_sec")) | |
| gains.append( | |
| { | |
| "optimizer": row.get("optimizer"), | |
| "eval_perplexity_gain_vs_adamw": ( | |
| None if baseline_ppl is None or candidate_ppl is None else baseline_ppl - candidate_ppl | |
| ), | |
| "eval_perplexity_pct_gain_vs_adamw": ( | |
| None | |
| if baseline_ppl in (None, 0.0) or candidate_ppl is None | |
| else (baseline_ppl - candidate_ppl) / baseline_ppl | |
| ), | |
| "wall_time_speedup_vs_adamw": ( | |
| None | |
| if baseline_wall in (None, 0.0) or candidate_wall in (None, 0.0) | |
| else baseline_wall / candidate_wall | |
| ), | |
| } | |
| ) | |
| return gains | |
| def _write_csv(path: Path, rows: list[dict[str, Any]]) -> None: | |
| if not rows: | |
| path.write_text("", encoding="utf-8") | |
| return | |
| with path.open("w", encoding="utf-8", newline="") as handle: | |
| writer = csv.DictWriter(handle, fieldnames=list(rows[0].keys())) | |
| writer.writeheader() | |
| writer.writerows(rows) | |
| def _set_lr(opt, value: float) -> None: | |
| for group in getattr(opt, "param_groups", []) or []: | |
| group["lr"] = float(value) | |
| def _scheduled_lr(cfg: dict[str, Any], step: int) -> float: | |
| base_lr = float(cfg["lr"]) | |
| warmup = max(int(cfg.get("warmup_steps", 0)), 0) | |
| max_steps = max(int(cfg["max_steps"]), 1) | |
| if warmup > 0 and int(step) <= warmup: | |
| return base_lr * float(step) / float(warmup) | |
| mode = str(cfg.get("schedule_mode", "constant")).strip().lower() | |
| if mode == "cosine": | |
| progress = (int(step) - warmup) / max(max_steps - warmup, 1) | |
| progress = min(max(progress, 0.0), 1.0) | |
| return base_lr * 0.5 * (1.0 + math.cos(math.pi * progress)) | |
| return base_lr | |
| def _parse_float_sweep(text: str, default: list[float]) -> list[float]: | |
| raw = str(text or "").replace("\n", ",").replace(";", ",").split(",") | |
| values: list[float] = [] | |
| for item in raw: | |
| item = item.strip() | |
| if not item: | |
| continue | |
| values.append(float(item)) | |
| return values or list(default) | |
| def _parse_int_sweep(text: str, default: list[int]) -> list[int]: | |
| return [int(value) for value in _parse_float_sweep(text, [float(item) for item in default])] | |
| def run_easy_test( | |
| model_name: str, | |
| run_lbw_guard: bool, | |
| max_steps: int, | |
| eval_every: int, | |
| eval_batches: int, | |
| seq_len: int, | |
| batch_size: int, | |
| train_chars: int, | |
| eval_chars: int, | |
| full_wikitext_train: bool, | |
| full_wikitext_eval: bool, | |
| full_validation_ppl: bool, | |
| lr: float, | |
| seed: int, | |
| ): | |
| logs: list[str] = [] | |
| rows: list[dict[str, Any]] = [] | |
| run_dir = RUNS_DIR / f"easy_test_{int(time.time())}" | |
| run_dir.mkdir(parents=True, exist_ok=True) | |
| device_name = _device_default() | |
| device = torch.device(device_name) | |
| optimizers = ["adamw", "lbw_guard"] if bool(run_lbw_guard) else ["adamw"] | |
| try: | |
| if device.type == "cpu" and ( | |
| int(max_steps) > 1 | |
| or int(train_chars) > 20_000 | |
| or int(eval_chars) > 8_000 | |
| or bool(full_wikitext_train) | |
| or bool(full_wikitext_eval) | |
| or bool(full_validation_ppl) | |
| ): | |
| yield ( | |
| "This Space is currently on `cpu-basic`. CPU mode is capped to 1 step, 20k train chars, " | |
| "8k eval chars, and sampled validation. Switch the Space hardware to GPU for the Quick Comparison defaults.", | |
| None, | |
| None, | |
| None, | |
| ) | |
| return | |
| if device.type == "cuda" and bool(run_lbw_guard) and torch.cuda.device_count() > 1: | |
| yield ( | |
| "LBW Guard should run with one visible GPU. Set the Space to single-GPU hardware or restrict CUDA_VISIBLE_DEVICES.", | |
| None, | |
| None, | |
| None, | |
| ) | |
| return | |
| _append_log(logs, f"Device: {device_name}") | |
| if device.type == "cuda": | |
| _append_log(logs, f"GPU: {torch.cuda.get_device_name(0)}") | |
| _append_log(logs, f"Optimizers: {', '.join(optimizers)}") | |
| yield _status_markdown(device_name=device_name, rows=rows, logs=logs, phase="Loading tokenizer"), None, None, None | |
| _set_seed(int(seed)) | |
| resolved_model = str(model_name).strip() or "TinyLlama/TinyLlama_v1.1" | |
| tokenizer = AutoTokenizer.from_pretrained(resolved_model, use_fast=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| train_cap = None if bool(full_wikitext_train) else int(train_chars) | |
| eval_cap = None if bool(full_wikitext_eval) else int(eval_chars) | |
| train_chunks = _build_wikitext_chunks( | |
| tokenizer, | |
| split="train", | |
| max_chars=train_cap, | |
| seq_len=int(seq_len), | |
| logs=logs, | |
| ) | |
| yield _status_markdown(device_name=device_name, rows=rows, logs=logs, phase="Prepared train split"), None, None, None | |
| eval_chunks = _build_wikitext_chunks( | |
| tokenizer, | |
| split="validation", | |
| max_chars=eval_cap, | |
| seq_len=int(seq_len), | |
| logs=logs, | |
| ) | |
| yield _status_markdown(device_name=device_name, rows=rows, logs=logs, phase="Prepared validation split"), None, None, None | |
| for optimizer_name in optimizers: | |
| _append_log(logs, f"=== {optimizer_name} ===") | |
| yield _status_markdown( | |
| device_name=device_name, | |
| rows=rows, | |
| logs=logs, | |
| phase=f"Running {optimizer_name}", | |
| ), None, None, None | |
| for event in _run_one_optimizer_events( | |
| optimizer_name=optimizer_name, | |
| model_name=resolved_model, | |
| train_chunks=train_chunks, | |
| eval_chunks=eval_chunks, | |
| device=device, | |
| seed=int(seed), | |
| max_steps=int(max_steps), | |
| eval_every=max(1, int(eval_every)), | |
| eval_batches=int(eval_batches), | |
| seq_len=int(seq_len), | |
| batch_size=int(batch_size), | |
| lr=float(lr), | |
| betas=(0.9, 0.999), | |
| weight_decay=0.01, | |
| full_validation_ppl=bool(full_validation_ppl), | |
| lora_r=8, | |
| lora_alpha=16, | |
| lora_dropout=0.05, | |
| lbw_stats_freq=10, | |
| lbw_stress_th=1.1, | |
| lbw_spike_th=1.5, | |
| lbw_rec_fast=0.01, | |
| lbw_ema_decay=0.95, | |
| logs=logs, | |
| ): | |
| if event.get("type") == "result": | |
| rows.append(event["result"]) | |
| yield _status_markdown( | |
| device_name=device_name, | |
| rows=rows, | |
| logs=logs, | |
| phase=f"Running {optimizer_name}", | |
| ), None, None, None | |
| gains = _gain_rows(rows) | |
| payload = { | |
| "source": "HF Quick Comparison Runner", | |
| "based_on_colab": "LBW_Guard_Easy_Test_COLAB.ipynb", | |
| "config": { | |
| "model_name": resolved_model, | |
| "device": device_name, | |
| "optimizers": optimizers, | |
| "seed": int(seed), | |
| "max_steps": int(max_steps), | |
| "eval_every": int(eval_every), | |
| "eval_batches": int(eval_batches), | |
| "seq_len": int(seq_len), | |
| "batch_size": int(batch_size), | |
| "max_chars": train_cap, | |
| "eval_chars": eval_cap, | |
| "full_wikitext_train": bool(full_wikitext_train), | |
| "full_wikitext_eval": bool(full_wikitext_eval), | |
| "full_validation_ppl": bool(full_validation_ppl), | |
| "lr": float(lr), | |
| "betas": [0.9, 0.999], | |
| "weight_decay": 0.01, | |
| "lora_r": 8, | |
| "lora_alpha": 16, | |
| "lora_dropout": 0.05, | |
| "lbw_stats_freq": 10, | |
| "lbw_stress_th": 1.1, | |
| "lbw_spike_th": 1.5, | |
| "lbw_rec_fast": 0.01, | |
| "lbw_ema_decay": 0.95, | |
| }, | |
| "results": rows, | |
| "gains": gains, | |
| "logs": logs, | |
| } | |
| json_path = run_dir / "lbw_guard_hf_quick_comparison_results.json" | |
| csv_path = run_dir / "lbw_guard_hf_quick_comparison_results.csv" | |
| gains_path = run_dir / "lbw_guard_hf_quick_comparison_gains.csv" | |
| json_path.write_text(json.dumps(payload, indent=2), encoding="utf-8") | |
| _write_csv(csv_path, rows) | |
| _write_csv(gains_path, gains) | |
| _append_log(logs, f"Wrote {csv_path}") | |
| yield ( | |
| _status_markdown(device_name=device_name, rows=rows, logs=logs, phase="Complete"), | |
| str(json_path), | |
| str(csv_path), | |
| str(gains_path), | |
| ) | |
| except Exception: | |
| error_text = traceback.format_exc() | |
| error_path = run_dir / "error.txt" | |
| error_path.write_text(error_text + "\n\n" + "\n".join(logs), encoding="utf-8") | |
| yield f"Run failed.\n\n```text\n{error_text}\n```", str(error_path), None, None | |
| def _make_ablation_scenario(slug: str, label: str, note: str, base_config: dict[str, Any], overrides=None): | |
| cfg = dict(base_config) | |
| if overrides: | |
| cfg.update(overrides) | |
| return { | |
| "slug": slug, | |
| "label": label, | |
| "note": note, | |
| "config": cfg, | |
| } | |
| def _build_ablation_scenarios( | |
| *, | |
| selected_ablations: list[str], | |
| base_config: dict[str, Any], | |
| lr_sweep: list[float], | |
| step_sweep: list[int], | |
| lora_r_sweep: list[int], | |
| ) -> list[dict[str, Any]]: | |
| selected = {str(item).strip().lower() for item in selected_ablations if str(item).strip()} | |
| if not selected: | |
| selected = {"optimizer"} | |
| scenarios: list[dict[str, Any]] = [] | |
| if "optimizer" in selected: | |
| scenarios.append( | |
| _make_ablation_scenario( | |
| "optimizer-adamw-vs-lbw-guard", | |
| "Optimizer: AdamW vs lbw_guard", | |
| "Direct optimizer comparison with the base config.", | |
| base_config, | |
| ) | |
| ) | |
| if "lr" in selected: | |
| for lr in lr_sweep: | |
| scenarios.append( | |
| _make_ablation_scenario( | |
| f"lr-{lr:g}", | |
| f"Learning Rate: {lr:g}", | |
| "Learning-rate sensitivity check.", | |
| base_config, | |
| {"lr": float(lr)}, | |
| ) | |
| ) | |
| if "schedule" in selected: | |
| for mode in ["constant", "cosine"]: | |
| scenarios.append( | |
| _make_ablation_scenario( | |
| f"schedule-{mode}", | |
| f"Schedule: {mode}", | |
| "Scheduler-shape sensitivity check.", | |
| base_config, | |
| {"schedule_mode": mode}, | |
| ) | |
| ) | |
| if "steps" in selected: | |
| for steps in step_sweep: | |
| scenarios.append( | |
| _make_ablation_scenario( | |
| f"steps-{steps}", | |
| f"Steps: {steps}", | |
| "Training-length sensitivity check.", | |
| base_config, | |
| {"max_steps": int(steps), "eval_every": max(1, int(steps) // 4)}, | |
| ) | |
| ) | |
| if "data" in selected: | |
| for item in [ | |
| {"max_chars": 20_000, "eval_chars": 8_000, "label": "small-data"}, | |
| {"max_chars": 80_000, "eval_chars": 20_000, "label": "larger-data"}, | |
| ]: | |
| scenarios.append( | |
| _make_ablation_scenario( | |
| item["label"], | |
| f"Data Slice: {item['label']}", | |
| "WikiText slice-size sensitivity check.", | |
| base_config, | |
| {"max_chars": int(item["max_chars"]), "eval_chars": int(item["eval_chars"])}, | |
| ) | |
| ) | |
| if "lora" in selected: | |
| for rank in lora_r_sweep: | |
| scenarios.append( | |
| _make_ablation_scenario( | |
| f"lora-r{rank}", | |
| f"LoRA Rank: {rank}", | |
| "Adapter-capacity sensitivity check.", | |
| base_config, | |
| {"lora_r": int(rank), "lora_alpha": int(rank) * 2}, | |
| ) | |
| ) | |
| if not scenarios: | |
| raise ValueError("No scenarios selected. Choose optimizer, lr, schedule, steps, data, or lora.") | |
| return scenarios | |
| def _ablation_status_markdown( | |
| *, | |
| device_name: str, | |
| rows: list[dict[str, Any]], | |
| logs: list[str], | |
| phase: str, | |
| plan: list[dict[str, Any]], | |
| ) -> str: | |
| summary = [ | |
| f"Device: `{device_name}`", | |
| "", | |
| f"Status: {phase}", | |
| "", | |
| "## Plan", | |
| "", | |
| "| Scenario | Steps | LR | Schedule | Train Chars | Eval Chars | LoRA r |", | |
| "| --- | --- | --- | --- | --- | --- | --- |", | |
| ] | |
| for item in plan: | |
| cfg = item["config"] | |
| summary.append( | |
| "| {label} | {steps} | {lr:g} | {schedule} | {train_chars} | {eval_chars} | {lora_r} |".format( | |
| label=item["label"], | |
| steps=int(cfg["max_steps"]), | |
| lr=float(cfg["lr"]), | |
| schedule=cfg["schedule_mode"], | |
| train_chars="FULL" if cfg["full_wikitext_train"] else int(cfg["max_chars"]), | |
| eval_chars="FULL" if cfg["full_wikitext_eval"] else int(cfg["eval_chars"]), | |
| lora_r=int(cfg["lora_r"]), | |
| ) | |
| ) | |
| summary.extend( | |
| [ | |
| "", | |
| "## Metrics", | |
| "", | |
| "| Scenario | Optimizer | Final Eval PPL | Final Eval Loss | Tokens/s | Scale | Ratio | Stress Mode |", | |
| "| --- | --- | --- | --- | --- | --- | --- | --- |", | |
| ] | |
| ) | |
| if rows: | |
| for row in rows: | |
| summary.append( | |
| "| {scenario} | {optimizer} | {ppl} | {loss} | {tps} | {scale} | {ratio} | {stress} |".format( | |
| scenario=row.get("scenario"), | |
| optimizer=row.get("optimizer"), | |
| ppl=_fmt_float(row.get("final_eval_ppl")), | |
| loss=_fmt_float(row.get("final_eval_loss")), | |
| tps=_fmt_float(row.get("tokens_per_sec_wall"), digits=2), | |
| scale=_fmt_float(row.get("scale")), | |
| ratio=_fmt_float(row.get("ratio")), | |
| stress=row.get("stress_mode") or "-", | |
| ) | |
| ) | |
| else: | |
| summary.append("| - | - | - | - | - | - | - | - |") | |
| gains = _build_ablation_gain_rows(rows) | |
| if gains: | |
| summary.extend(["", "## LBW vs AdamW", ""]) | |
| for gain in gains: | |
| summary.append( | |
| f"- `{gain.get('scenario')}`: `{gain.get('optimizer')}` " | |
| f"PPL gain `{_fmt_float(gain.get('ppl_gain_pct_vs_adamw'))}%`, " | |
| f"loss gain `{_fmt_float(gain.get('loss_gain_pct_vs_adamw'))}%`, " | |
| f"speed gain `{_fmt_float(gain.get('speed_gain_pct_vs_adamw'))}%`." | |
| ) | |
| summary.extend(["", "## Runtime Log", "", "```text", "\n".join(logs[-100:]), "```"]) | |
| return "\n".join(summary) | |
| def _run_ablation_optimizer_events( | |
| *, | |
| scenario_item: dict[str, Any], | |
| optimizer_name: str, | |
| model_name: str, | |
| train_chunks: dict[str, Any], | |
| eval_chunks: dict[str, Any], | |
| device: torch.device, | |
| logs: list[str], | |
| ): | |
| cfg = scenario_item["config"] | |
| _set_seed(int(cfg["seed"])) | |
| _append_log(logs, f"Loading {model_name} with LoRA for {scenario_item['slug']} / {optimizer_name}.") | |
| model = _load_lora_model( | |
| model_name=model_name, | |
| device=device, | |
| lora_r=int(cfg["lora_r"]), | |
| lora_alpha=int(cfg["lora_alpha"]), | |
| lora_dropout=float(cfg["lora_dropout"]), | |
| ) | |
| model.train() | |
| opt = _make_optimizer( | |
| optimizer_name, | |
| model, | |
| lr=float(cfg["lr"]), | |
| betas=tuple(cfg["betas"]), | |
| weight_decay=float(cfg["weight_decay"]), | |
| lbw_stats_freq=int(cfg["lbw_stats_freq"]), | |
| lbw_stress_th=float(cfg["lbw_stress_th"]), | |
| lbw_spike_th=float(cfg["lbw_spike_th"]), | |
| lbw_rec_fast=float(cfg["lbw_rec_fast"]), | |
| lbw_ema_decay=float(cfg["lbw_ema_decay"]), | |
| ) | |
| train_batches = _batch_iter(train_chunks, batch_size=int(cfg["batch_size"]), device=device) | |
| trainable_params = [param for param in model.parameters() if param.requires_grad] | |
| start_time = time.time() | |
| losses: list[float] = [] | |
| eval_loss = None | |
| eval_ppl = None | |
| last_lr = float(cfg["lr"]) | |
| state = _optimizer_state(opt) | |
| for step in range(1, int(cfg["max_steps"]) + 1): | |
| last_lr = _scheduled_lr(cfg, step) | |
| _set_lr(opt, last_lr) | |
| xb = next(train_batches) | |
| with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=(device.type == "cuda")): | |
| loss = model(input_ids=xb, labels=xb).loss | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(trainable_params, 1.0) | |
| opt.step() | |
| opt.zero_grad(set_to_none=True) | |
| loss_value = float(loss.detach().cpu()) | |
| losses.append(loss_value) | |
| if step == 1 or step == int(cfg["max_steps"]) or step % int(cfg["eval_every"]) == 0: | |
| eval_loss, eval_ppl = _evaluate_ppl( | |
| model, | |
| eval_chunks, | |
| batch_size=int(cfg["batch_size"]), | |
| eval_batches=int(cfg["eval_batches"]), | |
| device=device, | |
| full_pass=False, | |
| ) | |
| state = _optimizer_state(opt) | |
| message = ( | |
| f"[{scenario_item['slug']}] {optimizer_name} step {step}/{cfg['max_steps']}: " | |
| f"loss={loss_value:.4f}, sampled_eval_ppl={eval_ppl:.4f}, " | |
| f"lr={last_lr:.2e}, scale={state['scale']:.4f}, ratio={state['ratio']:.4f}" | |
| ) | |
| _append_log(logs, message) | |
| yield {"type": "progress", "message": message} | |
| model.train() | |
| final_full_pass = bool(cfg["full_validation_ppl"]) | |
| if final_full_pass and eval_chunks["cap"] is None: | |
| final_scope = "full_wikitext" | |
| elif final_full_pass: | |
| final_scope = "full_loaded_subset" | |
| else: | |
| final_scope = "sampled" | |
| _append_log(logs, f"Running final {final_scope} validation PPL for {scenario_item['slug']} / {optimizer_name}.") | |
| final_loss, final_ppl = _evaluate_ppl( | |
| model, | |
| eval_chunks, | |
| batch_size=int(cfg["batch_size"]), | |
| eval_batches=int(cfg["eval_batches"]), | |
| device=device, | |
| full_pass=final_full_pass, | |
| ) | |
| state = _optimizer_state(opt) | |
| wall_time = max(time.time() - start_time, 1e-9) | |
| trained_tokens = int(cfg["max_steps"]) * int(cfg["batch_size"]) * int(cfg["seq_len"]) | |
| result = { | |
| "scenario_slug": scenario_item["slug"], | |
| "scenario": scenario_item["label"], | |
| "optimizer": optimizer_name, | |
| "final_eval_ppl": final_ppl, | |
| "final_eval_loss": final_loss, | |
| "train_loss_last": losses[-1] if losses else None, | |
| "last_sampled_eval_loss": eval_loss, | |
| "last_sampled_eval_ppl": eval_ppl, | |
| "final_eval_scope": final_scope, | |
| "max_steps": int(cfg["max_steps"]), | |
| "lr": float(cfg["lr"]), | |
| "scheduled_lr_last": float(last_lr), | |
| "schedule_mode": str(cfg["schedule_mode"]), | |
| "batch_size": int(cfg["batch_size"]), | |
| "seq_len": int(cfg["seq_len"]), | |
| "lora_r": int(cfg["lora_r"]), | |
| "train_chars": int(train_chunks["chars"]), | |
| "eval_chars": int(eval_chunks["chars"]), | |
| "train_sequences": int(train_chunks["input_ids"].size(0)), | |
| "eval_sequences": int(eval_chunks["input_ids"].size(0)), | |
| "scale": state["scale"], | |
| "ratio": state["ratio"], | |
| "stress_mode": state["stress_mode"], | |
| "wall_time_sec": wall_time, | |
| "tokens_per_sec_wall": trained_tokens / wall_time, | |
| } | |
| del model, opt | |
| gc.collect() | |
| if device.type == "cuda": | |
| torch.cuda.empty_cache() | |
| yield {"type": "result", "result": result} | |
| def _build_ablation_gain_rows(metrics: list[dict[str, Any]]) -> list[dict[str, Any]]: | |
| grouped: dict[str, list[dict[str, Any]]] = {} | |
| for row in metrics: | |
| grouped.setdefault(str(row.get("scenario_slug")), []).append(row) | |
| gain_rows: list[dict[str, Any]] = [] | |
| for scenario_slug, rows in grouped.items(): | |
| baseline = next((row for row in rows if row.get("optimizer") == "adamw"), None) | |
| if baseline is None: | |
| continue | |
| baseline_ppl = _safe_float(baseline.get("final_eval_ppl")) | |
| baseline_loss = _safe_float(baseline.get("final_eval_loss")) | |
| baseline_tps = _safe_float(baseline.get("tokens_per_sec_wall")) | |
| for row in rows: | |
| if row.get("optimizer") == "adamw": | |
| continue | |
| candidate_ppl = _safe_float(row.get("final_eval_ppl")) | |
| candidate_loss = _safe_float(row.get("final_eval_loss")) | |
| candidate_tps = _safe_float(row.get("tokens_per_sec_wall")) | |
| gain_rows.append( | |
| { | |
| "scenario_slug": scenario_slug, | |
| "scenario": row.get("scenario"), | |
| "optimizer": row.get("optimizer"), | |
| "adamw_final_eval_ppl": baseline_ppl, | |
| "optimizer_final_eval_ppl": candidate_ppl, | |
| "ppl_gain_pct_vs_adamw": ( | |
| None | |
| if baseline_ppl in (None, 0.0) or candidate_ppl is None | |
| else (baseline_ppl - candidate_ppl) / baseline_ppl * 100.0 | |
| ), | |
| "loss_gain_pct_vs_adamw": ( | |
| None | |
| if baseline_loss in (None, 0.0) or candidate_loss is None | |
| else (baseline_loss - candidate_loss) / baseline_loss * 100.0 | |
| ), | |
| "speed_gain_pct_vs_adamw": ( | |
| None | |
| if baseline_tps in (None, 0.0) or candidate_tps is None | |
| else (candidate_tps - baseline_tps) / baseline_tps * 100.0 | |
| ), | |
| "adamw_tokens_per_sec_wall": baseline_tps, | |
| "optimizer_tokens_per_sec_wall": candidate_tps, | |
| "lbw_scale": row.get("scale"), | |
| "lbw_ratio": row.get("ratio"), | |
| "lbw_stress_mode": row.get("stress_mode"), | |
| } | |
| ) | |
| return gain_rows | |
| def run_ablation_test( | |
| model_name: str, | |
| selected_ablations: list[str], | |
| run_lbw_guard: bool, | |
| max_steps: int, | |
| eval_every: int, | |
| eval_batches: int, | |
| seq_len: int, | |
| batch_size: int, | |
| train_chars: int, | |
| eval_chars: int, | |
| full_wikitext_train: bool, | |
| full_wikitext_eval: bool, | |
| full_validation_ppl: bool, | |
| lr: float, | |
| schedule_mode: str, | |
| warmup_steps: int, | |
| seed: int, | |
| lr_sweep_text: str, | |
| step_sweep_text: str, | |
| lora_r_sweep_text: str, | |
| ): | |
| logs: list[str] = [] | |
| rows: list[dict[str, Any]] = [] | |
| run_dir = RUNS_DIR / f"ablation_test_{int(time.time())}" | |
| run_dir.mkdir(parents=True, exist_ok=True) | |
| device_name = _device_default() | |
| device = torch.device(device_name) | |
| optimizers = ["adamw", "lbw_guard"] if bool(run_lbw_guard) else ["adamw"] | |
| try: | |
| base_config = { | |
| "seed": int(seed), | |
| "max_steps": int(max_steps), | |
| "eval_every": max(1, int(eval_every)), | |
| "eval_batches": int(eval_batches), | |
| "seq_len": int(seq_len), | |
| "batch_size": int(batch_size), | |
| "max_chars": int(train_chars), | |
| "eval_chars": int(eval_chars), | |
| "full_wikitext_train": bool(full_wikitext_train), | |
| "full_wikitext_eval": bool(full_wikitext_eval), | |
| "full_validation_ppl": bool(full_validation_ppl), | |
| "lr": float(lr), | |
| "betas": (0.9, 0.999), | |
| "weight_decay": 0.01, | |
| "warmup_steps": int(warmup_steps), | |
| "schedule_mode": str(schedule_mode or "constant").strip().lower(), | |
| "lora_r": 8, | |
| "lora_alpha": 16, | |
| "lora_dropout": 0.05, | |
| "lbw_stats_freq": 10, | |
| "lbw_stress_th": 1.1, | |
| "lbw_spike_th": 1.5, | |
| "lbw_rec_fast": 0.01, | |
| "lbw_ema_decay": 0.95, | |
| } | |
| lr_sweep = _parse_float_sweep(lr_sweep_text, [1e-3, 5e-4]) | |
| step_sweep = _parse_int_sweep(step_sweep_text, [100, 200]) | |
| lora_r_sweep = _parse_int_sweep(lora_r_sweep_text, [4, 8, 16]) | |
| scenarios = _build_ablation_scenarios( | |
| selected_ablations=list(selected_ablations or ["optimizer"]), | |
| base_config=base_config, | |
| lr_sweep=lr_sweep, | |
| step_sweep=step_sweep, | |
| lora_r_sweep=lora_r_sweep, | |
| ) | |
| if device.type == "cpu" and ( | |
| len(scenarios) > 1 | |
| or int(max_steps) > 1 | |
| or int(train_chars) > 20_000 | |
| or int(eval_chars) > 8_000 | |
| or bool(full_wikitext_train) | |
| or bool(full_wikitext_eval) | |
| or bool(full_validation_ppl) | |
| ): | |
| yield ( | |
| "This Space is currently on `cpu-basic`. CPU ablation mode is capped to one optimizer scenario, " | |
| "1 step, 20k train chars, 8k eval chars, and sampled validation. Switch the Space hardware to GPU for ablations.", | |
| None, | |
| None, | |
| None, | |
| ) | |
| return | |
| if device.type == "cuda" and bool(run_lbw_guard) and torch.cuda.device_count() > 1: | |
| yield ( | |
| "LBW Guard should run with one visible GPU. Set the Space to single-GPU hardware or restrict CUDA_VISIBLE_DEVICES.", | |
| None, | |
| None, | |
| None, | |
| ) | |
| return | |
| resolved_model = str(model_name).strip() or "Qwen/Qwen2.5-0.5B" | |
| _append_log(logs, f"Device: {device_name}") | |
| if device.type == "cuda": | |
| _append_log(logs, f"GPU: {torch.cuda.get_device_name(0)}") | |
| _append_log(logs, f"Selected ablations: {', '.join(selected_ablations or ['optimizer'])}") | |
| _append_log(logs, f"Optimizers: {', '.join(optimizers)}") | |
| yield _ablation_status_markdown( | |
| device_name=device_name, | |
| rows=rows, | |
| logs=logs, | |
| phase="Loading tokenizer", | |
| plan=scenarios, | |
| ), None, None, None | |
| tokenizer = AutoTokenizer.from_pretrained(resolved_model, use_fast=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| data_cache: dict[tuple[int, int | None, int | None], dict[str, dict[str, Any]]] = {} | |
| for scenario_item in scenarios: | |
| cfg = scenario_item["config"] | |
| train_cap = None if cfg["full_wikitext_train"] else int(cfg["max_chars"]) | |
| eval_cap = None if cfg["full_wikitext_eval"] else int(cfg["eval_chars"]) | |
| cache_key = (int(cfg["seq_len"]), train_cap, eval_cap) | |
| if cache_key not in data_cache: | |
| data_cache[cache_key] = { | |
| "train": _build_wikitext_chunks( | |
| tokenizer, | |
| split="train", | |
| max_chars=train_cap, | |
| seq_len=int(cfg["seq_len"]), | |
| logs=logs, | |
| ), | |
| "eval": _build_wikitext_chunks( | |
| tokenizer, | |
| split="validation", | |
| max_chars=eval_cap, | |
| seq_len=int(cfg["seq_len"]), | |
| logs=logs, | |
| ), | |
| } | |
| _append_log(logs, f"=== Scenario: {scenario_item['label']} ===") | |
| for optimizer_name in optimizers: | |
| _append_log(logs, f"--- {optimizer_name} ---") | |
| yield _ablation_status_markdown( | |
| device_name=device_name, | |
| rows=rows, | |
| logs=logs, | |
| phase=f"Running {scenario_item['label']} / {optimizer_name}", | |
| plan=scenarios, | |
| ), None, None, None | |
| for event in _run_ablation_optimizer_events( | |
| scenario_item=scenario_item, | |
| optimizer_name=optimizer_name, | |
| model_name=resolved_model, | |
| train_chunks=data_cache[cache_key]["train"], | |
| eval_chunks=data_cache[cache_key]["eval"], | |
| device=device, | |
| logs=logs, | |
| ): | |
| if event.get("type") == "result": | |
| rows.append(event["result"]) | |
| yield _ablation_status_markdown( | |
| device_name=device_name, | |
| rows=rows, | |
| logs=logs, | |
| phase=f"Running {scenario_item['label']} / {optimizer_name}", | |
| plan=scenarios, | |
| ), None, None, None | |
| gains = _build_ablation_gain_rows(rows) | |
| payload = { | |
| "source": "HF Ablation Matrix Runner", | |
| "based_on_colab": "LBW_Guard_Ablation_Test_COLAB.ipynb", | |
| "model_name": resolved_model, | |
| "device": device_name, | |
| "optimizers": optimizers, | |
| "selected_ablations": list(selected_ablations or ["optimizer"]), | |
| "base_config": base_config, | |
| "scenarios": scenarios, | |
| "results": rows, | |
| "gains": gains, | |
| "logs": logs, | |
| } | |
| json_path = run_dir / "lbw_guard_hf_ablation_matrix_results.json" | |
| metrics_path = run_dir / "lbw_guard_hf_ablation_matrix_metrics.csv" | |
| gains_path = run_dir / "lbw_guard_hf_ablation_matrix_gains.csv" | |
| json_path.write_text(json.dumps(payload, indent=2), encoding="utf-8") | |
| _write_csv(metrics_path, rows) | |
| _write_csv(gains_path, gains) | |
| _append_log(logs, f"Wrote {metrics_path}") | |
| _append_log(logs, f"Wrote {gains_path}") | |
| yield ( | |
| _ablation_status_markdown(device_name=device_name, rows=rows, logs=logs, phase="Complete", plan=scenarios), | |
| str(json_path), | |
| str(metrics_path), | |
| str(gains_path), | |
| ) | |
| except Exception: | |
| error_text = traceback.format_exc() | |
| error_path = run_dir / "error.txt" | |
| error_path.write_text(error_text + "\n\n" + "\n".join(logs), encoding="utf-8") | |
| yield f"Run failed.\n\n```text\n{error_text}\n```", str(error_path), None, None | |
| INTRO = """ | |
| # LBW Guard HF Evaluation Runner | |
| This private Space has two HF-specific runners based on the customer Colab notebooks: | |
| - **Quick Comparison**: a short AdamW vs `lbw_guard` WikiText LoRA run. | |
| - **Ablation Matrix**: a selectable scenario sweep across optimizer, LR, schedule, steps, data, and LoRA rank. | |
| Each run writes JSON and CSV artifacts. GPU hardware is recommended; CPU mode is only for tiny smoke checks. | |
| """ | |
| with gr.Blocks(title="LBW Guard HF Evaluation Runner") as demo: | |
| gr.Markdown(INTRO) | |
| with gr.Tabs(): | |
| with gr.Tab("Quick Comparison"): | |
| gr.Markdown( | |
| "Runs the HF version of the Easy Test flow: one shared WikiText slice, then AdamW and optional " | |
| "`lbw_guard` LoRA training with final perplexity and gain CSV output." | |
| ) | |
| with gr.Row(): | |
| easy_model_name = gr.Textbox(value="TinyLlama/TinyLlama_v1.1", label="Model") | |
| easy_run_lbw_guard = gr.Checkbox(value=True, label="Run LBW Guard comparison") | |
| with gr.Row(): | |
| easy_max_steps = gr.Slider(1, 1000, value=5, step=1, label="Optimizer steps") | |
| easy_eval_every = gr.Slider(1, 200, value=5, step=1, label="Eval every") | |
| easy_eval_batches = gr.Slider(1, 128, value=8, step=1, label="Eval batches") | |
| with gr.Row(): | |
| easy_seq_len = gr.Dropdown([64, 128, 256, 512], value=64, label="Sequence length") | |
| easy_batch_size = gr.Slider(1, 8, value=1, step=1, label="Batch size") | |
| easy_lr = gr.Number(value=5e-4, label="Learning rate") | |
| with gr.Row(): | |
| easy_train_chars = gr.Slider(5_000, 2_000_000, value=20_000, step=5_000, label="Train char cap") | |
| easy_eval_chars = gr.Slider(1_000, 500_000, value=8_000, step=1_000, label="Eval char cap") | |
| easy_seed = gr.Number(value=42, precision=0, label="Seed") | |
| with gr.Row(): | |
| easy_full_wikitext_train = gr.Checkbox(value=False, label="Full WikiText train") | |
| easy_full_wikitext_eval = gr.Checkbox(value=False, label="Full WikiText eval") | |
| easy_full_validation_ppl = gr.Checkbox(value=False, label="Full validation PPL") | |
| easy_run_button = gr.Button("Run Quick Comparison", variant="primary") | |
| easy_summary = gr.Markdown() | |
| easy_json_file = gr.File(label="Raw JSON") | |
| easy_results_file = gr.File(label="Results CSV") | |
| easy_gains_file = gr.File(label="Gains CSV") | |
| easy_run_button.click( | |
| fn=run_easy_test, | |
| inputs=[ | |
| easy_model_name, | |
| easy_run_lbw_guard, | |
| easy_max_steps, | |
| easy_eval_every, | |
| easy_eval_batches, | |
| easy_seq_len, | |
| easy_batch_size, | |
| easy_train_chars, | |
| easy_eval_chars, | |
| easy_full_wikitext_train, | |
| easy_full_wikitext_eval, | |
| easy_full_validation_ppl, | |
| easy_lr, | |
| easy_seed, | |
| ], | |
| outputs=[easy_summary, easy_json_file, easy_results_file, easy_gains_file], | |
| ) | |
| with gr.Tab("Ablation Matrix"): | |
| gr.Markdown( | |
| "Runs the HF version of the ablation flow: build selected scenarios, run AdamW and optional " | |
| "`lbw_guard` for each scenario, then export metrics and LBW-vs-AdamW gains." | |
| ) | |
| with gr.Row(): | |
| ablation_model_name = gr.Textbox(value="Qwen/Qwen2.5-0.5B", label="Model") | |
| ablation_run_lbw_guard = gr.Checkbox(value=True, label="Run LBW Guard comparison") | |
| selected_ablations = gr.CheckboxGroup( | |
| choices=["optimizer", "lr", "schedule", "steps", "data", "lora"], | |
| value=["optimizer"], | |
| label="Ablations", | |
| ) | |
| with gr.Row(): | |
| ablation_max_steps = gr.Slider(1, 1000, value=200, step=1, label="Base optimizer steps") | |
| ablation_eval_every = gr.Slider(1, 200, value=50, step=1, label="Eval every") | |
| ablation_eval_batches = gr.Slider(1, 128, value=8, step=1, label="Eval batches") | |
| with gr.Row(): | |
| ablation_seq_len = gr.Dropdown([64, 128, 256, 512], value=64, label="Sequence length") | |
| ablation_batch_size = gr.Slider(1, 8, value=1, step=1, label="Batch size") | |
| ablation_lr = gr.Number(value=5e-4, label="Base learning rate") | |
| with gr.Row(): | |
| ablation_train_chars = gr.Slider(5_000, 2_000_000, value=20_000, step=5_000, label="Train char cap") | |
| ablation_eval_chars = gr.Slider(1_000, 500_000, value=8_000, step=1_000, label="Eval char cap") | |
| ablation_seed = gr.Number(value=42, precision=0, label="Seed") | |
| with gr.Row(): | |
| ablation_schedule_mode = gr.Dropdown(["constant", "cosine"], value="constant", label="Base schedule") | |
| ablation_warmup_steps = gr.Slider(0, 100, value=10, step=1, label="Warmup steps") | |
| with gr.Row(): | |
| ablation_full_wikitext_train = gr.Checkbox(value=False, label="Full WikiText train") | |
| ablation_full_wikitext_eval = gr.Checkbox(value=False, label="Full WikiText eval") | |
| ablation_full_validation_ppl = gr.Checkbox(value=False, label="Full validation PPL") | |
| with gr.Row(): | |
| lr_sweep_text = gr.Textbox(value="1e-3, 5e-4", label="LR sweep") | |
| step_sweep_text = gr.Textbox(value="100, 200", label="Step sweep") | |
| lora_r_sweep_text = gr.Textbox(value="4, 8, 16", label="LoRA r sweep") | |
| ablation_run_button = gr.Button("Run Ablation Matrix", variant="primary") | |
| ablation_summary = gr.Markdown() | |
| ablation_json_file = gr.File(label="Raw JSON") | |
| ablation_metrics_file = gr.File(label="Metrics CSV") | |
| ablation_gains_file = gr.File(label="Gains CSV") | |
| ablation_run_button.click( | |
| fn=run_ablation_test, | |
| inputs=[ | |
| ablation_model_name, | |
| selected_ablations, | |
| ablation_run_lbw_guard, | |
| ablation_max_steps, | |
| ablation_eval_every, | |
| ablation_eval_batches, | |
| ablation_seq_len, | |
| ablation_batch_size, | |
| ablation_train_chars, | |
| ablation_eval_chars, | |
| ablation_full_wikitext_train, | |
| ablation_full_wikitext_eval, | |
| ablation_full_validation_ppl, | |
| ablation_lr, | |
| ablation_schedule_mode, | |
| ablation_warmup_steps, | |
| ablation_seed, | |
| lr_sweep_text, | |
| step_sweep_text, | |
| lora_r_sweep_text, | |
| ], | |
| outputs=[ablation_summary, ablation_json_file, ablation_metrics_file, ablation_gains_file], | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(default_concurrency_limit=1).launch() | |