ARBS / testing /eval /eval_metrics.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
"""
eval_metrics.py — Generation quality metrics and BPB/perplexity helpers.
Provides:
- bpb_from_loss() : BPB = avg_loss / ln(2) (D-97)
- perplexity_from_loss() : Perplexity = exp(avg_loss) (EVAL-02)
- repetition_rate() : n-gram repetition rate (generalization of eval_generation.byte_repetition_rate)
- distinct_n() : Standard NLP distinct-n metric
- self_perplexity() : Model's own perplexity on generated text
- assess_generation_quality() : Full generation quality assessment dict
"""
import math
import sys
import os
import torch
import torch.nn.functional as F
sys.path.insert(0, os.path.dirname(__file__))
def bpb_from_loss(avg_loss):
"""Bits-per-byte via batch-average shortcut: BPB = loss / ln(2) (D-97)."""
return avg_loss / math.log(2)
def perplexity_from_loss(avg_loss):
"""Perplexity = exp(avg_loss) (EVAL-02)."""
return math.exp(avg_loss)
def repetition_rate(byte_list, n=2):
"""
Fraction of repeated n-grams.
1.0 - (unique_ngrams / total_ngrams). 0.0 for sequences shorter than n.
Generalizes eval_generation.byte_repetition_rate (bigram only) to any n.
"""
if len(byte_list) < n:
return 0.0
ngrams = [tuple(byte_list[i:i + n]) for i in range(len(byte_list) - n + 1)]
if not ngrams:
return 0.0
return 1.0 - len(set(ngrams)) / len(ngrams)
def distinct_n(byte_list, n):
"""
Unique n-grams / total n-grams (standard NLP distinct-n metric).
0.0 for sequences shorter than n.
"""
if len(byte_list) < n:
return 0.0
ngrams = [tuple(byte_list[i:i + n]) for i in range(len(byte_list) - n + 1)]
if not ngrams:
return 0.0
return len(set(ngrams)) / len(ngrams)
def _printable_fraction(byte_list):
"""Fraction of bytes that are printable ASCII or common whitespace."""
printable = sum(1 for b in byte_list if (32 <= b < 127) or b in (10, 13, 9))
return printable / max(len(byte_list), 1)
def _byte_diversity(byte_list):
"""Fraction of unique byte values used (out of 256)."""
unique = len(set(b for b in byte_list if b < 256))
return unique / 256.0
@torch.no_grad()
def self_perplexity(model, byte_list, ctx, device):
"""
Model's own perplexity on a generated byte sequence.
Runs model forward on the byte sequence, computes average NLL per byte,
returns exp(avg_nll). This is the model's confidence on its own output
(D-98 discretion: use self-perplexity instead of KenLM).
"""
model.eval()
byte_tensor = torch.tensor([byte_list], dtype=torch.long, device=device)
targets = byte_tensor[:, 3:]
with torch.no_grad():
_, loss_comps, _, _ = model(byte_tensor, targets=targets)
avg_nll = loss_comps.total.item()
return math.exp(avg_nll)
@torch.no_grad()
def assess_generation_quality(model, seed_bytes, max_new_tokens=500, ctx=64,
device="cuda", temperature=0.8, top_k=40):
"""
Generate a 500+ byte sequence and compute all quality metrics.
Uses model.generate()-style loop with top_k sampling.
Returns dict with keys:
repetition_rate_2, distinct_2, distinct_3, distinct_4,
self_perplexity, printable_fraction, byte_diversity, n_bytes
"""
model.eval()
# Generate byte sequence
idx = torch.tensor([seed_bytes], dtype=torch.long, device=device)
for _ in range(max_new_tokens):
idx_cond = idx[:, -ctx:]
if idx_cond.shape[1] < 3:
break
with torch.no_grad():
logits, _, _, _ = model(idx_cond)
last_logits = logits[:, -1, :] / temperature
if top_k is not None:
v, _ = torch.topk(last_logits, top_k)
last_logits[last_logits < v[:, [-1]]] = float("-inf")
probs = F.softmax(last_logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat([idx, idx_next], dim=1)
generated = idx[0].cpu().tolist()
# Trim seed bytes — keep only newly generated portion
byte_list = generated[len(seed_bytes):]
n_bytes = len(byte_list)
return {
"repetition_rate_2": repetition_rate(byte_list, n=2),
"distinct_2": distinct_n(byte_list, n=2),
"distinct_3": distinct_n(byte_list, n=3),
"distinct_4": distinct_n(byte_list, n=4),
"self_perplexity": self_perplexity(model, byte_list, ctx, device),
"printable_fraction": _printable_fraction(byte_list),
"byte_diversity": _byte_diversity(byte_list),
"n_bytes": n_bytes,
}