| """ |
| eval_metrics.py — Generation quality metrics and BPB/perplexity helpers. |
| |
| Provides: |
| - bpb_from_loss() : BPB = avg_loss / ln(2) (D-97) |
| - perplexity_from_loss() : Perplexity = exp(avg_loss) (EVAL-02) |
| - repetition_rate() : n-gram repetition rate (generalization of eval_generation.byte_repetition_rate) |
| - distinct_n() : Standard NLP distinct-n metric |
| - self_perplexity() : Model's own perplexity on generated text |
| - assess_generation_quality() : Full generation quality assessment dict |
| """ |
|
|
| import math |
| import sys |
| import os |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| sys.path.insert(0, os.path.dirname(__file__)) |
|
|
|
|
| def bpb_from_loss(avg_loss): |
| """Bits-per-byte via batch-average shortcut: BPB = loss / ln(2) (D-97).""" |
| return avg_loss / math.log(2) |
|
|
|
|
| def perplexity_from_loss(avg_loss): |
| """Perplexity = exp(avg_loss) (EVAL-02).""" |
| return math.exp(avg_loss) |
|
|
|
|
| def repetition_rate(byte_list, n=2): |
| """ |
| Fraction of repeated n-grams. |
| 1.0 - (unique_ngrams / total_ngrams). 0.0 for sequences shorter than n. |
| |
| Generalizes eval_generation.byte_repetition_rate (bigram only) to any n. |
| """ |
| if len(byte_list) < n: |
| return 0.0 |
| ngrams = [tuple(byte_list[i:i + n]) for i in range(len(byte_list) - n + 1)] |
| if not ngrams: |
| return 0.0 |
| return 1.0 - len(set(ngrams)) / len(ngrams) |
|
|
|
|
| def distinct_n(byte_list, n): |
| """ |
| Unique n-grams / total n-grams (standard NLP distinct-n metric). |
| 0.0 for sequences shorter than n. |
| """ |
| if len(byte_list) < n: |
| return 0.0 |
| ngrams = [tuple(byte_list[i:i + n]) for i in range(len(byte_list) - n + 1)] |
| if not ngrams: |
| return 0.0 |
| return len(set(ngrams)) / len(ngrams) |
|
|
|
|
| def _printable_fraction(byte_list): |
| """Fraction of bytes that are printable ASCII or common whitespace.""" |
| printable = sum(1 for b in byte_list if (32 <= b < 127) or b in (10, 13, 9)) |
| return printable / max(len(byte_list), 1) |
|
|
|
|
| def _byte_diversity(byte_list): |
| """Fraction of unique byte values used (out of 256).""" |
| unique = len(set(b for b in byte_list if b < 256)) |
| return unique / 256.0 |
|
|
|
|
| @torch.no_grad() |
| def self_perplexity(model, byte_list, ctx, device): |
| """ |
| Model's own perplexity on a generated byte sequence. |
| |
| Runs model forward on the byte sequence, computes average NLL per byte, |
| returns exp(avg_nll). This is the model's confidence on its own output |
| (D-98 discretion: use self-perplexity instead of KenLM). |
| """ |
| model.eval() |
| byte_tensor = torch.tensor([byte_list], dtype=torch.long, device=device) |
| targets = byte_tensor[:, 3:] |
| with torch.no_grad(): |
| _, loss_comps, _, _ = model(byte_tensor, targets=targets) |
| avg_nll = loss_comps.total.item() |
| return math.exp(avg_nll) |
|
|
|
|
| @torch.no_grad() |
| def assess_generation_quality(model, seed_bytes, max_new_tokens=500, ctx=64, |
| device="cuda", temperature=0.8, top_k=40): |
| """ |
| Generate a 500+ byte sequence and compute all quality metrics. |
| |
| Uses model.generate()-style loop with top_k sampling. |
| Returns dict with keys: |
| repetition_rate_2, distinct_2, distinct_3, distinct_4, |
| self_perplexity, printable_fraction, byte_diversity, n_bytes |
| """ |
| model.eval() |
|
|
| |
| idx = torch.tensor([seed_bytes], dtype=torch.long, device=device) |
| for _ in range(max_new_tokens): |
| idx_cond = idx[:, -ctx:] |
| if idx_cond.shape[1] < 3: |
| break |
| with torch.no_grad(): |
| logits, _, _, _ = model(idx_cond) |
| last_logits = logits[:, -1, :] / temperature |
| if top_k is not None: |
| v, _ = torch.topk(last_logits, top_k) |
| last_logits[last_logits < v[:, [-1]]] = float("-inf") |
| probs = F.softmax(last_logits, dim=-1) |
| idx_next = torch.multinomial(probs, num_samples=1) |
| idx = torch.cat([idx, idx_next], dim=1) |
|
|
| generated = idx[0].cpu().tolist() |
| |
| byte_list = generated[len(seed_bytes):] |
| n_bytes = len(byte_list) |
|
|
| return { |
| "repetition_rate_2": repetition_rate(byte_list, n=2), |
| "distinct_2": distinct_n(byte_list, n=2), |
| "distinct_3": distinct_n(byte_list, n=3), |
| "distinct_4": distinct_n(byte_list, n=4), |
| "self_perplexity": self_perplexity(model, byte_list, ctx, device), |
| "printable_fraction": _printable_fraction(byte_list), |
| "byte_diversity": _byte_diversity(byte_list), |
| "n_bytes": n_bytes, |
| } |
|
|