import torch import torch.nn as nn import torch.nn.functional as F import urllib.request import time import sys import os sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) from arbitor.main import ( VOCAB, EMBEDDING_DIM, HIDDEN_DIM, FFN_HIDDEN, CTX, THRESHOLD, SPECIAL_VOCAB, ARBModel, StickyZoneSTE, save_model, load_model, ) TRAIN_PARAMS = { "batch_size": 32, "ctx": 66, "lr": 3e-4, "weight_decay": 0.01, "max_steps": 5000, "eval_interval": 500, "eval_steps": 100, } def download_data(): path = os.path.join(os.path.dirname(__file__), "tinyshakespeare.txt") if not os.path.exists(path): url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" urllib.request.urlretrieve(url, path) with open(path, "r", encoding="utf-8") as f: text = f.read() byte_data = torch.tensor(list(text.encode("utf-8")), dtype=torch.long) n = int(0.9 * len(byte_data)) return byte_data[:n], byte_data[n:] def get_batch(data, batch_size, ctx, device): ix = torch.randint(0, len(data) - ctx - 1, (batch_size,)) x = torch.stack([data[i : i + ctx] for i in ix]) targets = x[:, 3:] return x.to(device), targets.to(device) class FP32Linear(nn.Module): def __init__(self, in_dim, out_dim): super().__init__() self.weight = nn.Parameter(torch.randn(out_dim, in_dim) * 0.1) self.bias = nn.Parameter(torch.zeros(out_dim)) def forward(self, x): return F.linear(x, self.weight, self.bias) class FP32RMSNorm(nn.Module): def __init__(self, dim, eps=1e-8): super().__init__() self.scale = nn.Parameter(torch.ones(dim)) self.eps = eps def forward(self, x): rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps) return self.scale * (x / rms) class FP32TrigramModel(nn.Module): def __init__(self): super().__init__() self.embedding = nn.Embedding(VOCAB, EMBEDDING_DIM) self.embed_norm = FP32RMSNorm(EMBEDDING_DIM) self.projection = FP32Linear(EMBEDDING_DIM * 3, HIDDEN_DIM) self.proj_norm = FP32RMSNorm(HIDDEN_DIM) self.ffn_norm1 = FP32RMSNorm(HIDDEN_DIM) self.fc1 = FP32Linear(HIDDEN_DIM, FFN_HIDDEN) self.ffn_norm2 = FP32RMSNorm(FFN_HIDDEN) self.fc2 = FP32Linear(FFN_HIDDEN, HIDDEN_DIM) self.head_norm = FP32RMSNorm(HIDDEN_DIM) self.head = FP32Linear(HIDDEN_DIM, VOCAB) def forward(self, x, targets=None): from einops import rearrange embedded = self.embed_norm(self.embedding(x)) trigrams = embedded.unfold(dimension=1, size=3, step=1) trigrams = rearrange(trigrams, 'b t d w -> b t (d w)') relational = self.proj_norm(self.projection(trigrams)) h = self.ffn_norm1(relational) h = torch.relu(self.fc1(h)) h = self.ffn_norm2(h) h = self.fc2(h) logits = self.head(self.head_norm(h)) loss = None if targets is not None: next_byte_logits = logits[:, :-1, :].contiguous() loss = F.cross_entropy( next_byte_logits.view(-1, VOCAB), targets.contiguous().view(-1), ignore_index=SPECIAL_VOCAB["PAD"], ) return logits, loss def evaluate(model, val_data, device, eval_steps=100): model.eval() losses = [] with torch.no_grad(): for _ in range(eval_steps): x, targets = get_batch(val_data, TRAIN_PARAMS["batch_size"], TRAIN_PARAMS["ctx"], device) _, loss = model(x, targets=targets) losses.append(loss.item()) model.train() return sum(losses) / len(losses) def count_params(model): total = sum(p.numel() for p in model.parameters()) ternary = 0 fp32 = 0 for n, p in model.named_parameters(): if "weight" in n and p.ndim >= 2 and "embed" not in n: ternary += p.numel() else: fp32 += p.numel() return total, ternary, fp32 def log_diagnostics(model, step, train_loss, val_loss, config_name, is_ternary=False): print(f"[{config_name}] step {step} | train_loss={train_loss:.4f} | val_loss={val_loss:.4f}") if is_ternary: for name, param in model.named_parameters(): if "weight" in name and param.ndim >= 2 and "embed" not in name: with torch.no_grad(): T = StickyZoneSTE.apply(param, THRESHOLD) frac_zero = (T == 0).float().mean().item() frac_pos = (T > 0).float().mean().item() frac_neg = (T < 0).float().mean().item() grad_norm = param.grad.norm().item() if param.grad is not None else 0.0 print(f" {name}: +{frac_pos:.2%} -{frac_neg:.2%} 0{frac_zero:.2%} | grad={grad_norm:.6f}") if frac_zero > 0.95: print(f" ⚠ COLLAPSE: {name} is >95% zeros") def train_model(model, train_data, val_data, config_name, device, steps=5000): optimizer = torch.optim.AdamW( model.parameters(), lr=TRAIN_PARAMS["lr"], weight_decay=TRAIN_PARAMS["weight_decay"] ) is_ternary = "Ternary" in config_name train_losses = [] val_losses = [] step_list = [] step_times = [] for step in range(steps): t0 = time.perf_counter() x, targets = get_batch(train_data, TRAIN_PARAMS["batch_size"], TRAIN_PARAMS["ctx"], device) _, loss = model(x, targets=targets) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() if device == "cuda": torch.cuda.synchronize() step_times.append(time.perf_counter() - t0) if (step + 1) % TRAIN_PARAMS["eval_interval"] == 0: val_loss = evaluate(model, val_data, device, TRAIN_PARAMS["eval_steps"]) log_diagnostics(model, step + 1, loss.item(), val_loss, config_name, is_ternary) train_losses.append(loss.item()) val_losses.append(val_loss) step_list.append(step + 1) final_val = evaluate(model, val_data, device, TRAIN_PARAMS["eval_steps"]) avg_step_ms = sum(step_times) / len(step_times) * 1000 total_s = sum(step_times) print(f"[{config_name}] speed: {avg_step_ms:.2f} ms/step | {steps/total_s:.1f} steps/s | total {total_s:.1f}s") return { "config": config_name, "final_train_loss": train_losses[-1] if train_losses else loss.item(), "final_val_loss": final_val, "train_losses": train_losses, "val_losses": val_loss, "steps": step_list, "avg_step_ms": avg_step_ms, "steps_per_sec": steps / total_s, "total_s": total_s, "param_count": sum(p.numel() for p in model.parameters()), } def analyze_results(results_fp32, results_ternary): print("\n" + "=" * 80) print("MORPH TERNARY vs FP32 — REAL MODEL BENCHMARK") print("=" * 80) print(f"{'Config':<30} {'Val Loss':>9} {'Params':>10} {'ms/step':>9} {'stps/s':>8}") print("-" * 70) for r in [results_fp32, results_ternary]: print( f"{r['config']:<30} {r['final_val_loss']:>9.4f} " f"{r['param_count']:>10} {r['avg_step_ms']:>9.2f} {r['steps_per_sec']:>8.1f}" ) fp32_loss = results_fp32["final_val_loss"] ternary_loss = results_ternary["final_val_loss"] fp32_speed = results_fp32["avg_step_ms"] ternary_speed = results_ternary["avg_step_ms"] ratio = ternary_loss / fp32_loss speed_ratio = ternary_speed / fp32_speed print(f"\n--- Precision ---") print(f"Ternary / FP32 loss ratio = {ratio:.3f}x") if ratio <= 1.25: print("✅ Ternary within 1.25x FP32 — PASS") else: print("❌ Ternary exceeds 1.25x FP32 — FAIL") print(f"\n--- Speed ---") print(f"Ternary vs FP32: {speed_ratio:.2f}x ({ternary_speed:.2f}ms vs {fp32_speed:.2f}ms)") if speed_ratio <= 1.0: print("✅ Ternary is faster or equal to FP32") elif speed_ratio <= 1.5: print("⚡ Ternary is slower but within 1.5x — expected for unoptimized path") else: print("⚠ Ternary significantly slower — investigate") total, ternary_p, fp32_p = count_params(ARBModel()) eff_bpw = (fp32_p * 32 + ternary_p * 1.58) / total print(f"\n--- Effective BPW ---") print(f"Ternary params: {ternary_p:,} (1.58 BPW)") print(f"FP32 params: {fp32_p:,} (32 BPW)") print(f"Effective BPW: {eff_bpw:.2f} bits/weight (avg across all params)") print(f"Memory savings: {32.0/eff_bpw:.1f}x vs pure FP32") def run_all(): device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Device: {device}") print("Downloading data...") train_data, val_data = download_data() print(f"Train: {len(train_data):,} bytes, Val: {len(val_data):,} bytes") print("\n--- FP32 Baseline ---") model_fp32 = FP32TrigramModel().to(device) total_fp = sum(p.numel() for p in model_fp32.parameters()) print(f"Params: {total_fp:,}") results_fp32 = train_model( model_fp32, train_data, val_data, "FP32-Trigram", device, TRAIN_PARAMS["max_steps"] ) print(f"FP32 final val loss: {results_fp32['final_val_loss']:.4f}") del model_fp32 torch.cuda.empty_cache() print("\n--- Ternary (Config E: Factorized Scaled) ---") model_ternary = ARBModel().to(device) total_t = sum(p.numel() for p in model_ternary.parameters()) print(f"Params: {total_t:,}") results_ternary = train_model( model_ternary, train_data, val_data, "Ternary-ConfigE", device, TRAIN_PARAMS["max_steps"] ) print(f"Ternary final val loss: {results_ternary['final_val_loss']:.4f}") save_path = os.path.join(os.path.dirname(__file__), "..", "..", "models", "conversions", "arb-model.pt") save_model(model_ternary, save_path) test_model = load_model(save_path, device) test_x = torch.randint(0, VOCAB, (2, 66), device=device) with torch.no_grad(): logits, _ = test_model(test_x) print(f"\nLoaded model test: logits shape = {logits.shape} ✅") del model_ternary torch.cuda.empty_cache() analyze_results(results_fp32, results_ternary) if __name__ == "__main__": run_all()