LUNA-Training / validate_sft.py
ASTERIZER's picture
Upload validate_sft.py with huggingface_hub
01e6957 verified
"""
LUNA 100M β€” SFT Validation on Complex Examples
================================================
Selects ~100 complex examples from the SFT validation set,
runs the fine-tuned model on each, and produces a detailed report.
Metrics computed:
- Per-sample cross-entropy loss (prompt-masked) & perplexity
- Token-level accuracy on the output portion
- BLEU-1/2 (word overlap with reference output)
- Repetition ratio (degeneration detection)
- Response length stats
- Category breakdown (coding, explanation, analysis, creative, how-to, identity)
- Overall pass/fail grading
Usage:
python validate_sft.py
python validate_sft.py --ckpt "Base/out/sft/model.pth" --val_json "Base/Datasets/sft_clean/val.json"
"""
import os, sys, json, math, time, argparse, re
from pathlib import Path
from collections import Counter, defaultdict
from datetime import datetime
import torch
import torch.nn as nn
import torch.nn.functional as F
# ─── Model (identical to sft_train.py / chat.py) ─────────────────────────────
class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_seq_len=1024):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
t = torch.arange(max_seq_len).float()
freqs = torch.einsum("i,j->ij", t, inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
self.register_buffer("cos_cached", emb.cos())
self.register_buffer("sin_cached", emb.sin())
def forward(self, seq_len):
return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat([-x2, x1], dim=-1)
def apply_rotary(x, cos, sin):
c = cos.unsqueeze(0).unsqueeze(0)
s = sin.unsqueeze(0).unsqueeze(0)
return x * c + rotate_half(x) * s
class CausalSelfAttention(nn.Module):
def __init__(self, n_embd, n_head, block_size, rotary_pct=0.25):
super().__init__()
self.n_head = n_head
self.head_dim = n_embd // n_head
self.rot_dim = int(self.head_dim * rotary_pct)
self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=True)
self.c_proj = nn.Linear(n_embd, n_embd, bias=True)
self.rotary = RotaryEmbedding(self.rot_dim, block_size)
def forward(self, x):
B, T, C = x.size()
qkv = self.c_attn(x).reshape(B, T, 3, self.n_head, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
cos, sin = self.rotary(T)
q = torch.cat([apply_rotary(q[..., :self.rot_dim], cos, sin), q[..., self.rot_dim:]], dim=-1)
k = torch.cat([apply_rotary(k[..., :self.rot_dim], cos, sin), k[..., self.rot_dim:]], dim=-1)
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
return self.c_proj(y.transpose(1, 2).contiguous().view(B, T, C))
class MLP(nn.Module):
def __init__(self, n_embd):
super().__init__()
self.fc = nn.Linear(n_embd, 4 * n_embd, bias=True)
self.gelu = nn.GELU()
self.proj = nn.Linear(4 * n_embd, n_embd, bias=True)
def forward(self, x):
return self.proj(self.gelu(self.fc(x)))
class Block(nn.Module):
def __init__(self, n_embd, n_head, block_size):
super().__init__()
self.ln1 = nn.LayerNorm(n_embd)
self.attn = CausalSelfAttention(n_embd, n_head, block_size)
self.ln2 = nn.LayerNorm(n_embd)
self.mlp = MLP(n_embd)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return x
class LUNAModel(nn.Module):
def __init__(self, vocab_size=50304, block_size=1024,
n_layer=10, n_embd=768, n_head=12):
super().__init__()
self.block_size = block_size
self.wte = nn.Embedding(vocab_size, n_embd)
self.blocks = nn.ModuleList(
[Block(n_embd, n_head, block_size) for _ in range(n_layer)]
)
self.ln_f = nn.LayerNorm(n_embd)
self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
self.lm_head.weight = self.wte.weight
def forward(self, idx):
x = self.wte(idx)
for block in self.blocks:
x = block(x)
return self.lm_head(self.ln_f(x))
# ─── Generation ───────────────────────────────────────────────────────────────
@torch.no_grad()
def generate(model, input_ids, max_new=150, temperature=0.7,
top_p=0.9, top_k=40, repetition_penalty=1.0, device="cpu"):
ids = input_ids.to(device)
generated = []
for _ in range(max_new):
logits = model(ids[:, -model.block_size:])[:, -1, :]
if repetition_penalty != 1.0:
for tok_id in set(ids[0].tolist()):
if logits[0, tok_id] > 0:
logits[0, tok_id] /= repetition_penalty
else:
logits[0, tok_id] *= repetition_penalty
if temperature < 1e-6:
next_token = logits.argmax(dim=-1, keepdim=True)
else:
logits = logits / temperature
probs = F.softmax(logits, dim=-1)
if top_k > 0:
kval = min(top_k, probs.size(-1))
topk_vals, _ = torch.topk(probs, kval)
probs[probs < topk_vals[:, [-1]]] = 0.0
probs /= probs.sum()
if top_p < 1.0:
sorted_probs, sorted_idx = torch.sort(probs, descending=True)
cumsum = torch.cumsum(sorted_probs, dim=-1)
mask = cumsum - sorted_probs > top_p
sorted_probs[mask] = 0.0
sorted_probs /= sorted_probs.sum()
next_token = sorted_idx[0, torch.multinomial(sorted_probs[0], 1)]
else:
next_token = torch.multinomial(probs[0], 1)
ids = torch.cat([ids, next_token.view(1, 1)], dim=1)
generated.append(next_token.item())
if next_token.item() == 0: # EOS token
break
return generated
# ─── Prompt formatting (matches sft_train.py) ────────────────────────────────
def format_prompt(instruction, inp=""):
inst = instruction.strip()
inp = inp.strip()
if inst and inp:
return f"### Instruction:\n{inst}\n\n### Input:\n{inp}\n\n### Response:\n"
elif inst:
return f"### Instruction:\n{inst}\n\n### Response:\n"
else:
return f"### Input:\n{inp}\n\n### Response:\n"
# ─── Complexity scoring & selection ───────────────────────────────────────────
COMPLEXITY_KEYWORDS = [
"step", "first", "second", "then", "next", "finally",
"because", "however", "therefore", "explain", "analyze",
"compare", "evaluate", "describe", "discuss", "provide",
"example", "detail", "elaborate", "summarize",
]
def complexity_score(entry):
inst = entry.get("instruction", "")
inp = entry.get("input", "")
out = entry.get("output", "")
total_text = (inst + " " + inp + " " + out).lower()
total_len = len(inst) + len(inp) + len(out)
has_input = 1 if len(inp) > 20 else 0
kw_count = sum(1 for w in COMPLEXITY_KEYWORDS if w in total_text)
return total_len * 0.3 + len(out) * 0.5 + has_input * 500 + kw_count * 200
def categorize(instruction):
inst = instruction.lower()
if any(w in inst for w in ["code", "python", "java", "swift", "function", "program", "algorithm", "script", "sql", "html", "css"]):
return "coding"
if any(w in inst for w in ["who are you", "your name", "who created", "asterizer", "luna", "are you an ai"]):
return "identity"
if any(w in inst for w in ["explain", "what is", "define", "describe", "meaning of"]):
return "explanation"
if any(w in inst for w in ["analyze", "compare", "evaluate", "assess", "critique"]):
return "analysis"
if any(w in inst for w in ["write", "create", "generate", "compose", "draft", "poem", "story", "essay"]):
return "creative"
if any(w in inst for w in ["how", "step", "guide", "method", "procedure", "tutorial"]):
return "how-to"
return "other"
def select_complex_examples(data, n=100):
scored = [(complexity_score(entry), i) for i, entry in enumerate(data)]
scored.sort(reverse=True)
return [data[idx] for _, idx in scored[:n]]
# ─── Metrics ──────────────────────────────────────────────────────────────────
def compute_bleu(reference, hypothesis, max_n=2):
"""Simple BLEU-1 and BLEU-2 (word-level, no brevity penalty)."""
ref_tokens = reference.lower().split()
hyp_tokens = hypothesis.lower().split()
if not hyp_tokens or not ref_tokens:
return {f"bleu_{n}": 0.0 for n in range(1, max_n + 1)}
scores = {}
for n in range(1, max_n + 1):
ref_ngrams = Counter()
for i in range(len(ref_tokens) - n + 1):
ref_ngrams[tuple(ref_tokens[i:i + n])] += 1
hyp_ngrams = Counter()
for i in range(len(hyp_tokens) - n + 1):
hyp_ngrams[tuple(hyp_tokens[i:i + n])] += 1
clipped = sum(min(hyp_ngrams[ng], ref_ngrams[ng]) for ng in hyp_ngrams)
total = max(sum(hyp_ngrams.values()), 1)
scores[f"bleu_{n}"] = clipped / total
return scores
def repetition_ratio(text):
"""Fraction of repeated trigrams in the text (higher = more degenerate)."""
words = text.lower().split()
if len(words) < 4:
return 0.0
trigrams = [tuple(words[i:i + 3]) for i in range(len(words) - 2)]
if not trigrams:
return 0.0
unique = len(set(trigrams))
return 1.0 - (unique / len(trigrams))
@torch.no_grad()
def compute_loss_and_accuracy(model, tokenizer, entry, max_len, device):
"""Compute prompt-masked CE loss & token accuracy for one example."""
prompt = format_prompt(entry.get("instruction", ""), entry.get("input", ""))
response = entry.get("output", "").strip()
prompt_ids = tokenizer.encode(prompt)
response_ids = tokenizer.encode(response) + [tokenizer.eos_token_id or 0]
total_ids = prompt_ids + response_ids
if len(total_ids) > max_len:
total_ids = total_ids[:max_len]
total_ids[-1] = tokenizer.eos_token_id or 0
prompt_len = min(len(prompt_ids), max_len)
else:
prompt_len = len(prompt_ids)
input_tensor = torch.tensor([total_ids], dtype=torch.long, device=device)
logits = model(input_tensor) # (1, T, V)
# Shift for next-token prediction
shift_logits = logits[:, :-1, :].contiguous()
shift_targets = input_tensor[:, 1:].contiguous()
# Build mask: only on response portion
mask = torch.zeros(shift_targets.shape, dtype=torch.float, device=device)
resp_start = max(prompt_len - 1, 0)
resp_end = len(total_ids) - 1
mask[0, resp_start:resp_end] = 1.0
if mask.sum() == 0:
return float("inf"), 0.0, 0
per_token_loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_targets.view(-1),
reduction="none"
).view(shift_targets.shape)
masked_loss = (per_token_loss * mask).sum() / mask.sum()
# Token accuracy on response portion
preds = shift_logits.argmax(dim=-1)
correct = ((preds == shift_targets).float() * mask).sum()
total_resp = mask.sum()
accuracy = (correct / total_resp).item() if total_resp > 0 else 0.0
return masked_loss.item(), accuracy, int(total_resp.item())
# ─── Main validation ─────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description="LUNA SFT β€” Complex Example Validation")
parser.add_argument("--ckpt", default=r"D:\ASTERIZER 2026\LUNA\Base\out\sft\model.pth")
parser.add_argument("--tok_dir", default="Base/checkpoints/EleutherAI/pythia-160m")
parser.add_argument("--val_json", default="Base/Datasets/sft_clean/val.json")
parser.add_argument("--n_examples", type=int, default=100)
parser.add_argument("--max_len", type=int, default=1024)
parser.add_argument("--max_new", type=int, default=150)
parser.add_argument("--temperature", type=float, default=0.7)
parser.add_argument("--top_k", type=int, default=40)
parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--rep_pen", type=float, default=1.0)
parser.add_argument("--device", default="auto")
parser.add_argument("--out_dir", default="Base/out/sft/validation_report_v2")
args = parser.parse_args()
device = "cuda" if args.device == "auto" and torch.cuda.is_available() else args.device
if device == "auto":
device = "cpu"
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
sep = "=" * 72
print(f"\n{sep}")
print(" LUNA 100M β€” SFT VALIDATION (Complex Examples)")
print(sep)
print(f" Checkpoint : {args.ckpt}")
print(f" Val data : {args.val_json}")
print(f" N examples : {args.n_examples}")
print(f" Device : {device}")
print(f" Max seq : {args.max_len}")
print(f" Temperature: {args.temperature}")
print(f" Top-k : {args.top_k}")
print(f" Top-p : {args.top_p}")
print(f" Rep penalty: {args.rep_pen}")
print(sep)
# ── Load model ────────────────────────────────────────────────────────────
print("\n[1/5] Loading model...")
t0 = time.time()
state_dict = torch.load(args.ckpt, map_location="cpu", weights_only=True)
if isinstance(state_dict, dict) and "model" in state_dict:
state_dict = state_dict["model"]
model = LUNAModel()
model.load_state_dict(state_dict, strict=True)
model = model.to(device).eval()
n_params = sum(p.numel() for p in model.parameters())
print(f" Model loaded: {n_params:,} params ({time.time()-t0:.1f}s)")
# ── Load tokenizer ────────────────────────────────────────────────────────
print("[2/5] Loading tokenizer...")
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(args.tok_dir)
print(f" Tokenizer: vocab_size={tokenizer.vocab_size}")
# ── Select complex examples ───────────────────────────────────────────────
print(f"[3/5] Selecting {args.n_examples} most complex examples from val set...")
with open(args.val_json, "r", encoding="utf-8") as f:
all_val = json.load(f)
print(f" Total val samples: {len(all_val)}")
examples = select_complex_examples(all_val, args.n_examples)
print(f" Selected: {len(examples)} complex examples")
# Category breakdown
cat_counts = Counter(categorize(e["instruction"]) for e in examples)
print(f" Categories: {dict(cat_counts)}")
# ── Run validation ────────────────────────────────────────────────────────
print(f"\n[4/5] Running validation ({len(examples)} examples)...")
print("-" * 72)
results = []
cat_metrics = defaultdict(lambda: {"losses": [], "perplexities": [],
"accuracies": [], "bleu1": [],
"bleu2": [], "rep_ratios": [],
"gen_lens": []})
for i, entry in enumerate(examples):
inst = entry.get("instruction", "")
inp = entry.get("input", "")
ref_output = entry.get("output", "")
category = categorize(inst)
# 1) Compute loss & accuracy (teacher-forced)
loss, tok_acc, n_resp_tokens = compute_loss_and_accuracy(
model, tokenizer, entry, args.max_len, device
)
ppl = math.exp(min(loss, 20)) # cap to avoid overflow
# 2) Generate response (autoregressive)
prompt = format_prompt(inst, inp)
prompt_ids = tokenizer.encode(prompt, return_tensors="pt")
gen_tokens = generate(
model, prompt_ids,
max_new=args.max_new,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
repetition_penalty=args.rep_pen,
device=device,
)
gen_text = tokenizer.decode(gen_tokens, skip_special_tokens=True).strip()
# Clean trailing template markers
if "### " in gen_text:
gen_text = gen_text.split("### ")[0].strip()
# 3) Compute text metrics
bleu = compute_bleu(ref_output, gen_text)
rep = repetition_ratio(gen_text)
gen_words = len(gen_text.split())
# 4) Quality flags
is_empty = len(gen_text.strip()) < 5
is_repetitive = rep > 0.5
is_truncated = len(gen_tokens) >= args.max_new
result = {
"index": i,
"category": category,
"instruction": inst[:200],
"input_preview": inp[:100] if inp else "",
"reference_preview": ref_output[:200],
"generated_preview": gen_text[:300],
"loss": round(loss, 4),
"perplexity": round(ppl, 2),
"token_accuracy": round(tok_acc, 4),
"bleu_1": round(bleu["bleu_1"], 4),
"bleu_2": round(bleu["bleu_2"], 4),
"repetition_ratio": round(rep, 4),
"generated_words": gen_words,
"resp_tokens": n_resp_tokens,
"is_empty": is_empty,
"is_repetitive": is_repetitive,
"is_truncated": is_truncated,
}
results.append(result)
# Accumulate per-category
cat_metrics[category]["losses"].append(loss)
cat_metrics[category]["perplexities"].append(ppl)
cat_metrics[category]["accuracies"].append(tok_acc)
cat_metrics[category]["bleu1"].append(bleu["bleu_1"])
cat_metrics[category]["bleu2"].append(bleu["bleu_2"])
cat_metrics[category]["rep_ratios"].append(rep)
cat_metrics[category]["gen_lens"].append(gen_words)
# Progress
status = ""
if is_empty:
status = " [EMPTY]"
elif is_repetitive:
status = " [REPETITIVE]"
elif is_truncated:
status = " [TRUNCATED]"
if (i + 1) % 5 == 0 or i == 0:
print(f" [{i+1:3d}/{len(examples)}] loss={loss:.3f} ppl={ppl:.1f} "
f"acc={tok_acc:.3f} B1={bleu['bleu_1']:.3f} "
f"rep={rep:.3f} words={gen_words}{status}")
# ── Aggregate & Report ────────────────────────────────────────────────────
print(f"\n[5/5] Generating report...")
all_losses = [r["loss"] for r in results if r["loss"] < float("inf")]
all_ppls = [r["perplexity"] for r in results if r["perplexity"] < 1e6]
all_accs = [r["token_accuracy"] for r in results]
all_b1 = [r["bleu_1"] for r in results]
all_b2 = [r["bleu_2"] for r in results]
all_reps = [r["repetition_ratio"] for r in results]
all_lens = [r["generated_words"] for r in results]
n_empty = sum(1 for r in results if r["is_empty"])
n_repetitive = sum(1 for r in results if r["is_repetitive"])
n_truncated = sum(1 for r in results if r["is_truncated"])
avg = lambda xs: sum(xs) / len(xs) if xs else 0.0
# ── Build report text ─────────────────────────────────────────────────────
report_lines = []
def P(s=""):
report_lines.append(s)
P(sep)
P(" LUNA 100M β€” SFT VALIDATION REPORT")
P(f" Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
P(sep)
P()
P(f" Checkpoint : {args.ckpt}")
P(f" Val source : {args.val_json} ({len(all_val)} total samples)")
P(f" Examples tested: {len(examples)} (top complex by scoring)")
P(f" Device : {device}")
P(f" Max seq len : {args.max_len}")
P(f" Gen temp : {args.temperature}")
P(f" Gen top_k : {args.top_k}")
P(f" Gen top_p : {args.top_p}")
P(f" Gen rep_pen : {args.rep_pen}")
P(f" Gen max tokens: {args.max_new}")
P()
P("=" * 72)
P(" OVERALL METRICS")
P("=" * 72)
P(f" Avg Loss (CE) : {avg(all_losses):.4f}")
P(f" Avg Perplexity : {avg(all_ppls):.2f}")
P(f" Median Perplexity : {sorted(all_ppls)[len(all_ppls)//2]:.2f}" if all_ppls else " Median Perplexity : N/A")
P(f" Avg Token Accuracy : {avg(all_accs):.4f} ({avg(all_accs)*100:.1f}%)")
P(f" Avg BLEU-1 : {avg(all_b1):.4f}")
P(f" Avg BLEU-2 : {avg(all_b2):.4f}")
P(f" Avg Repetition Ratio : {avg(all_reps):.4f}")
P(f" Avg Gen Length (words): {avg(all_lens):.1f}")
P()
P(f" Empty responses : {n_empty}/{len(results)}")
P(f" Repetitive responses : {n_repetitive}/{len(results)}")
P(f" Truncated responses : {n_truncated}/{len(results)}")
P()
# Quality grade
grade = "A"
grade_notes = []
if avg(all_ppls) > 50:
grade = "C"
grade_notes.append("high perplexity (>50)")
elif avg(all_ppls) > 20:
grade = "B"
grade_notes.append("moderate perplexity (>20)")
if n_empty > 10:
grade = "D" if grade > "C" else grade
grade_notes.append(f"{n_empty} empty responses")
elif n_empty > 3:
grade = max(grade, "C")
grade_notes.append(f"{n_empty} empty responses")
if n_repetitive > 15:
grade = max(grade, "C")
grade_notes.append(f"{n_repetitive} repetitive responses")
if avg(all_b1) < 0.05:
grade = max(grade, "C")
grade_notes.append("very low BLEU-1")
if avg(all_accs) > 0.4:
grade_notes.append("strong token accuracy")
if avg(all_accs) > 0.5:
if grade == "B":
grade = "A-"
P(f" OVERALL GRADE: {grade}")
if grade_notes:
P(f" Notes: {'; '.join(grade_notes)}")
P()
P("=" * 72)
P(" CATEGORY BREAKDOWN")
P("=" * 72)
P(f" {'Category':<14} {'Count':>5} {'Avg Loss':>9} {'Avg PPL':>9} "
f"{'Avg Acc':>8} {'BLEU-1':>7} {'BLEU-2':>7} {'Rep %':>6}")
P(" " + "-" * 68)
for cat in sorted(cat_metrics.keys()):
m = cat_metrics[cat]
cnt = len(m["losses"])
P(f" {cat:<14} {cnt:>5} {avg(m['losses']):>9.4f} "
f"{avg(m['perplexities']):>9.2f} {avg(m['accuracies']):>8.4f} "
f"{avg(m['bleu1']):>7.4f} {avg(m['bleu2']):>7.4f} "
f"{avg(m['rep_ratios'])*100:>5.1f}%")
P()
# ── Top 5 Best / Worst ────────────────────────────────────────────────────
P("=" * 72)
P(" TOP 5 BEST (lowest perplexity)")
P("=" * 72)
by_ppl = sorted(results, key=lambda r: r["perplexity"])
for r in by_ppl[:5]:
P(f" [{r['index']:3d}] PPL={r['perplexity']:>8.2f} Acc={r['token_accuracy']:.3f} "
f"B1={r['bleu_1']:.3f} [{r['category']}]")
P(f" Q: {r['instruction'][:80]}")
P(f" A: {r['generated_preview'][:100]}")
P()
P("=" * 72)
P(" TOP 5 WORST (highest perplexity)")
P("=" * 72)
for r in by_ppl[-5:]:
P(f" [{r['index']:3d}] PPL={r['perplexity']:>8.2f} Acc={r['token_accuracy']:.3f} "
f"B1={r['bleu_1']:.3f} [{r['category']}]")
P(f" Q: {r['instruction'][:80]}")
P(f" A: {r['generated_preview'][:100]}")
P()
# ── Failure Analysis ──────────────────────────────────────────────────────
failures = [r for r in results if r["is_empty"] or r["is_repetitive"]]
if failures:
P("=" * 72)
P(f" FAILURE ANALYSIS ({len(failures)} problematic responses)")
P("=" * 72)
for r in failures[:10]:
flags = []
if r["is_empty"]:
flags.append("EMPTY")
if r["is_repetitive"]:
flags.append("REPETITIVE")
P(f" [{r['index']:3d}] {' | '.join(flags)} [{r['category']}]")
P(f" Q: {r['instruction'][:80]}")
P(f" A: {r['generated_preview'][:120]}")
P()
# ── Perplexity distribution ───────────────────────────────────────────────
P("=" * 72)
P(" PERPLEXITY DISTRIBUTION")
P("=" * 72)
buckets = [(0, 5), (5, 10), (10, 20), (20, 50), (50, 100),
(100, 500), (500, float("inf"))]
for lo, hi in buckets:
cnt = sum(1 for p in all_ppls if lo <= p < hi)
bar = "#" * cnt
label = f"{lo}-{hi}" if hi != float("inf") else f"{lo}+"
P(f" {label:>8}: {cnt:>3} {bar}")
P()
# ── Sample generations (10 diverse examples) ──────────────────────────────
P("=" * 72)
P(" SAMPLE GENERATIONS (10 diverse examples)")
P("=" * 72)
# Pick every 10th
sample_indices = list(range(0, len(results), max(1, len(results) // 10)))[:10]
for si in sample_indices:
r = results[si]
P(f"\n --- Example {r['index']+1} [{r['category']}] ---")
P(f" Instruction: {r['instruction'][:150]}")
if r["input_preview"]:
P(f" Input: {r['input_preview'][:100]}")
P(f" Reference: {r['reference_preview'][:200]}")
P(f" Generated: {r['generated_preview'][:300]}")
P(f" Loss={r['loss']:.4f} PPL={r['perplexity']:.2f} "
f"Acc={r['token_accuracy']:.4f} BLEU-1={r['bleu_1']:.4f} "
f"Rep={r['repetition_ratio']:.4f}")
P()
P(sep)
P(" END OF REPORT")
P(sep)
report_text = "\n".join(report_lines)
# Print to console
print(report_text)
# Save report
report_path = out_dir / "SFT_VALIDATION_REPORT.txt"
with open(report_path, "w", encoding="utf-8") as f:
f.write(report_text)
print(f"\n Report saved: {report_path}")
# Save detailed JSON results
json_path = out_dir / "validation_results.json"
summary = {
"meta": {
"checkpoint": args.ckpt,
"val_source": args.val_json,
"total_val_samples": len(all_val),
"n_tested": len(examples),
"device": device,
"max_len": args.max_len,
"temperature": args.temperature,
"top_k": args.top_k,
"top_p": args.top_p,
"repetition_penalty": args.rep_pen,
"timestamp": datetime.now().isoformat(),
},
"overall": {
"avg_loss": round(avg(all_losses), 4),
"avg_perplexity": round(avg(all_ppls), 2),
"median_perplexity": round(sorted(all_ppls)[len(all_ppls)//2], 2) if all_ppls else None,
"avg_token_accuracy": round(avg(all_accs), 4),
"avg_bleu_1": round(avg(all_b1), 4),
"avg_bleu_2": round(avg(all_b2), 4),
"avg_repetition_ratio": round(avg(all_reps), 4),
"avg_gen_length_words": round(avg(all_lens), 1),
"n_empty": n_empty,
"n_repetitive": n_repetitive,
"n_truncated": n_truncated,
"grade": grade,
},
"category_breakdown": {
cat: {
"count": len(m["losses"]),
"avg_loss": round(avg(m["losses"]), 4),
"avg_perplexity": round(avg(m["perplexities"]), 2),
"avg_token_accuracy": round(avg(m["accuracies"]), 4),
"avg_bleu_1": round(avg(m["bleu1"]), 4),
"avg_bleu_2": round(avg(m["bleu2"]), 4),
}
for cat, m in cat_metrics.items()
},
"per_example": results,
}
with open(json_path, "w", encoding="utf-8") as f:
json.dump(summary, f, indent=2, ensure_ascii=False)
print(f" JSON results: {json_path}")
if __name__ == "__main__":
main()