ARBS / testing /eval /eval_generation.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
import torch.nn.functional as F
import urllib.request
import sys
import math
from collections import Counter
from trigram import (
VOCAB, EMBEDDING_DIM, HIDDEN_DIM, FFN_HIDDEN, CTX, THRESHOLD,
SPECIAL_VOCAB, MORPHTernaryModel, StickyZoneSTE,
)
CKPT_DIR = os.path.join(os.path.dirname(__file__) or ".", "runs", "ternary-v1")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def load_model_from(path):
model = MORPHTernaryModel().to(DEVICE)
if path is None:
return model
ckpt = torch.load(path, map_location=DEVICE, weights_only=False)
model.load_state_dict(ckpt["model_state_dict"])
return model
@torch.no_grad()
def generate(model, seed_bytes, max_new_tokens=300, temperature=0.8, top_k=40):
model.eval()
idx = torch.tensor([seed_bytes], dtype=torch.long, device=DEVICE)
for _ in range(max_new_tokens):
idx_cond = idx[:, -CTX:]
if idx_cond.shape[1] < 3:
break
with torch.autocast("cuda", dtype=torch.bfloat16):
logits, _ = model(idx_cond)
last_logits = logits[:, -1, :] / temperature
if top_k is not None:
v, _ = torch.topk(last_logits, top_k)
last_logits[last_logits < v[:, [-1]]] = float("-inf")
probs = F.softmax(last_logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat([idx, idx_next], dim=1)
return idx[0].cpu().tolist()
def bytes_to_text(byte_list):
readable = []
for b in byte_list:
if 32 <= b < 127:
readable.append(chr(b))
elif b == 10:
readable.append("\n")
elif b == 13:
readable.append("")
elif b == 9:
readable.append("\t")
elif b >= 256:
readable.append(f"<{b}>")
else:
readable.append(f"\\x{b:02x}")
return "".join(readable)
def byte_repetition_rate(byte_list):
if len(byte_list) < 2:
return 0.0
bigrams = [(byte_list[i], byte_list[i+1]) for i in range(len(byte_list)-1)]
return 1.0 - len(set(bigrams)) / len(bigrams)
def byte_diversity(byte_list):
unique = len(set(b for b in byte_list if b < 256))
return unique / 256.0
def english_word_fraction(byte_list):
text = bytes_to_text(byte_list).lower()
words = text.split()
if not words:
return 0.0
common = {
"the","and","that","have","for","not","with","you","this","but",
"his","they","her","she","will","would","there","their","what","which",
"out","all","were","your","when","who","him","been","has","more",
"my","than","its","can","no","do","is","it","me","so","as","if",
"am","be","of","at","by","an","or","in","to","a","i","on","we",
"our","us","from","them","he","was","are","had","did","shall",
"king","lord","sir","come","good","love","make","thee","thou",
"now","here","then","where","how","why","what","let","go","must",
"enter","exit","exeunt","act","scene",
}
recognized = sum(1 for w in words if w.strip(".,:;!?\"'()") in common)
return recognized / len(words)
def shakespeare_character_ratio(byte_list):
text = bytes_to_text(byte_list)
lines = text.split("\n")
char_lines = 0
total_lines = 0
for line in lines:
stripped = line.strip()
if not stripped:
continue
total_lines += 1
if ":" in stripped and stripped.split(":")[0].strip().isupper():
char_lines += 1
return char_lines / max(total_lines, 1)
def printable_fraction(byte_list):
printable = sum(1 for b in byte_list if (32 <= b < 127) or b in (10, 13, 9))
return printable / max(len(byte_list), 1)
SEEDS = {
"romeo": list(b"ROMEO:\nWhat light through yonder window breaks?\n"),
"king": list(b"KING RICHARD III:\nNow is the winter of our discontent\n"),
"hamlet": list(b"HAMLET:\nTo be, or not to be, that is the question:\n"),
"macbeth": list(b"MACBETH:\nTomorrow, and tomorrow, and tomorrow\n"),
"blank": list(b"\n"),
}
CHECKPOINTS = [
("init", None),
("step5K", os.path.join(CKPT_DIR, "trigram-morph-step5000.pt")),
("best", os.path.join(CKPT_DIR, "trigram-morph-best.pt")),
("step13K", os.path.join(CKPT_DIR, "trigram-morph-step13000.pt")),
("step25K", os.path.join(CKPT_DIR, "trigram-morph-step25000.pt")),
]
TEMPS = [0.5, 0.8, 1.2]
def main():
print(f"Device: {DEVICE}")
print("=" * 90)
n_gen = 400
all_results = {}
for ckpt_label, ckpt_path in CHECKPOINTS:
model = load_model_from(ckpt_path)
print(f"\n{'=' * 90}")
print(f"CHECKPOINT: {ckpt_label}")
print(f"{'=' * 90}")
for seed_name, seed_bytes in SEEDS.items():
for temp in TEMPS:
tag = f"{ckpt_label}/{seed_name}/t{temp}"
tokens = generate(model, seed_bytes, max_new_tokens=n_gen, temperature=temp, top_k=40)
text = bytes_to_text(tokens)
rep = byte_repetition_rate(tokens)
div = byte_diversity(tokens)
eng = english_word_fraction(tokens)
shk = shakespeare_character_ratio(tokens)
prn = printable_fraction(tokens)
all_results[tag] = {
"ckpt": ckpt_label, "seed": seed_name, "temp": temp,
"rep": rep, "div": div, "eng": eng, "shk": shk, "prn": prn,
"text": text,
}
print(f"\n--- {seed_name} seed, temp={temp} ---")
print(f" printable={prn:.2%} diversity={div:.2%} repetition={rep:.2%} english={eng:.2%} shakespeare_fmt={shk:.2%}")
for line in text.split("\n")[:6]:
print(f" | {line}")
remaining_lines = text.split("\n")
if len(remaining_lines) > 6:
print(f" | ... ({len(text)} chars, {len(remaining_lines)} lines)")
del model
if DEVICE == "cuda":
torch.cuda.empty_cache()
print(f"\n\n{'=' * 90}")
print("GENERATION QUALITY TABLE (averaged across seeds)")
print(f"{'=' * 90}")
print(f"{'Checkpoint':<12} {'Temp':>5} {'Print%':>7} {'Divers%':>8} {'Repeat%':>8} {'English%':>9} {'Shakesp%':>9}")
print(f"{'-'*12} {'-'*5} {'-'*7} {'-'*8} {'-'*8} {'-'*9} {'-'*9}")
for ckpt_label, _ in CHECKPOINTS:
for temp in TEMPS:
matching = [r for r in all_results.values() if r["ckpt"] == ckpt_label and r["temp"] == temp]
if not matching:
continue
avg_prn = sum(r["prn"] for r in matching) / len(matching)
avg_div = sum(r["div"] for r in matching) / len(matching)
avg_rep = sum(r["rep"] for r in matching) / len(matching)
avg_eng = sum(r["eng"] for r in matching) / len(matching)
avg_shk = sum(r["shk"] for r in matching) / len(matching)
print(f"{ckpt_label:<12} {temp:>5.1f} {avg_prn:>7.1%} {avg_div:>8.1%} {avg_rep:>8.1%} {avg_eng:>9.1%} {avg_shk:>9.1%}")
if __name__ == "__main__":
main()