import os import sys sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" import torch import torch.nn.functional as F import urllib.request import sys import time import math from trigram import ( VOCAB, EMBEDDING_DIM, HIDDEN_DIM, FFN_HIDDEN, CTX, THRESHOLD, SPECIAL_VOCAB, MORPHTernaryModel, StickyZoneSTE, ) CKPT_DIR = os.path.join(os.path.dirname(__file__) or ".", "runs", "ternary-v1") BATCH_SIZE = 1024 CTX = 66 EVAL_STEPS = 500 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" def download_data(data_dir): path = os.path.join(data_dir, "tinyshakespeare.txt") if not os.path.exists(path): print("Downloading tinyshakespeare...") urllib.request.urlretrieve( "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt", path, ) with open(path, "r", encoding="utf-8") as f: text = f.read() byte_data = torch.tensor(list(text.encode("utf-8")), dtype=torch.long) n = int(0.9 * len(byte_data)) return byte_data[:n], byte_data[n:] def get_batch(data, batch_size, ctx, device): ix = torch.randint(0, len(data) - ctx - 1, (batch_size,)) x = torch.stack([data[i : i + ctx] for i in ix]) targets = x[:, 3:] return x.to(device, non_blocking=True), targets.to(device, non_blocking=True) @torch.no_grad() def evaluate(model, val_data): model.eval() losses = [] for _ in range(EVAL_STEPS): x, targets = get_batch(val_data, batch_size=BATCH_SIZE, ctx=CTX, device=DEVICE) with torch.autocast("cuda", dtype=torch.bfloat16): _, loss = model(x, targets=targets) losses.append(loss.item()) return sum(losses) / len(losses) @torch.no_grad() def evaluate_train(model, train_data, n_steps=200): model.eval() losses = [] for _ in range(n_steps): x, targets = get_batch(train_data, batch_size=BATCH_SIZE, ctx=CTX, device=DEVICE) with torch.autocast("cuda", dtype=torch.bfloat16): _, loss = model(x, targets=targets) losses.append(loss.item()) return sum(losses) / len(losses) @torch.no_grad() def ternary_distribution(model): stats = {} for name, param in model.named_parameters(): if "weight" in name and param.ndim >= 2 and "embed" not in name: T = StickyZoneSTE.apply(param, THRESHOLD) frac_pos = (T > 0).float().mean().item() frac_neg = (T < 0).float().mean().item() frac_zero = (T == 0).float().mean().item() s_mean = param.abs().mean().item() s_std = param.abs().std().item() stats[name] = { "pos": frac_pos, "neg": frac_neg, "zero": frac_zero, "s_mean": s_mean, "s_std": s_std, } return stats @torch.no_grad() def generate_sample(model, seed_bytes, max_new_tokens=200, temperature=0.8, top_k=40): model.eval() idx = torch.tensor([seed_bytes], dtype=torch.long, device=DEVICE) for _ in range(max_new_tokens): idx_cond = idx[:, -CTX:] with torch.autocast("cuda", dtype=torch.bfloat16): logits, _ = model(idx_cond) last_logits = logits[:, -1, :] / temperature if top_k is not None: v, _ = torch.topk(last_logits, top_k) last_logits[last_logits < v[:, [-1]]] = float("-inf") probs = F.softmax(last_logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1) idx = torch.cat([idx, idx_next], dim=1) return idx[0].cpu().tolist() def bytes_to_text(byte_list): readable = [] for b in byte_list: if 32 <= b < 127: readable.append(chr(b)) elif b == 10: readable.append("\n") elif b == 13: readable.append("") elif b == 9: readable.append("\t") elif b >= 256: readable.append(f"<{b}>") else: readable.append(f"\\x{b:02x}") return "".join(readable) @torch.no_grad() def measure_inference_speed(model, n_steps=100): model.eval() x = torch.randint(0, VOCAB, (1, CTX), device=DEVICE) with torch.autocast("cuda", dtype=torch.bfloat16): for _ in range(10): model(x) if DEVICE == "cuda": torch.cuda.synchronize() t0 = time.perf_counter() for _ in range(n_steps): model(x) if DEVICE == "cuda": torch.cuda.synchronize() t1 = time.perf_counter() return n_steps / (t1 - t0) def perplexity(loss): return math.exp(loss) def main(): print(f"Device: {DEVICE}") print(f"Eval: {EVAL_STEPS} batches x {BATCH_SIZE} samples, ctx={CTX}") print("=" * 80) data_dir = os.path.dirname(__file__) or "." train_data, val_data = download_data(data_dir) print(f"Data: train={len(train_data):,} bytes | val={len(val_data):,} bytes\n") seed_text = "ROMEO:\nWhat light through yonder window breaks?\n" seed_bytes = list(seed_text.encode("utf-8")) checkpoints = [ ("init (random)", None), ("step5000", os.path.join(CKPT_DIR, "trigram-morph-step5000.pt")), ("best (step7K)", os.path.join(CKPT_DIR, "trigram-morph-best.pt")), ("step13000", os.path.join(CKPT_DIR, "trigram-morph-step13000.pt")), ("step25000", os.path.join(CKPT_DIR, "trigram-morph-step25000.pt")), ] results = [] for label, path in checkpoints: print(f"\n{'=' * 80}") print(f"CHECKPOINT: {label}") print(f"{'=' * 80}") model = MORPHTernaryModel().to(DEVICE) if path is not None and os.path.exists(path): ckpt = torch.load(path, map_location=DEVICE, weights_only=False) model.load_state_dict(ckpt["model_state_dict"]) print(f"Loaded: {path}") elif path is not None: print(f"MISSING: {path} — skipping") del model continue else: print("Init model (random weights, no training)") total_params = sum(p.numel() for p in model.parameters()) ternary_params = sum( p.numel() for n, p in model.named_parameters() if "weight" in n and p.ndim >= 2 and "embed" not in n ) fp32_params = total_params - ternary_params eff_bpw = (fp32_params * 32 + ternary_params * 1.58) / total_params print(f"Params: {total_params:,} | ternary: {ternary_params:,} | fp32: {fp32_params:,} | BPW: {eff_bpw:.2f}") t0 = time.perf_counter() val_loss = evaluate(model, val_data) t_val = time.perf_counter() - t0 val_ppl = perplexity(val_loss) t0 = time.perf_counter() train_loss = evaluate_train(model, train_data) t_train = time.perf_counter() - t0 train_ppl = perplexity(train_loss) gap = train_loss - val_loss speed = measure_inference_speed(model) stats = ternary_distribution(model) sample_tokens = generate_sample(model, seed_bytes, max_new_tokens=150, temperature=0.8) sample_text = bytes_to_text(sample_tokens) results.append({ "label": label, "val_loss": val_loss, "val_ppl": val_ppl, "train_loss": train_loss, "train_ppl": train_ppl, "gap": gap, "speed": speed, "stats": stats, "sample": sample_text, }) print(f"\n--- Metrics ---") print(f" Val loss: {val_loss:.4f} (ppl={val_ppl:.2f})") print(f" Train loss: {train_loss:.4f} (ppl={train_ppl:.2f})") print(f" Train-Val gap: {gap:+.4f}") print(f" Inference: {speed:.1f} seq/s") print(f"\n--- Ternary Distribution ---") for name, s in stats.items(): short = name.replace(".weight", "") print(f" {short:40s} +{s['pos']:.3f} -{s['neg']:.3f} 0={s['zero']:.3f} S={s['s_mean']:.4f}±{s['s_std']:.4f}") print(f"\n--- Sample (temp=0.8, top_k=40) ---") for line in sample_text.split("\n")[:8]: print(f" {line}") if len(sample_text.split("\n")) > 8: print(f" ... ({len(sample_text)} chars total)") del model if DEVICE == "cuda": torch.cuda.empty_cache() print(f"\n\n{'=' * 80}") print(f"COMPARISON TABLE") print(f"{'=' * 80}") print(f"{'Checkpoint':<20s} {'Val Loss':>10s} {'Val PPL':>10s} {'Train Loss':>11s} {'Gap':>8s} {'Speed':>10s}") print(f"{'-'*20} {'-'*10} {'-'*10} {'-'*11} {'-'*8} {'-'*10}") for r in results: print(f"{r['label']:<20s} {r['val_loss']:>10.4f} {r['val_ppl']:>10.2f} {r['train_loss']:>11.4f} {r['gap']:>+8.4f} {r['speed']:>9.1f}/s") best = min(results, key=lambda r: r["val_loss"]) print(f"\nBest checkpoint: {best['label']} (val_loss={best['val_loss']:.4f}, ppl={best['val_ppl']:.2f})") if __name__ == "__main__": main()