| """
|
| LUNA 100M β Validate Pretrained + Quantization Benchmark
|
| =========================================================
|
| 1. Load pretrained base model (latest.pt β auto-downloads from HF)
|
| 2. Run eval prompts with the base (F32) model
|
| 3. Simulate quantisation at each level (F16, Q8_0, Q4_K_M) IN PYTORCH
|
| 4. Run the SAME eval prompts with each quantised copy
|
| 5. Compute precision metrics (cosine-sim of logits, perplexity delta)
|
| 6. Export all GGUF files
|
| 7. Print comparison report + pick the best quantisation
|
|
|
| Usage:
|
| python validate_and_quantize.py
|
| python validate_and_quantize.py --ckpt Base/out/pretrain/luna_100m/latest.pt
|
| python validate_and_quantize.py --skip_gguf # skip GGUF export
|
| """
|
|
|
| import os, sys, copy, math, json, argparse, struct, time
|
| import numpy as np
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from pathlib import Path
|
|
|
|
|
|
|
| class RotaryEmbedding(nn.Module):
|
| def __init__(self, dim, max_seq_len=1024):
|
| super().__init__()
|
| inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
| self.register_buffer("inv_freq", inv_freq)
|
| t = torch.arange(max_seq_len).float()
|
| freqs = torch.einsum("i,j->ij", t, inv_freq)
|
| emb = torch.cat([freqs, freqs], dim=-1)
|
| self.register_buffer("cos_cached", emb.cos())
|
| self.register_buffer("sin_cached", emb.sin())
|
|
|
| def forward(self, seq_len):
|
| return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
|
|
|
| def rotate_half(x):
|
| x1, x2 = x.chunk(2, dim=-1)
|
| return torch.cat([-x2, x1], dim=-1)
|
|
|
| def apply_rotary(x, cos, sin):
|
| c = cos.unsqueeze(0).unsqueeze(0)
|
| s = sin.unsqueeze(0).unsqueeze(0)
|
| return x * c + rotate_half(x) * s
|
|
|
| class CausalSelfAttention(nn.Module):
|
| def __init__(self, n_embd, n_head, block_size, rotary_pct=0.25):
|
| super().__init__()
|
| self.n_head = n_head
|
| self.head_dim = n_embd // n_head
|
| self.rot_dim = int(self.head_dim * rotary_pct)
|
| self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=True)
|
| self.c_proj = nn.Linear(n_embd, n_embd, bias=True)
|
| self.rotary = RotaryEmbedding(self.rot_dim, block_size)
|
|
|
| def forward(self, x):
|
| B, T, C = x.size()
|
| qkv = self.c_attn(x).reshape(B, T, 3, self.n_head, self.head_dim).permute(2, 0, 3, 1, 4)
|
| q, k, v = qkv.unbind(0)
|
| cos, sin = self.rotary(T)
|
| q = torch.cat([apply_rotary(q[..., :self.rot_dim], cos, sin), q[..., self.rot_dim:]], dim=-1)
|
| k = torch.cat([apply_rotary(k[..., :self.rot_dim], cos, sin), k[..., self.rot_dim:]], dim=-1)
|
| y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
|
| return self.c_proj(y.transpose(1, 2).contiguous().view(B, T, C))
|
|
|
| class MLP(nn.Module):
|
| def __init__(self, n_embd):
|
| super().__init__()
|
| self.fc = nn.Linear(n_embd, 4 * n_embd, bias=True)
|
| self.gelu = nn.GELU()
|
| self.proj = nn.Linear(4 * n_embd, n_embd, bias=True)
|
| def forward(self, x):
|
| return self.proj(self.gelu(self.fc(x)))
|
|
|
| class Block(nn.Module):
|
| def __init__(self, n_embd, n_head, block_size):
|
| super().__init__()
|
| self.ln1 = nn.LayerNorm(n_embd)
|
| self.attn = CausalSelfAttention(n_embd, n_head, block_size)
|
| self.ln2 = nn.LayerNorm(n_embd)
|
| self.mlp = MLP(n_embd)
|
| def forward(self, x):
|
| x = x + self.attn(self.ln1(x))
|
| x = x + self.mlp(self.ln2(x))
|
| return x
|
|
|
| class LUNAModel(nn.Module):
|
| def __init__(self, vocab_size, block_size, n_layer, n_embd, n_head):
|
| super().__init__()
|
| self.block_size = block_size
|
| self.wte = nn.Embedding(vocab_size, n_embd)
|
| self.blocks = nn.ModuleList([Block(n_embd, n_head, block_size) for _ in range(n_layer)])
|
| self.ln_f = nn.LayerNorm(n_embd)
|
| self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
|
| self.lm_head.weight = self.wte.weight
|
| def forward(self, idx):
|
| x = self.wte(idx)
|
| for block in self.blocks:
|
| x = block(x)
|
| x = self.ln_f(x)
|
| return self.lm_head(x)
|
| @property
|
| def num_params(self):
|
| return sum(p.numel() for p in self.parameters()) - self.wte.weight.numel()
|
|
|
|
|
|
|
|
|
| BLOCK_SIZE = 32
|
|
|
| def _sim_q8_0(tensor: torch.Tensor) -> torch.Tensor:
|
| """Simulate Q8_0: blockwise int8 quantise β dequantise."""
|
| orig_shape = tensor.shape
|
| flat = tensor.flatten().float()
|
| pad = (-len(flat)) % BLOCK_SIZE
|
| if pad:
|
| flat = F.pad(flat, (0, pad))
|
| blocks = flat.view(-1, BLOCK_SIZE)
|
| scales = blocks.abs().max(dim=1, keepdim=True).values / 127.0
|
| scales = scales.clamp(min=1e-8)
|
| q = (blocks / scales).round().clamp(-128, 127)
|
| deq = (q * scales).flatten()[:tensor.numel()]
|
| return deq.view(orig_shape).to(tensor.dtype)
|
|
|
| def _sim_q4_k_m(tensor: torch.Tensor) -> torch.Tensor:
|
| """Simulate Q4_K_M: blockwise 4-bit quantise β dequantise."""
|
| orig_shape = tensor.shape
|
| flat = tensor.flatten().float()
|
| pad = (-len(flat)) % BLOCK_SIZE
|
| if pad:
|
| flat = F.pad(flat, (0, pad))
|
| blocks = flat.view(-1, BLOCK_SIZE)
|
| abs_max = blocks.abs().max(dim=1, keepdim=True).values
|
| scales = abs_max / 7.0
|
| scales = scales.clamp(min=1e-8)
|
| q = ((blocks / scales) + 8).round().clamp(0, 15)
|
| deq = ((q - 8) * scales).flatten()[:tensor.numel()]
|
| return deq.view(orig_shape).to(tensor.dtype)
|
|
|
|
|
| _QUANT_PARAM_SUFFIXES = (".weight",)
|
| _SKIP_QUANT = ("ln1.", "ln2.", "ln_f.")
|
|
|
| def apply_simulated_quant(model: LUNAModel, quant: str):
|
| """Apply simulated quantisation to model weights (in-place). Returns model."""
|
| if quant == "F32":
|
| return model
|
| for name, p in model.named_parameters():
|
| if not any(name.endswith(s) for s in _QUANT_PARAM_SUFFIXES):
|
| continue
|
| if any(skip in name for skip in _SKIP_QUANT):
|
| continue
|
| if quant == "F16":
|
| p.data = p.data.half().float()
|
| elif quant == "Q8_0":
|
| p.data = _sim_q8_0(p.data)
|
| elif quant == "Q4_K_M":
|
| p.data = _sim_q4_k_m(p.data)
|
| return model
|
|
|
|
|
|
|
|
|
| @torch.no_grad()
|
| def generate(model, input_ids, max_new_tokens=100, temperature=0.7, top_k=40):
|
| """Greedy/sampling generation."""
|
| device = input_ids.device
|
| for _ in range(max_new_tokens):
|
| idx_cond = input_ids[:, -model.block_size:]
|
| logits = model(idx_cond)
|
| logits = logits[:, -1, :] / max(temperature, 1e-8)
|
| if top_k > 0:
|
| v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| logits[logits < v[:, [-1]]] = float("-inf")
|
| probs = F.softmax(logits, dim=-1)
|
| nxt = torch.multinomial(probs, num_samples=1)
|
| input_ids = torch.cat([input_ids, nxt], dim=1)
|
| if nxt.item() == 0:
|
| break
|
| return input_ids
|
|
|
| @torch.no_grad()
|
| def get_logits(model, input_ids):
|
| """Get full logits for a sequence (for precision comparison)."""
|
| return model(input_ids[:, -model.block_size:])
|
|
|
| @torch.no_grad()
|
| def compute_perplexity(model, input_ids):
|
| """Compute perplexity of the model on a token sequence."""
|
| if input_ids.size(1) < 2:
|
| return float("inf")
|
| logits = model(input_ids[:, -model.block_size:])
|
| shift_logits = logits[:, :-1, :].contiguous()
|
| shift_labels = input_ids[:, 1:].contiguous()
|
| loss = F.cross_entropy(
|
| shift_logits.view(-1, shift_logits.size(-1)),
|
| shift_labels.view(-1)
|
| )
|
| return math.exp(loss.item())
|
|
|
|
|
|
|
|
|
| EVAL_PROMPTS = [
|
|
|
| "Who are you?",
|
| "Who created you?",
|
| "What is your name?",
|
|
|
| "The capital of France is",
|
| "Water boils at a temperature of",
|
| "The largest planet in our solar system is",
|
| "Albert Einstein is famous for",
|
|
|
| "The quick brown fox jumps over the lazy",
|
| "In a groundbreaking study, researchers found that",
|
| "The most important thing about education is",
|
| "Once upon a time, in a land far away,",
|
| "The future of artificial intelligence will",
|
|
|
| "If it rains tomorrow, I will",
|
| "She went to the store because she needed to buy",
|
| "The difference between a cat and a dog is that",
|
| ]
|
|
|
|
|
| PERPLEXITY_TEXTS = [
|
| "The quick brown fox jumps over the lazy dog and then runs into the forest.",
|
| "Artificial intelligence has transformed the way we interact with technology in recent years.",
|
| "Education is the most powerful weapon which you can use to change the world.",
|
| "The sun rises in the east and sets in the west, a cycle that has continued for billions of years.",
|
| "Water is composed of two hydrogen atoms and one oxygen atom, making it essential for all life.",
|
| ]
|
|
|
|
|
|
|
|
|
| def main():
|
| parser = argparse.ArgumentParser(description="LUNA 100M β Validate & Quantize Benchmark")
|
| parser.add_argument("--ckpt", default="Base/out/pretrain/luna_100m/latest.pt",
|
| help="Path to latest.pt checkpoint")
|
| parser.add_argument("--hf_repo", default="ASTERIZER/LUNA-100M",
|
| help="HF model repo to download from if ckpt not found")
|
| parser.add_argument("--tok_dir", default="Base/checkpoints/EleutherAI/pythia-160m",
|
| help="Tokenizer directory")
|
| parser.add_argument("--max_tokens", type=int, default=80,
|
| help="Max tokens to generate per prompt")
|
| parser.add_argument("--temperature", type=float, default=0.7)
|
| parser.add_argument("--top_k", type=int, default=40)
|
| parser.add_argument("--skip_gguf", action="store_true",
|
| help="Skip GGUF export (just do the PyTorch comparison)")
|
| args = parser.parse_args()
|
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| print(f"\n{'='*70}")
|
| print(f" LUNA 100M β Validate & Quantize Benchmark")
|
| print(f" Device: {device}")
|
| print(f"{'='*70}")
|
|
|
|
|
| from transformers import AutoTokenizer
|
| tok = AutoTokenizer.from_pretrained(args.tok_dir)
|
| print(f"\n Tokenizer: {args.tok_dir} (vocab={tok.vocab_size})")
|
|
|
|
|
| ckpt_path = Path(args.ckpt)
|
| if not ckpt_path.exists():
|
| print(f"\n Checkpoint not found locally: {ckpt_path}")
|
| print(f" Downloading from HuggingFace: {args.hf_repo}")
|
| from huggingface_hub import hf_hub_download
|
| ckpt_path.parent.mkdir(parents=True, exist_ok=True)
|
| hf_hub_download(
|
| repo_id=args.hf_repo,
|
| filename="latest.pt",
|
| local_dir=str(ckpt_path.parent),
|
| token=os.environ.get("HF_TOKEN"),
|
| )
|
| print(f" Downloaded to: {ckpt_path}")
|
|
|
| print(f"\n Loading checkpoint: {ckpt_path}")
|
| ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
|
|
| if isinstance(ckpt, dict) and "model" in ckpt:
|
| state = ckpt["model"]
|
| step = ckpt.get("step", "?")
|
| tokens_seen = ckpt.get("tokens_seen", 0)
|
| else:
|
| state = ckpt
|
| step = "final"
|
| tokens_seen = 0
|
| print(f" Pretrained @ step {step}, tokens seen: {tokens_seen:,}")
|
|
|
|
|
| model = LUNAModel(
|
| vocab_size=50304, block_size=1024,
|
| n_layer=10, n_embd=768, n_head=12,
|
| )
|
| model.load_state_dict(state, strict=True)
|
| model = model.to(device).eval()
|
| print(f" Parameters: {model.num_params:,}")
|
| del ckpt, state
|
|
|
|
|
| original_sd = {k: v.clone() for k, v in model.state_dict().items()}
|
|
|
|
|
| quant_levels = ["F32", "F16", "Q8_0", "Q4_K_M"]
|
| all_results = {}
|
| all_ppls = {}
|
| logit_cosine = {}
|
| base_logits = {}
|
|
|
| for qi, quant in enumerate(quant_levels):
|
|
|
| model.load_state_dict(original_sd, strict=True)
|
|
|
|
|
| apply_simulated_quant(model, quant)
|
|
|
| print(f"\n{'='*70}")
|
| print(f" [{qi+1}/{len(quant_levels)}] {quant}")
|
| print(f"{'='*70}")
|
|
|
|
|
| results = {}
|
| cosines = []
|
|
|
| for prompt in EVAL_PROMPTS:
|
| ids = tok.encode(prompt, return_tensors="pt").to(device)
|
| out_ids = generate(model, ids, max_new_tokens=args.max_tokens,
|
| temperature=args.temperature, top_k=args.top_k)
|
| text = tok.decode(out_ids[0], skip_special_tokens=True)
|
| results[prompt] = text
|
|
|
|
|
| cur_logits = get_logits(model, ids)
|
| if quant == "F32":
|
| base_logits[prompt] = cur_logits.cpu()
|
| else:
|
| bl = base_logits[prompt].to(device)
|
| min_len = min(cur_logits.size(1), bl.size(1))
|
| cos = F.cosine_similarity(
|
| cur_logits[:, :min_len, :].flatten().unsqueeze(0),
|
| bl[:, :min_len, :].flatten().unsqueeze(0),
|
| ).item()
|
| cosines.append(cos)
|
|
|
| print(f"\n Prompt: \"{prompt}\"")
|
| print(f" Output: {text}")
|
|
|
| all_results[quant] = results
|
|
|
|
|
| ppls = []
|
| for ref in PERPLEXITY_TEXTS:
|
| ref_ids = tok.encode(ref, return_tensors="pt").to(device)
|
| ppl = compute_perplexity(model, ref_ids)
|
| ppls.append(ppl)
|
| avg_ppl = sum(ppls) / len(ppls)
|
| all_ppls[quant] = avg_ppl
|
| print(f"\n Avg Perplexity: {avg_ppl:.2f}")
|
|
|
| if cosines:
|
| avg_cos = sum(cosines) / len(cosines)
|
| logit_cosine[quant] = avg_cos
|
| print(f" Logit Cosine Sim vs F32: {avg_cos:.6f}")
|
|
|
|
|
| print(f"\n\n{'='*70}")
|
| print(f" QUANTISATION COMPARISON REPORT")
|
| print(f"{'='*70}")
|
| print(f"\n {'Quant':<10} {'Avg PPL':>10} {'Cosine vs F32':>15} {'PPL Delta':>12}")
|
| print(f" {'-'*50}")
|
|
|
| base_ppl = all_ppls["F32"]
|
| scores = {}
|
| for quant in quant_levels:
|
| ppl = all_ppls[quant]
|
| cos = logit_cosine.get(quant, 1.0)
|
| delta = ppl - base_ppl
|
| scores[quant] = (cos, delta)
|
| cos_str = f"{cos:.6f}" if quant != "F32" else "1.000000 (ref)"
|
| delta_str = f"+{delta:.2f}" if delta >= 0 else f"{delta:.2f}"
|
| if quant == "F32":
|
| delta_str = "β (ref)"
|
| print(f" {quant:<10} {ppl:>10.2f} {cos_str:>15} {delta_str:>12}")
|
|
|
|
|
| best_quant = None
|
| best_score = -1
|
| for q in ["F16", "Q8_0", "Q4_K_M"]:
|
| cos, delta = scores[q]
|
|
|
| score = cos - (abs(delta) / max(base_ppl, 1)) * 0.1
|
| if score > best_score:
|
| best_score = score
|
| best_quant = q
|
|
|
| print(f"\n Best quantisation: {best_quant}")
|
| print(f" (highest logit fidelity with minimal perplexity increase)")
|
|
|
|
|
| print(f"\n\n{'='*70}")
|
| print(f" SIDE-BY-SIDE: F32 (base) vs {best_quant}")
|
| print(f"{'='*70}")
|
| for prompt in EVAL_PROMPTS:
|
| f32_out = all_results["F32"][prompt]
|
| best_out = all_results[best_quant][prompt]
|
| match = "MATCH" if f32_out.strip() == best_out.strip() else "DIFFER"
|
| print(f"\n Prompt: \"{prompt}\"")
|
| print(f" F32 : {f32_out}")
|
| print(f" {best_quant:<5}: {best_out}")
|
| print(f" [{match}]")
|
|
|
|
|
| print(f"\n\n{'='*70}")
|
| print(f" ENGLISH UNDERSTANDING VALIDATION")
|
| print(f"{'='*70}")
|
|
|
| english_tests = [
|
| ("Completion", "The capital of the United Kingdom is"),
|
| ("Grammar", "She has been working at the company for five"),
|
| ("Reasoning", "If a train travels at 60 miles per hour for 2 hours, it covers"),
|
| ("Vocab", "The opposite of hot is"),
|
| ("Context", "Doctors work in hospitals, and teachers work in"),
|
| ("Fluency", "In the year 2025, technology has advanced to the point where"),
|
| ]
|
|
|
| for quant_test in ["F32", best_quant]:
|
| model.load_state_dict(original_sd, strict=True)
|
| apply_simulated_quant(model, quant_test)
|
| print(f"\n --- {quant_test} ---")
|
| for label, prompt in english_tests:
|
| ids = tok.encode(prompt, return_tensors="pt").to(device)
|
| out_ids = generate(model, ids, max_new_tokens=50,
|
| temperature=0.3, top_k=10)
|
| text = tok.decode(out_ids[0], skip_special_tokens=True)
|
| print(f" [{label:>10}] {text}")
|
|
|
|
|
| if not args.skip_gguf:
|
| print(f"\n\n{'='*70}")
|
| print(f" EXPORTING GGUF FILES")
|
| print(f"{'='*70}")
|
| gguf_script = Path("quantisations/convert_to_gguf.py")
|
| if gguf_script.exists():
|
| import subprocess
|
| cmd = [
|
| sys.executable, str(gguf_script),
|
| "--ckpt", str(args.ckpt),
|
| "--tok_dir", str(args.tok_dir),
|
| "--quant", "all",
|
| ]
|
| print(f" Running: {' '.join(cmd)}")
|
| subprocess.run(cmd, check=True)
|
| else:
|
| print(f" WARNING: {gguf_script} not found β skipping GGUF export")
|
| else:
|
| print(f"\n (GGUF export skipped)")
|
|
|
|
|
| print(f"\n\n{'='*70}")
|
| print(f" FINAL SUMMARY")
|
| print(f"{'='*70}")
|
| print(f" Pretrained step: {step} | Tokens seen: {tokens_seen:,}")
|
| print(f" Base F32 perplexity: {base_ppl:.2f}")
|
| print(f" Best quantisation: {best_quant}")
|
| print(f" Cosine similarity vs F32: {logit_cosine.get(best_quant, 1.0):.6f}")
|
| print(f" Perplexity: {all_ppls[best_quant]:.2f} (Ξ {all_ppls[best_quant] - base_ppl:+.2f})")
|
| print(f"\n Recommendation:")
|
| print(f" Use {best_quant} for deployment β best precision/size tradeoff.")
|
| if not args.skip_gguf:
|
| print(f" GGUF file: quantisations/LUNA-100M-{best_quant}.gguf")
|
| print(f"\n{'='*70}")
|
| print(f" Done!")
|
| print(f"{'='*70}\n")
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|