ASTERIZER commited on
Commit
01e6957
Β·
verified Β·
1 Parent(s): 7bab2fa

Upload validate_sft.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. validate_sft.py +707 -0
validate_sft.py ADDED
@@ -0,0 +1,707 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LUNA 100M β€” SFT Validation on Complex Examples
3
+ ================================================
4
+ Selects ~100 complex examples from the SFT validation set,
5
+ runs the fine-tuned model on each, and produces a detailed report.
6
+
7
+ Metrics computed:
8
+ - Per-sample cross-entropy loss (prompt-masked) & perplexity
9
+ - Token-level accuracy on the output portion
10
+ - BLEU-1/2 (word overlap with reference output)
11
+ - Repetition ratio (degeneration detection)
12
+ - Response length stats
13
+ - Category breakdown (coding, explanation, analysis, creative, how-to, identity)
14
+ - Overall pass/fail grading
15
+
16
+ Usage:
17
+ python validate_sft.py
18
+ python validate_sft.py --ckpt "Base/out/sft/model.pth" --val_json "Base/Datasets/sft_clean/val.json"
19
+ """
20
+
21
+ import os, sys, json, math, time, argparse, re
22
+ from pathlib import Path
23
+ from collections import Counter, defaultdict
24
+ from datetime import datetime
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+ import torch.nn.functional as F
29
+
30
+
31
+ # ─── Model (identical to sft_train.py / chat.py) ─────────────────────────────
32
+
33
+ class RotaryEmbedding(nn.Module):
34
+ def __init__(self, dim, max_seq_len=1024):
35
+ super().__init__()
36
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
37
+ self.register_buffer("inv_freq", inv_freq)
38
+ t = torch.arange(max_seq_len).float()
39
+ freqs = torch.einsum("i,j->ij", t, inv_freq)
40
+ emb = torch.cat([freqs, freqs], dim=-1)
41
+ self.register_buffer("cos_cached", emb.cos())
42
+ self.register_buffer("sin_cached", emb.sin())
43
+
44
+ def forward(self, seq_len):
45
+ return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
46
+
47
+
48
+ def rotate_half(x):
49
+ x1, x2 = x.chunk(2, dim=-1)
50
+ return torch.cat([-x2, x1], dim=-1)
51
+
52
+
53
+ def apply_rotary(x, cos, sin):
54
+ c = cos.unsqueeze(0).unsqueeze(0)
55
+ s = sin.unsqueeze(0).unsqueeze(0)
56
+ return x * c + rotate_half(x) * s
57
+
58
+
59
+ class CausalSelfAttention(nn.Module):
60
+ def __init__(self, n_embd, n_head, block_size, rotary_pct=0.25):
61
+ super().__init__()
62
+ self.n_head = n_head
63
+ self.head_dim = n_embd // n_head
64
+ self.rot_dim = int(self.head_dim * rotary_pct)
65
+ self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=True)
66
+ self.c_proj = nn.Linear(n_embd, n_embd, bias=True)
67
+ self.rotary = RotaryEmbedding(self.rot_dim, block_size)
68
+
69
+ def forward(self, x):
70
+ B, T, C = x.size()
71
+ qkv = self.c_attn(x).reshape(B, T, 3, self.n_head, self.head_dim).permute(2, 0, 3, 1, 4)
72
+ q, k, v = qkv.unbind(0)
73
+ cos, sin = self.rotary(T)
74
+ q = torch.cat([apply_rotary(q[..., :self.rot_dim], cos, sin), q[..., self.rot_dim:]], dim=-1)
75
+ k = torch.cat([apply_rotary(k[..., :self.rot_dim], cos, sin), k[..., self.rot_dim:]], dim=-1)
76
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
77
+ return self.c_proj(y.transpose(1, 2).contiguous().view(B, T, C))
78
+
79
+
80
+ class MLP(nn.Module):
81
+ def __init__(self, n_embd):
82
+ super().__init__()
83
+ self.fc = nn.Linear(n_embd, 4 * n_embd, bias=True)
84
+ self.gelu = nn.GELU()
85
+ self.proj = nn.Linear(4 * n_embd, n_embd, bias=True)
86
+
87
+ def forward(self, x):
88
+ return self.proj(self.gelu(self.fc(x)))
89
+
90
+
91
+ class Block(nn.Module):
92
+ def __init__(self, n_embd, n_head, block_size):
93
+ super().__init__()
94
+ self.ln1 = nn.LayerNorm(n_embd)
95
+ self.attn = CausalSelfAttention(n_embd, n_head, block_size)
96
+ self.ln2 = nn.LayerNorm(n_embd)
97
+ self.mlp = MLP(n_embd)
98
+
99
+ def forward(self, x):
100
+ x = x + self.attn(self.ln1(x))
101
+ x = x + self.mlp(self.ln2(x))
102
+ return x
103
+
104
+
105
+ class LUNAModel(nn.Module):
106
+ def __init__(self, vocab_size=50304, block_size=1024,
107
+ n_layer=10, n_embd=768, n_head=12):
108
+ super().__init__()
109
+ self.block_size = block_size
110
+ self.wte = nn.Embedding(vocab_size, n_embd)
111
+ self.blocks = nn.ModuleList(
112
+ [Block(n_embd, n_head, block_size) for _ in range(n_layer)]
113
+ )
114
+ self.ln_f = nn.LayerNorm(n_embd)
115
+ self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
116
+ self.lm_head.weight = self.wte.weight
117
+
118
+ def forward(self, idx):
119
+ x = self.wte(idx)
120
+ for block in self.blocks:
121
+ x = block(x)
122
+ return self.lm_head(self.ln_f(x))
123
+
124
+
125
+ # ─── Generation ───────────────────────────────────────────────────────────────
126
+
127
+ @torch.no_grad()
128
+ def generate(model, input_ids, max_new=150, temperature=0.7,
129
+ top_p=0.9, top_k=40, repetition_penalty=1.0, device="cpu"):
130
+ ids = input_ids.to(device)
131
+ generated = []
132
+ for _ in range(max_new):
133
+ logits = model(ids[:, -model.block_size:])[:, -1, :]
134
+ if repetition_penalty != 1.0:
135
+ for tok_id in set(ids[0].tolist()):
136
+ if logits[0, tok_id] > 0:
137
+ logits[0, tok_id] /= repetition_penalty
138
+ else:
139
+ logits[0, tok_id] *= repetition_penalty
140
+ if temperature < 1e-6:
141
+ next_token = logits.argmax(dim=-1, keepdim=True)
142
+ else:
143
+ logits = logits / temperature
144
+ probs = F.softmax(logits, dim=-1)
145
+ if top_k > 0:
146
+ kval = min(top_k, probs.size(-1))
147
+ topk_vals, _ = torch.topk(probs, kval)
148
+ probs[probs < topk_vals[:, [-1]]] = 0.0
149
+ probs /= probs.sum()
150
+ if top_p < 1.0:
151
+ sorted_probs, sorted_idx = torch.sort(probs, descending=True)
152
+ cumsum = torch.cumsum(sorted_probs, dim=-1)
153
+ mask = cumsum - sorted_probs > top_p
154
+ sorted_probs[mask] = 0.0
155
+ sorted_probs /= sorted_probs.sum()
156
+ next_token = sorted_idx[0, torch.multinomial(sorted_probs[0], 1)]
157
+ else:
158
+ next_token = torch.multinomial(probs[0], 1)
159
+ ids = torch.cat([ids, next_token.view(1, 1)], dim=1)
160
+ generated.append(next_token.item())
161
+ if next_token.item() == 0: # EOS token
162
+ break
163
+ return generated
164
+
165
+
166
+ # ─── Prompt formatting (matches sft_train.py) ────────────────────────────────
167
+
168
+ def format_prompt(instruction, inp=""):
169
+ inst = instruction.strip()
170
+ inp = inp.strip()
171
+ if inst and inp:
172
+ return f"### Instruction:\n{inst}\n\n### Input:\n{inp}\n\n### Response:\n"
173
+ elif inst:
174
+ return f"### Instruction:\n{inst}\n\n### Response:\n"
175
+ else:
176
+ return f"### Input:\n{inp}\n\n### Response:\n"
177
+
178
+
179
+ # ─── Complexity scoring & selection ───────────────────────────────────────────
180
+
181
+ COMPLEXITY_KEYWORDS = [
182
+ "step", "first", "second", "then", "next", "finally",
183
+ "because", "however", "therefore", "explain", "analyze",
184
+ "compare", "evaluate", "describe", "discuss", "provide",
185
+ "example", "detail", "elaborate", "summarize",
186
+ ]
187
+
188
+ def complexity_score(entry):
189
+ inst = entry.get("instruction", "")
190
+ inp = entry.get("input", "")
191
+ out = entry.get("output", "")
192
+ total_text = (inst + " " + inp + " " + out).lower()
193
+ total_len = len(inst) + len(inp) + len(out)
194
+ has_input = 1 if len(inp) > 20 else 0
195
+ kw_count = sum(1 for w in COMPLEXITY_KEYWORDS if w in total_text)
196
+ return total_len * 0.3 + len(out) * 0.5 + has_input * 500 + kw_count * 200
197
+
198
+
199
+ def categorize(instruction):
200
+ inst = instruction.lower()
201
+ if any(w in inst for w in ["code", "python", "java", "swift", "function", "program", "algorithm", "script", "sql", "html", "css"]):
202
+ return "coding"
203
+ if any(w in inst for w in ["who are you", "your name", "who created", "asterizer", "luna", "are you an ai"]):
204
+ return "identity"
205
+ if any(w in inst for w in ["explain", "what is", "define", "describe", "meaning of"]):
206
+ return "explanation"
207
+ if any(w in inst for w in ["analyze", "compare", "evaluate", "assess", "critique"]):
208
+ return "analysis"
209
+ if any(w in inst for w in ["write", "create", "generate", "compose", "draft", "poem", "story", "essay"]):
210
+ return "creative"
211
+ if any(w in inst for w in ["how", "step", "guide", "method", "procedure", "tutorial"]):
212
+ return "how-to"
213
+ return "other"
214
+
215
+
216
+ def select_complex_examples(data, n=100):
217
+ scored = [(complexity_score(entry), i) for i, entry in enumerate(data)]
218
+ scored.sort(reverse=True)
219
+ return [data[idx] for _, idx in scored[:n]]
220
+
221
+
222
+ # ─── Metrics ──────────────────────────────────────────────────────────────────
223
+
224
+ def compute_bleu(reference, hypothesis, max_n=2):
225
+ """Simple BLEU-1 and BLEU-2 (word-level, no brevity penalty)."""
226
+ ref_tokens = reference.lower().split()
227
+ hyp_tokens = hypothesis.lower().split()
228
+ if not hyp_tokens or not ref_tokens:
229
+ return {f"bleu_{n}": 0.0 for n in range(1, max_n + 1)}
230
+ scores = {}
231
+ for n in range(1, max_n + 1):
232
+ ref_ngrams = Counter()
233
+ for i in range(len(ref_tokens) - n + 1):
234
+ ref_ngrams[tuple(ref_tokens[i:i + n])] += 1
235
+ hyp_ngrams = Counter()
236
+ for i in range(len(hyp_tokens) - n + 1):
237
+ hyp_ngrams[tuple(hyp_tokens[i:i + n])] += 1
238
+ clipped = sum(min(hyp_ngrams[ng], ref_ngrams[ng]) for ng in hyp_ngrams)
239
+ total = max(sum(hyp_ngrams.values()), 1)
240
+ scores[f"bleu_{n}"] = clipped / total
241
+ return scores
242
+
243
+
244
+ def repetition_ratio(text):
245
+ """Fraction of repeated trigrams in the text (higher = more degenerate)."""
246
+ words = text.lower().split()
247
+ if len(words) < 4:
248
+ return 0.0
249
+ trigrams = [tuple(words[i:i + 3]) for i in range(len(words) - 2)]
250
+ if not trigrams:
251
+ return 0.0
252
+ unique = len(set(trigrams))
253
+ return 1.0 - (unique / len(trigrams))
254
+
255
+
256
+ @torch.no_grad()
257
+ def compute_loss_and_accuracy(model, tokenizer, entry, max_len, device):
258
+ """Compute prompt-masked CE loss & token accuracy for one example."""
259
+ prompt = format_prompt(entry.get("instruction", ""), entry.get("input", ""))
260
+ response = entry.get("output", "").strip()
261
+
262
+ prompt_ids = tokenizer.encode(prompt)
263
+ response_ids = tokenizer.encode(response) + [tokenizer.eos_token_id or 0]
264
+ total_ids = prompt_ids + response_ids
265
+
266
+ if len(total_ids) > max_len:
267
+ total_ids = total_ids[:max_len]
268
+ total_ids[-1] = tokenizer.eos_token_id or 0
269
+ prompt_len = min(len(prompt_ids), max_len)
270
+ else:
271
+ prompt_len = len(prompt_ids)
272
+
273
+ input_tensor = torch.tensor([total_ids], dtype=torch.long, device=device)
274
+ logits = model(input_tensor) # (1, T, V)
275
+
276
+ # Shift for next-token prediction
277
+ shift_logits = logits[:, :-1, :].contiguous()
278
+ shift_targets = input_tensor[:, 1:].contiguous()
279
+
280
+ # Build mask: only on response portion
281
+ mask = torch.zeros(shift_targets.shape, dtype=torch.float, device=device)
282
+ resp_start = max(prompt_len - 1, 0)
283
+ resp_end = len(total_ids) - 1
284
+ mask[0, resp_start:resp_end] = 1.0
285
+
286
+ if mask.sum() == 0:
287
+ return float("inf"), 0.0, 0
288
+
289
+ per_token_loss = F.cross_entropy(
290
+ shift_logits.view(-1, shift_logits.size(-1)),
291
+ shift_targets.view(-1),
292
+ reduction="none"
293
+ ).view(shift_targets.shape)
294
+
295
+ masked_loss = (per_token_loss * mask).sum() / mask.sum()
296
+
297
+ # Token accuracy on response portion
298
+ preds = shift_logits.argmax(dim=-1)
299
+ correct = ((preds == shift_targets).float() * mask).sum()
300
+ total_resp = mask.sum()
301
+ accuracy = (correct / total_resp).item() if total_resp > 0 else 0.0
302
+
303
+ return masked_loss.item(), accuracy, int(total_resp.item())
304
+
305
+
306
+ # ─── Main validation ─────────────────────────────────────────────────────────
307
+
308
+ def main():
309
+ parser = argparse.ArgumentParser(description="LUNA SFT β€” Complex Example Validation")
310
+ parser.add_argument("--ckpt", default=r"D:\ASTERIZER 2026\LUNA\Base\out\sft\model.pth")
311
+ parser.add_argument("--tok_dir", default="Base/checkpoints/EleutherAI/pythia-160m")
312
+ parser.add_argument("--val_json", default="Base/Datasets/sft_clean/val.json")
313
+ parser.add_argument("--n_examples", type=int, default=100)
314
+ parser.add_argument("--max_len", type=int, default=1024)
315
+ parser.add_argument("--max_new", type=int, default=150)
316
+ parser.add_argument("--temperature", type=float, default=0.7)
317
+ parser.add_argument("--top_k", type=int, default=40)
318
+ parser.add_argument("--top_p", type=float, default=0.9)
319
+ parser.add_argument("--rep_pen", type=float, default=1.0)
320
+ parser.add_argument("--device", default="auto")
321
+ parser.add_argument("--out_dir", default="Base/out/sft/validation_report_v2")
322
+ args = parser.parse_args()
323
+
324
+ device = "cuda" if args.device == "auto" and torch.cuda.is_available() else args.device
325
+ if device == "auto":
326
+ device = "cpu"
327
+
328
+ out_dir = Path(args.out_dir)
329
+ out_dir.mkdir(parents=True, exist_ok=True)
330
+
331
+ sep = "=" * 72
332
+ print(f"\n{sep}")
333
+ print(" LUNA 100M β€” SFT VALIDATION (Complex Examples)")
334
+ print(sep)
335
+ print(f" Checkpoint : {args.ckpt}")
336
+ print(f" Val data : {args.val_json}")
337
+ print(f" N examples : {args.n_examples}")
338
+ print(f" Device : {device}")
339
+ print(f" Max seq : {args.max_len}")
340
+ print(f" Temperature: {args.temperature}")
341
+ print(f" Top-k : {args.top_k}")
342
+ print(f" Top-p : {args.top_p}")
343
+ print(f" Rep penalty: {args.rep_pen}")
344
+ print(sep)
345
+
346
+ # ── Load model ────────────────────────────────────────────────────────────
347
+ print("\n[1/5] Loading model...")
348
+ t0 = time.time()
349
+ state_dict = torch.load(args.ckpt, map_location="cpu", weights_only=True)
350
+ if isinstance(state_dict, dict) and "model" in state_dict:
351
+ state_dict = state_dict["model"]
352
+ model = LUNAModel()
353
+ model.load_state_dict(state_dict, strict=True)
354
+ model = model.to(device).eval()
355
+ n_params = sum(p.numel() for p in model.parameters())
356
+ print(f" Model loaded: {n_params:,} params ({time.time()-t0:.1f}s)")
357
+
358
+ # ── Load tokenizer ─────────────────────────────────────────────────────��──
359
+ print("[2/5] Loading tokenizer...")
360
+ from transformers import AutoTokenizer
361
+ tokenizer = AutoTokenizer.from_pretrained(args.tok_dir)
362
+ print(f" Tokenizer: vocab_size={tokenizer.vocab_size}")
363
+
364
+ # ── Select complex examples ───────────────────────────────────────────────
365
+ print(f"[3/5] Selecting {args.n_examples} most complex examples from val set...")
366
+ with open(args.val_json, "r", encoding="utf-8") as f:
367
+ all_val = json.load(f)
368
+ print(f" Total val samples: {len(all_val)}")
369
+ examples = select_complex_examples(all_val, args.n_examples)
370
+ print(f" Selected: {len(examples)} complex examples")
371
+
372
+ # Category breakdown
373
+ cat_counts = Counter(categorize(e["instruction"]) for e in examples)
374
+ print(f" Categories: {dict(cat_counts)}")
375
+
376
+ # ── Run validation ────────────────────────────────────────────────────────
377
+ print(f"\n[4/5] Running validation ({len(examples)} examples)...")
378
+ print("-" * 72)
379
+
380
+ results = []
381
+ cat_metrics = defaultdict(lambda: {"losses": [], "perplexities": [],
382
+ "accuracies": [], "bleu1": [],
383
+ "bleu2": [], "rep_ratios": [],
384
+ "gen_lens": []})
385
+
386
+ for i, entry in enumerate(examples):
387
+ inst = entry.get("instruction", "")
388
+ inp = entry.get("input", "")
389
+ ref_output = entry.get("output", "")
390
+ category = categorize(inst)
391
+
392
+ # 1) Compute loss & accuracy (teacher-forced)
393
+ loss, tok_acc, n_resp_tokens = compute_loss_and_accuracy(
394
+ model, tokenizer, entry, args.max_len, device
395
+ )
396
+ ppl = math.exp(min(loss, 20)) # cap to avoid overflow
397
+
398
+ # 2) Generate response (autoregressive)
399
+ prompt = format_prompt(inst, inp)
400
+ prompt_ids = tokenizer.encode(prompt, return_tensors="pt")
401
+ gen_tokens = generate(
402
+ model, prompt_ids,
403
+ max_new=args.max_new,
404
+ temperature=args.temperature,
405
+ top_k=args.top_k,
406
+ top_p=args.top_p,
407
+ repetition_penalty=args.rep_pen,
408
+ device=device,
409
+ )
410
+ gen_text = tokenizer.decode(gen_tokens, skip_special_tokens=True).strip()
411
+ # Clean trailing template markers
412
+ if "### " in gen_text:
413
+ gen_text = gen_text.split("### ")[0].strip()
414
+
415
+ # 3) Compute text metrics
416
+ bleu = compute_bleu(ref_output, gen_text)
417
+ rep = repetition_ratio(gen_text)
418
+ gen_words = len(gen_text.split())
419
+
420
+ # 4) Quality flags
421
+ is_empty = len(gen_text.strip()) < 5
422
+ is_repetitive = rep > 0.5
423
+ is_truncated = len(gen_tokens) >= args.max_new
424
+
425
+ result = {
426
+ "index": i,
427
+ "category": category,
428
+ "instruction": inst[:200],
429
+ "input_preview": inp[:100] if inp else "",
430
+ "reference_preview": ref_output[:200],
431
+ "generated_preview": gen_text[:300],
432
+ "loss": round(loss, 4),
433
+ "perplexity": round(ppl, 2),
434
+ "token_accuracy": round(tok_acc, 4),
435
+ "bleu_1": round(bleu["bleu_1"], 4),
436
+ "bleu_2": round(bleu["bleu_2"], 4),
437
+ "repetition_ratio": round(rep, 4),
438
+ "generated_words": gen_words,
439
+ "resp_tokens": n_resp_tokens,
440
+ "is_empty": is_empty,
441
+ "is_repetitive": is_repetitive,
442
+ "is_truncated": is_truncated,
443
+ }
444
+ results.append(result)
445
+
446
+ # Accumulate per-category
447
+ cat_metrics[category]["losses"].append(loss)
448
+ cat_metrics[category]["perplexities"].append(ppl)
449
+ cat_metrics[category]["accuracies"].append(tok_acc)
450
+ cat_metrics[category]["bleu1"].append(bleu["bleu_1"])
451
+ cat_metrics[category]["bleu2"].append(bleu["bleu_2"])
452
+ cat_metrics[category]["rep_ratios"].append(rep)
453
+ cat_metrics[category]["gen_lens"].append(gen_words)
454
+
455
+ # Progress
456
+ status = ""
457
+ if is_empty:
458
+ status = " [EMPTY]"
459
+ elif is_repetitive:
460
+ status = " [REPETITIVE]"
461
+ elif is_truncated:
462
+ status = " [TRUNCATED]"
463
+
464
+ if (i + 1) % 5 == 0 or i == 0:
465
+ print(f" [{i+1:3d}/{len(examples)}] loss={loss:.3f} ppl={ppl:.1f} "
466
+ f"acc={tok_acc:.3f} B1={bleu['bleu_1']:.3f} "
467
+ f"rep={rep:.3f} words={gen_words}{status}")
468
+
469
+ # ── Aggregate & Report ────────────────────────────────────────────────────
470
+ print(f"\n[5/5] Generating report...")
471
+
472
+ all_losses = [r["loss"] for r in results if r["loss"] < float("inf")]
473
+ all_ppls = [r["perplexity"] for r in results if r["perplexity"] < 1e6]
474
+ all_accs = [r["token_accuracy"] for r in results]
475
+ all_b1 = [r["bleu_1"] for r in results]
476
+ all_b2 = [r["bleu_2"] for r in results]
477
+ all_reps = [r["repetition_ratio"] for r in results]
478
+ all_lens = [r["generated_words"] for r in results]
479
+
480
+ n_empty = sum(1 for r in results if r["is_empty"])
481
+ n_repetitive = sum(1 for r in results if r["is_repetitive"])
482
+ n_truncated = sum(1 for r in results if r["is_truncated"])
483
+
484
+ avg = lambda xs: sum(xs) / len(xs) if xs else 0.0
485
+
486
+ # ── Build report text ─────────────────────────────────────────────────────
487
+ report_lines = []
488
+ def P(s=""):
489
+ report_lines.append(s)
490
+
491
+ P(sep)
492
+ P(" LUNA 100M β€” SFT VALIDATION REPORT")
493
+ P(f" Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
494
+ P(sep)
495
+ P()
496
+ P(f" Checkpoint : {args.ckpt}")
497
+ P(f" Val source : {args.val_json} ({len(all_val)} total samples)")
498
+ P(f" Examples tested: {len(examples)} (top complex by scoring)")
499
+ P(f" Device : {device}")
500
+ P(f" Max seq len : {args.max_len}")
501
+ P(f" Gen temp : {args.temperature}")
502
+ P(f" Gen top_k : {args.top_k}")
503
+ P(f" Gen top_p : {args.top_p}")
504
+ P(f" Gen rep_pen : {args.rep_pen}")
505
+ P(f" Gen max tokens: {args.max_new}")
506
+ P()
507
+
508
+ P("=" * 72)
509
+ P(" OVERALL METRICS")
510
+ P("=" * 72)
511
+ P(f" Avg Loss (CE) : {avg(all_losses):.4f}")
512
+ P(f" Avg Perplexity : {avg(all_ppls):.2f}")
513
+ P(f" Median Perplexity : {sorted(all_ppls)[len(all_ppls)//2]:.2f}" if all_ppls else " Median Perplexity : N/A")
514
+ P(f" Avg Token Accuracy : {avg(all_accs):.4f} ({avg(all_accs)*100:.1f}%)")
515
+ P(f" Avg BLEU-1 : {avg(all_b1):.4f}")
516
+ P(f" Avg BLEU-2 : {avg(all_b2):.4f}")
517
+ P(f" Avg Repetition Ratio : {avg(all_reps):.4f}")
518
+ P(f" Avg Gen Length (words): {avg(all_lens):.1f}")
519
+ P()
520
+ P(f" Empty responses : {n_empty}/{len(results)}")
521
+ P(f" Repetitive responses : {n_repetitive}/{len(results)}")
522
+ P(f" Truncated responses : {n_truncated}/{len(results)}")
523
+ P()
524
+
525
+ # Quality grade
526
+ grade = "A"
527
+ grade_notes = []
528
+ if avg(all_ppls) > 50:
529
+ grade = "C"
530
+ grade_notes.append("high perplexity (>50)")
531
+ elif avg(all_ppls) > 20:
532
+ grade = "B"
533
+ grade_notes.append("moderate perplexity (>20)")
534
+ if n_empty > 10:
535
+ grade = "D" if grade > "C" else grade
536
+ grade_notes.append(f"{n_empty} empty responses")
537
+ elif n_empty > 3:
538
+ grade = max(grade, "C")
539
+ grade_notes.append(f"{n_empty} empty responses")
540
+ if n_repetitive > 15:
541
+ grade = max(grade, "C")
542
+ grade_notes.append(f"{n_repetitive} repetitive responses")
543
+ if avg(all_b1) < 0.05:
544
+ grade = max(grade, "C")
545
+ grade_notes.append("very low BLEU-1")
546
+ if avg(all_accs) > 0.4:
547
+ grade_notes.append("strong token accuracy")
548
+ if avg(all_accs) > 0.5:
549
+ if grade == "B":
550
+ grade = "A-"
551
+
552
+ P(f" OVERALL GRADE: {grade}")
553
+ if grade_notes:
554
+ P(f" Notes: {'; '.join(grade_notes)}")
555
+ P()
556
+
557
+ P("=" * 72)
558
+ P(" CATEGORY BREAKDOWN")
559
+ P("=" * 72)
560
+ P(f" {'Category':<14} {'Count':>5} {'Avg Loss':>9} {'Avg PPL':>9} "
561
+ f"{'Avg Acc':>8} {'BLEU-1':>7} {'BLEU-2':>7} {'Rep %':>6}")
562
+ P(" " + "-" * 68)
563
+ for cat in sorted(cat_metrics.keys()):
564
+ m = cat_metrics[cat]
565
+ cnt = len(m["losses"])
566
+ P(f" {cat:<14} {cnt:>5} {avg(m['losses']):>9.4f} "
567
+ f"{avg(m['perplexities']):>9.2f} {avg(m['accuracies']):>8.4f} "
568
+ f"{avg(m['bleu1']):>7.4f} {avg(m['bleu2']):>7.4f} "
569
+ f"{avg(m['rep_ratios'])*100:>5.1f}%")
570
+ P()
571
+
572
+ # ── Top 5 Best / Worst ────────────────────────────────────────────────────
573
+ P("=" * 72)
574
+ P(" TOP 5 BEST (lowest perplexity)")
575
+ P("=" * 72)
576
+ by_ppl = sorted(results, key=lambda r: r["perplexity"])
577
+ for r in by_ppl[:5]:
578
+ P(f" [{r['index']:3d}] PPL={r['perplexity']:>8.2f} Acc={r['token_accuracy']:.3f} "
579
+ f"B1={r['bleu_1']:.3f} [{r['category']}]")
580
+ P(f" Q: {r['instruction'][:80]}")
581
+ P(f" A: {r['generated_preview'][:100]}")
582
+ P()
583
+
584
+ P("=" * 72)
585
+ P(" TOP 5 WORST (highest perplexity)")
586
+ P("=" * 72)
587
+ for r in by_ppl[-5:]:
588
+ P(f" [{r['index']:3d}] PPL={r['perplexity']:>8.2f} Acc={r['token_accuracy']:.3f} "
589
+ f"B1={r['bleu_1']:.3f} [{r['category']}]")
590
+ P(f" Q: {r['instruction'][:80]}")
591
+ P(f" A: {r['generated_preview'][:100]}")
592
+ P()
593
+
594
+ # ── Failure Analysis ──────────────────────────────────────────────────────
595
+ failures = [r for r in results if r["is_empty"] or r["is_repetitive"]]
596
+ if failures:
597
+ P("=" * 72)
598
+ P(f" FAILURE ANALYSIS ({len(failures)} problematic responses)")
599
+ P("=" * 72)
600
+ for r in failures[:10]:
601
+ flags = []
602
+ if r["is_empty"]:
603
+ flags.append("EMPTY")
604
+ if r["is_repetitive"]:
605
+ flags.append("REPETITIVE")
606
+ P(f" [{r['index']:3d}] {' | '.join(flags)} [{r['category']}]")
607
+ P(f" Q: {r['instruction'][:80]}")
608
+ P(f" A: {r['generated_preview'][:120]}")
609
+ P()
610
+
611
+ # ── Perplexity distribution ───────────────────────────────────────────────
612
+ P("=" * 72)
613
+ P(" PERPLEXITY DISTRIBUTION")
614
+ P("=" * 72)
615
+ buckets = [(0, 5), (5, 10), (10, 20), (20, 50), (50, 100),
616
+ (100, 500), (500, float("inf"))]
617
+ for lo, hi in buckets:
618
+ cnt = sum(1 for p in all_ppls if lo <= p < hi)
619
+ bar = "#" * cnt
620
+ label = f"{lo}-{hi}" if hi != float("inf") else f"{lo}+"
621
+ P(f" {label:>8}: {cnt:>3} {bar}")
622
+ P()
623
+
624
+ # ── Sample generations (10 diverse examples) ──────────────────────────────
625
+ P("=" * 72)
626
+ P(" SAMPLE GENERATIONS (10 diverse examples)")
627
+ P("=" * 72)
628
+ # Pick every 10th
629
+ sample_indices = list(range(0, len(results), max(1, len(results) // 10)))[:10]
630
+ for si in sample_indices:
631
+ r = results[si]
632
+ P(f"\n --- Example {r['index']+1} [{r['category']}] ---")
633
+ P(f" Instruction: {r['instruction'][:150]}")
634
+ if r["input_preview"]:
635
+ P(f" Input: {r['input_preview'][:100]}")
636
+ P(f" Reference: {r['reference_preview'][:200]}")
637
+ P(f" Generated: {r['generated_preview'][:300]}")
638
+ P(f" Loss={r['loss']:.4f} PPL={r['perplexity']:.2f} "
639
+ f"Acc={r['token_accuracy']:.4f} BLEU-1={r['bleu_1']:.4f} "
640
+ f"Rep={r['repetition_ratio']:.4f}")
641
+ P()
642
+
643
+ P(sep)
644
+ P(" END OF REPORT")
645
+ P(sep)
646
+
647
+ report_text = "\n".join(report_lines)
648
+
649
+ # Print to console
650
+ print(report_text)
651
+
652
+ # Save report
653
+ report_path = out_dir / "SFT_VALIDATION_REPORT.txt"
654
+ with open(report_path, "w", encoding="utf-8") as f:
655
+ f.write(report_text)
656
+ print(f"\n Report saved: {report_path}")
657
+
658
+ # Save detailed JSON results
659
+ json_path = out_dir / "validation_results.json"
660
+ summary = {
661
+ "meta": {
662
+ "checkpoint": args.ckpt,
663
+ "val_source": args.val_json,
664
+ "total_val_samples": len(all_val),
665
+ "n_tested": len(examples),
666
+ "device": device,
667
+ "max_len": args.max_len,
668
+ "temperature": args.temperature,
669
+ "top_k": args.top_k,
670
+ "top_p": args.top_p,
671
+ "repetition_penalty": args.rep_pen,
672
+ "timestamp": datetime.now().isoformat(),
673
+ },
674
+ "overall": {
675
+ "avg_loss": round(avg(all_losses), 4),
676
+ "avg_perplexity": round(avg(all_ppls), 2),
677
+ "median_perplexity": round(sorted(all_ppls)[len(all_ppls)//2], 2) if all_ppls else None,
678
+ "avg_token_accuracy": round(avg(all_accs), 4),
679
+ "avg_bleu_1": round(avg(all_b1), 4),
680
+ "avg_bleu_2": round(avg(all_b2), 4),
681
+ "avg_repetition_ratio": round(avg(all_reps), 4),
682
+ "avg_gen_length_words": round(avg(all_lens), 1),
683
+ "n_empty": n_empty,
684
+ "n_repetitive": n_repetitive,
685
+ "n_truncated": n_truncated,
686
+ "grade": grade,
687
+ },
688
+ "category_breakdown": {
689
+ cat: {
690
+ "count": len(m["losses"]),
691
+ "avg_loss": round(avg(m["losses"]), 4),
692
+ "avg_perplexity": round(avg(m["perplexities"]), 2),
693
+ "avg_token_accuracy": round(avg(m["accuracies"]), 4),
694
+ "avg_bleu_1": round(avg(m["bleu1"]), 4),
695
+ "avg_bleu_2": round(avg(m["bleu2"]), 4),
696
+ }
697
+ for cat, m in cat_metrics.items()
698
+ },
699
+ "per_example": results,
700
+ }
701
+ with open(json_path, "w", encoding="utf-8") as f:
702
+ json.dump(summary, f, indent=2, ensure_ascii=False)
703
+ print(f" JSON results: {json_path}")
704
+
705
+
706
+ if __name__ == "__main__":
707
+ main()