File size: 4,556 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""
eval_metrics.py — Generation quality metrics and BPB/perplexity helpers.

Provides:
  - bpb_from_loss()        : BPB = avg_loss / ln(2)  (D-97)
  - perplexity_from_loss() : Perplexity = exp(avg_loss)  (EVAL-02)
  - repetition_rate()      : n-gram repetition rate (generalization of eval_generation.byte_repetition_rate)
  - distinct_n()           : Standard NLP distinct-n metric
  - self_perplexity()      : Model's own perplexity on generated text
  - assess_generation_quality() : Full generation quality assessment dict
"""

import math
import sys
import os

import torch
import torch.nn.functional as F

sys.path.insert(0, os.path.dirname(__file__))


def bpb_from_loss(avg_loss):
    """Bits-per-byte via batch-average shortcut: BPB = loss / ln(2) (D-97)."""
    return avg_loss / math.log(2)


def perplexity_from_loss(avg_loss):
    """Perplexity = exp(avg_loss) (EVAL-02)."""
    return math.exp(avg_loss)


def repetition_rate(byte_list, n=2):
    """
    Fraction of repeated n-grams.
    1.0 - (unique_ngrams / total_ngrams). 0.0 for sequences shorter than n.

    Generalizes eval_generation.byte_repetition_rate (bigram only) to any n.
    """
    if len(byte_list) < n:
        return 0.0
    ngrams = [tuple(byte_list[i:i + n]) for i in range(len(byte_list) - n + 1)]
    if not ngrams:
        return 0.0
    return 1.0 - len(set(ngrams)) / len(ngrams)


def distinct_n(byte_list, n):
    """
    Unique n-grams / total n-grams (standard NLP distinct-n metric).
    0.0 for sequences shorter than n.
    """
    if len(byte_list) < n:
        return 0.0
    ngrams = [tuple(byte_list[i:i + n]) for i in range(len(byte_list) - n + 1)]
    if not ngrams:
        return 0.0
    return len(set(ngrams)) / len(ngrams)


def _printable_fraction(byte_list):
    """Fraction of bytes that are printable ASCII or common whitespace."""
    printable = sum(1 for b in byte_list if (32 <= b < 127) or b in (10, 13, 9))
    return printable / max(len(byte_list), 1)


def _byte_diversity(byte_list):
    """Fraction of unique byte values used (out of 256)."""
    unique = len(set(b for b in byte_list if b < 256))
    return unique / 256.0


@torch.no_grad()
def self_perplexity(model, byte_list, ctx, device):
    """
    Model's own perplexity on a generated byte sequence.

    Runs model forward on the byte sequence, computes average NLL per byte,
    returns exp(avg_nll). This is the model's confidence on its own output
    (D-98 discretion: use self-perplexity instead of KenLM).
    """
    model.eval()
    byte_tensor = torch.tensor([byte_list], dtype=torch.long, device=device)
    targets = byte_tensor[:, 3:]
    with torch.no_grad():
        _, loss_comps, _, _ = model(byte_tensor, targets=targets)
        avg_nll = loss_comps.total.item()
    return math.exp(avg_nll)


@torch.no_grad()
def assess_generation_quality(model, seed_bytes, max_new_tokens=500, ctx=64,
                              device="cuda", temperature=0.8, top_k=40):
    """
    Generate a 500+ byte sequence and compute all quality metrics.

    Uses model.generate()-style loop with top_k sampling.
    Returns dict with keys:
      repetition_rate_2, distinct_2, distinct_3, distinct_4,
      self_perplexity, printable_fraction, byte_diversity, n_bytes
    """
    model.eval()

    # Generate byte sequence
    idx = torch.tensor([seed_bytes], dtype=torch.long, device=device)
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -ctx:]
        if idx_cond.shape[1] < 3:
            break
        with torch.no_grad():
            logits, _, _, _ = model(idx_cond)
        last_logits = logits[:, -1, :] / temperature
        if top_k is not None:
            v, _ = torch.topk(last_logits, top_k)
            last_logits[last_logits < v[:, [-1]]] = float("-inf")
        probs = F.softmax(last_logits, dim=-1)
        idx_next = torch.multinomial(probs, num_samples=1)
        idx = torch.cat([idx, idx_next], dim=1)

    generated = idx[0].cpu().tolist()
    # Trim seed bytes — keep only newly generated portion
    byte_list = generated[len(seed_bytes):]
    n_bytes = len(byte_list)

    return {
        "repetition_rate_2": repetition_rate(byte_list, n=2),
        "distinct_2": distinct_n(byte_list, n=2),
        "distinct_3": distinct_n(byte_list, n=3),
        "distinct_4": distinct_n(byte_list, n=4),
        "self_perplexity": self_perplexity(model, byte_list, ctx, device),
        "printable_fraction": _printable_fraction(byte_list),
        "byte_diversity": _byte_diversity(byte_list),
        "n_bytes": n_bytes,
    }