File size: 4,556 Bytes
d8bc908 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | """
eval_metrics.py — Generation quality metrics and BPB/perplexity helpers.
Provides:
- bpb_from_loss() : BPB = avg_loss / ln(2) (D-97)
- perplexity_from_loss() : Perplexity = exp(avg_loss) (EVAL-02)
- repetition_rate() : n-gram repetition rate (generalization of eval_generation.byte_repetition_rate)
- distinct_n() : Standard NLP distinct-n metric
- self_perplexity() : Model's own perplexity on generated text
- assess_generation_quality() : Full generation quality assessment dict
"""
import math
import sys
import os
import torch
import torch.nn.functional as F
sys.path.insert(0, os.path.dirname(__file__))
def bpb_from_loss(avg_loss):
"""Bits-per-byte via batch-average shortcut: BPB = loss / ln(2) (D-97)."""
return avg_loss / math.log(2)
def perplexity_from_loss(avg_loss):
"""Perplexity = exp(avg_loss) (EVAL-02)."""
return math.exp(avg_loss)
def repetition_rate(byte_list, n=2):
"""
Fraction of repeated n-grams.
1.0 - (unique_ngrams / total_ngrams). 0.0 for sequences shorter than n.
Generalizes eval_generation.byte_repetition_rate (bigram only) to any n.
"""
if len(byte_list) < n:
return 0.0
ngrams = [tuple(byte_list[i:i + n]) for i in range(len(byte_list) - n + 1)]
if not ngrams:
return 0.0
return 1.0 - len(set(ngrams)) / len(ngrams)
def distinct_n(byte_list, n):
"""
Unique n-grams / total n-grams (standard NLP distinct-n metric).
0.0 for sequences shorter than n.
"""
if len(byte_list) < n:
return 0.0
ngrams = [tuple(byte_list[i:i + n]) for i in range(len(byte_list) - n + 1)]
if not ngrams:
return 0.0
return len(set(ngrams)) / len(ngrams)
def _printable_fraction(byte_list):
"""Fraction of bytes that are printable ASCII or common whitespace."""
printable = sum(1 for b in byte_list if (32 <= b < 127) or b in (10, 13, 9))
return printable / max(len(byte_list), 1)
def _byte_diversity(byte_list):
"""Fraction of unique byte values used (out of 256)."""
unique = len(set(b for b in byte_list if b < 256))
return unique / 256.0
@torch.no_grad()
def self_perplexity(model, byte_list, ctx, device):
"""
Model's own perplexity on a generated byte sequence.
Runs model forward on the byte sequence, computes average NLL per byte,
returns exp(avg_nll). This is the model's confidence on its own output
(D-98 discretion: use self-perplexity instead of KenLM).
"""
model.eval()
byte_tensor = torch.tensor([byte_list], dtype=torch.long, device=device)
targets = byte_tensor[:, 3:]
with torch.no_grad():
_, loss_comps, _, _ = model(byte_tensor, targets=targets)
avg_nll = loss_comps.total.item()
return math.exp(avg_nll)
@torch.no_grad()
def assess_generation_quality(model, seed_bytes, max_new_tokens=500, ctx=64,
device="cuda", temperature=0.8, top_k=40):
"""
Generate a 500+ byte sequence and compute all quality metrics.
Uses model.generate()-style loop with top_k sampling.
Returns dict with keys:
repetition_rate_2, distinct_2, distinct_3, distinct_4,
self_perplexity, printable_fraction, byte_diversity, n_bytes
"""
model.eval()
# Generate byte sequence
idx = torch.tensor([seed_bytes], dtype=torch.long, device=device)
for _ in range(max_new_tokens):
idx_cond = idx[:, -ctx:]
if idx_cond.shape[1] < 3:
break
with torch.no_grad():
logits, _, _, _ = model(idx_cond)
last_logits = logits[:, -1, :] / temperature
if top_k is not None:
v, _ = torch.topk(last_logits, top_k)
last_logits[last_logits < v[:, [-1]]] = float("-inf")
probs = F.softmax(last_logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat([idx, idx_next], dim=1)
generated = idx[0].cpu().tolist()
# Trim seed bytes — keep only newly generated portion
byte_list = generated[len(seed_bytes):]
n_bytes = len(byte_list)
return {
"repetition_rate_2": repetition_rate(byte_list, n=2),
"distinct_2": distinct_n(byte_list, n=2),
"distinct_3": distinct_n(byte_list, n=3),
"distinct_4": distinct_n(byte_list, n=4),
"self_perplexity": self_perplexity(model, byte_list, ctx, device),
"printable_fraction": _printable_fraction(byte_list),
"byte_diversity": _byte_diversity(byte_list),
"n_bytes": n_bytes,
}
|