ASTERIZER commited on
Commit
d134029
Β·
verified Β·
1 Parent(s): 56ea5bf

Upload validate_and_quantize.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. validate_and_quantize.py +484 -0
validate_and_quantize.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LUNA 100M β€” Validate Pretrained + Quantization Benchmark
3
+ =========================================================
4
+ 1. Load pretrained base model (latest.pt β€” auto-downloads from HF)
5
+ 2. Run eval prompts with the base (F32) model
6
+ 3. Simulate quantisation at each level (F16, Q8_0, Q4_K_M) IN PYTORCH
7
+ 4. Run the SAME eval prompts with each quantised copy
8
+ 5. Compute precision metrics (cosine-sim of logits, perplexity delta)
9
+ 6. Export all GGUF files
10
+ 7. Print comparison report + pick the best quantisation
11
+
12
+ Usage:
13
+ python validate_and_quantize.py
14
+ python validate_and_quantize.py --ckpt Base/out/pretrain/luna_100m/latest.pt
15
+ python validate_and_quantize.py --skip_gguf # skip GGUF export
16
+ """
17
+
18
+ import os, sys, copy, math, json, argparse, struct, time
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from pathlib import Path
24
+
25
+ # ─── Model (identical to train.py / sft_train.py) ────────────────────────────
26
+
27
+ class RotaryEmbedding(nn.Module):
28
+ def __init__(self, dim, max_seq_len=1024):
29
+ super().__init__()
30
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
31
+ self.register_buffer("inv_freq", inv_freq)
32
+ t = torch.arange(max_seq_len).float()
33
+ freqs = torch.einsum("i,j->ij", t, inv_freq)
34
+ emb = torch.cat([freqs, freqs], dim=-1)
35
+ self.register_buffer("cos_cached", emb.cos())
36
+ self.register_buffer("sin_cached", emb.sin())
37
+
38
+ def forward(self, seq_len):
39
+ return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
40
+
41
+ def rotate_half(x):
42
+ x1, x2 = x.chunk(2, dim=-1)
43
+ return torch.cat([-x2, x1], dim=-1)
44
+
45
+ def apply_rotary(x, cos, sin):
46
+ c = cos.unsqueeze(0).unsqueeze(0)
47
+ s = sin.unsqueeze(0).unsqueeze(0)
48
+ return x * c + rotate_half(x) * s
49
+
50
+ class CausalSelfAttention(nn.Module):
51
+ def __init__(self, n_embd, n_head, block_size, rotary_pct=0.25):
52
+ super().__init__()
53
+ self.n_head = n_head
54
+ self.head_dim = n_embd // n_head
55
+ self.rot_dim = int(self.head_dim * rotary_pct)
56
+ self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=True)
57
+ self.c_proj = nn.Linear(n_embd, n_embd, bias=True)
58
+ self.rotary = RotaryEmbedding(self.rot_dim, block_size)
59
+
60
+ def forward(self, x):
61
+ B, T, C = x.size()
62
+ qkv = self.c_attn(x).reshape(B, T, 3, self.n_head, self.head_dim).permute(2, 0, 3, 1, 4)
63
+ q, k, v = qkv.unbind(0)
64
+ cos, sin = self.rotary(T)
65
+ q = torch.cat([apply_rotary(q[..., :self.rot_dim], cos, sin), q[..., self.rot_dim:]], dim=-1)
66
+ k = torch.cat([apply_rotary(k[..., :self.rot_dim], cos, sin), k[..., self.rot_dim:]], dim=-1)
67
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
68
+ return self.c_proj(y.transpose(1, 2).contiguous().view(B, T, C))
69
+
70
+ class MLP(nn.Module):
71
+ def __init__(self, n_embd):
72
+ super().__init__()
73
+ self.fc = nn.Linear(n_embd, 4 * n_embd, bias=True)
74
+ self.gelu = nn.GELU()
75
+ self.proj = nn.Linear(4 * n_embd, n_embd, bias=True)
76
+ def forward(self, x):
77
+ return self.proj(self.gelu(self.fc(x)))
78
+
79
+ class Block(nn.Module):
80
+ def __init__(self, n_embd, n_head, block_size):
81
+ super().__init__()
82
+ self.ln1 = nn.LayerNorm(n_embd)
83
+ self.attn = CausalSelfAttention(n_embd, n_head, block_size)
84
+ self.ln2 = nn.LayerNorm(n_embd)
85
+ self.mlp = MLP(n_embd)
86
+ def forward(self, x):
87
+ x = x + self.attn(self.ln1(x))
88
+ x = x + self.mlp(self.ln2(x))
89
+ return x
90
+
91
+ class LUNAModel(nn.Module):
92
+ def __init__(self, vocab_size, block_size, n_layer, n_embd, n_head):
93
+ super().__init__()
94
+ self.block_size = block_size
95
+ self.wte = nn.Embedding(vocab_size, n_embd)
96
+ self.blocks = nn.ModuleList([Block(n_embd, n_head, block_size) for _ in range(n_layer)])
97
+ self.ln_f = nn.LayerNorm(n_embd)
98
+ self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
99
+ self.lm_head.weight = self.wte.weight
100
+ def forward(self, idx):
101
+ x = self.wte(idx)
102
+ for block in self.blocks:
103
+ x = block(x)
104
+ x = self.ln_f(x)
105
+ return self.lm_head(x)
106
+ @property
107
+ def num_params(self):
108
+ return sum(p.numel() for p in self.parameters()) - self.wte.weight.numel()
109
+
110
+
111
+ # ─── Quantise-and-dequantise in PyTorch (simulates precision loss) ────────────
112
+
113
+ BLOCK_SIZE = 32
114
+
115
+ def _sim_q8_0(tensor: torch.Tensor) -> torch.Tensor:
116
+ """Simulate Q8_0: blockwise int8 quantise β†’ dequantise."""
117
+ orig_shape = tensor.shape
118
+ flat = tensor.flatten().float()
119
+ pad = (-len(flat)) % BLOCK_SIZE
120
+ if pad:
121
+ flat = F.pad(flat, (0, pad))
122
+ blocks = flat.view(-1, BLOCK_SIZE)
123
+ scales = blocks.abs().max(dim=1, keepdim=True).values / 127.0
124
+ scales = scales.clamp(min=1e-8)
125
+ q = (blocks / scales).round().clamp(-128, 127)
126
+ deq = (q * scales).flatten()[:tensor.numel()]
127
+ return deq.view(orig_shape).to(tensor.dtype)
128
+
129
+ def _sim_q4_k_m(tensor: torch.Tensor) -> torch.Tensor:
130
+ """Simulate Q4_K_M: blockwise 4-bit quantise β†’ dequantise."""
131
+ orig_shape = tensor.shape
132
+ flat = tensor.flatten().float()
133
+ pad = (-len(flat)) % BLOCK_SIZE
134
+ if pad:
135
+ flat = F.pad(flat, (0, pad))
136
+ blocks = flat.view(-1, BLOCK_SIZE)
137
+ abs_max = blocks.abs().max(dim=1, keepdim=True).values
138
+ scales = abs_max / 7.0
139
+ scales = scales.clamp(min=1e-8)
140
+ q = ((blocks / scales) + 8).round().clamp(0, 15)
141
+ deq = ((q - 8) * scales).flatten()[:tensor.numel()]
142
+ return deq.view(orig_shape).to(tensor.dtype)
143
+
144
+ # Which params get quantised (biases + norms stay F32)
145
+ _QUANT_PARAM_SUFFIXES = (".weight",)
146
+ _SKIP_QUANT = ("ln1.", "ln2.", "ln_f.")
147
+
148
+ def apply_simulated_quant(model: LUNAModel, quant: str):
149
+ """Apply simulated quantisation to model weights (in-place). Returns model."""
150
+ if quant == "F32":
151
+ return model
152
+ for name, p in model.named_parameters():
153
+ if not any(name.endswith(s) for s in _QUANT_PARAM_SUFFIXES):
154
+ continue
155
+ if any(skip in name for skip in _SKIP_QUANT):
156
+ continue
157
+ if quant == "F16":
158
+ p.data = p.data.half().float()
159
+ elif quant == "Q8_0":
160
+ p.data = _sim_q8_0(p.data)
161
+ elif quant == "Q4_K_M":
162
+ p.data = _sim_q4_k_m(p.data)
163
+ return model
164
+
165
+
166
+ # ─── Generation ───────────────────────────────────────────────────────────────
167
+
168
+ @torch.no_grad()
169
+ def generate(model, input_ids, max_new_tokens=100, temperature=0.7, top_k=40):
170
+ """Greedy/sampling generation."""
171
+ device = input_ids.device
172
+ for _ in range(max_new_tokens):
173
+ idx_cond = input_ids[:, -model.block_size:]
174
+ logits = model(idx_cond)
175
+ logits = logits[:, -1, :] / max(temperature, 1e-8)
176
+ if top_k > 0:
177
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
178
+ logits[logits < v[:, [-1]]] = float("-inf")
179
+ probs = F.softmax(logits, dim=-1)
180
+ nxt = torch.multinomial(probs, num_samples=1)
181
+ input_ids = torch.cat([input_ids, nxt], dim=1)
182
+ if nxt.item() == 0: # EOS
183
+ break
184
+ return input_ids
185
+
186
+ @torch.no_grad()
187
+ def get_logits(model, input_ids):
188
+ """Get full logits for a sequence (for precision comparison)."""
189
+ return model(input_ids[:, -model.block_size:])
190
+
191
+ @torch.no_grad()
192
+ def compute_perplexity(model, input_ids):
193
+ """Compute perplexity of the model on a token sequence."""
194
+ if input_ids.size(1) < 2:
195
+ return float("inf")
196
+ logits = model(input_ids[:, -model.block_size:])
197
+ shift_logits = logits[:, :-1, :].contiguous()
198
+ shift_labels = input_ids[:, 1:].contiguous()
199
+ loss = F.cross_entropy(
200
+ shift_logits.view(-1, shift_logits.size(-1)),
201
+ shift_labels.view(-1)
202
+ )
203
+ return math.exp(loss.item())
204
+
205
+
206
+ # ─── Eval prompts ─────────────────────────────────────────────────────────────
207
+
208
+ EVAL_PROMPTS = [
209
+ # Identity
210
+ "Who are you?",
211
+ "Who created you?",
212
+ "What is your name?",
213
+ # Knowledge
214
+ "The capital of France is",
215
+ "Water boils at a temperature of",
216
+ "The largest planet in our solar system is",
217
+ "Albert Einstein is famous for",
218
+ # English comprehension
219
+ "The quick brown fox jumps over the lazy",
220
+ "In a groundbreaking study, researchers found that",
221
+ "The most important thing about education is",
222
+ "Once upon a time, in a land far away,",
223
+ "The future of artificial intelligence will",
224
+ # Reasoning / grammar
225
+ "If it rains tomorrow, I will",
226
+ "She went to the store because she needed to buy",
227
+ "The difference between a cat and a dog is that",
228
+ ]
229
+
230
+ # Reference sentences for perplexity measurement (well-formed English)
231
+ PERPLEXITY_TEXTS = [
232
+ "The quick brown fox jumps over the lazy dog and then runs into the forest.",
233
+ "Artificial intelligence has transformed the way we interact with technology in recent years.",
234
+ "Education is the most powerful weapon which you can use to change the world.",
235
+ "The sun rises in the east and sets in the west, a cycle that has continued for billions of years.",
236
+ "Water is composed of two hydrogen atoms and one oxygen atom, making it essential for all life.",
237
+ ]
238
+
239
+
240
+ # ─── Main ─────────────────────────────────────────────────────────────────────
241
+
242
+ def main():
243
+ parser = argparse.ArgumentParser(description="LUNA 100M β€” Validate & Quantize Benchmark")
244
+ parser.add_argument("--ckpt", default="Base/out/pretrain/luna_100m/latest.pt",
245
+ help="Path to latest.pt checkpoint")
246
+ parser.add_argument("--hf_repo", default="ASTERIZER/LUNA-100M",
247
+ help="HF model repo to download from if ckpt not found")
248
+ parser.add_argument("--tok_dir", default="Base/checkpoints/EleutherAI/pythia-160m",
249
+ help="Tokenizer directory")
250
+ parser.add_argument("--max_tokens", type=int, default=80,
251
+ help="Max tokens to generate per prompt")
252
+ parser.add_argument("--temperature", type=float, default=0.7)
253
+ parser.add_argument("--top_k", type=int, default=40)
254
+ parser.add_argument("--skip_gguf", action="store_true",
255
+ help="Skip GGUF export (just do the PyTorch comparison)")
256
+ args = parser.parse_args()
257
+
258
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
259
+ print(f"\n{'='*70}")
260
+ print(f" LUNA 100M β€” Validate & Quantize Benchmark")
261
+ print(f" Device: {device}")
262
+ print(f"{'='*70}")
263
+
264
+ # ── 1. Load tokenizer ─────────────────────────────────────────────────────
265
+ from transformers import AutoTokenizer
266
+ tok = AutoTokenizer.from_pretrained(args.tok_dir)
267
+ print(f"\n Tokenizer: {args.tok_dir} (vocab={tok.vocab_size})")
268
+
269
+ # ── 2. Load checkpoint ────────────────────────────────────────────────────
270
+ ckpt_path = Path(args.ckpt)
271
+ if not ckpt_path.exists():
272
+ print(f"\n Checkpoint not found locally: {ckpt_path}")
273
+ print(f" Downloading from HuggingFace: {args.hf_repo}")
274
+ from huggingface_hub import hf_hub_download
275
+ ckpt_path.parent.mkdir(parents=True, exist_ok=True)
276
+ hf_hub_download(
277
+ repo_id=args.hf_repo,
278
+ filename="latest.pt",
279
+ local_dir=str(ckpt_path.parent),
280
+ token=os.environ.get("HF_TOKEN"),
281
+ )
282
+ print(f" Downloaded to: {ckpt_path}")
283
+
284
+ print(f"\n Loading checkpoint: {ckpt_path}")
285
+ ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
286
+ # Handle both formats: {"model": sd, "step": ...} or raw state_dict
287
+ if isinstance(ckpt, dict) and "model" in ckpt:
288
+ state = ckpt["model"]
289
+ step = ckpt.get("step", "?")
290
+ tokens_seen = ckpt.get("tokens_seen", 0)
291
+ else:
292
+ state = ckpt
293
+ step = "final"
294
+ tokens_seen = 0
295
+ print(f" Pretrained @ step {step}, tokens seen: {tokens_seen:,}")
296
+
297
+ # ── 3. Build model ────────────────────────────────────────────────────────
298
+ model = LUNAModel(
299
+ vocab_size=50304, block_size=1024,
300
+ n_layer=10, n_embd=768, n_head=12,
301
+ )
302
+ model.load_state_dict(state, strict=True)
303
+ model = model.to(device).eval()
304
+ print(f" Parameters: {model.num_params:,}")
305
+ del ckpt, state
306
+
307
+ # Save original F32 weights for restoring after each quant
308
+ original_sd = {k: v.clone() for k, v in model.state_dict().items()}
309
+
310
+ # ── 4. Run benchmark across all quant levels ──────────────────────────────
311
+ quant_levels = ["F32", "F16", "Q8_0", "Q4_K_M"]
312
+ all_results = {} # quant -> {prompt: generated_text}
313
+ all_ppls = {} # quant -> avg perplexity
314
+ logit_cosine = {} # quant -> avg cosine similarity vs F32
315
+ base_logits = {} # prompt -> F32 logits (for comparison)
316
+
317
+ for qi, quant in enumerate(quant_levels):
318
+ # Restore original weights
319
+ model.load_state_dict(original_sd, strict=True)
320
+
321
+ # Apply simulated quantisation
322
+ apply_simulated_quant(model, quant)
323
+
324
+ print(f"\n{'='*70}")
325
+ print(f" [{qi+1}/{len(quant_levels)}] {quant}")
326
+ print(f"{'='*70}")
327
+
328
+ # ── Generate from eval prompts ────────────────────────────────────────
329
+ results = {}
330
+ cosines = []
331
+
332
+ for prompt in EVAL_PROMPTS:
333
+ ids = tok.encode(prompt, return_tensors="pt").to(device)
334
+ out_ids = generate(model, ids, max_new_tokens=args.max_tokens,
335
+ temperature=args.temperature, top_k=args.top_k)
336
+ text = tok.decode(out_ids[0], skip_special_tokens=True)
337
+ results[prompt] = text
338
+
339
+ # Compute logit similarity vs F32
340
+ cur_logits = get_logits(model, ids)
341
+ if quant == "F32":
342
+ base_logits[prompt] = cur_logits.cpu()
343
+ else:
344
+ bl = base_logits[prompt].to(device)
345
+ min_len = min(cur_logits.size(1), bl.size(1))
346
+ cos = F.cosine_similarity(
347
+ cur_logits[:, :min_len, :].flatten().unsqueeze(0),
348
+ bl[:, :min_len, :].flatten().unsqueeze(0),
349
+ ).item()
350
+ cosines.append(cos)
351
+
352
+ print(f"\n Prompt: \"{prompt}\"")
353
+ print(f" Output: {text}")
354
+
355
+ all_results[quant] = results
356
+
357
+ # ── Perplexity on reference English text ──────────────────────────────
358
+ ppls = []
359
+ for ref in PERPLEXITY_TEXTS:
360
+ ref_ids = tok.encode(ref, return_tensors="pt").to(device)
361
+ ppl = compute_perplexity(model, ref_ids)
362
+ ppls.append(ppl)
363
+ avg_ppl = sum(ppls) / len(ppls)
364
+ all_ppls[quant] = avg_ppl
365
+ print(f"\n Avg Perplexity: {avg_ppl:.2f}")
366
+
367
+ if cosines:
368
+ avg_cos = sum(cosines) / len(cosines)
369
+ logit_cosine[quant] = avg_cos
370
+ print(f" Logit Cosine Sim vs F32: {avg_cos:.6f}")
371
+
372
+ # ── 5. Comparison Report ──────────────────────────────────────────────────
373
+ print(f"\n\n{'='*70}")
374
+ print(f" QUANTISATION COMPARISON REPORT")
375
+ print(f"{'='*70}")
376
+ print(f"\n {'Quant':<10} {'Avg PPL':>10} {'Cosine vs F32':>15} {'PPL Delta':>12}")
377
+ print(f" {'-'*50}")
378
+
379
+ base_ppl = all_ppls["F32"]
380
+ scores = {}
381
+ for quant in quant_levels:
382
+ ppl = all_ppls[quant]
383
+ cos = logit_cosine.get(quant, 1.0)
384
+ delta = ppl - base_ppl
385
+ scores[quant] = (cos, delta)
386
+ cos_str = f"{cos:.6f}" if quant != "F32" else "1.000000 (ref)"
387
+ delta_str = f"+{delta:.2f}" if delta >= 0 else f"{delta:.2f}"
388
+ if quant == "F32":
389
+ delta_str = "β€” (ref)"
390
+ print(f" {quant:<10} {ppl:>10.2f} {cos_str:>15} {delta_str:>12}")
391
+
392
+ # Pick best non-F32 quant
393
+ best_quant = None
394
+ best_score = -1
395
+ for q in ["F16", "Q8_0", "Q4_K_M"]:
396
+ cos, delta = scores[q]
397
+ # Score: high cosine + low ppl delta = good
398
+ score = cos - (abs(delta) / max(base_ppl, 1)) * 0.1
399
+ if score > best_score:
400
+ best_score = score
401
+ best_quant = q
402
+
403
+ print(f"\n Best quantisation: {best_quant}")
404
+ print(f" (highest logit fidelity with minimal perplexity increase)")
405
+
406
+ # ── 6. Side-by-side output comparison ─────────────────────────────────────
407
+ print(f"\n\n{'='*70}")
408
+ print(f" SIDE-BY-SIDE: F32 (base) vs {best_quant}")
409
+ print(f"{'='*70}")
410
+ for prompt in EVAL_PROMPTS:
411
+ f32_out = all_results["F32"][prompt]
412
+ best_out = all_results[best_quant][prompt]
413
+ match = "MATCH" if f32_out.strip() == best_out.strip() else "DIFFER"
414
+ print(f"\n Prompt: \"{prompt}\"")
415
+ print(f" F32 : {f32_out}")
416
+ print(f" {best_quant:<5}: {best_out}")
417
+ print(f" [{match}]")
418
+
419
+ # ── 7. English Understanding Validation ───────────────────────────────────
420
+ print(f"\n\n{'='*70}")
421
+ print(f" ENGLISH UNDERSTANDING VALIDATION")
422
+ print(f"{'='*70}")
423
+
424
+ english_tests = [
425
+ ("Completion", "The capital of the United Kingdom is"),
426
+ ("Grammar", "She has been working at the company for five"),
427
+ ("Reasoning", "If a train travels at 60 miles per hour for 2 hours, it covers"),
428
+ ("Vocab", "The opposite of hot is"),
429
+ ("Context", "Doctors work in hospitals, and teachers work in"),
430
+ ("Fluency", "In the year 2025, technology has advanced to the point where"),
431
+ ]
432
+
433
+ for quant_test in ["F32", best_quant]:
434
+ model.load_state_dict(original_sd, strict=True)
435
+ apply_simulated_quant(model, quant_test)
436
+ print(f"\n --- {quant_test} ---")
437
+ for label, prompt in english_tests:
438
+ ids = tok.encode(prompt, return_tensors="pt").to(device)
439
+ out_ids = generate(model, ids, max_new_tokens=50,
440
+ temperature=0.3, top_k=10)
441
+ text = tok.decode(out_ids[0], skip_special_tokens=True)
442
+ print(f" [{label:>10}] {text}")
443
+
444
+ # ── 8. Export GGUF files ──────────────────────────────────────────────────
445
+ if not args.skip_gguf:
446
+ print(f"\n\n{'='*70}")
447
+ print(f" EXPORTING GGUF FILES")
448
+ print(f"{'='*70}")
449
+ gguf_script = Path("quantisations/convert_to_gguf.py")
450
+ if gguf_script.exists():
451
+ import subprocess
452
+ cmd = [
453
+ sys.executable, str(gguf_script),
454
+ "--ckpt", str(args.ckpt),
455
+ "--tok_dir", str(args.tok_dir),
456
+ "--quant", "all",
457
+ ]
458
+ print(f" Running: {' '.join(cmd)}")
459
+ subprocess.run(cmd, check=True)
460
+ else:
461
+ print(f" WARNING: {gguf_script} not found β€” skipping GGUF export")
462
+ else:
463
+ print(f"\n (GGUF export skipped)")
464
+
465
+ # ── 9. Final Summary ──────────────────────────────────────────────────────
466
+ print(f"\n\n{'='*70}")
467
+ print(f" FINAL SUMMARY")
468
+ print(f"{'='*70}")
469
+ print(f" Pretrained step: {step} | Tokens seen: {tokens_seen:,}")
470
+ print(f" Base F32 perplexity: {base_ppl:.2f}")
471
+ print(f" Best quantisation: {best_quant}")
472
+ print(f" Cosine similarity vs F32: {logit_cosine.get(best_quant, 1.0):.6f}")
473
+ print(f" Perplexity: {all_ppls[best_quant]:.2f} (Ξ” {all_ppls[best_quant] - base_ppl:+.2f})")
474
+ print(f"\n Recommendation:")
475
+ print(f" Use {best_quant} for deployment β€” best precision/size tradeoff.")
476
+ if not args.skip_gguf:
477
+ print(f" GGUF file: quantisations/LUNA-100M-{best_quant}.gguf")
478
+ print(f"\n{'='*70}")
479
+ print(f" Done!")
480
+ print(f"{'='*70}\n")
481
+
482
+
483
+ if __name__ == "__main__":
484
+ main()