""" eval_metrics.py — Generation quality metrics and BPB/perplexity helpers. Provides: - bpb_from_loss() : BPB = avg_loss / ln(2) (D-97) - perplexity_from_loss() : Perplexity = exp(avg_loss) (EVAL-02) - repetition_rate() : n-gram repetition rate (generalization of eval_generation.byte_repetition_rate) - distinct_n() : Standard NLP distinct-n metric - self_perplexity() : Model's own perplexity on generated text - assess_generation_quality() : Full generation quality assessment dict """ import math import sys import os import torch import torch.nn.functional as F sys.path.insert(0, os.path.dirname(__file__)) def bpb_from_loss(avg_loss): """Bits-per-byte via batch-average shortcut: BPB = loss / ln(2) (D-97).""" return avg_loss / math.log(2) def perplexity_from_loss(avg_loss): """Perplexity = exp(avg_loss) (EVAL-02).""" return math.exp(avg_loss) def repetition_rate(byte_list, n=2): """ Fraction of repeated n-grams. 1.0 - (unique_ngrams / total_ngrams). 0.0 for sequences shorter than n. Generalizes eval_generation.byte_repetition_rate (bigram only) to any n. """ if len(byte_list) < n: return 0.0 ngrams = [tuple(byte_list[i:i + n]) for i in range(len(byte_list) - n + 1)] if not ngrams: return 0.0 return 1.0 - len(set(ngrams)) / len(ngrams) def distinct_n(byte_list, n): """ Unique n-grams / total n-grams (standard NLP distinct-n metric). 0.0 for sequences shorter than n. """ if len(byte_list) < n: return 0.0 ngrams = [tuple(byte_list[i:i + n]) for i in range(len(byte_list) - n + 1)] if not ngrams: return 0.0 return len(set(ngrams)) / len(ngrams) def _printable_fraction(byte_list): """Fraction of bytes that are printable ASCII or common whitespace.""" printable = sum(1 for b in byte_list if (32 <= b < 127) or b in (10, 13, 9)) return printable / max(len(byte_list), 1) def _byte_diversity(byte_list): """Fraction of unique byte values used (out of 256).""" unique = len(set(b for b in byte_list if b < 256)) return unique / 256.0 @torch.no_grad() def self_perplexity(model, byte_list, ctx, device): """ Model's own perplexity on a generated byte sequence. Runs model forward on the byte sequence, computes average NLL per byte, returns exp(avg_nll). This is the model's confidence on its own output (D-98 discretion: use self-perplexity instead of KenLM). """ model.eval() byte_tensor = torch.tensor([byte_list], dtype=torch.long, device=device) targets = byte_tensor[:, 3:] with torch.no_grad(): _, loss_comps, _, _ = model(byte_tensor, targets=targets) avg_nll = loss_comps.total.item() return math.exp(avg_nll) @torch.no_grad() def assess_generation_quality(model, seed_bytes, max_new_tokens=500, ctx=64, device="cuda", temperature=0.8, top_k=40): """ Generate a 500+ byte sequence and compute all quality metrics. Uses model.generate()-style loop with top_k sampling. Returns dict with keys: repetition_rate_2, distinct_2, distinct_3, distinct_4, self_perplexity, printable_fraction, byte_diversity, n_bytes """ model.eval() # Generate byte sequence idx = torch.tensor([seed_bytes], dtype=torch.long, device=device) for _ in range(max_new_tokens): idx_cond = idx[:, -ctx:] if idx_cond.shape[1] < 3: break with torch.no_grad(): logits, _, _, _ = model(idx_cond) last_logits = logits[:, -1, :] / temperature if top_k is not None: v, _ = torch.topk(last_logits, top_k) last_logits[last_logits < v[:, [-1]]] = float("-inf") probs = F.softmax(last_logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1) idx = torch.cat([idx, idx_next], dim=1) generated = idx[0].cpu().tolist() # Trim seed bytes — keep only newly generated portion byte_list = generated[len(seed_bytes):] n_bytes = len(byte_list) return { "repetition_rate_2": repetition_rate(byte_list, n=2), "distinct_2": distinct_n(byte_list, n=2), "distinct_3": distinct_n(byte_list, n=3), "distinct_4": distinct_n(byte_list, n=4), "self_perplexity": self_perplexity(model, byte_list, ctx, device), "printable_fraction": _printable_fraction(byte_list), "byte_diversity": _byte_diversity(byte_list), "n_bytes": n_bytes, }