""" LUNA 100M — SFT Validation on Complex Examples ================================================ Selects ~100 complex examples from the SFT validation set, runs the fine-tuned model on each, and produces a detailed report. Metrics computed: - Per-sample cross-entropy loss (prompt-masked) & perplexity - Token-level accuracy on the output portion - BLEU-1/2 (word overlap with reference output) - Repetition ratio (degeneration detection) - Response length stats - Category breakdown (coding, explanation, analysis, creative, how-to, identity) - Overall pass/fail grading Usage: python validate_sft.py python validate_sft.py --ckpt "Base/out/sft/model.pth" --val_json "Base/Datasets/sft_clean/val.json" """ import os, sys, json, math, time, argparse, re from pathlib import Path from collections import Counter, defaultdict from datetime import datetime import torch import torch.nn as nn import torch.nn.functional as F # ─── Model (identical to sft_train.py / chat.py) ───────────────────────────── class RotaryEmbedding(nn.Module): def __init__(self, dim, max_seq_len=1024): super().__init__() inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) t = torch.arange(max_seq_len).float() freqs = torch.einsum("i,j->ij", t, inv_freq) emb = torch.cat([freqs, freqs], dim=-1) self.register_buffer("cos_cached", emb.cos()) self.register_buffer("sin_cached", emb.sin()) def forward(self, seq_len): return self.cos_cached[:seq_len], self.sin_cached[:seq_len] def rotate_half(x): x1, x2 = x.chunk(2, dim=-1) return torch.cat([-x2, x1], dim=-1) def apply_rotary(x, cos, sin): c = cos.unsqueeze(0).unsqueeze(0) s = sin.unsqueeze(0).unsqueeze(0) return x * c + rotate_half(x) * s class CausalSelfAttention(nn.Module): def __init__(self, n_embd, n_head, block_size, rotary_pct=0.25): super().__init__() self.n_head = n_head self.head_dim = n_embd // n_head self.rot_dim = int(self.head_dim * rotary_pct) self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=True) self.c_proj = nn.Linear(n_embd, n_embd, bias=True) self.rotary = RotaryEmbedding(self.rot_dim, block_size) def forward(self, x): B, T, C = x.size() qkv = self.c_attn(x).reshape(B, T, 3, self.n_head, self.head_dim).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) cos, sin = self.rotary(T) q = torch.cat([apply_rotary(q[..., :self.rot_dim], cos, sin), q[..., self.rot_dim:]], dim=-1) k = torch.cat([apply_rotary(k[..., :self.rot_dim], cos, sin), k[..., self.rot_dim:]], dim=-1) y = F.scaled_dot_product_attention(q, k, v, is_causal=True) return self.c_proj(y.transpose(1, 2).contiguous().view(B, T, C)) class MLP(nn.Module): def __init__(self, n_embd): super().__init__() self.fc = nn.Linear(n_embd, 4 * n_embd, bias=True) self.gelu = nn.GELU() self.proj = nn.Linear(4 * n_embd, n_embd, bias=True) def forward(self, x): return self.proj(self.gelu(self.fc(x))) class Block(nn.Module): def __init__(self, n_embd, n_head, block_size): super().__init__() self.ln1 = nn.LayerNorm(n_embd) self.attn = CausalSelfAttention(n_embd, n_head, block_size) self.ln2 = nn.LayerNorm(n_embd) self.mlp = MLP(n_embd) def forward(self, x): x = x + self.attn(self.ln1(x)) x = x + self.mlp(self.ln2(x)) return x class LUNAModel(nn.Module): def __init__(self, vocab_size=50304, block_size=1024, n_layer=10, n_embd=768, n_head=12): super().__init__() self.block_size = block_size self.wte = nn.Embedding(vocab_size, n_embd) self.blocks = nn.ModuleList( [Block(n_embd, n_head, block_size) for _ in range(n_layer)] ) self.ln_f = nn.LayerNorm(n_embd) self.lm_head = nn.Linear(n_embd, vocab_size, bias=False) self.lm_head.weight = self.wte.weight def forward(self, idx): x = self.wte(idx) for block in self.blocks: x = block(x) return self.lm_head(self.ln_f(x)) # ─── Generation ─────────────────────────────────────────────────────────────── @torch.no_grad() def generate(model, input_ids, max_new=150, temperature=0.7, top_p=0.9, top_k=40, repetition_penalty=1.0, device="cpu"): ids = input_ids.to(device) generated = [] for _ in range(max_new): logits = model(ids[:, -model.block_size:])[:, -1, :] if repetition_penalty != 1.0: for tok_id in set(ids[0].tolist()): if logits[0, tok_id] > 0: logits[0, tok_id] /= repetition_penalty else: logits[0, tok_id] *= repetition_penalty if temperature < 1e-6: next_token = logits.argmax(dim=-1, keepdim=True) else: logits = logits / temperature probs = F.softmax(logits, dim=-1) if top_k > 0: kval = min(top_k, probs.size(-1)) topk_vals, _ = torch.topk(probs, kval) probs[probs < topk_vals[:, [-1]]] = 0.0 probs /= probs.sum() if top_p < 1.0: sorted_probs, sorted_idx = torch.sort(probs, descending=True) cumsum = torch.cumsum(sorted_probs, dim=-1) mask = cumsum - sorted_probs > top_p sorted_probs[mask] = 0.0 sorted_probs /= sorted_probs.sum() next_token = sorted_idx[0, torch.multinomial(sorted_probs[0], 1)] else: next_token = torch.multinomial(probs[0], 1) ids = torch.cat([ids, next_token.view(1, 1)], dim=1) generated.append(next_token.item()) if next_token.item() == 0: # EOS token break return generated # ─── Prompt formatting (matches sft_train.py) ──────────────────────────────── def format_prompt(instruction, inp=""): inst = instruction.strip() inp = inp.strip() if inst and inp: return f"### Instruction:\n{inst}\n\n### Input:\n{inp}\n\n### Response:\n" elif inst: return f"### Instruction:\n{inst}\n\n### Response:\n" else: return f"### Input:\n{inp}\n\n### Response:\n" # ─── Complexity scoring & selection ─────────────────────────────────────────── COMPLEXITY_KEYWORDS = [ "step", "first", "second", "then", "next", "finally", "because", "however", "therefore", "explain", "analyze", "compare", "evaluate", "describe", "discuss", "provide", "example", "detail", "elaborate", "summarize", ] def complexity_score(entry): inst = entry.get("instruction", "") inp = entry.get("input", "") out = entry.get("output", "") total_text = (inst + " " + inp + " " + out).lower() total_len = len(inst) + len(inp) + len(out) has_input = 1 if len(inp) > 20 else 0 kw_count = sum(1 for w in COMPLEXITY_KEYWORDS if w in total_text) return total_len * 0.3 + len(out) * 0.5 + has_input * 500 + kw_count * 200 def categorize(instruction): inst = instruction.lower() if any(w in inst for w in ["code", "python", "java", "swift", "function", "program", "algorithm", "script", "sql", "html", "css"]): return "coding" if any(w in inst for w in ["who are you", "your name", "who created", "asterizer", "luna", "are you an ai"]): return "identity" if any(w in inst for w in ["explain", "what is", "define", "describe", "meaning of"]): return "explanation" if any(w in inst for w in ["analyze", "compare", "evaluate", "assess", "critique"]): return "analysis" if any(w in inst for w in ["write", "create", "generate", "compose", "draft", "poem", "story", "essay"]): return "creative" if any(w in inst for w in ["how", "step", "guide", "method", "procedure", "tutorial"]): return "how-to" return "other" def select_complex_examples(data, n=100): scored = [(complexity_score(entry), i) for i, entry in enumerate(data)] scored.sort(reverse=True) return [data[idx] for _, idx in scored[:n]] # ─── Metrics ────────────────────────────────────────────────────────────────── def compute_bleu(reference, hypothesis, max_n=2): """Simple BLEU-1 and BLEU-2 (word-level, no brevity penalty).""" ref_tokens = reference.lower().split() hyp_tokens = hypothesis.lower().split() if not hyp_tokens or not ref_tokens: return {f"bleu_{n}": 0.0 for n in range(1, max_n + 1)} scores = {} for n in range(1, max_n + 1): ref_ngrams = Counter() for i in range(len(ref_tokens) - n + 1): ref_ngrams[tuple(ref_tokens[i:i + n])] += 1 hyp_ngrams = Counter() for i in range(len(hyp_tokens) - n + 1): hyp_ngrams[tuple(hyp_tokens[i:i + n])] += 1 clipped = sum(min(hyp_ngrams[ng], ref_ngrams[ng]) for ng in hyp_ngrams) total = max(sum(hyp_ngrams.values()), 1) scores[f"bleu_{n}"] = clipped / total return scores def repetition_ratio(text): """Fraction of repeated trigrams in the text (higher = more degenerate).""" words = text.lower().split() if len(words) < 4: return 0.0 trigrams = [tuple(words[i:i + 3]) for i in range(len(words) - 2)] if not trigrams: return 0.0 unique = len(set(trigrams)) return 1.0 - (unique / len(trigrams)) @torch.no_grad() def compute_loss_and_accuracy(model, tokenizer, entry, max_len, device): """Compute prompt-masked CE loss & token accuracy for one example.""" prompt = format_prompt(entry.get("instruction", ""), entry.get("input", "")) response = entry.get("output", "").strip() prompt_ids = tokenizer.encode(prompt) response_ids = tokenizer.encode(response) + [tokenizer.eos_token_id or 0] total_ids = prompt_ids + response_ids if len(total_ids) > max_len: total_ids = total_ids[:max_len] total_ids[-1] = tokenizer.eos_token_id or 0 prompt_len = min(len(prompt_ids), max_len) else: prompt_len = len(prompt_ids) input_tensor = torch.tensor([total_ids], dtype=torch.long, device=device) logits = model(input_tensor) # (1, T, V) # Shift for next-token prediction shift_logits = logits[:, :-1, :].contiguous() shift_targets = input_tensor[:, 1:].contiguous() # Build mask: only on response portion mask = torch.zeros(shift_targets.shape, dtype=torch.float, device=device) resp_start = max(prompt_len - 1, 0) resp_end = len(total_ids) - 1 mask[0, resp_start:resp_end] = 1.0 if mask.sum() == 0: return float("inf"), 0.0, 0 per_token_loss = F.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_targets.view(-1), reduction="none" ).view(shift_targets.shape) masked_loss = (per_token_loss * mask).sum() / mask.sum() # Token accuracy on response portion preds = shift_logits.argmax(dim=-1) correct = ((preds == shift_targets).float() * mask).sum() total_resp = mask.sum() accuracy = (correct / total_resp).item() if total_resp > 0 else 0.0 return masked_loss.item(), accuracy, int(total_resp.item()) # ─── Main validation ───────────────────────────────────────────────────────── def main(): parser = argparse.ArgumentParser(description="LUNA SFT — Complex Example Validation") parser.add_argument("--ckpt", default=r"D:\ASTERIZER 2026\LUNA\Base\out\sft\model.pth") parser.add_argument("--tok_dir", default="Base/checkpoints/EleutherAI/pythia-160m") parser.add_argument("--val_json", default="Base/Datasets/sft_clean/val.json") parser.add_argument("--n_examples", type=int, default=100) parser.add_argument("--max_len", type=int, default=1024) parser.add_argument("--max_new", type=int, default=150) parser.add_argument("--temperature", type=float, default=0.7) parser.add_argument("--top_k", type=int, default=40) parser.add_argument("--top_p", type=float, default=0.9) parser.add_argument("--rep_pen", type=float, default=1.0) parser.add_argument("--device", default="auto") parser.add_argument("--out_dir", default="Base/out/sft/validation_report_v2") args = parser.parse_args() device = "cuda" if args.device == "auto" and torch.cuda.is_available() else args.device if device == "auto": device = "cpu" out_dir = Path(args.out_dir) out_dir.mkdir(parents=True, exist_ok=True) sep = "=" * 72 print(f"\n{sep}") print(" LUNA 100M — SFT VALIDATION (Complex Examples)") print(sep) print(f" Checkpoint : {args.ckpt}") print(f" Val data : {args.val_json}") print(f" N examples : {args.n_examples}") print(f" Device : {device}") print(f" Max seq : {args.max_len}") print(f" Temperature: {args.temperature}") print(f" Top-k : {args.top_k}") print(f" Top-p : {args.top_p}") print(f" Rep penalty: {args.rep_pen}") print(sep) # ── Load model ──────────────────────────────────────────────────────────── print("\n[1/5] Loading model...") t0 = time.time() state_dict = torch.load(args.ckpt, map_location="cpu", weights_only=True) if isinstance(state_dict, dict) and "model" in state_dict: state_dict = state_dict["model"] model = LUNAModel() model.load_state_dict(state_dict, strict=True) model = model.to(device).eval() n_params = sum(p.numel() for p in model.parameters()) print(f" Model loaded: {n_params:,} params ({time.time()-t0:.1f}s)") # ── Load tokenizer ──────────────────────────────────────────────────────── print("[2/5] Loading tokenizer...") from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(args.tok_dir) print(f" Tokenizer: vocab_size={tokenizer.vocab_size}") # ── Select complex examples ─────────────────────────────────────────────── print(f"[3/5] Selecting {args.n_examples} most complex examples from val set...") with open(args.val_json, "r", encoding="utf-8") as f: all_val = json.load(f) print(f" Total val samples: {len(all_val)}") examples = select_complex_examples(all_val, args.n_examples) print(f" Selected: {len(examples)} complex examples") # Category breakdown cat_counts = Counter(categorize(e["instruction"]) for e in examples) print(f" Categories: {dict(cat_counts)}") # ── Run validation ──────────────────────────────────────────────────────── print(f"\n[4/5] Running validation ({len(examples)} examples)...") print("-" * 72) results = [] cat_metrics = defaultdict(lambda: {"losses": [], "perplexities": [], "accuracies": [], "bleu1": [], "bleu2": [], "rep_ratios": [], "gen_lens": []}) for i, entry in enumerate(examples): inst = entry.get("instruction", "") inp = entry.get("input", "") ref_output = entry.get("output", "") category = categorize(inst) # 1) Compute loss & accuracy (teacher-forced) loss, tok_acc, n_resp_tokens = compute_loss_and_accuracy( model, tokenizer, entry, args.max_len, device ) ppl = math.exp(min(loss, 20)) # cap to avoid overflow # 2) Generate response (autoregressive) prompt = format_prompt(inst, inp) prompt_ids = tokenizer.encode(prompt, return_tensors="pt") gen_tokens = generate( model, prompt_ids, max_new=args.max_new, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, repetition_penalty=args.rep_pen, device=device, ) gen_text = tokenizer.decode(gen_tokens, skip_special_tokens=True).strip() # Clean trailing template markers if "### " in gen_text: gen_text = gen_text.split("### ")[0].strip() # 3) Compute text metrics bleu = compute_bleu(ref_output, gen_text) rep = repetition_ratio(gen_text) gen_words = len(gen_text.split()) # 4) Quality flags is_empty = len(gen_text.strip()) < 5 is_repetitive = rep > 0.5 is_truncated = len(gen_tokens) >= args.max_new result = { "index": i, "category": category, "instruction": inst[:200], "input_preview": inp[:100] if inp else "", "reference_preview": ref_output[:200], "generated_preview": gen_text[:300], "loss": round(loss, 4), "perplexity": round(ppl, 2), "token_accuracy": round(tok_acc, 4), "bleu_1": round(bleu["bleu_1"], 4), "bleu_2": round(bleu["bleu_2"], 4), "repetition_ratio": round(rep, 4), "generated_words": gen_words, "resp_tokens": n_resp_tokens, "is_empty": is_empty, "is_repetitive": is_repetitive, "is_truncated": is_truncated, } results.append(result) # Accumulate per-category cat_metrics[category]["losses"].append(loss) cat_metrics[category]["perplexities"].append(ppl) cat_metrics[category]["accuracies"].append(tok_acc) cat_metrics[category]["bleu1"].append(bleu["bleu_1"]) cat_metrics[category]["bleu2"].append(bleu["bleu_2"]) cat_metrics[category]["rep_ratios"].append(rep) cat_metrics[category]["gen_lens"].append(gen_words) # Progress status = "" if is_empty: status = " [EMPTY]" elif is_repetitive: status = " [REPETITIVE]" elif is_truncated: status = " [TRUNCATED]" if (i + 1) % 5 == 0 or i == 0: print(f" [{i+1:3d}/{len(examples)}] loss={loss:.3f} ppl={ppl:.1f} " f"acc={tok_acc:.3f} B1={bleu['bleu_1']:.3f} " f"rep={rep:.3f} words={gen_words}{status}") # ── Aggregate & Report ──────────────────────────────────────────────────── print(f"\n[5/5] Generating report...") all_losses = [r["loss"] for r in results if r["loss"] < float("inf")] all_ppls = [r["perplexity"] for r in results if r["perplexity"] < 1e6] all_accs = [r["token_accuracy"] for r in results] all_b1 = [r["bleu_1"] for r in results] all_b2 = [r["bleu_2"] for r in results] all_reps = [r["repetition_ratio"] for r in results] all_lens = [r["generated_words"] for r in results] n_empty = sum(1 for r in results if r["is_empty"]) n_repetitive = sum(1 for r in results if r["is_repetitive"]) n_truncated = sum(1 for r in results if r["is_truncated"]) avg = lambda xs: sum(xs) / len(xs) if xs else 0.0 # ── Build report text ───────────────────────────────────────────────────── report_lines = [] def P(s=""): report_lines.append(s) P(sep) P(" LUNA 100M — SFT VALIDATION REPORT") P(f" Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") P(sep) P() P(f" Checkpoint : {args.ckpt}") P(f" Val source : {args.val_json} ({len(all_val)} total samples)") P(f" Examples tested: {len(examples)} (top complex by scoring)") P(f" Device : {device}") P(f" Max seq len : {args.max_len}") P(f" Gen temp : {args.temperature}") P(f" Gen top_k : {args.top_k}") P(f" Gen top_p : {args.top_p}") P(f" Gen rep_pen : {args.rep_pen}") P(f" Gen max tokens: {args.max_new}") P() P("=" * 72) P(" OVERALL METRICS") P("=" * 72) P(f" Avg Loss (CE) : {avg(all_losses):.4f}") P(f" Avg Perplexity : {avg(all_ppls):.2f}") P(f" Median Perplexity : {sorted(all_ppls)[len(all_ppls)//2]:.2f}" if all_ppls else " Median Perplexity : N/A") P(f" Avg Token Accuracy : {avg(all_accs):.4f} ({avg(all_accs)*100:.1f}%)") P(f" Avg BLEU-1 : {avg(all_b1):.4f}") P(f" Avg BLEU-2 : {avg(all_b2):.4f}") P(f" Avg Repetition Ratio : {avg(all_reps):.4f}") P(f" Avg Gen Length (words): {avg(all_lens):.1f}") P() P(f" Empty responses : {n_empty}/{len(results)}") P(f" Repetitive responses : {n_repetitive}/{len(results)}") P(f" Truncated responses : {n_truncated}/{len(results)}") P() # Quality grade grade = "A" grade_notes = [] if avg(all_ppls) > 50: grade = "C" grade_notes.append("high perplexity (>50)") elif avg(all_ppls) > 20: grade = "B" grade_notes.append("moderate perplexity (>20)") if n_empty > 10: grade = "D" if grade > "C" else grade grade_notes.append(f"{n_empty} empty responses") elif n_empty > 3: grade = max(grade, "C") grade_notes.append(f"{n_empty} empty responses") if n_repetitive > 15: grade = max(grade, "C") grade_notes.append(f"{n_repetitive} repetitive responses") if avg(all_b1) < 0.05: grade = max(grade, "C") grade_notes.append("very low BLEU-1") if avg(all_accs) > 0.4: grade_notes.append("strong token accuracy") if avg(all_accs) > 0.5: if grade == "B": grade = "A-" P(f" OVERALL GRADE: {grade}") if grade_notes: P(f" Notes: {'; '.join(grade_notes)}") P() P("=" * 72) P(" CATEGORY BREAKDOWN") P("=" * 72) P(f" {'Category':<14} {'Count':>5} {'Avg Loss':>9} {'Avg PPL':>9} " f"{'Avg Acc':>8} {'BLEU-1':>7} {'BLEU-2':>7} {'Rep %':>6}") P(" " + "-" * 68) for cat in sorted(cat_metrics.keys()): m = cat_metrics[cat] cnt = len(m["losses"]) P(f" {cat:<14} {cnt:>5} {avg(m['losses']):>9.4f} " f"{avg(m['perplexities']):>9.2f} {avg(m['accuracies']):>8.4f} " f"{avg(m['bleu1']):>7.4f} {avg(m['bleu2']):>7.4f} " f"{avg(m['rep_ratios'])*100:>5.1f}%") P() # ── Top 5 Best / Worst ──────────────────────────────────────────────────── P("=" * 72) P(" TOP 5 BEST (lowest perplexity)") P("=" * 72) by_ppl = sorted(results, key=lambda r: r["perplexity"]) for r in by_ppl[:5]: P(f" [{r['index']:3d}] PPL={r['perplexity']:>8.2f} Acc={r['token_accuracy']:.3f} " f"B1={r['bleu_1']:.3f} [{r['category']}]") P(f" Q: {r['instruction'][:80]}") P(f" A: {r['generated_preview'][:100]}") P() P("=" * 72) P(" TOP 5 WORST (highest perplexity)") P("=" * 72) for r in by_ppl[-5:]: P(f" [{r['index']:3d}] PPL={r['perplexity']:>8.2f} Acc={r['token_accuracy']:.3f} " f"B1={r['bleu_1']:.3f} [{r['category']}]") P(f" Q: {r['instruction'][:80]}") P(f" A: {r['generated_preview'][:100]}") P() # ── Failure Analysis ────────────────────────────────────────────────────── failures = [r for r in results if r["is_empty"] or r["is_repetitive"]] if failures: P("=" * 72) P(f" FAILURE ANALYSIS ({len(failures)} problematic responses)") P("=" * 72) for r in failures[:10]: flags = [] if r["is_empty"]: flags.append("EMPTY") if r["is_repetitive"]: flags.append("REPETITIVE") P(f" [{r['index']:3d}] {' | '.join(flags)} [{r['category']}]") P(f" Q: {r['instruction'][:80]}") P(f" A: {r['generated_preview'][:120]}") P() # ── Perplexity distribution ─────────────────────────────────────────────── P("=" * 72) P(" PERPLEXITY DISTRIBUTION") P("=" * 72) buckets = [(0, 5), (5, 10), (10, 20), (20, 50), (50, 100), (100, 500), (500, float("inf"))] for lo, hi in buckets: cnt = sum(1 for p in all_ppls if lo <= p < hi) bar = "#" * cnt label = f"{lo}-{hi}" if hi != float("inf") else f"{lo}+" P(f" {label:>8}: {cnt:>3} {bar}") P() # ── Sample generations (10 diverse examples) ────────────────────────────── P("=" * 72) P(" SAMPLE GENERATIONS (10 diverse examples)") P("=" * 72) # Pick every 10th sample_indices = list(range(0, len(results), max(1, len(results) // 10)))[:10] for si in sample_indices: r = results[si] P(f"\n --- Example {r['index']+1} [{r['category']}] ---") P(f" Instruction: {r['instruction'][:150]}") if r["input_preview"]: P(f" Input: {r['input_preview'][:100]}") P(f" Reference: {r['reference_preview'][:200]}") P(f" Generated: {r['generated_preview'][:300]}") P(f" Loss={r['loss']:.4f} PPL={r['perplexity']:.2f} " f"Acc={r['token_accuracy']:.4f} BLEU-1={r['bleu_1']:.4f} " f"Rep={r['repetition_ratio']:.4f}") P() P(sep) P(" END OF REPORT") P(sep) report_text = "\n".join(report_lines) # Print to console print(report_text) # Save report report_path = out_dir / "SFT_VALIDATION_REPORT.txt" with open(report_path, "w", encoding="utf-8") as f: f.write(report_text) print(f"\n Report saved: {report_path}") # Save detailed JSON results json_path = out_dir / "validation_results.json" summary = { "meta": { "checkpoint": args.ckpt, "val_source": args.val_json, "total_val_samples": len(all_val), "n_tested": len(examples), "device": device, "max_len": args.max_len, "temperature": args.temperature, "top_k": args.top_k, "top_p": args.top_p, "repetition_penalty": args.rep_pen, "timestamp": datetime.now().isoformat(), }, "overall": { "avg_loss": round(avg(all_losses), 4), "avg_perplexity": round(avg(all_ppls), 2), "median_perplexity": round(sorted(all_ppls)[len(all_ppls)//2], 2) if all_ppls else None, "avg_token_accuracy": round(avg(all_accs), 4), "avg_bleu_1": round(avg(all_b1), 4), "avg_bleu_2": round(avg(all_b2), 4), "avg_repetition_ratio": round(avg(all_reps), 4), "avg_gen_length_words": round(avg(all_lens), 1), "n_empty": n_empty, "n_repetitive": n_repetitive, "n_truncated": n_truncated, "grade": grade, }, "category_breakdown": { cat: { "count": len(m["losses"]), "avg_loss": round(avg(m["losses"]), 4), "avg_perplexity": round(avg(m["perplexities"]), 2), "avg_token_accuracy": round(avg(m["accuracies"]), 4), "avg_bleu_1": round(avg(m["bleu1"]), 4), "avg_bleu_2": round(avg(m["bleu2"]), 4), } for cat, m in cat_metrics.items() }, "per_example": results, } with open(json_path, "w", encoding="utf-8") as f: json.dump(summary, f, indent=2, ensure_ascii=False) print(f" JSON results: {json_path}") if __name__ == "__main__": main()