ARBS / testing /model /test-stp.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import urllib.request
import time
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from arbitor.main import (
VOCAB, EMBEDDING_DIM, HIDDEN_DIM, FFN_HIDDEN, CTX, THRESHOLD,
SPECIAL_VOCAB, ARBModel, StickyZoneSTE,
save_model, load_model,
)
TRAIN_PARAMS = {
"batch_size": 32,
"ctx": 66,
"lr": 3e-4,
"weight_decay": 0.01,
"max_steps": 5000,
"eval_interval": 500,
"eval_steps": 100,
}
def download_data():
path = os.path.join(os.path.dirname(__file__), "tinyshakespeare.txt")
if not os.path.exists(path):
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
urllib.request.urlretrieve(url, path)
with open(path, "r", encoding="utf-8") as f:
text = f.read()
byte_data = torch.tensor(list(text.encode("utf-8")), dtype=torch.long)
n = int(0.9 * len(byte_data))
return byte_data[:n], byte_data[n:]
def get_batch(data, batch_size, ctx, device):
ix = torch.randint(0, len(data) - ctx - 1, (batch_size,))
x = torch.stack([data[i : i + ctx] for i in ix])
targets = x[:, 3:]
return x.to(device), targets.to(device)
class FP32Linear(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_dim, in_dim) * 0.1)
self.bias = nn.Parameter(torch.zeros(out_dim))
def forward(self, x):
return F.linear(x, self.weight, self.bias)
class FP32RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-8):
super().__init__()
self.scale = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x):
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
return self.scale * (x / rms)
class FP32TrigramModel(nn.Module):
def __init__(self):
super().__init__()
self.embedding = nn.Embedding(VOCAB, EMBEDDING_DIM)
self.embed_norm = FP32RMSNorm(EMBEDDING_DIM)
self.projection = FP32Linear(EMBEDDING_DIM * 3, HIDDEN_DIM)
self.proj_norm = FP32RMSNorm(HIDDEN_DIM)
self.ffn_norm1 = FP32RMSNorm(HIDDEN_DIM)
self.fc1 = FP32Linear(HIDDEN_DIM, FFN_HIDDEN)
self.ffn_norm2 = FP32RMSNorm(FFN_HIDDEN)
self.fc2 = FP32Linear(FFN_HIDDEN, HIDDEN_DIM)
self.head_norm = FP32RMSNorm(HIDDEN_DIM)
self.head = FP32Linear(HIDDEN_DIM, VOCAB)
def forward(self, x, targets=None):
from einops import rearrange
embedded = self.embed_norm(self.embedding(x))
trigrams = embedded.unfold(dimension=1, size=3, step=1)
trigrams = rearrange(trigrams, 'b t d w -> b t (d w)')
relational = self.proj_norm(self.projection(trigrams))
h = self.ffn_norm1(relational)
h = torch.relu(self.fc1(h))
h = self.ffn_norm2(h)
h = self.fc2(h)
logits = self.head(self.head_norm(h))
loss = None
if targets is not None:
next_byte_logits = logits[:, :-1, :].contiguous()
loss = F.cross_entropy(
next_byte_logits.view(-1, VOCAB),
targets.contiguous().view(-1),
ignore_index=SPECIAL_VOCAB["PAD"],
)
return logits, loss
def evaluate(model, val_data, device, eval_steps=100):
model.eval()
losses = []
with torch.no_grad():
for _ in range(eval_steps):
x, targets = get_batch(val_data, TRAIN_PARAMS["batch_size"], TRAIN_PARAMS["ctx"], device)
_, loss = model(x, targets=targets)
losses.append(loss.item())
model.train()
return sum(losses) / len(losses)
def count_params(model):
total = sum(p.numel() for p in model.parameters())
ternary = 0
fp32 = 0
for n, p in model.named_parameters():
if "weight" in n and p.ndim >= 2 and "embed" not in n:
ternary += p.numel()
else:
fp32 += p.numel()
return total, ternary, fp32
def log_diagnostics(model, step, train_loss, val_loss, config_name, is_ternary=False):
print(f"[{config_name}] step {step} | train_loss={train_loss:.4f} | val_loss={val_loss:.4f}")
if is_ternary:
for name, param in model.named_parameters():
if "weight" in name and param.ndim >= 2 and "embed" not in name:
with torch.no_grad():
T = StickyZoneSTE.apply(param, THRESHOLD)
frac_zero = (T == 0).float().mean().item()
frac_pos = (T > 0).float().mean().item()
frac_neg = (T < 0).float().mean().item()
grad_norm = param.grad.norm().item() if param.grad is not None else 0.0
print(f" {name}: +{frac_pos:.2%} -{frac_neg:.2%} 0{frac_zero:.2%} | grad={grad_norm:.6f}")
if frac_zero > 0.95:
print(f" ⚠ COLLAPSE: {name} is >95% zeros")
def train_model(model, train_data, val_data, config_name, device, steps=5000):
optimizer = torch.optim.AdamW(
model.parameters(), lr=TRAIN_PARAMS["lr"], weight_decay=TRAIN_PARAMS["weight_decay"]
)
is_ternary = "Ternary" in config_name
train_losses = []
val_losses = []
step_list = []
step_times = []
for step in range(steps):
t0 = time.perf_counter()
x, targets = get_batch(train_data, TRAIN_PARAMS["batch_size"], TRAIN_PARAMS["ctx"], device)
_, loss = model(x, targets=targets)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
if device == "cuda":
torch.cuda.synchronize()
step_times.append(time.perf_counter() - t0)
if (step + 1) % TRAIN_PARAMS["eval_interval"] == 0:
val_loss = evaluate(model, val_data, device, TRAIN_PARAMS["eval_steps"])
log_diagnostics(model, step + 1, loss.item(), val_loss, config_name, is_ternary)
train_losses.append(loss.item())
val_losses.append(val_loss)
step_list.append(step + 1)
final_val = evaluate(model, val_data, device, TRAIN_PARAMS["eval_steps"])
avg_step_ms = sum(step_times) / len(step_times) * 1000
total_s = sum(step_times)
print(f"[{config_name}] speed: {avg_step_ms:.2f} ms/step | {steps/total_s:.1f} steps/s | total {total_s:.1f}s")
return {
"config": config_name,
"final_train_loss": train_losses[-1] if train_losses else loss.item(),
"final_val_loss": final_val,
"train_losses": train_losses,
"val_losses": val_loss,
"steps": step_list,
"avg_step_ms": avg_step_ms,
"steps_per_sec": steps / total_s,
"total_s": total_s,
"param_count": sum(p.numel() for p in model.parameters()),
}
def analyze_results(results_fp32, results_ternary):
print("\n" + "=" * 80)
print("MORPH TERNARY vs FP32 β€” REAL MODEL BENCHMARK")
print("=" * 80)
print(f"{'Config':<30} {'Val Loss':>9} {'Params':>10} {'ms/step':>9} {'stps/s':>8}")
print("-" * 70)
for r in [results_fp32, results_ternary]:
print(
f"{r['config']:<30} {r['final_val_loss']:>9.4f} "
f"{r['param_count']:>10} {r['avg_step_ms']:>9.2f} {r['steps_per_sec']:>8.1f}"
)
fp32_loss = results_fp32["final_val_loss"]
ternary_loss = results_ternary["final_val_loss"]
fp32_speed = results_fp32["avg_step_ms"]
ternary_speed = results_ternary["avg_step_ms"]
ratio = ternary_loss / fp32_loss
speed_ratio = ternary_speed / fp32_speed
print(f"\n--- Precision ---")
print(f"Ternary / FP32 loss ratio = {ratio:.3f}x")
if ratio <= 1.25:
print("βœ… Ternary within 1.25x FP32 β€” PASS")
else:
print("❌ Ternary exceeds 1.25x FP32 β€” FAIL")
print(f"\n--- Speed ---")
print(f"Ternary vs FP32: {speed_ratio:.2f}x ({ternary_speed:.2f}ms vs {fp32_speed:.2f}ms)")
if speed_ratio <= 1.0:
print("βœ… Ternary is faster or equal to FP32")
elif speed_ratio <= 1.5:
print("⚑ Ternary is slower but within 1.5x β€” expected for unoptimized path")
else:
print("⚠ Ternary significantly slower β€” investigate")
total, ternary_p, fp32_p = count_params(ARBModel())
eff_bpw = (fp32_p * 32 + ternary_p * 1.58) / total
print(f"\n--- Effective BPW ---")
print(f"Ternary params: {ternary_p:,} (1.58 BPW)")
print(f"FP32 params: {fp32_p:,} (32 BPW)")
print(f"Effective BPW: {eff_bpw:.2f} bits/weight (avg across all params)")
print(f"Memory savings: {32.0/eff_bpw:.1f}x vs pure FP32")
def run_all():
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
print("Downloading data...")
train_data, val_data = download_data()
print(f"Train: {len(train_data):,} bytes, Val: {len(val_data):,} bytes")
print("\n--- FP32 Baseline ---")
model_fp32 = FP32TrigramModel().to(device)
total_fp = sum(p.numel() for p in model_fp32.parameters())
print(f"Params: {total_fp:,}")
results_fp32 = train_model(
model_fp32, train_data, val_data, "FP32-Trigram", device, TRAIN_PARAMS["max_steps"]
)
print(f"FP32 final val loss: {results_fp32['final_val_loss']:.4f}")
del model_fp32
torch.cuda.empty_cache()
print("\n--- Ternary (Config E: Factorized Scaled) ---")
model_ternary = ARBModel().to(device)
total_t = sum(p.numel() for p in model_ternary.parameters())
print(f"Params: {total_t:,}")
results_ternary = train_model(
model_ternary, train_data, val_data, "Ternary-ConfigE", device, TRAIN_PARAMS["max_steps"]
)
print(f"Ternary final val loss: {results_ternary['final_val_loss']:.4f}")
save_path = os.path.join(os.path.dirname(__file__), "..", "..", "models", "conversions", "arb-model.pt")
save_model(model_ternary, save_path)
test_model = load_model(save_path, device)
test_x = torch.randint(0, VOCAB, (2, 66), device=device)
with torch.no_grad():
logits, _ = test_model(test_x)
print(f"\nLoaded model test: logits shape = {logits.shape} βœ…")
del model_ternary
torch.cuda.empty_cache()
analyze_results(results_fp32, results_ternary)
if __name__ == "__main__":
run_all()