| """
|
| LUNA 100M β SFT Validation on Complex Examples
|
| ================================================
|
| Selects ~100 complex examples from the SFT validation set,
|
| runs the fine-tuned model on each, and produces a detailed report.
|
|
|
| Metrics computed:
|
| - Per-sample cross-entropy loss (prompt-masked) & perplexity
|
| - Token-level accuracy on the output portion
|
| - BLEU-1/2 (word overlap with reference output)
|
| - Repetition ratio (degeneration detection)
|
| - Response length stats
|
| - Category breakdown (coding, explanation, analysis, creative, how-to, identity)
|
| - Overall pass/fail grading
|
|
|
| Usage:
|
| python validate_sft.py
|
| python validate_sft.py --ckpt "Base/out/sft/model.pth" --val_json "Base/Datasets/sft_clean/val.json"
|
| """
|
|
|
| import os, sys, json, math, time, argparse, re
|
| from pathlib import Path
|
| from collections import Counter, defaultdict
|
| from datetime import datetime
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
| class RotaryEmbedding(nn.Module):
|
| def __init__(self, dim, max_seq_len=1024):
|
| super().__init__()
|
| inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
| self.register_buffer("inv_freq", inv_freq)
|
| t = torch.arange(max_seq_len).float()
|
| freqs = torch.einsum("i,j->ij", t, inv_freq)
|
| emb = torch.cat([freqs, freqs], dim=-1)
|
| self.register_buffer("cos_cached", emb.cos())
|
| self.register_buffer("sin_cached", emb.sin())
|
|
|
| def forward(self, seq_len):
|
| return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
|
|
|
|
|
| def rotate_half(x):
|
| x1, x2 = x.chunk(2, dim=-1)
|
| return torch.cat([-x2, x1], dim=-1)
|
|
|
|
|
| def apply_rotary(x, cos, sin):
|
| c = cos.unsqueeze(0).unsqueeze(0)
|
| s = sin.unsqueeze(0).unsqueeze(0)
|
| return x * c + rotate_half(x) * s
|
|
|
|
|
| class CausalSelfAttention(nn.Module):
|
| def __init__(self, n_embd, n_head, block_size, rotary_pct=0.25):
|
| super().__init__()
|
| self.n_head = n_head
|
| self.head_dim = n_embd // n_head
|
| self.rot_dim = int(self.head_dim * rotary_pct)
|
| self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=True)
|
| self.c_proj = nn.Linear(n_embd, n_embd, bias=True)
|
| self.rotary = RotaryEmbedding(self.rot_dim, block_size)
|
|
|
| def forward(self, x):
|
| B, T, C = x.size()
|
| qkv = self.c_attn(x).reshape(B, T, 3, self.n_head, self.head_dim).permute(2, 0, 3, 1, 4)
|
| q, k, v = qkv.unbind(0)
|
| cos, sin = self.rotary(T)
|
| q = torch.cat([apply_rotary(q[..., :self.rot_dim], cos, sin), q[..., self.rot_dim:]], dim=-1)
|
| k = torch.cat([apply_rotary(k[..., :self.rot_dim], cos, sin), k[..., self.rot_dim:]], dim=-1)
|
| y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
|
| return self.c_proj(y.transpose(1, 2).contiguous().view(B, T, C))
|
|
|
|
|
| class MLP(nn.Module):
|
| def __init__(self, n_embd):
|
| super().__init__()
|
| self.fc = nn.Linear(n_embd, 4 * n_embd, bias=True)
|
| self.gelu = nn.GELU()
|
| self.proj = nn.Linear(4 * n_embd, n_embd, bias=True)
|
|
|
| def forward(self, x):
|
| return self.proj(self.gelu(self.fc(x)))
|
|
|
|
|
| class Block(nn.Module):
|
| def __init__(self, n_embd, n_head, block_size):
|
| super().__init__()
|
| self.ln1 = nn.LayerNorm(n_embd)
|
| self.attn = CausalSelfAttention(n_embd, n_head, block_size)
|
| self.ln2 = nn.LayerNorm(n_embd)
|
| self.mlp = MLP(n_embd)
|
|
|
| def forward(self, x):
|
| x = x + self.attn(self.ln1(x))
|
| x = x + self.mlp(self.ln2(x))
|
| return x
|
|
|
|
|
| class LUNAModel(nn.Module):
|
| def __init__(self, vocab_size=50304, block_size=1024,
|
| n_layer=10, n_embd=768, n_head=12):
|
| super().__init__()
|
| self.block_size = block_size
|
| self.wte = nn.Embedding(vocab_size, n_embd)
|
| self.blocks = nn.ModuleList(
|
| [Block(n_embd, n_head, block_size) for _ in range(n_layer)]
|
| )
|
| self.ln_f = nn.LayerNorm(n_embd)
|
| self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
|
| self.lm_head.weight = self.wte.weight
|
|
|
| def forward(self, idx):
|
| x = self.wte(idx)
|
| for block in self.blocks:
|
| x = block(x)
|
| return self.lm_head(self.ln_f(x))
|
|
|
|
|
|
|
|
|
| @torch.no_grad()
|
| def generate(model, input_ids, max_new=150, temperature=0.7,
|
| top_p=0.9, top_k=40, repetition_penalty=1.0, device="cpu"):
|
| ids = input_ids.to(device)
|
| generated = []
|
| for _ in range(max_new):
|
| logits = model(ids[:, -model.block_size:])[:, -1, :]
|
| if repetition_penalty != 1.0:
|
| for tok_id in set(ids[0].tolist()):
|
| if logits[0, tok_id] > 0:
|
| logits[0, tok_id] /= repetition_penalty
|
| else:
|
| logits[0, tok_id] *= repetition_penalty
|
| if temperature < 1e-6:
|
| next_token = logits.argmax(dim=-1, keepdim=True)
|
| else:
|
| logits = logits / temperature
|
| probs = F.softmax(logits, dim=-1)
|
| if top_k > 0:
|
| kval = min(top_k, probs.size(-1))
|
| topk_vals, _ = torch.topk(probs, kval)
|
| probs[probs < topk_vals[:, [-1]]] = 0.0
|
| probs /= probs.sum()
|
| if top_p < 1.0:
|
| sorted_probs, sorted_idx = torch.sort(probs, descending=True)
|
| cumsum = torch.cumsum(sorted_probs, dim=-1)
|
| mask = cumsum - sorted_probs > top_p
|
| sorted_probs[mask] = 0.0
|
| sorted_probs /= sorted_probs.sum()
|
| next_token = sorted_idx[0, torch.multinomial(sorted_probs[0], 1)]
|
| else:
|
| next_token = torch.multinomial(probs[0], 1)
|
| ids = torch.cat([ids, next_token.view(1, 1)], dim=1)
|
| generated.append(next_token.item())
|
| if next_token.item() == 0:
|
| break
|
| return generated
|
|
|
|
|
|
|
|
|
| def format_prompt(instruction, inp=""):
|
| inst = instruction.strip()
|
| inp = inp.strip()
|
| if inst and inp:
|
| return f"### Instruction:\n{inst}\n\n### Input:\n{inp}\n\n### Response:\n"
|
| elif inst:
|
| return f"### Instruction:\n{inst}\n\n### Response:\n"
|
| else:
|
| return f"### Input:\n{inp}\n\n### Response:\n"
|
|
|
|
|
|
|
|
|
| COMPLEXITY_KEYWORDS = [
|
| "step", "first", "second", "then", "next", "finally",
|
| "because", "however", "therefore", "explain", "analyze",
|
| "compare", "evaluate", "describe", "discuss", "provide",
|
| "example", "detail", "elaborate", "summarize",
|
| ]
|
|
|
| def complexity_score(entry):
|
| inst = entry.get("instruction", "")
|
| inp = entry.get("input", "")
|
| out = entry.get("output", "")
|
| total_text = (inst + " " + inp + " " + out).lower()
|
| total_len = len(inst) + len(inp) + len(out)
|
| has_input = 1 if len(inp) > 20 else 0
|
| kw_count = sum(1 for w in COMPLEXITY_KEYWORDS if w in total_text)
|
| return total_len * 0.3 + len(out) * 0.5 + has_input * 500 + kw_count * 200
|
|
|
|
|
| def categorize(instruction):
|
| inst = instruction.lower()
|
| if any(w in inst for w in ["code", "python", "java", "swift", "function", "program", "algorithm", "script", "sql", "html", "css"]):
|
| return "coding"
|
| if any(w in inst for w in ["who are you", "your name", "who created", "asterizer", "luna", "are you an ai"]):
|
| return "identity"
|
| if any(w in inst for w in ["explain", "what is", "define", "describe", "meaning of"]):
|
| return "explanation"
|
| if any(w in inst for w in ["analyze", "compare", "evaluate", "assess", "critique"]):
|
| return "analysis"
|
| if any(w in inst for w in ["write", "create", "generate", "compose", "draft", "poem", "story", "essay"]):
|
| return "creative"
|
| if any(w in inst for w in ["how", "step", "guide", "method", "procedure", "tutorial"]):
|
| return "how-to"
|
| return "other"
|
|
|
|
|
| def select_complex_examples(data, n=100):
|
| scored = [(complexity_score(entry), i) for i, entry in enumerate(data)]
|
| scored.sort(reverse=True)
|
| return [data[idx] for _, idx in scored[:n]]
|
|
|
|
|
|
|
|
|
| def compute_bleu(reference, hypothesis, max_n=2):
|
| """Simple BLEU-1 and BLEU-2 (word-level, no brevity penalty)."""
|
| ref_tokens = reference.lower().split()
|
| hyp_tokens = hypothesis.lower().split()
|
| if not hyp_tokens or not ref_tokens:
|
| return {f"bleu_{n}": 0.0 for n in range(1, max_n + 1)}
|
| scores = {}
|
| for n in range(1, max_n + 1):
|
| ref_ngrams = Counter()
|
| for i in range(len(ref_tokens) - n + 1):
|
| ref_ngrams[tuple(ref_tokens[i:i + n])] += 1
|
| hyp_ngrams = Counter()
|
| for i in range(len(hyp_tokens) - n + 1):
|
| hyp_ngrams[tuple(hyp_tokens[i:i + n])] += 1
|
| clipped = sum(min(hyp_ngrams[ng], ref_ngrams[ng]) for ng in hyp_ngrams)
|
| total = max(sum(hyp_ngrams.values()), 1)
|
| scores[f"bleu_{n}"] = clipped / total
|
| return scores
|
|
|
|
|
| def repetition_ratio(text):
|
| """Fraction of repeated trigrams in the text (higher = more degenerate)."""
|
| words = text.lower().split()
|
| if len(words) < 4:
|
| return 0.0
|
| trigrams = [tuple(words[i:i + 3]) for i in range(len(words) - 2)]
|
| if not trigrams:
|
| return 0.0
|
| unique = len(set(trigrams))
|
| return 1.0 - (unique / len(trigrams))
|
|
|
|
|
| @torch.no_grad()
|
| def compute_loss_and_accuracy(model, tokenizer, entry, max_len, device):
|
| """Compute prompt-masked CE loss & token accuracy for one example."""
|
| prompt = format_prompt(entry.get("instruction", ""), entry.get("input", ""))
|
| response = entry.get("output", "").strip()
|
|
|
| prompt_ids = tokenizer.encode(prompt)
|
| response_ids = tokenizer.encode(response) + [tokenizer.eos_token_id or 0]
|
| total_ids = prompt_ids + response_ids
|
|
|
| if len(total_ids) > max_len:
|
| total_ids = total_ids[:max_len]
|
| total_ids[-1] = tokenizer.eos_token_id or 0
|
| prompt_len = min(len(prompt_ids), max_len)
|
| else:
|
| prompt_len = len(prompt_ids)
|
|
|
| input_tensor = torch.tensor([total_ids], dtype=torch.long, device=device)
|
| logits = model(input_tensor)
|
|
|
|
|
| shift_logits = logits[:, :-1, :].contiguous()
|
| shift_targets = input_tensor[:, 1:].contiguous()
|
|
|
|
|
| mask = torch.zeros(shift_targets.shape, dtype=torch.float, device=device)
|
| resp_start = max(prompt_len - 1, 0)
|
| resp_end = len(total_ids) - 1
|
| mask[0, resp_start:resp_end] = 1.0
|
|
|
| if mask.sum() == 0:
|
| return float("inf"), 0.0, 0
|
|
|
| per_token_loss = F.cross_entropy(
|
| shift_logits.view(-1, shift_logits.size(-1)),
|
| shift_targets.view(-1),
|
| reduction="none"
|
| ).view(shift_targets.shape)
|
|
|
| masked_loss = (per_token_loss * mask).sum() / mask.sum()
|
|
|
|
|
| preds = shift_logits.argmax(dim=-1)
|
| correct = ((preds == shift_targets).float() * mask).sum()
|
| total_resp = mask.sum()
|
| accuracy = (correct / total_resp).item() if total_resp > 0 else 0.0
|
|
|
| return masked_loss.item(), accuracy, int(total_resp.item())
|
|
|
|
|
|
|
|
|
| def main():
|
| parser = argparse.ArgumentParser(description="LUNA SFT β Complex Example Validation")
|
| parser.add_argument("--ckpt", default=r"D:\ASTERIZER 2026\LUNA\Base\out\sft\model.pth")
|
| parser.add_argument("--tok_dir", default="Base/checkpoints/EleutherAI/pythia-160m")
|
| parser.add_argument("--val_json", default="Base/Datasets/sft_clean/val.json")
|
| parser.add_argument("--n_examples", type=int, default=100)
|
| parser.add_argument("--max_len", type=int, default=1024)
|
| parser.add_argument("--max_new", type=int, default=150)
|
| parser.add_argument("--temperature", type=float, default=0.7)
|
| parser.add_argument("--top_k", type=int, default=40)
|
| parser.add_argument("--top_p", type=float, default=0.9)
|
| parser.add_argument("--rep_pen", type=float, default=1.0)
|
| parser.add_argument("--device", default="auto")
|
| parser.add_argument("--out_dir", default="Base/out/sft/validation_report_v2")
|
| args = parser.parse_args()
|
|
|
| device = "cuda" if args.device == "auto" and torch.cuda.is_available() else args.device
|
| if device == "auto":
|
| device = "cpu"
|
|
|
| out_dir = Path(args.out_dir)
|
| out_dir.mkdir(parents=True, exist_ok=True)
|
|
|
| sep = "=" * 72
|
| print(f"\n{sep}")
|
| print(" LUNA 100M β SFT VALIDATION (Complex Examples)")
|
| print(sep)
|
| print(f" Checkpoint : {args.ckpt}")
|
| print(f" Val data : {args.val_json}")
|
| print(f" N examples : {args.n_examples}")
|
| print(f" Device : {device}")
|
| print(f" Max seq : {args.max_len}")
|
| print(f" Temperature: {args.temperature}")
|
| print(f" Top-k : {args.top_k}")
|
| print(f" Top-p : {args.top_p}")
|
| print(f" Rep penalty: {args.rep_pen}")
|
| print(sep)
|
|
|
|
|
| print("\n[1/5] Loading model...")
|
| t0 = time.time()
|
| state_dict = torch.load(args.ckpt, map_location="cpu", weights_only=True)
|
| if isinstance(state_dict, dict) and "model" in state_dict:
|
| state_dict = state_dict["model"]
|
| model = LUNAModel()
|
| model.load_state_dict(state_dict, strict=True)
|
| model = model.to(device).eval()
|
| n_params = sum(p.numel() for p in model.parameters())
|
| print(f" Model loaded: {n_params:,} params ({time.time()-t0:.1f}s)")
|
|
|
|
|
| print("[2/5] Loading tokenizer...")
|
| from transformers import AutoTokenizer
|
| tokenizer = AutoTokenizer.from_pretrained(args.tok_dir)
|
| print(f" Tokenizer: vocab_size={tokenizer.vocab_size}")
|
|
|
|
|
| print(f"[3/5] Selecting {args.n_examples} most complex examples from val set...")
|
| with open(args.val_json, "r", encoding="utf-8") as f:
|
| all_val = json.load(f)
|
| print(f" Total val samples: {len(all_val)}")
|
| examples = select_complex_examples(all_val, args.n_examples)
|
| print(f" Selected: {len(examples)} complex examples")
|
|
|
|
|
| cat_counts = Counter(categorize(e["instruction"]) for e in examples)
|
| print(f" Categories: {dict(cat_counts)}")
|
|
|
|
|
| print(f"\n[4/5] Running validation ({len(examples)} examples)...")
|
| print("-" * 72)
|
|
|
| results = []
|
| cat_metrics = defaultdict(lambda: {"losses": [], "perplexities": [],
|
| "accuracies": [], "bleu1": [],
|
| "bleu2": [], "rep_ratios": [],
|
| "gen_lens": []})
|
|
|
| for i, entry in enumerate(examples):
|
| inst = entry.get("instruction", "")
|
| inp = entry.get("input", "")
|
| ref_output = entry.get("output", "")
|
| category = categorize(inst)
|
|
|
|
|
| loss, tok_acc, n_resp_tokens = compute_loss_and_accuracy(
|
| model, tokenizer, entry, args.max_len, device
|
| )
|
| ppl = math.exp(min(loss, 20))
|
|
|
|
|
| prompt = format_prompt(inst, inp)
|
| prompt_ids = tokenizer.encode(prompt, return_tensors="pt")
|
| gen_tokens = generate(
|
| model, prompt_ids,
|
| max_new=args.max_new,
|
| temperature=args.temperature,
|
| top_k=args.top_k,
|
| top_p=args.top_p,
|
| repetition_penalty=args.rep_pen,
|
| device=device,
|
| )
|
| gen_text = tokenizer.decode(gen_tokens, skip_special_tokens=True).strip()
|
|
|
| if "### " in gen_text:
|
| gen_text = gen_text.split("### ")[0].strip()
|
|
|
|
|
| bleu = compute_bleu(ref_output, gen_text)
|
| rep = repetition_ratio(gen_text)
|
| gen_words = len(gen_text.split())
|
|
|
|
|
| is_empty = len(gen_text.strip()) < 5
|
| is_repetitive = rep > 0.5
|
| is_truncated = len(gen_tokens) >= args.max_new
|
|
|
| result = {
|
| "index": i,
|
| "category": category,
|
| "instruction": inst[:200],
|
| "input_preview": inp[:100] if inp else "",
|
| "reference_preview": ref_output[:200],
|
| "generated_preview": gen_text[:300],
|
| "loss": round(loss, 4),
|
| "perplexity": round(ppl, 2),
|
| "token_accuracy": round(tok_acc, 4),
|
| "bleu_1": round(bleu["bleu_1"], 4),
|
| "bleu_2": round(bleu["bleu_2"], 4),
|
| "repetition_ratio": round(rep, 4),
|
| "generated_words": gen_words,
|
| "resp_tokens": n_resp_tokens,
|
| "is_empty": is_empty,
|
| "is_repetitive": is_repetitive,
|
| "is_truncated": is_truncated,
|
| }
|
| results.append(result)
|
|
|
|
|
| cat_metrics[category]["losses"].append(loss)
|
| cat_metrics[category]["perplexities"].append(ppl)
|
| cat_metrics[category]["accuracies"].append(tok_acc)
|
| cat_metrics[category]["bleu1"].append(bleu["bleu_1"])
|
| cat_metrics[category]["bleu2"].append(bleu["bleu_2"])
|
| cat_metrics[category]["rep_ratios"].append(rep)
|
| cat_metrics[category]["gen_lens"].append(gen_words)
|
|
|
|
|
| status = ""
|
| if is_empty:
|
| status = " [EMPTY]"
|
| elif is_repetitive:
|
| status = " [REPETITIVE]"
|
| elif is_truncated:
|
| status = " [TRUNCATED]"
|
|
|
| if (i + 1) % 5 == 0 or i == 0:
|
| print(f" [{i+1:3d}/{len(examples)}] loss={loss:.3f} ppl={ppl:.1f} "
|
| f"acc={tok_acc:.3f} B1={bleu['bleu_1']:.3f} "
|
| f"rep={rep:.3f} words={gen_words}{status}")
|
|
|
|
|
| print(f"\n[5/5] Generating report...")
|
|
|
| all_losses = [r["loss"] for r in results if r["loss"] < float("inf")]
|
| all_ppls = [r["perplexity"] for r in results if r["perplexity"] < 1e6]
|
| all_accs = [r["token_accuracy"] for r in results]
|
| all_b1 = [r["bleu_1"] for r in results]
|
| all_b2 = [r["bleu_2"] for r in results]
|
| all_reps = [r["repetition_ratio"] for r in results]
|
| all_lens = [r["generated_words"] for r in results]
|
|
|
| n_empty = sum(1 for r in results if r["is_empty"])
|
| n_repetitive = sum(1 for r in results if r["is_repetitive"])
|
| n_truncated = sum(1 for r in results if r["is_truncated"])
|
|
|
| avg = lambda xs: sum(xs) / len(xs) if xs else 0.0
|
|
|
|
|
| report_lines = []
|
| def P(s=""):
|
| report_lines.append(s)
|
|
|
| P(sep)
|
| P(" LUNA 100M β SFT VALIDATION REPORT")
|
| P(f" Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
| P(sep)
|
| P()
|
| P(f" Checkpoint : {args.ckpt}")
|
| P(f" Val source : {args.val_json} ({len(all_val)} total samples)")
|
| P(f" Examples tested: {len(examples)} (top complex by scoring)")
|
| P(f" Device : {device}")
|
| P(f" Max seq len : {args.max_len}")
|
| P(f" Gen temp : {args.temperature}")
|
| P(f" Gen top_k : {args.top_k}")
|
| P(f" Gen top_p : {args.top_p}")
|
| P(f" Gen rep_pen : {args.rep_pen}")
|
| P(f" Gen max tokens: {args.max_new}")
|
| P()
|
|
|
| P("=" * 72)
|
| P(" OVERALL METRICS")
|
| P("=" * 72)
|
| P(f" Avg Loss (CE) : {avg(all_losses):.4f}")
|
| P(f" Avg Perplexity : {avg(all_ppls):.2f}")
|
| P(f" Median Perplexity : {sorted(all_ppls)[len(all_ppls)//2]:.2f}" if all_ppls else " Median Perplexity : N/A")
|
| P(f" Avg Token Accuracy : {avg(all_accs):.4f} ({avg(all_accs)*100:.1f}%)")
|
| P(f" Avg BLEU-1 : {avg(all_b1):.4f}")
|
| P(f" Avg BLEU-2 : {avg(all_b2):.4f}")
|
| P(f" Avg Repetition Ratio : {avg(all_reps):.4f}")
|
| P(f" Avg Gen Length (words): {avg(all_lens):.1f}")
|
| P()
|
| P(f" Empty responses : {n_empty}/{len(results)}")
|
| P(f" Repetitive responses : {n_repetitive}/{len(results)}")
|
| P(f" Truncated responses : {n_truncated}/{len(results)}")
|
| P()
|
|
|
|
|
| grade = "A"
|
| grade_notes = []
|
| if avg(all_ppls) > 50:
|
| grade = "C"
|
| grade_notes.append("high perplexity (>50)")
|
| elif avg(all_ppls) > 20:
|
| grade = "B"
|
| grade_notes.append("moderate perplexity (>20)")
|
| if n_empty > 10:
|
| grade = "D" if grade > "C" else grade
|
| grade_notes.append(f"{n_empty} empty responses")
|
| elif n_empty > 3:
|
| grade = max(grade, "C")
|
| grade_notes.append(f"{n_empty} empty responses")
|
| if n_repetitive > 15:
|
| grade = max(grade, "C")
|
| grade_notes.append(f"{n_repetitive} repetitive responses")
|
| if avg(all_b1) < 0.05:
|
| grade = max(grade, "C")
|
| grade_notes.append("very low BLEU-1")
|
| if avg(all_accs) > 0.4:
|
| grade_notes.append("strong token accuracy")
|
| if avg(all_accs) > 0.5:
|
| if grade == "B":
|
| grade = "A-"
|
|
|
| P(f" OVERALL GRADE: {grade}")
|
| if grade_notes:
|
| P(f" Notes: {'; '.join(grade_notes)}")
|
| P()
|
|
|
| P("=" * 72)
|
| P(" CATEGORY BREAKDOWN")
|
| P("=" * 72)
|
| P(f" {'Category':<14} {'Count':>5} {'Avg Loss':>9} {'Avg PPL':>9} "
|
| f"{'Avg Acc':>8} {'BLEU-1':>7} {'BLEU-2':>7} {'Rep %':>6}")
|
| P(" " + "-" * 68)
|
| for cat in sorted(cat_metrics.keys()):
|
| m = cat_metrics[cat]
|
| cnt = len(m["losses"])
|
| P(f" {cat:<14} {cnt:>5} {avg(m['losses']):>9.4f} "
|
| f"{avg(m['perplexities']):>9.2f} {avg(m['accuracies']):>8.4f} "
|
| f"{avg(m['bleu1']):>7.4f} {avg(m['bleu2']):>7.4f} "
|
| f"{avg(m['rep_ratios'])*100:>5.1f}%")
|
| P()
|
|
|
|
|
| P("=" * 72)
|
| P(" TOP 5 BEST (lowest perplexity)")
|
| P("=" * 72)
|
| by_ppl = sorted(results, key=lambda r: r["perplexity"])
|
| for r in by_ppl[:5]:
|
| P(f" [{r['index']:3d}] PPL={r['perplexity']:>8.2f} Acc={r['token_accuracy']:.3f} "
|
| f"B1={r['bleu_1']:.3f} [{r['category']}]")
|
| P(f" Q: {r['instruction'][:80]}")
|
| P(f" A: {r['generated_preview'][:100]}")
|
| P()
|
|
|
| P("=" * 72)
|
| P(" TOP 5 WORST (highest perplexity)")
|
| P("=" * 72)
|
| for r in by_ppl[-5:]:
|
| P(f" [{r['index']:3d}] PPL={r['perplexity']:>8.2f} Acc={r['token_accuracy']:.3f} "
|
| f"B1={r['bleu_1']:.3f} [{r['category']}]")
|
| P(f" Q: {r['instruction'][:80]}")
|
| P(f" A: {r['generated_preview'][:100]}")
|
| P()
|
|
|
|
|
| failures = [r for r in results if r["is_empty"] or r["is_repetitive"]]
|
| if failures:
|
| P("=" * 72)
|
| P(f" FAILURE ANALYSIS ({len(failures)} problematic responses)")
|
| P("=" * 72)
|
| for r in failures[:10]:
|
| flags = []
|
| if r["is_empty"]:
|
| flags.append("EMPTY")
|
| if r["is_repetitive"]:
|
| flags.append("REPETITIVE")
|
| P(f" [{r['index']:3d}] {' | '.join(flags)} [{r['category']}]")
|
| P(f" Q: {r['instruction'][:80]}")
|
| P(f" A: {r['generated_preview'][:120]}")
|
| P()
|
|
|
|
|
| P("=" * 72)
|
| P(" PERPLEXITY DISTRIBUTION")
|
| P("=" * 72)
|
| buckets = [(0, 5), (5, 10), (10, 20), (20, 50), (50, 100),
|
| (100, 500), (500, float("inf"))]
|
| for lo, hi in buckets:
|
| cnt = sum(1 for p in all_ppls if lo <= p < hi)
|
| bar = "#" * cnt
|
| label = f"{lo}-{hi}" if hi != float("inf") else f"{lo}+"
|
| P(f" {label:>8}: {cnt:>3} {bar}")
|
| P()
|
|
|
|
|
| P("=" * 72)
|
| P(" SAMPLE GENERATIONS (10 diverse examples)")
|
| P("=" * 72)
|
|
|
| sample_indices = list(range(0, len(results), max(1, len(results) // 10)))[:10]
|
| for si in sample_indices:
|
| r = results[si]
|
| P(f"\n --- Example {r['index']+1} [{r['category']}] ---")
|
| P(f" Instruction: {r['instruction'][:150]}")
|
| if r["input_preview"]:
|
| P(f" Input: {r['input_preview'][:100]}")
|
| P(f" Reference: {r['reference_preview'][:200]}")
|
| P(f" Generated: {r['generated_preview'][:300]}")
|
| P(f" Loss={r['loss']:.4f} PPL={r['perplexity']:.2f} "
|
| f"Acc={r['token_accuracy']:.4f} BLEU-1={r['bleu_1']:.4f} "
|
| f"Rep={r['repetition_ratio']:.4f}")
|
| P()
|
|
|
| P(sep)
|
| P(" END OF REPORT")
|
| P(sep)
|
|
|
| report_text = "\n".join(report_lines)
|
|
|
|
|
| print(report_text)
|
|
|
|
|
| report_path = out_dir / "SFT_VALIDATION_REPORT.txt"
|
| with open(report_path, "w", encoding="utf-8") as f:
|
| f.write(report_text)
|
| print(f"\n Report saved: {report_path}")
|
|
|
|
|
| json_path = out_dir / "validation_results.json"
|
| summary = {
|
| "meta": {
|
| "checkpoint": args.ckpt,
|
| "val_source": args.val_json,
|
| "total_val_samples": len(all_val),
|
| "n_tested": len(examples),
|
| "device": device,
|
| "max_len": args.max_len,
|
| "temperature": args.temperature,
|
| "top_k": args.top_k,
|
| "top_p": args.top_p,
|
| "repetition_penalty": args.rep_pen,
|
| "timestamp": datetime.now().isoformat(),
|
| },
|
| "overall": {
|
| "avg_loss": round(avg(all_losses), 4),
|
| "avg_perplexity": round(avg(all_ppls), 2),
|
| "median_perplexity": round(sorted(all_ppls)[len(all_ppls)//2], 2) if all_ppls else None,
|
| "avg_token_accuracy": round(avg(all_accs), 4),
|
| "avg_bleu_1": round(avg(all_b1), 4),
|
| "avg_bleu_2": round(avg(all_b2), 4),
|
| "avg_repetition_ratio": round(avg(all_reps), 4),
|
| "avg_gen_length_words": round(avg(all_lens), 1),
|
| "n_empty": n_empty,
|
| "n_repetitive": n_repetitive,
|
| "n_truncated": n_truncated,
|
| "grade": grade,
|
| },
|
| "category_breakdown": {
|
| cat: {
|
| "count": len(m["losses"]),
|
| "avg_loss": round(avg(m["losses"]), 4),
|
| "avg_perplexity": round(avg(m["perplexities"]), 2),
|
| "avg_token_accuracy": round(avg(m["accuracies"]), 4),
|
| "avg_bleu_1": round(avg(m["bleu1"]), 4),
|
| "avg_bleu_2": round(avg(m["bleu2"]), 4),
|
| }
|
| for cat, m in cat_metrics.items()
|
| },
|
| "per_example": results,
|
| }
|
| with open(json_path, "w", encoding="utf-8") as f:
|
| json.dump(summary, f, indent=2, ensure_ascii=False)
|
| print(f" JSON results: {json_path}")
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|