| import os |
| import sys |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" |
|
|
| import torch |
| import torch.nn.functional as F |
| import urllib.request |
| import sys |
| import time |
| import math |
|
|
| from trigram import ( |
| VOCAB, EMBEDDING_DIM, HIDDEN_DIM, FFN_HIDDEN, CTX, THRESHOLD, |
| SPECIAL_VOCAB, MORPHTernaryModel, StickyZoneSTE, |
| ) |
|
|
| CKPT_DIR = os.path.join(os.path.dirname(__file__) or ".", "runs", "ternary-v1") |
| BATCH_SIZE = 1024 |
| CTX = 66 |
| EVAL_STEPS = 500 |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
| def download_data(data_dir): |
| path = os.path.join(data_dir, "tinyshakespeare.txt") |
| if not os.path.exists(path): |
| print("Downloading tinyshakespeare...") |
| urllib.request.urlretrieve( |
| "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt", |
| path, |
| ) |
| with open(path, "r", encoding="utf-8") as f: |
| text = f.read() |
| byte_data = torch.tensor(list(text.encode("utf-8")), dtype=torch.long) |
| n = int(0.9 * len(byte_data)) |
| return byte_data[:n], byte_data[n:] |
|
|
|
|
| def get_batch(data, batch_size, ctx, device): |
| ix = torch.randint(0, len(data) - ctx - 1, (batch_size,)) |
| x = torch.stack([data[i : i + ctx] for i in ix]) |
| targets = x[:, 3:] |
| return x.to(device, non_blocking=True), targets.to(device, non_blocking=True) |
|
|
|
|
| @torch.no_grad() |
| def evaluate(model, val_data): |
| model.eval() |
| losses = [] |
| for _ in range(EVAL_STEPS): |
| x, targets = get_batch(val_data, batch_size=BATCH_SIZE, ctx=CTX, device=DEVICE) |
| with torch.autocast("cuda", dtype=torch.bfloat16): |
| _, loss = model(x, targets=targets) |
| losses.append(loss.item()) |
| return sum(losses) / len(losses) |
|
|
|
|
| @torch.no_grad() |
| def evaluate_train(model, train_data, n_steps=200): |
| model.eval() |
| losses = [] |
| for _ in range(n_steps): |
| x, targets = get_batch(train_data, batch_size=BATCH_SIZE, ctx=CTX, device=DEVICE) |
| with torch.autocast("cuda", dtype=torch.bfloat16): |
| _, loss = model(x, targets=targets) |
| losses.append(loss.item()) |
| return sum(losses) / len(losses) |
|
|
|
|
| @torch.no_grad() |
| def ternary_distribution(model): |
| stats = {} |
| for name, param in model.named_parameters(): |
| if "weight" in name and param.ndim >= 2 and "embed" not in name: |
| T = StickyZoneSTE.apply(param, THRESHOLD) |
| frac_pos = (T > 0).float().mean().item() |
| frac_neg = (T < 0).float().mean().item() |
| frac_zero = (T == 0).float().mean().item() |
| s_mean = param.abs().mean().item() |
| s_std = param.abs().std().item() |
| stats[name] = { |
| "pos": frac_pos, "neg": frac_neg, "zero": frac_zero, |
| "s_mean": s_mean, "s_std": s_std, |
| } |
| return stats |
|
|
|
|
| @torch.no_grad() |
| def generate_sample(model, seed_bytes, max_new_tokens=200, temperature=0.8, top_k=40): |
| model.eval() |
| idx = torch.tensor([seed_bytes], dtype=torch.long, device=DEVICE) |
| for _ in range(max_new_tokens): |
| idx_cond = idx[:, -CTX:] |
| with torch.autocast("cuda", dtype=torch.bfloat16): |
| logits, _ = model(idx_cond) |
| last_logits = logits[:, -1, :] / temperature |
| if top_k is not None: |
| v, _ = torch.topk(last_logits, top_k) |
| last_logits[last_logits < v[:, [-1]]] = float("-inf") |
| probs = F.softmax(last_logits, dim=-1) |
| idx_next = torch.multinomial(probs, num_samples=1) |
| idx = torch.cat([idx, idx_next], dim=1) |
| return idx[0].cpu().tolist() |
|
|
|
|
| def bytes_to_text(byte_list): |
| readable = [] |
| for b in byte_list: |
| if 32 <= b < 127: |
| readable.append(chr(b)) |
| elif b == 10: |
| readable.append("\n") |
| elif b == 13: |
| readable.append("") |
| elif b == 9: |
| readable.append("\t") |
| elif b >= 256: |
| readable.append(f"<{b}>") |
| else: |
| readable.append(f"\\x{b:02x}") |
| return "".join(readable) |
|
|
|
|
| @torch.no_grad() |
| def measure_inference_speed(model, n_steps=100): |
| model.eval() |
| x = torch.randint(0, VOCAB, (1, CTX), device=DEVICE) |
| with torch.autocast("cuda", dtype=torch.bfloat16): |
| for _ in range(10): |
| model(x) |
| if DEVICE == "cuda": |
| torch.cuda.synchronize() |
| t0 = time.perf_counter() |
| for _ in range(n_steps): |
| model(x) |
| if DEVICE == "cuda": |
| torch.cuda.synchronize() |
| t1 = time.perf_counter() |
| return n_steps / (t1 - t0) |
|
|
|
|
| def perplexity(loss): |
| return math.exp(loss) |
|
|
|
|
| def main(): |
| print(f"Device: {DEVICE}") |
| print(f"Eval: {EVAL_STEPS} batches x {BATCH_SIZE} samples, ctx={CTX}") |
| print("=" * 80) |
|
|
| data_dir = os.path.dirname(__file__) or "." |
| train_data, val_data = download_data(data_dir) |
| print(f"Data: train={len(train_data):,} bytes | val={len(val_data):,} bytes\n") |
|
|
| seed_text = "ROMEO:\nWhat light through yonder window breaks?\n" |
| seed_bytes = list(seed_text.encode("utf-8")) |
|
|
| checkpoints = [ |
| ("init (random)", None), |
| ("step5000", os.path.join(CKPT_DIR, "trigram-morph-step5000.pt")), |
| ("best (step7K)", os.path.join(CKPT_DIR, "trigram-morph-best.pt")), |
| ("step13000", os.path.join(CKPT_DIR, "trigram-morph-step13000.pt")), |
| ("step25000", os.path.join(CKPT_DIR, "trigram-morph-step25000.pt")), |
| ] |
|
|
| results = [] |
|
|
| for label, path in checkpoints: |
| print(f"\n{'=' * 80}") |
| print(f"CHECKPOINT: {label}") |
| print(f"{'=' * 80}") |
|
|
| model = MORPHTernaryModel().to(DEVICE) |
|
|
| if path is not None and os.path.exists(path): |
| ckpt = torch.load(path, map_location=DEVICE, weights_only=False) |
| model.load_state_dict(ckpt["model_state_dict"]) |
| print(f"Loaded: {path}") |
| elif path is not None: |
| print(f"MISSING: {path} — skipping") |
| del model |
| continue |
| else: |
| print("Init model (random weights, no training)") |
|
|
| total_params = sum(p.numel() for p in model.parameters()) |
| ternary_params = sum( |
| p.numel() for n, p in model.named_parameters() |
| if "weight" in n and p.ndim >= 2 and "embed" not in n |
| ) |
| fp32_params = total_params - ternary_params |
| eff_bpw = (fp32_params * 32 + ternary_params * 1.58) / total_params |
|
|
| print(f"Params: {total_params:,} | ternary: {ternary_params:,} | fp32: {fp32_params:,} | BPW: {eff_bpw:.2f}") |
|
|
| t0 = time.perf_counter() |
| val_loss = evaluate(model, val_data) |
| t_val = time.perf_counter() - t0 |
| val_ppl = perplexity(val_loss) |
|
|
| t0 = time.perf_counter() |
| train_loss = evaluate_train(model, train_data) |
| t_train = time.perf_counter() - t0 |
| train_ppl = perplexity(train_loss) |
|
|
| gap = train_loss - val_loss |
|
|
| speed = measure_inference_speed(model) |
|
|
| stats = ternary_distribution(model) |
|
|
| sample_tokens = generate_sample(model, seed_bytes, max_new_tokens=150, temperature=0.8) |
| sample_text = bytes_to_text(sample_tokens) |
|
|
| results.append({ |
| "label": label, |
| "val_loss": val_loss, |
| "val_ppl": val_ppl, |
| "train_loss": train_loss, |
| "train_ppl": train_ppl, |
| "gap": gap, |
| "speed": speed, |
| "stats": stats, |
| "sample": sample_text, |
| }) |
|
|
| print(f"\n--- Metrics ---") |
| print(f" Val loss: {val_loss:.4f} (ppl={val_ppl:.2f})") |
| print(f" Train loss: {train_loss:.4f} (ppl={train_ppl:.2f})") |
| print(f" Train-Val gap: {gap:+.4f}") |
| print(f" Inference: {speed:.1f} seq/s") |
| print(f"\n--- Ternary Distribution ---") |
| for name, s in stats.items(): |
| short = name.replace(".weight", "") |
| print(f" {short:40s} +{s['pos']:.3f} -{s['neg']:.3f} 0={s['zero']:.3f} S={s['s_mean']:.4f}±{s['s_std']:.4f}") |
| print(f"\n--- Sample (temp=0.8, top_k=40) ---") |
| for line in sample_text.split("\n")[:8]: |
| print(f" {line}") |
| if len(sample_text.split("\n")) > 8: |
| print(f" ... ({len(sample_text)} chars total)") |
|
|
| del model |
| if DEVICE == "cuda": |
| torch.cuda.empty_cache() |
|
|
| print(f"\n\n{'=' * 80}") |
| print(f"COMPARISON TABLE") |
| print(f"{'=' * 80}") |
| print(f"{'Checkpoint':<20s} {'Val Loss':>10s} {'Val PPL':>10s} {'Train Loss':>11s} {'Gap':>8s} {'Speed':>10s}") |
| print(f"{'-'*20} {'-'*10} {'-'*10} {'-'*11} {'-'*8} {'-'*10}") |
| for r in results: |
| print(f"{r['label']:<20s} {r['val_loss']:>10.4f} {r['val_ppl']:>10.2f} {r['train_loss']:>11.4f} {r['gap']:>+8.4f} {r['speed']:>9.1f}/s") |
|
|
| best = min(results, key=lambda r: r["val_loss"]) |
| print(f"\nBest checkpoint: {best['label']} (val_loss={best['val_loss']:.4f}, ppl={best['val_ppl']:.2f})") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|