ARBS / testing /eval /eval_checkpoints.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
import torch.nn.functional as F
import urllib.request
import sys
import time
import math
from trigram import (
VOCAB, EMBEDDING_DIM, HIDDEN_DIM, FFN_HIDDEN, CTX, THRESHOLD,
SPECIAL_VOCAB, MORPHTernaryModel, StickyZoneSTE,
)
CKPT_DIR = os.path.join(os.path.dirname(__file__) or ".", "runs", "ternary-v1")
BATCH_SIZE = 1024
CTX = 66
EVAL_STEPS = 500
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def download_data(data_dir):
path = os.path.join(data_dir, "tinyshakespeare.txt")
if not os.path.exists(path):
print("Downloading tinyshakespeare...")
urllib.request.urlretrieve(
"https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt",
path,
)
with open(path, "r", encoding="utf-8") as f:
text = f.read()
byte_data = torch.tensor(list(text.encode("utf-8")), dtype=torch.long)
n = int(0.9 * len(byte_data))
return byte_data[:n], byte_data[n:]
def get_batch(data, batch_size, ctx, device):
ix = torch.randint(0, len(data) - ctx - 1, (batch_size,))
x = torch.stack([data[i : i + ctx] for i in ix])
targets = x[:, 3:]
return x.to(device, non_blocking=True), targets.to(device, non_blocking=True)
@torch.no_grad()
def evaluate(model, val_data):
model.eval()
losses = []
for _ in range(EVAL_STEPS):
x, targets = get_batch(val_data, batch_size=BATCH_SIZE, ctx=CTX, device=DEVICE)
with torch.autocast("cuda", dtype=torch.bfloat16):
_, loss = model(x, targets=targets)
losses.append(loss.item())
return sum(losses) / len(losses)
@torch.no_grad()
def evaluate_train(model, train_data, n_steps=200):
model.eval()
losses = []
for _ in range(n_steps):
x, targets = get_batch(train_data, batch_size=BATCH_SIZE, ctx=CTX, device=DEVICE)
with torch.autocast("cuda", dtype=torch.bfloat16):
_, loss = model(x, targets=targets)
losses.append(loss.item())
return sum(losses) / len(losses)
@torch.no_grad()
def ternary_distribution(model):
stats = {}
for name, param in model.named_parameters():
if "weight" in name and param.ndim >= 2 and "embed" not in name:
T = StickyZoneSTE.apply(param, THRESHOLD)
frac_pos = (T > 0).float().mean().item()
frac_neg = (T < 0).float().mean().item()
frac_zero = (T == 0).float().mean().item()
s_mean = param.abs().mean().item()
s_std = param.abs().std().item()
stats[name] = {
"pos": frac_pos, "neg": frac_neg, "zero": frac_zero,
"s_mean": s_mean, "s_std": s_std,
}
return stats
@torch.no_grad()
def generate_sample(model, seed_bytes, max_new_tokens=200, temperature=0.8, top_k=40):
model.eval()
idx = torch.tensor([seed_bytes], dtype=torch.long, device=DEVICE)
for _ in range(max_new_tokens):
idx_cond = idx[:, -CTX:]
with torch.autocast("cuda", dtype=torch.bfloat16):
logits, _ = model(idx_cond)
last_logits = logits[:, -1, :] / temperature
if top_k is not None:
v, _ = torch.topk(last_logits, top_k)
last_logits[last_logits < v[:, [-1]]] = float("-inf")
probs = F.softmax(last_logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat([idx, idx_next], dim=1)
return idx[0].cpu().tolist()
def bytes_to_text(byte_list):
readable = []
for b in byte_list:
if 32 <= b < 127:
readable.append(chr(b))
elif b == 10:
readable.append("\n")
elif b == 13:
readable.append("")
elif b == 9:
readable.append("\t")
elif b >= 256:
readable.append(f"<{b}>")
else:
readable.append(f"\\x{b:02x}")
return "".join(readable)
@torch.no_grad()
def measure_inference_speed(model, n_steps=100):
model.eval()
x = torch.randint(0, VOCAB, (1, CTX), device=DEVICE)
with torch.autocast("cuda", dtype=torch.bfloat16):
for _ in range(10):
model(x)
if DEVICE == "cuda":
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(n_steps):
model(x)
if DEVICE == "cuda":
torch.cuda.synchronize()
t1 = time.perf_counter()
return n_steps / (t1 - t0)
def perplexity(loss):
return math.exp(loss)
def main():
print(f"Device: {DEVICE}")
print(f"Eval: {EVAL_STEPS} batches x {BATCH_SIZE} samples, ctx={CTX}")
print("=" * 80)
data_dir = os.path.dirname(__file__) or "."
train_data, val_data = download_data(data_dir)
print(f"Data: train={len(train_data):,} bytes | val={len(val_data):,} bytes\n")
seed_text = "ROMEO:\nWhat light through yonder window breaks?\n"
seed_bytes = list(seed_text.encode("utf-8"))
checkpoints = [
("init (random)", None),
("step5000", os.path.join(CKPT_DIR, "trigram-morph-step5000.pt")),
("best (step7K)", os.path.join(CKPT_DIR, "trigram-morph-best.pt")),
("step13000", os.path.join(CKPT_DIR, "trigram-morph-step13000.pt")),
("step25000", os.path.join(CKPT_DIR, "trigram-morph-step25000.pt")),
]
results = []
for label, path in checkpoints:
print(f"\n{'=' * 80}")
print(f"CHECKPOINT: {label}")
print(f"{'=' * 80}")
model = MORPHTernaryModel().to(DEVICE)
if path is not None and os.path.exists(path):
ckpt = torch.load(path, map_location=DEVICE, weights_only=False)
model.load_state_dict(ckpt["model_state_dict"])
print(f"Loaded: {path}")
elif path is not None:
print(f"MISSING: {path} — skipping")
del model
continue
else:
print("Init model (random weights, no training)")
total_params = sum(p.numel() for p in model.parameters())
ternary_params = sum(
p.numel() for n, p in model.named_parameters()
if "weight" in n and p.ndim >= 2 and "embed" not in n
)
fp32_params = total_params - ternary_params
eff_bpw = (fp32_params * 32 + ternary_params * 1.58) / total_params
print(f"Params: {total_params:,} | ternary: {ternary_params:,} | fp32: {fp32_params:,} | BPW: {eff_bpw:.2f}")
t0 = time.perf_counter()
val_loss = evaluate(model, val_data)
t_val = time.perf_counter() - t0
val_ppl = perplexity(val_loss)
t0 = time.perf_counter()
train_loss = evaluate_train(model, train_data)
t_train = time.perf_counter() - t0
train_ppl = perplexity(train_loss)
gap = train_loss - val_loss
speed = measure_inference_speed(model)
stats = ternary_distribution(model)
sample_tokens = generate_sample(model, seed_bytes, max_new_tokens=150, temperature=0.8)
sample_text = bytes_to_text(sample_tokens)
results.append({
"label": label,
"val_loss": val_loss,
"val_ppl": val_ppl,
"train_loss": train_loss,
"train_ppl": train_ppl,
"gap": gap,
"speed": speed,
"stats": stats,
"sample": sample_text,
})
print(f"\n--- Metrics ---")
print(f" Val loss: {val_loss:.4f} (ppl={val_ppl:.2f})")
print(f" Train loss: {train_loss:.4f} (ppl={train_ppl:.2f})")
print(f" Train-Val gap: {gap:+.4f}")
print(f" Inference: {speed:.1f} seq/s")
print(f"\n--- Ternary Distribution ---")
for name, s in stats.items():
short = name.replace(".weight", "")
print(f" {short:40s} +{s['pos']:.3f} -{s['neg']:.3f} 0={s['zero']:.3f} S={s['s_mean']:.4f}±{s['s_std']:.4f}")
print(f"\n--- Sample (temp=0.8, top_k=40) ---")
for line in sample_text.split("\n")[:8]:
print(f" {line}")
if len(sample_text.split("\n")) > 8:
print(f" ... ({len(sample_text)} chars total)")
del model
if DEVICE == "cuda":
torch.cuda.empty_cache()
print(f"\n\n{'=' * 80}")
print(f"COMPARISON TABLE")
print(f"{'=' * 80}")
print(f"{'Checkpoint':<20s} {'Val Loss':>10s} {'Val PPL':>10s} {'Train Loss':>11s} {'Gap':>8s} {'Speed':>10s}")
print(f"{'-'*20} {'-'*10} {'-'*10} {'-'*11} {'-'*8} {'-'*10}")
for r in results:
print(f"{r['label']:<20s} {r['val_loss']:>10.4f} {r['val_ppl']:>10.2f} {r['train_loss']:>11.4f} {r['gap']:>+8.4f} {r['speed']:>9.1f}/s")
best = min(results, key=lambda r: r["val_loss"])
print(f"\nBest checkpoint: {best['label']} (val_loss={best['val_loss']:.4f}, ppl={best['val_ppl']:.2f})")
if __name__ == "__main__":
main()