| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import urllib.request |
| import time |
| import sys |
| import os |
|
|
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) |
|
|
| from arbitor.main import ( |
| VOCAB, EMBEDDING_DIM, HIDDEN_DIM, FFN_HIDDEN, CTX, THRESHOLD, |
| SPECIAL_VOCAB, ARBModel, StickyZoneSTE, |
| save_model, load_model, |
| ) |
|
|
| TRAIN_PARAMS = { |
| "batch_size": 32, |
| "ctx": 66, |
| "lr": 3e-4, |
| "weight_decay": 0.01, |
| "max_steps": 5000, |
| "eval_interval": 500, |
| "eval_steps": 100, |
| } |
|
|
|
|
| def download_data(): |
| path = os.path.join(os.path.dirname(__file__), "tinyshakespeare.txt") |
| if not os.path.exists(path): |
| url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" |
| urllib.request.urlretrieve(url, path) |
| with open(path, "r", encoding="utf-8") as f: |
| text = f.read() |
| byte_data = torch.tensor(list(text.encode("utf-8")), dtype=torch.long) |
| n = int(0.9 * len(byte_data)) |
| return byte_data[:n], byte_data[n:] |
|
|
|
|
| def get_batch(data, batch_size, ctx, device): |
| ix = torch.randint(0, len(data) - ctx - 1, (batch_size,)) |
| x = torch.stack([data[i : i + ctx] for i in ix]) |
| targets = x[:, 3:] |
| return x.to(device), targets.to(device) |
|
|
|
|
| class FP32Linear(nn.Module): |
| def __init__(self, in_dim, out_dim): |
| super().__init__() |
| self.weight = nn.Parameter(torch.randn(out_dim, in_dim) * 0.1) |
| self.bias = nn.Parameter(torch.zeros(out_dim)) |
|
|
| def forward(self, x): |
| return F.linear(x, self.weight, self.bias) |
|
|
|
|
| class FP32RMSNorm(nn.Module): |
| def __init__(self, dim, eps=1e-8): |
| super().__init__() |
| self.scale = nn.Parameter(torch.ones(dim)) |
| self.eps = eps |
|
|
| def forward(self, x): |
| rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps) |
| return self.scale * (x / rms) |
|
|
|
|
| class FP32TrigramModel(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.embedding = nn.Embedding(VOCAB, EMBEDDING_DIM) |
| self.embed_norm = FP32RMSNorm(EMBEDDING_DIM) |
| self.projection = FP32Linear(EMBEDDING_DIM * 3, HIDDEN_DIM) |
| self.proj_norm = FP32RMSNorm(HIDDEN_DIM) |
| self.ffn_norm1 = FP32RMSNorm(HIDDEN_DIM) |
| self.fc1 = FP32Linear(HIDDEN_DIM, FFN_HIDDEN) |
| self.ffn_norm2 = FP32RMSNorm(FFN_HIDDEN) |
| self.fc2 = FP32Linear(FFN_HIDDEN, HIDDEN_DIM) |
| self.head_norm = FP32RMSNorm(HIDDEN_DIM) |
| self.head = FP32Linear(HIDDEN_DIM, VOCAB) |
|
|
| def forward(self, x, targets=None): |
| from einops import rearrange |
| embedded = self.embed_norm(self.embedding(x)) |
| trigrams = embedded.unfold(dimension=1, size=3, step=1) |
| trigrams = rearrange(trigrams, 'b t d w -> b t (d w)') |
| relational = self.proj_norm(self.projection(trigrams)) |
| h = self.ffn_norm1(relational) |
| h = torch.relu(self.fc1(h)) |
| h = self.ffn_norm2(h) |
| h = self.fc2(h) |
| logits = self.head(self.head_norm(h)) |
|
|
| loss = None |
| if targets is not None: |
| next_byte_logits = logits[:, :-1, :].contiguous() |
| loss = F.cross_entropy( |
| next_byte_logits.view(-1, VOCAB), |
| targets.contiguous().view(-1), |
| ignore_index=SPECIAL_VOCAB["PAD"], |
| ) |
| return logits, loss |
|
|
|
|
| def evaluate(model, val_data, device, eval_steps=100): |
| model.eval() |
| losses = [] |
| with torch.no_grad(): |
| for _ in range(eval_steps): |
| x, targets = get_batch(val_data, TRAIN_PARAMS["batch_size"], TRAIN_PARAMS["ctx"], device) |
| _, loss = model(x, targets=targets) |
| losses.append(loss.item()) |
| model.train() |
| return sum(losses) / len(losses) |
|
|
|
|
| def count_params(model): |
| total = sum(p.numel() for p in model.parameters()) |
| ternary = 0 |
| fp32 = 0 |
| for n, p in model.named_parameters(): |
| if "weight" in n and p.ndim >= 2 and "embed" not in n: |
| ternary += p.numel() |
| else: |
| fp32 += p.numel() |
| return total, ternary, fp32 |
|
|
|
|
| def log_diagnostics(model, step, train_loss, val_loss, config_name, is_ternary=False): |
| print(f"[{config_name}] step {step} | train_loss={train_loss:.4f} | val_loss={val_loss:.4f}") |
| if is_ternary: |
| for name, param in model.named_parameters(): |
| if "weight" in name and param.ndim >= 2 and "embed" not in name: |
| with torch.no_grad(): |
| T = StickyZoneSTE.apply(param, THRESHOLD) |
| frac_zero = (T == 0).float().mean().item() |
| frac_pos = (T > 0).float().mean().item() |
| frac_neg = (T < 0).float().mean().item() |
| grad_norm = param.grad.norm().item() if param.grad is not None else 0.0 |
| print(f" {name}: +{frac_pos:.2%} -{frac_neg:.2%} 0{frac_zero:.2%} | grad={grad_norm:.6f}") |
| if frac_zero > 0.95: |
| print(f" β COLLAPSE: {name} is >95% zeros") |
|
|
|
|
| def train_model(model, train_data, val_data, config_name, device, steps=5000): |
| optimizer = torch.optim.AdamW( |
| model.parameters(), lr=TRAIN_PARAMS["lr"], weight_decay=TRAIN_PARAMS["weight_decay"] |
| ) |
| is_ternary = "Ternary" in config_name |
| train_losses = [] |
| val_losses = [] |
| step_list = [] |
| step_times = [] |
|
|
| for step in range(steps): |
| t0 = time.perf_counter() |
| x, targets = get_batch(train_data, TRAIN_PARAMS["batch_size"], TRAIN_PARAMS["ctx"], device) |
| _, loss = model(x, targets=targets) |
| optimizer.zero_grad() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| if device == "cuda": |
| torch.cuda.synchronize() |
| step_times.append(time.perf_counter() - t0) |
|
|
| if (step + 1) % TRAIN_PARAMS["eval_interval"] == 0: |
| val_loss = evaluate(model, val_data, device, TRAIN_PARAMS["eval_steps"]) |
| log_diagnostics(model, step + 1, loss.item(), val_loss, config_name, is_ternary) |
| train_losses.append(loss.item()) |
| val_losses.append(val_loss) |
| step_list.append(step + 1) |
|
|
| final_val = evaluate(model, val_data, device, TRAIN_PARAMS["eval_steps"]) |
| avg_step_ms = sum(step_times) / len(step_times) * 1000 |
| total_s = sum(step_times) |
| print(f"[{config_name}] speed: {avg_step_ms:.2f} ms/step | {steps/total_s:.1f} steps/s | total {total_s:.1f}s") |
| return { |
| "config": config_name, |
| "final_train_loss": train_losses[-1] if train_losses else loss.item(), |
| "final_val_loss": final_val, |
| "train_losses": train_losses, |
| "val_losses": val_loss, |
| "steps": step_list, |
| "avg_step_ms": avg_step_ms, |
| "steps_per_sec": steps / total_s, |
| "total_s": total_s, |
| "param_count": sum(p.numel() for p in model.parameters()), |
| } |
|
|
|
|
| def analyze_results(results_fp32, results_ternary): |
| print("\n" + "=" * 80) |
| print("MORPH TERNARY vs FP32 β REAL MODEL BENCHMARK") |
| print("=" * 80) |
| print(f"{'Config':<30} {'Val Loss':>9} {'Params':>10} {'ms/step':>9} {'stps/s':>8}") |
| print("-" * 70) |
| for r in [results_fp32, results_ternary]: |
| print( |
| f"{r['config']:<30} {r['final_val_loss']:>9.4f} " |
| f"{r['param_count']:>10} {r['avg_step_ms']:>9.2f} {r['steps_per_sec']:>8.1f}" |
| ) |
|
|
| fp32_loss = results_fp32["final_val_loss"] |
| ternary_loss = results_ternary["final_val_loss"] |
| fp32_speed = results_fp32["avg_step_ms"] |
| ternary_speed = results_ternary["avg_step_ms"] |
|
|
| ratio = ternary_loss / fp32_loss |
| speed_ratio = ternary_speed / fp32_speed |
| print(f"\n--- Precision ---") |
| print(f"Ternary / FP32 loss ratio = {ratio:.3f}x") |
| if ratio <= 1.25: |
| print("β
Ternary within 1.25x FP32 β PASS") |
| else: |
| print("β Ternary exceeds 1.25x FP32 β FAIL") |
|
|
| print(f"\n--- Speed ---") |
| print(f"Ternary vs FP32: {speed_ratio:.2f}x ({ternary_speed:.2f}ms vs {fp32_speed:.2f}ms)") |
| if speed_ratio <= 1.0: |
| print("β
Ternary is faster or equal to FP32") |
| elif speed_ratio <= 1.5: |
| print("β‘ Ternary is slower but within 1.5x β expected for unoptimized path") |
| else: |
| print("β Ternary significantly slower β investigate") |
|
|
| total, ternary_p, fp32_p = count_params(ARBModel()) |
| eff_bpw = (fp32_p * 32 + ternary_p * 1.58) / total |
| print(f"\n--- Effective BPW ---") |
| print(f"Ternary params: {ternary_p:,} (1.58 BPW)") |
| print(f"FP32 params: {fp32_p:,} (32 BPW)") |
| print(f"Effective BPW: {eff_bpw:.2f} bits/weight (avg across all params)") |
| print(f"Memory savings: {32.0/eff_bpw:.1f}x vs pure FP32") |
|
|
|
|
| def run_all(): |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"Device: {device}") |
|
|
| print("Downloading data...") |
| train_data, val_data = download_data() |
| print(f"Train: {len(train_data):,} bytes, Val: {len(val_data):,} bytes") |
|
|
| print("\n--- FP32 Baseline ---") |
| model_fp32 = FP32TrigramModel().to(device) |
| total_fp = sum(p.numel() for p in model_fp32.parameters()) |
| print(f"Params: {total_fp:,}") |
| results_fp32 = train_model( |
| model_fp32, train_data, val_data, "FP32-Trigram", device, TRAIN_PARAMS["max_steps"] |
| ) |
| print(f"FP32 final val loss: {results_fp32['final_val_loss']:.4f}") |
| del model_fp32 |
| torch.cuda.empty_cache() |
|
|
| print("\n--- Ternary (Config E: Factorized Scaled) ---") |
| model_ternary = ARBModel().to(device) |
| total_t = sum(p.numel() for p in model_ternary.parameters()) |
| print(f"Params: {total_t:,}") |
| results_ternary = train_model( |
| model_ternary, train_data, val_data, "Ternary-ConfigE", device, TRAIN_PARAMS["max_steps"] |
| ) |
| print(f"Ternary final val loss: {results_ternary['final_val_loss']:.4f}") |
|
|
| save_path = os.path.join(os.path.dirname(__file__), "..", "..", "models", "conversions", "arb-model.pt") |
| save_model(model_ternary, save_path) |
|
|
| test_model = load_model(save_path, device) |
| test_x = torch.randint(0, VOCAB, (2, 66), device=device) |
| with torch.no_grad(): |
| logits, _ = test_model(test_x) |
| print(f"\nLoaded model test: logits shape = {logits.shape} β
") |
| del model_ternary |
| torch.cuda.empty_cache() |
|
|
| analyze_results(results_fp32, results_ternary) |
|
|
|
|
| if __name__ == "__main__": |
| run_all() |
|
|