Upload validate_sft.py with huggingface_hub
Browse files- validate_sft.py +707 -0
validate_sft.py
ADDED
|
@@ -0,0 +1,707 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LUNA 100M β SFT Validation on Complex Examples
|
| 3 |
+
================================================
|
| 4 |
+
Selects ~100 complex examples from the SFT validation set,
|
| 5 |
+
runs the fine-tuned model on each, and produces a detailed report.
|
| 6 |
+
|
| 7 |
+
Metrics computed:
|
| 8 |
+
- Per-sample cross-entropy loss (prompt-masked) & perplexity
|
| 9 |
+
- Token-level accuracy on the output portion
|
| 10 |
+
- BLEU-1/2 (word overlap with reference output)
|
| 11 |
+
- Repetition ratio (degeneration detection)
|
| 12 |
+
- Response length stats
|
| 13 |
+
- Category breakdown (coding, explanation, analysis, creative, how-to, identity)
|
| 14 |
+
- Overall pass/fail grading
|
| 15 |
+
|
| 16 |
+
Usage:
|
| 17 |
+
python validate_sft.py
|
| 18 |
+
python validate_sft.py --ckpt "Base/out/sft/model.pth" --val_json "Base/Datasets/sft_clean/val.json"
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import os, sys, json, math, time, argparse, re
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from collections import Counter, defaultdict
|
| 24 |
+
from datetime import datetime
|
| 25 |
+
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn as nn
|
| 28 |
+
import torch.nn.functional as F
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# βββ Model (identical to sft_train.py / chat.py) βββββββββββββββββββββββββββββ
|
| 32 |
+
|
| 33 |
+
class RotaryEmbedding(nn.Module):
|
| 34 |
+
def __init__(self, dim, max_seq_len=1024):
|
| 35 |
+
super().__init__()
|
| 36 |
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
| 37 |
+
self.register_buffer("inv_freq", inv_freq)
|
| 38 |
+
t = torch.arange(max_seq_len).float()
|
| 39 |
+
freqs = torch.einsum("i,j->ij", t, inv_freq)
|
| 40 |
+
emb = torch.cat([freqs, freqs], dim=-1)
|
| 41 |
+
self.register_buffer("cos_cached", emb.cos())
|
| 42 |
+
self.register_buffer("sin_cached", emb.sin())
|
| 43 |
+
|
| 44 |
+
def forward(self, seq_len):
|
| 45 |
+
return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def rotate_half(x):
|
| 49 |
+
x1, x2 = x.chunk(2, dim=-1)
|
| 50 |
+
return torch.cat([-x2, x1], dim=-1)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def apply_rotary(x, cos, sin):
|
| 54 |
+
c = cos.unsqueeze(0).unsqueeze(0)
|
| 55 |
+
s = sin.unsqueeze(0).unsqueeze(0)
|
| 56 |
+
return x * c + rotate_half(x) * s
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class CausalSelfAttention(nn.Module):
|
| 60 |
+
def __init__(self, n_embd, n_head, block_size, rotary_pct=0.25):
|
| 61 |
+
super().__init__()
|
| 62 |
+
self.n_head = n_head
|
| 63 |
+
self.head_dim = n_embd // n_head
|
| 64 |
+
self.rot_dim = int(self.head_dim * rotary_pct)
|
| 65 |
+
self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=True)
|
| 66 |
+
self.c_proj = nn.Linear(n_embd, n_embd, bias=True)
|
| 67 |
+
self.rotary = RotaryEmbedding(self.rot_dim, block_size)
|
| 68 |
+
|
| 69 |
+
def forward(self, x):
|
| 70 |
+
B, T, C = x.size()
|
| 71 |
+
qkv = self.c_attn(x).reshape(B, T, 3, self.n_head, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 72 |
+
q, k, v = qkv.unbind(0)
|
| 73 |
+
cos, sin = self.rotary(T)
|
| 74 |
+
q = torch.cat([apply_rotary(q[..., :self.rot_dim], cos, sin), q[..., self.rot_dim:]], dim=-1)
|
| 75 |
+
k = torch.cat([apply_rotary(k[..., :self.rot_dim], cos, sin), k[..., self.rot_dim:]], dim=-1)
|
| 76 |
+
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
|
| 77 |
+
return self.c_proj(y.transpose(1, 2).contiguous().view(B, T, C))
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class MLP(nn.Module):
|
| 81 |
+
def __init__(self, n_embd):
|
| 82 |
+
super().__init__()
|
| 83 |
+
self.fc = nn.Linear(n_embd, 4 * n_embd, bias=True)
|
| 84 |
+
self.gelu = nn.GELU()
|
| 85 |
+
self.proj = nn.Linear(4 * n_embd, n_embd, bias=True)
|
| 86 |
+
|
| 87 |
+
def forward(self, x):
|
| 88 |
+
return self.proj(self.gelu(self.fc(x)))
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class Block(nn.Module):
|
| 92 |
+
def __init__(self, n_embd, n_head, block_size):
|
| 93 |
+
super().__init__()
|
| 94 |
+
self.ln1 = nn.LayerNorm(n_embd)
|
| 95 |
+
self.attn = CausalSelfAttention(n_embd, n_head, block_size)
|
| 96 |
+
self.ln2 = nn.LayerNorm(n_embd)
|
| 97 |
+
self.mlp = MLP(n_embd)
|
| 98 |
+
|
| 99 |
+
def forward(self, x):
|
| 100 |
+
x = x + self.attn(self.ln1(x))
|
| 101 |
+
x = x + self.mlp(self.ln2(x))
|
| 102 |
+
return x
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class LUNAModel(nn.Module):
|
| 106 |
+
def __init__(self, vocab_size=50304, block_size=1024,
|
| 107 |
+
n_layer=10, n_embd=768, n_head=12):
|
| 108 |
+
super().__init__()
|
| 109 |
+
self.block_size = block_size
|
| 110 |
+
self.wte = nn.Embedding(vocab_size, n_embd)
|
| 111 |
+
self.blocks = nn.ModuleList(
|
| 112 |
+
[Block(n_embd, n_head, block_size) for _ in range(n_layer)]
|
| 113 |
+
)
|
| 114 |
+
self.ln_f = nn.LayerNorm(n_embd)
|
| 115 |
+
self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
|
| 116 |
+
self.lm_head.weight = self.wte.weight
|
| 117 |
+
|
| 118 |
+
def forward(self, idx):
|
| 119 |
+
x = self.wte(idx)
|
| 120 |
+
for block in self.blocks:
|
| 121 |
+
x = block(x)
|
| 122 |
+
return self.lm_head(self.ln_f(x))
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# βββ Generation βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 126 |
+
|
| 127 |
+
@torch.no_grad()
|
| 128 |
+
def generate(model, input_ids, max_new=150, temperature=0.7,
|
| 129 |
+
top_p=0.9, top_k=40, repetition_penalty=1.0, device="cpu"):
|
| 130 |
+
ids = input_ids.to(device)
|
| 131 |
+
generated = []
|
| 132 |
+
for _ in range(max_new):
|
| 133 |
+
logits = model(ids[:, -model.block_size:])[:, -1, :]
|
| 134 |
+
if repetition_penalty != 1.0:
|
| 135 |
+
for tok_id in set(ids[0].tolist()):
|
| 136 |
+
if logits[0, tok_id] > 0:
|
| 137 |
+
logits[0, tok_id] /= repetition_penalty
|
| 138 |
+
else:
|
| 139 |
+
logits[0, tok_id] *= repetition_penalty
|
| 140 |
+
if temperature < 1e-6:
|
| 141 |
+
next_token = logits.argmax(dim=-1, keepdim=True)
|
| 142 |
+
else:
|
| 143 |
+
logits = logits / temperature
|
| 144 |
+
probs = F.softmax(logits, dim=-1)
|
| 145 |
+
if top_k > 0:
|
| 146 |
+
kval = min(top_k, probs.size(-1))
|
| 147 |
+
topk_vals, _ = torch.topk(probs, kval)
|
| 148 |
+
probs[probs < topk_vals[:, [-1]]] = 0.0
|
| 149 |
+
probs /= probs.sum()
|
| 150 |
+
if top_p < 1.0:
|
| 151 |
+
sorted_probs, sorted_idx = torch.sort(probs, descending=True)
|
| 152 |
+
cumsum = torch.cumsum(sorted_probs, dim=-1)
|
| 153 |
+
mask = cumsum - sorted_probs > top_p
|
| 154 |
+
sorted_probs[mask] = 0.0
|
| 155 |
+
sorted_probs /= sorted_probs.sum()
|
| 156 |
+
next_token = sorted_idx[0, torch.multinomial(sorted_probs[0], 1)]
|
| 157 |
+
else:
|
| 158 |
+
next_token = torch.multinomial(probs[0], 1)
|
| 159 |
+
ids = torch.cat([ids, next_token.view(1, 1)], dim=1)
|
| 160 |
+
generated.append(next_token.item())
|
| 161 |
+
if next_token.item() == 0: # EOS token
|
| 162 |
+
break
|
| 163 |
+
return generated
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
# βββ Prompt formatting (matches sft_train.py) ββββββββββββββββββββββββββββββββ
|
| 167 |
+
|
| 168 |
+
def format_prompt(instruction, inp=""):
|
| 169 |
+
inst = instruction.strip()
|
| 170 |
+
inp = inp.strip()
|
| 171 |
+
if inst and inp:
|
| 172 |
+
return f"### Instruction:\n{inst}\n\n### Input:\n{inp}\n\n### Response:\n"
|
| 173 |
+
elif inst:
|
| 174 |
+
return f"### Instruction:\n{inst}\n\n### Response:\n"
|
| 175 |
+
else:
|
| 176 |
+
return f"### Input:\n{inp}\n\n### Response:\n"
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
# βββ Complexity scoring & selection βββββββββββββββββββββββββββββββββββββββββββ
|
| 180 |
+
|
| 181 |
+
COMPLEXITY_KEYWORDS = [
|
| 182 |
+
"step", "first", "second", "then", "next", "finally",
|
| 183 |
+
"because", "however", "therefore", "explain", "analyze",
|
| 184 |
+
"compare", "evaluate", "describe", "discuss", "provide",
|
| 185 |
+
"example", "detail", "elaborate", "summarize",
|
| 186 |
+
]
|
| 187 |
+
|
| 188 |
+
def complexity_score(entry):
|
| 189 |
+
inst = entry.get("instruction", "")
|
| 190 |
+
inp = entry.get("input", "")
|
| 191 |
+
out = entry.get("output", "")
|
| 192 |
+
total_text = (inst + " " + inp + " " + out).lower()
|
| 193 |
+
total_len = len(inst) + len(inp) + len(out)
|
| 194 |
+
has_input = 1 if len(inp) > 20 else 0
|
| 195 |
+
kw_count = sum(1 for w in COMPLEXITY_KEYWORDS if w in total_text)
|
| 196 |
+
return total_len * 0.3 + len(out) * 0.5 + has_input * 500 + kw_count * 200
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def categorize(instruction):
|
| 200 |
+
inst = instruction.lower()
|
| 201 |
+
if any(w in inst for w in ["code", "python", "java", "swift", "function", "program", "algorithm", "script", "sql", "html", "css"]):
|
| 202 |
+
return "coding"
|
| 203 |
+
if any(w in inst for w in ["who are you", "your name", "who created", "asterizer", "luna", "are you an ai"]):
|
| 204 |
+
return "identity"
|
| 205 |
+
if any(w in inst for w in ["explain", "what is", "define", "describe", "meaning of"]):
|
| 206 |
+
return "explanation"
|
| 207 |
+
if any(w in inst for w in ["analyze", "compare", "evaluate", "assess", "critique"]):
|
| 208 |
+
return "analysis"
|
| 209 |
+
if any(w in inst for w in ["write", "create", "generate", "compose", "draft", "poem", "story", "essay"]):
|
| 210 |
+
return "creative"
|
| 211 |
+
if any(w in inst for w in ["how", "step", "guide", "method", "procedure", "tutorial"]):
|
| 212 |
+
return "how-to"
|
| 213 |
+
return "other"
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def select_complex_examples(data, n=100):
|
| 217 |
+
scored = [(complexity_score(entry), i) for i, entry in enumerate(data)]
|
| 218 |
+
scored.sort(reverse=True)
|
| 219 |
+
return [data[idx] for _, idx in scored[:n]]
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
# βββ Metrics ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 223 |
+
|
| 224 |
+
def compute_bleu(reference, hypothesis, max_n=2):
|
| 225 |
+
"""Simple BLEU-1 and BLEU-2 (word-level, no brevity penalty)."""
|
| 226 |
+
ref_tokens = reference.lower().split()
|
| 227 |
+
hyp_tokens = hypothesis.lower().split()
|
| 228 |
+
if not hyp_tokens or not ref_tokens:
|
| 229 |
+
return {f"bleu_{n}": 0.0 for n in range(1, max_n + 1)}
|
| 230 |
+
scores = {}
|
| 231 |
+
for n in range(1, max_n + 1):
|
| 232 |
+
ref_ngrams = Counter()
|
| 233 |
+
for i in range(len(ref_tokens) - n + 1):
|
| 234 |
+
ref_ngrams[tuple(ref_tokens[i:i + n])] += 1
|
| 235 |
+
hyp_ngrams = Counter()
|
| 236 |
+
for i in range(len(hyp_tokens) - n + 1):
|
| 237 |
+
hyp_ngrams[tuple(hyp_tokens[i:i + n])] += 1
|
| 238 |
+
clipped = sum(min(hyp_ngrams[ng], ref_ngrams[ng]) for ng in hyp_ngrams)
|
| 239 |
+
total = max(sum(hyp_ngrams.values()), 1)
|
| 240 |
+
scores[f"bleu_{n}"] = clipped / total
|
| 241 |
+
return scores
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def repetition_ratio(text):
|
| 245 |
+
"""Fraction of repeated trigrams in the text (higher = more degenerate)."""
|
| 246 |
+
words = text.lower().split()
|
| 247 |
+
if len(words) < 4:
|
| 248 |
+
return 0.0
|
| 249 |
+
trigrams = [tuple(words[i:i + 3]) for i in range(len(words) - 2)]
|
| 250 |
+
if not trigrams:
|
| 251 |
+
return 0.0
|
| 252 |
+
unique = len(set(trigrams))
|
| 253 |
+
return 1.0 - (unique / len(trigrams))
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
@torch.no_grad()
|
| 257 |
+
def compute_loss_and_accuracy(model, tokenizer, entry, max_len, device):
|
| 258 |
+
"""Compute prompt-masked CE loss & token accuracy for one example."""
|
| 259 |
+
prompt = format_prompt(entry.get("instruction", ""), entry.get("input", ""))
|
| 260 |
+
response = entry.get("output", "").strip()
|
| 261 |
+
|
| 262 |
+
prompt_ids = tokenizer.encode(prompt)
|
| 263 |
+
response_ids = tokenizer.encode(response) + [tokenizer.eos_token_id or 0]
|
| 264 |
+
total_ids = prompt_ids + response_ids
|
| 265 |
+
|
| 266 |
+
if len(total_ids) > max_len:
|
| 267 |
+
total_ids = total_ids[:max_len]
|
| 268 |
+
total_ids[-1] = tokenizer.eos_token_id or 0
|
| 269 |
+
prompt_len = min(len(prompt_ids), max_len)
|
| 270 |
+
else:
|
| 271 |
+
prompt_len = len(prompt_ids)
|
| 272 |
+
|
| 273 |
+
input_tensor = torch.tensor([total_ids], dtype=torch.long, device=device)
|
| 274 |
+
logits = model(input_tensor) # (1, T, V)
|
| 275 |
+
|
| 276 |
+
# Shift for next-token prediction
|
| 277 |
+
shift_logits = logits[:, :-1, :].contiguous()
|
| 278 |
+
shift_targets = input_tensor[:, 1:].contiguous()
|
| 279 |
+
|
| 280 |
+
# Build mask: only on response portion
|
| 281 |
+
mask = torch.zeros(shift_targets.shape, dtype=torch.float, device=device)
|
| 282 |
+
resp_start = max(prompt_len - 1, 0)
|
| 283 |
+
resp_end = len(total_ids) - 1
|
| 284 |
+
mask[0, resp_start:resp_end] = 1.0
|
| 285 |
+
|
| 286 |
+
if mask.sum() == 0:
|
| 287 |
+
return float("inf"), 0.0, 0
|
| 288 |
+
|
| 289 |
+
per_token_loss = F.cross_entropy(
|
| 290 |
+
shift_logits.view(-1, shift_logits.size(-1)),
|
| 291 |
+
shift_targets.view(-1),
|
| 292 |
+
reduction="none"
|
| 293 |
+
).view(shift_targets.shape)
|
| 294 |
+
|
| 295 |
+
masked_loss = (per_token_loss * mask).sum() / mask.sum()
|
| 296 |
+
|
| 297 |
+
# Token accuracy on response portion
|
| 298 |
+
preds = shift_logits.argmax(dim=-1)
|
| 299 |
+
correct = ((preds == shift_targets).float() * mask).sum()
|
| 300 |
+
total_resp = mask.sum()
|
| 301 |
+
accuracy = (correct / total_resp).item() if total_resp > 0 else 0.0
|
| 302 |
+
|
| 303 |
+
return masked_loss.item(), accuracy, int(total_resp.item())
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
# βββ Main validation βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 307 |
+
|
| 308 |
+
def main():
|
| 309 |
+
parser = argparse.ArgumentParser(description="LUNA SFT β Complex Example Validation")
|
| 310 |
+
parser.add_argument("--ckpt", default=r"D:\ASTERIZER 2026\LUNA\Base\out\sft\model.pth")
|
| 311 |
+
parser.add_argument("--tok_dir", default="Base/checkpoints/EleutherAI/pythia-160m")
|
| 312 |
+
parser.add_argument("--val_json", default="Base/Datasets/sft_clean/val.json")
|
| 313 |
+
parser.add_argument("--n_examples", type=int, default=100)
|
| 314 |
+
parser.add_argument("--max_len", type=int, default=1024)
|
| 315 |
+
parser.add_argument("--max_new", type=int, default=150)
|
| 316 |
+
parser.add_argument("--temperature", type=float, default=0.7)
|
| 317 |
+
parser.add_argument("--top_k", type=int, default=40)
|
| 318 |
+
parser.add_argument("--top_p", type=float, default=0.9)
|
| 319 |
+
parser.add_argument("--rep_pen", type=float, default=1.0)
|
| 320 |
+
parser.add_argument("--device", default="auto")
|
| 321 |
+
parser.add_argument("--out_dir", default="Base/out/sft/validation_report_v2")
|
| 322 |
+
args = parser.parse_args()
|
| 323 |
+
|
| 324 |
+
device = "cuda" if args.device == "auto" and torch.cuda.is_available() else args.device
|
| 325 |
+
if device == "auto":
|
| 326 |
+
device = "cpu"
|
| 327 |
+
|
| 328 |
+
out_dir = Path(args.out_dir)
|
| 329 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 330 |
+
|
| 331 |
+
sep = "=" * 72
|
| 332 |
+
print(f"\n{sep}")
|
| 333 |
+
print(" LUNA 100M β SFT VALIDATION (Complex Examples)")
|
| 334 |
+
print(sep)
|
| 335 |
+
print(f" Checkpoint : {args.ckpt}")
|
| 336 |
+
print(f" Val data : {args.val_json}")
|
| 337 |
+
print(f" N examples : {args.n_examples}")
|
| 338 |
+
print(f" Device : {device}")
|
| 339 |
+
print(f" Max seq : {args.max_len}")
|
| 340 |
+
print(f" Temperature: {args.temperature}")
|
| 341 |
+
print(f" Top-k : {args.top_k}")
|
| 342 |
+
print(f" Top-p : {args.top_p}")
|
| 343 |
+
print(f" Rep penalty: {args.rep_pen}")
|
| 344 |
+
print(sep)
|
| 345 |
+
|
| 346 |
+
# ββ Load model ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 347 |
+
print("\n[1/5] Loading model...")
|
| 348 |
+
t0 = time.time()
|
| 349 |
+
state_dict = torch.load(args.ckpt, map_location="cpu", weights_only=True)
|
| 350 |
+
if isinstance(state_dict, dict) and "model" in state_dict:
|
| 351 |
+
state_dict = state_dict["model"]
|
| 352 |
+
model = LUNAModel()
|
| 353 |
+
model.load_state_dict(state_dict, strict=True)
|
| 354 |
+
model = model.to(device).eval()
|
| 355 |
+
n_params = sum(p.numel() for p in model.parameters())
|
| 356 |
+
print(f" Model loaded: {n_params:,} params ({time.time()-t0:.1f}s)")
|
| 357 |
+
|
| 358 |
+
# ββ Load tokenizer βββββββββββββββββββββββββββββββββββββββββββββββββββββοΏ½οΏ½ββ
|
| 359 |
+
print("[2/5] Loading tokenizer...")
|
| 360 |
+
from transformers import AutoTokenizer
|
| 361 |
+
tokenizer = AutoTokenizer.from_pretrained(args.tok_dir)
|
| 362 |
+
print(f" Tokenizer: vocab_size={tokenizer.vocab_size}")
|
| 363 |
+
|
| 364 |
+
# ββ Select complex examples βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 365 |
+
print(f"[3/5] Selecting {args.n_examples} most complex examples from val set...")
|
| 366 |
+
with open(args.val_json, "r", encoding="utf-8") as f:
|
| 367 |
+
all_val = json.load(f)
|
| 368 |
+
print(f" Total val samples: {len(all_val)}")
|
| 369 |
+
examples = select_complex_examples(all_val, args.n_examples)
|
| 370 |
+
print(f" Selected: {len(examples)} complex examples")
|
| 371 |
+
|
| 372 |
+
# Category breakdown
|
| 373 |
+
cat_counts = Counter(categorize(e["instruction"]) for e in examples)
|
| 374 |
+
print(f" Categories: {dict(cat_counts)}")
|
| 375 |
+
|
| 376 |
+
# ββ Run validation ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 377 |
+
print(f"\n[4/5] Running validation ({len(examples)} examples)...")
|
| 378 |
+
print("-" * 72)
|
| 379 |
+
|
| 380 |
+
results = []
|
| 381 |
+
cat_metrics = defaultdict(lambda: {"losses": [], "perplexities": [],
|
| 382 |
+
"accuracies": [], "bleu1": [],
|
| 383 |
+
"bleu2": [], "rep_ratios": [],
|
| 384 |
+
"gen_lens": []})
|
| 385 |
+
|
| 386 |
+
for i, entry in enumerate(examples):
|
| 387 |
+
inst = entry.get("instruction", "")
|
| 388 |
+
inp = entry.get("input", "")
|
| 389 |
+
ref_output = entry.get("output", "")
|
| 390 |
+
category = categorize(inst)
|
| 391 |
+
|
| 392 |
+
# 1) Compute loss & accuracy (teacher-forced)
|
| 393 |
+
loss, tok_acc, n_resp_tokens = compute_loss_and_accuracy(
|
| 394 |
+
model, tokenizer, entry, args.max_len, device
|
| 395 |
+
)
|
| 396 |
+
ppl = math.exp(min(loss, 20)) # cap to avoid overflow
|
| 397 |
+
|
| 398 |
+
# 2) Generate response (autoregressive)
|
| 399 |
+
prompt = format_prompt(inst, inp)
|
| 400 |
+
prompt_ids = tokenizer.encode(prompt, return_tensors="pt")
|
| 401 |
+
gen_tokens = generate(
|
| 402 |
+
model, prompt_ids,
|
| 403 |
+
max_new=args.max_new,
|
| 404 |
+
temperature=args.temperature,
|
| 405 |
+
top_k=args.top_k,
|
| 406 |
+
top_p=args.top_p,
|
| 407 |
+
repetition_penalty=args.rep_pen,
|
| 408 |
+
device=device,
|
| 409 |
+
)
|
| 410 |
+
gen_text = tokenizer.decode(gen_tokens, skip_special_tokens=True).strip()
|
| 411 |
+
# Clean trailing template markers
|
| 412 |
+
if "### " in gen_text:
|
| 413 |
+
gen_text = gen_text.split("### ")[0].strip()
|
| 414 |
+
|
| 415 |
+
# 3) Compute text metrics
|
| 416 |
+
bleu = compute_bleu(ref_output, gen_text)
|
| 417 |
+
rep = repetition_ratio(gen_text)
|
| 418 |
+
gen_words = len(gen_text.split())
|
| 419 |
+
|
| 420 |
+
# 4) Quality flags
|
| 421 |
+
is_empty = len(gen_text.strip()) < 5
|
| 422 |
+
is_repetitive = rep > 0.5
|
| 423 |
+
is_truncated = len(gen_tokens) >= args.max_new
|
| 424 |
+
|
| 425 |
+
result = {
|
| 426 |
+
"index": i,
|
| 427 |
+
"category": category,
|
| 428 |
+
"instruction": inst[:200],
|
| 429 |
+
"input_preview": inp[:100] if inp else "",
|
| 430 |
+
"reference_preview": ref_output[:200],
|
| 431 |
+
"generated_preview": gen_text[:300],
|
| 432 |
+
"loss": round(loss, 4),
|
| 433 |
+
"perplexity": round(ppl, 2),
|
| 434 |
+
"token_accuracy": round(tok_acc, 4),
|
| 435 |
+
"bleu_1": round(bleu["bleu_1"], 4),
|
| 436 |
+
"bleu_2": round(bleu["bleu_2"], 4),
|
| 437 |
+
"repetition_ratio": round(rep, 4),
|
| 438 |
+
"generated_words": gen_words,
|
| 439 |
+
"resp_tokens": n_resp_tokens,
|
| 440 |
+
"is_empty": is_empty,
|
| 441 |
+
"is_repetitive": is_repetitive,
|
| 442 |
+
"is_truncated": is_truncated,
|
| 443 |
+
}
|
| 444 |
+
results.append(result)
|
| 445 |
+
|
| 446 |
+
# Accumulate per-category
|
| 447 |
+
cat_metrics[category]["losses"].append(loss)
|
| 448 |
+
cat_metrics[category]["perplexities"].append(ppl)
|
| 449 |
+
cat_metrics[category]["accuracies"].append(tok_acc)
|
| 450 |
+
cat_metrics[category]["bleu1"].append(bleu["bleu_1"])
|
| 451 |
+
cat_metrics[category]["bleu2"].append(bleu["bleu_2"])
|
| 452 |
+
cat_metrics[category]["rep_ratios"].append(rep)
|
| 453 |
+
cat_metrics[category]["gen_lens"].append(gen_words)
|
| 454 |
+
|
| 455 |
+
# Progress
|
| 456 |
+
status = ""
|
| 457 |
+
if is_empty:
|
| 458 |
+
status = " [EMPTY]"
|
| 459 |
+
elif is_repetitive:
|
| 460 |
+
status = " [REPETITIVE]"
|
| 461 |
+
elif is_truncated:
|
| 462 |
+
status = " [TRUNCATED]"
|
| 463 |
+
|
| 464 |
+
if (i + 1) % 5 == 0 or i == 0:
|
| 465 |
+
print(f" [{i+1:3d}/{len(examples)}] loss={loss:.3f} ppl={ppl:.1f} "
|
| 466 |
+
f"acc={tok_acc:.3f} B1={bleu['bleu_1']:.3f} "
|
| 467 |
+
f"rep={rep:.3f} words={gen_words}{status}")
|
| 468 |
+
|
| 469 |
+
# ββ Aggregate & Report ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 470 |
+
print(f"\n[5/5] Generating report...")
|
| 471 |
+
|
| 472 |
+
all_losses = [r["loss"] for r in results if r["loss"] < float("inf")]
|
| 473 |
+
all_ppls = [r["perplexity"] for r in results if r["perplexity"] < 1e6]
|
| 474 |
+
all_accs = [r["token_accuracy"] for r in results]
|
| 475 |
+
all_b1 = [r["bleu_1"] for r in results]
|
| 476 |
+
all_b2 = [r["bleu_2"] for r in results]
|
| 477 |
+
all_reps = [r["repetition_ratio"] for r in results]
|
| 478 |
+
all_lens = [r["generated_words"] for r in results]
|
| 479 |
+
|
| 480 |
+
n_empty = sum(1 for r in results if r["is_empty"])
|
| 481 |
+
n_repetitive = sum(1 for r in results if r["is_repetitive"])
|
| 482 |
+
n_truncated = sum(1 for r in results if r["is_truncated"])
|
| 483 |
+
|
| 484 |
+
avg = lambda xs: sum(xs) / len(xs) if xs else 0.0
|
| 485 |
+
|
| 486 |
+
# ββ Build report text βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 487 |
+
report_lines = []
|
| 488 |
+
def P(s=""):
|
| 489 |
+
report_lines.append(s)
|
| 490 |
+
|
| 491 |
+
P(sep)
|
| 492 |
+
P(" LUNA 100M β SFT VALIDATION REPORT")
|
| 493 |
+
P(f" Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
| 494 |
+
P(sep)
|
| 495 |
+
P()
|
| 496 |
+
P(f" Checkpoint : {args.ckpt}")
|
| 497 |
+
P(f" Val source : {args.val_json} ({len(all_val)} total samples)")
|
| 498 |
+
P(f" Examples tested: {len(examples)} (top complex by scoring)")
|
| 499 |
+
P(f" Device : {device}")
|
| 500 |
+
P(f" Max seq len : {args.max_len}")
|
| 501 |
+
P(f" Gen temp : {args.temperature}")
|
| 502 |
+
P(f" Gen top_k : {args.top_k}")
|
| 503 |
+
P(f" Gen top_p : {args.top_p}")
|
| 504 |
+
P(f" Gen rep_pen : {args.rep_pen}")
|
| 505 |
+
P(f" Gen max tokens: {args.max_new}")
|
| 506 |
+
P()
|
| 507 |
+
|
| 508 |
+
P("=" * 72)
|
| 509 |
+
P(" OVERALL METRICS")
|
| 510 |
+
P("=" * 72)
|
| 511 |
+
P(f" Avg Loss (CE) : {avg(all_losses):.4f}")
|
| 512 |
+
P(f" Avg Perplexity : {avg(all_ppls):.2f}")
|
| 513 |
+
P(f" Median Perplexity : {sorted(all_ppls)[len(all_ppls)//2]:.2f}" if all_ppls else " Median Perplexity : N/A")
|
| 514 |
+
P(f" Avg Token Accuracy : {avg(all_accs):.4f} ({avg(all_accs)*100:.1f}%)")
|
| 515 |
+
P(f" Avg BLEU-1 : {avg(all_b1):.4f}")
|
| 516 |
+
P(f" Avg BLEU-2 : {avg(all_b2):.4f}")
|
| 517 |
+
P(f" Avg Repetition Ratio : {avg(all_reps):.4f}")
|
| 518 |
+
P(f" Avg Gen Length (words): {avg(all_lens):.1f}")
|
| 519 |
+
P()
|
| 520 |
+
P(f" Empty responses : {n_empty}/{len(results)}")
|
| 521 |
+
P(f" Repetitive responses : {n_repetitive}/{len(results)}")
|
| 522 |
+
P(f" Truncated responses : {n_truncated}/{len(results)}")
|
| 523 |
+
P()
|
| 524 |
+
|
| 525 |
+
# Quality grade
|
| 526 |
+
grade = "A"
|
| 527 |
+
grade_notes = []
|
| 528 |
+
if avg(all_ppls) > 50:
|
| 529 |
+
grade = "C"
|
| 530 |
+
grade_notes.append("high perplexity (>50)")
|
| 531 |
+
elif avg(all_ppls) > 20:
|
| 532 |
+
grade = "B"
|
| 533 |
+
grade_notes.append("moderate perplexity (>20)")
|
| 534 |
+
if n_empty > 10:
|
| 535 |
+
grade = "D" if grade > "C" else grade
|
| 536 |
+
grade_notes.append(f"{n_empty} empty responses")
|
| 537 |
+
elif n_empty > 3:
|
| 538 |
+
grade = max(grade, "C")
|
| 539 |
+
grade_notes.append(f"{n_empty} empty responses")
|
| 540 |
+
if n_repetitive > 15:
|
| 541 |
+
grade = max(grade, "C")
|
| 542 |
+
grade_notes.append(f"{n_repetitive} repetitive responses")
|
| 543 |
+
if avg(all_b1) < 0.05:
|
| 544 |
+
grade = max(grade, "C")
|
| 545 |
+
grade_notes.append("very low BLEU-1")
|
| 546 |
+
if avg(all_accs) > 0.4:
|
| 547 |
+
grade_notes.append("strong token accuracy")
|
| 548 |
+
if avg(all_accs) > 0.5:
|
| 549 |
+
if grade == "B":
|
| 550 |
+
grade = "A-"
|
| 551 |
+
|
| 552 |
+
P(f" OVERALL GRADE: {grade}")
|
| 553 |
+
if grade_notes:
|
| 554 |
+
P(f" Notes: {'; '.join(grade_notes)}")
|
| 555 |
+
P()
|
| 556 |
+
|
| 557 |
+
P("=" * 72)
|
| 558 |
+
P(" CATEGORY BREAKDOWN")
|
| 559 |
+
P("=" * 72)
|
| 560 |
+
P(f" {'Category':<14} {'Count':>5} {'Avg Loss':>9} {'Avg PPL':>9} "
|
| 561 |
+
f"{'Avg Acc':>8} {'BLEU-1':>7} {'BLEU-2':>7} {'Rep %':>6}")
|
| 562 |
+
P(" " + "-" * 68)
|
| 563 |
+
for cat in sorted(cat_metrics.keys()):
|
| 564 |
+
m = cat_metrics[cat]
|
| 565 |
+
cnt = len(m["losses"])
|
| 566 |
+
P(f" {cat:<14} {cnt:>5} {avg(m['losses']):>9.4f} "
|
| 567 |
+
f"{avg(m['perplexities']):>9.2f} {avg(m['accuracies']):>8.4f} "
|
| 568 |
+
f"{avg(m['bleu1']):>7.4f} {avg(m['bleu2']):>7.4f} "
|
| 569 |
+
f"{avg(m['rep_ratios'])*100:>5.1f}%")
|
| 570 |
+
P()
|
| 571 |
+
|
| 572 |
+
# ββ Top 5 Best / Worst ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 573 |
+
P("=" * 72)
|
| 574 |
+
P(" TOP 5 BEST (lowest perplexity)")
|
| 575 |
+
P("=" * 72)
|
| 576 |
+
by_ppl = sorted(results, key=lambda r: r["perplexity"])
|
| 577 |
+
for r in by_ppl[:5]:
|
| 578 |
+
P(f" [{r['index']:3d}] PPL={r['perplexity']:>8.2f} Acc={r['token_accuracy']:.3f} "
|
| 579 |
+
f"B1={r['bleu_1']:.3f} [{r['category']}]")
|
| 580 |
+
P(f" Q: {r['instruction'][:80]}")
|
| 581 |
+
P(f" A: {r['generated_preview'][:100]}")
|
| 582 |
+
P()
|
| 583 |
+
|
| 584 |
+
P("=" * 72)
|
| 585 |
+
P(" TOP 5 WORST (highest perplexity)")
|
| 586 |
+
P("=" * 72)
|
| 587 |
+
for r in by_ppl[-5:]:
|
| 588 |
+
P(f" [{r['index']:3d}] PPL={r['perplexity']:>8.2f} Acc={r['token_accuracy']:.3f} "
|
| 589 |
+
f"B1={r['bleu_1']:.3f} [{r['category']}]")
|
| 590 |
+
P(f" Q: {r['instruction'][:80]}")
|
| 591 |
+
P(f" A: {r['generated_preview'][:100]}")
|
| 592 |
+
P()
|
| 593 |
+
|
| 594 |
+
# ββ Failure Analysis ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 595 |
+
failures = [r for r in results if r["is_empty"] or r["is_repetitive"]]
|
| 596 |
+
if failures:
|
| 597 |
+
P("=" * 72)
|
| 598 |
+
P(f" FAILURE ANALYSIS ({len(failures)} problematic responses)")
|
| 599 |
+
P("=" * 72)
|
| 600 |
+
for r in failures[:10]:
|
| 601 |
+
flags = []
|
| 602 |
+
if r["is_empty"]:
|
| 603 |
+
flags.append("EMPTY")
|
| 604 |
+
if r["is_repetitive"]:
|
| 605 |
+
flags.append("REPETITIVE")
|
| 606 |
+
P(f" [{r['index']:3d}] {' | '.join(flags)} [{r['category']}]")
|
| 607 |
+
P(f" Q: {r['instruction'][:80]}")
|
| 608 |
+
P(f" A: {r['generated_preview'][:120]}")
|
| 609 |
+
P()
|
| 610 |
+
|
| 611 |
+
# ββ Perplexity distribution βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 612 |
+
P("=" * 72)
|
| 613 |
+
P(" PERPLEXITY DISTRIBUTION")
|
| 614 |
+
P("=" * 72)
|
| 615 |
+
buckets = [(0, 5), (5, 10), (10, 20), (20, 50), (50, 100),
|
| 616 |
+
(100, 500), (500, float("inf"))]
|
| 617 |
+
for lo, hi in buckets:
|
| 618 |
+
cnt = sum(1 for p in all_ppls if lo <= p < hi)
|
| 619 |
+
bar = "#" * cnt
|
| 620 |
+
label = f"{lo}-{hi}" if hi != float("inf") else f"{lo}+"
|
| 621 |
+
P(f" {label:>8}: {cnt:>3} {bar}")
|
| 622 |
+
P()
|
| 623 |
+
|
| 624 |
+
# ββ Sample generations (10 diverse examples) ββββββββββββββββββββββββββββββ
|
| 625 |
+
P("=" * 72)
|
| 626 |
+
P(" SAMPLE GENERATIONS (10 diverse examples)")
|
| 627 |
+
P("=" * 72)
|
| 628 |
+
# Pick every 10th
|
| 629 |
+
sample_indices = list(range(0, len(results), max(1, len(results) // 10)))[:10]
|
| 630 |
+
for si in sample_indices:
|
| 631 |
+
r = results[si]
|
| 632 |
+
P(f"\n --- Example {r['index']+1} [{r['category']}] ---")
|
| 633 |
+
P(f" Instruction: {r['instruction'][:150]}")
|
| 634 |
+
if r["input_preview"]:
|
| 635 |
+
P(f" Input: {r['input_preview'][:100]}")
|
| 636 |
+
P(f" Reference: {r['reference_preview'][:200]}")
|
| 637 |
+
P(f" Generated: {r['generated_preview'][:300]}")
|
| 638 |
+
P(f" Loss={r['loss']:.4f} PPL={r['perplexity']:.2f} "
|
| 639 |
+
f"Acc={r['token_accuracy']:.4f} BLEU-1={r['bleu_1']:.4f} "
|
| 640 |
+
f"Rep={r['repetition_ratio']:.4f}")
|
| 641 |
+
P()
|
| 642 |
+
|
| 643 |
+
P(sep)
|
| 644 |
+
P(" END OF REPORT")
|
| 645 |
+
P(sep)
|
| 646 |
+
|
| 647 |
+
report_text = "\n".join(report_lines)
|
| 648 |
+
|
| 649 |
+
# Print to console
|
| 650 |
+
print(report_text)
|
| 651 |
+
|
| 652 |
+
# Save report
|
| 653 |
+
report_path = out_dir / "SFT_VALIDATION_REPORT.txt"
|
| 654 |
+
with open(report_path, "w", encoding="utf-8") as f:
|
| 655 |
+
f.write(report_text)
|
| 656 |
+
print(f"\n Report saved: {report_path}")
|
| 657 |
+
|
| 658 |
+
# Save detailed JSON results
|
| 659 |
+
json_path = out_dir / "validation_results.json"
|
| 660 |
+
summary = {
|
| 661 |
+
"meta": {
|
| 662 |
+
"checkpoint": args.ckpt,
|
| 663 |
+
"val_source": args.val_json,
|
| 664 |
+
"total_val_samples": len(all_val),
|
| 665 |
+
"n_tested": len(examples),
|
| 666 |
+
"device": device,
|
| 667 |
+
"max_len": args.max_len,
|
| 668 |
+
"temperature": args.temperature,
|
| 669 |
+
"top_k": args.top_k,
|
| 670 |
+
"top_p": args.top_p,
|
| 671 |
+
"repetition_penalty": args.rep_pen,
|
| 672 |
+
"timestamp": datetime.now().isoformat(),
|
| 673 |
+
},
|
| 674 |
+
"overall": {
|
| 675 |
+
"avg_loss": round(avg(all_losses), 4),
|
| 676 |
+
"avg_perplexity": round(avg(all_ppls), 2),
|
| 677 |
+
"median_perplexity": round(sorted(all_ppls)[len(all_ppls)//2], 2) if all_ppls else None,
|
| 678 |
+
"avg_token_accuracy": round(avg(all_accs), 4),
|
| 679 |
+
"avg_bleu_1": round(avg(all_b1), 4),
|
| 680 |
+
"avg_bleu_2": round(avg(all_b2), 4),
|
| 681 |
+
"avg_repetition_ratio": round(avg(all_reps), 4),
|
| 682 |
+
"avg_gen_length_words": round(avg(all_lens), 1),
|
| 683 |
+
"n_empty": n_empty,
|
| 684 |
+
"n_repetitive": n_repetitive,
|
| 685 |
+
"n_truncated": n_truncated,
|
| 686 |
+
"grade": grade,
|
| 687 |
+
},
|
| 688 |
+
"category_breakdown": {
|
| 689 |
+
cat: {
|
| 690 |
+
"count": len(m["losses"]),
|
| 691 |
+
"avg_loss": round(avg(m["losses"]), 4),
|
| 692 |
+
"avg_perplexity": round(avg(m["perplexities"]), 2),
|
| 693 |
+
"avg_token_accuracy": round(avg(m["accuracies"]), 4),
|
| 694 |
+
"avg_bleu_1": round(avg(m["bleu1"]), 4),
|
| 695 |
+
"avg_bleu_2": round(avg(m["bleu2"]), 4),
|
| 696 |
+
}
|
| 697 |
+
for cat, m in cat_metrics.items()
|
| 698 |
+
},
|
| 699 |
+
"per_example": results,
|
| 700 |
+
}
|
| 701 |
+
with open(json_path, "w", encoding="utf-8") as f:
|
| 702 |
+
json.dump(summary, f, indent=2, ensure_ascii=False)
|
| 703 |
+
print(f" JSON results: {json_path}")
|
| 704 |
+
|
| 705 |
+
|
| 706 |
+
if __name__ == "__main__":
|
| 707 |
+
main()
|