import os import sys sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" import torch import torch.nn.functional as F import urllib.request import sys import math from collections import Counter from trigram import ( VOCAB, EMBEDDING_DIM, HIDDEN_DIM, FFN_HIDDEN, CTX, THRESHOLD, SPECIAL_VOCAB, MORPHTernaryModel, StickyZoneSTE, ) CKPT_DIR = os.path.join(os.path.dirname(__file__) or ".", "runs", "ternary-v1") DEVICE = "cuda" if torch.cuda.is_available() else "cpu" def load_model_from(path): model = MORPHTernaryModel().to(DEVICE) if path is None: return model ckpt = torch.load(path, map_location=DEVICE, weights_only=False) model.load_state_dict(ckpt["model_state_dict"]) return model @torch.no_grad() def generate(model, seed_bytes, max_new_tokens=300, temperature=0.8, top_k=40): model.eval() idx = torch.tensor([seed_bytes], dtype=torch.long, device=DEVICE) for _ in range(max_new_tokens): idx_cond = idx[:, -CTX:] if idx_cond.shape[1] < 3: break with torch.autocast("cuda", dtype=torch.bfloat16): logits, _ = model(idx_cond) last_logits = logits[:, -1, :] / temperature if top_k is not None: v, _ = torch.topk(last_logits, top_k) last_logits[last_logits < v[:, [-1]]] = float("-inf") probs = F.softmax(last_logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1) idx = torch.cat([idx, idx_next], dim=1) return idx[0].cpu().tolist() def bytes_to_text(byte_list): readable = [] for b in byte_list: if 32 <= b < 127: readable.append(chr(b)) elif b == 10: readable.append("\n") elif b == 13: readable.append("") elif b == 9: readable.append("\t") elif b >= 256: readable.append(f"<{b}>") else: readable.append(f"\\x{b:02x}") return "".join(readable) def byte_repetition_rate(byte_list): if len(byte_list) < 2: return 0.0 bigrams = [(byte_list[i], byte_list[i+1]) for i in range(len(byte_list)-1)] return 1.0 - len(set(bigrams)) / len(bigrams) def byte_diversity(byte_list): unique = len(set(b for b in byte_list if b < 256)) return unique / 256.0 def english_word_fraction(byte_list): text = bytes_to_text(byte_list).lower() words = text.split() if not words: return 0.0 common = { "the","and","that","have","for","not","with","you","this","but", "his","they","her","she","will","would","there","their","what","which", "out","all","were","your","when","who","him","been","has","more", "my","than","its","can","no","do","is","it","me","so","as","if", "am","be","of","at","by","an","or","in","to","a","i","on","we", "our","us","from","them","he","was","are","had","did","shall", "king","lord","sir","come","good","love","make","thee","thou", "now","here","then","where","how","why","what","let","go","must", "enter","exit","exeunt","act","scene", } recognized = sum(1 for w in words if w.strip(".,:;!?\"'()") in common) return recognized / len(words) def shakespeare_character_ratio(byte_list): text = bytes_to_text(byte_list) lines = text.split("\n") char_lines = 0 total_lines = 0 for line in lines: stripped = line.strip() if not stripped: continue total_lines += 1 if ":" in stripped and stripped.split(":")[0].strip().isupper(): char_lines += 1 return char_lines / max(total_lines, 1) def printable_fraction(byte_list): printable = sum(1 for b in byte_list if (32 <= b < 127) or b in (10, 13, 9)) return printable / max(len(byte_list), 1) SEEDS = { "romeo": list(b"ROMEO:\nWhat light through yonder window breaks?\n"), "king": list(b"KING RICHARD III:\nNow is the winter of our discontent\n"), "hamlet": list(b"HAMLET:\nTo be, or not to be, that is the question:\n"), "macbeth": list(b"MACBETH:\nTomorrow, and tomorrow, and tomorrow\n"), "blank": list(b"\n"), } CHECKPOINTS = [ ("init", None), ("step5K", os.path.join(CKPT_DIR, "trigram-morph-step5000.pt")), ("best", os.path.join(CKPT_DIR, "trigram-morph-best.pt")), ("step13K", os.path.join(CKPT_DIR, "trigram-morph-step13000.pt")), ("step25K", os.path.join(CKPT_DIR, "trigram-morph-step25000.pt")), ] TEMPS = [0.5, 0.8, 1.2] def main(): print(f"Device: {DEVICE}") print("=" * 90) n_gen = 400 all_results = {} for ckpt_label, ckpt_path in CHECKPOINTS: model = load_model_from(ckpt_path) print(f"\n{'=' * 90}") print(f"CHECKPOINT: {ckpt_label}") print(f"{'=' * 90}") for seed_name, seed_bytes in SEEDS.items(): for temp in TEMPS: tag = f"{ckpt_label}/{seed_name}/t{temp}" tokens = generate(model, seed_bytes, max_new_tokens=n_gen, temperature=temp, top_k=40) text = bytes_to_text(tokens) rep = byte_repetition_rate(tokens) div = byte_diversity(tokens) eng = english_word_fraction(tokens) shk = shakespeare_character_ratio(tokens) prn = printable_fraction(tokens) all_results[tag] = { "ckpt": ckpt_label, "seed": seed_name, "temp": temp, "rep": rep, "div": div, "eng": eng, "shk": shk, "prn": prn, "text": text, } print(f"\n--- {seed_name} seed, temp={temp} ---") print(f" printable={prn:.2%} diversity={div:.2%} repetition={rep:.2%} english={eng:.2%} shakespeare_fmt={shk:.2%}") for line in text.split("\n")[:6]: print(f" | {line}") remaining_lines = text.split("\n") if len(remaining_lines) > 6: print(f" | ... ({len(text)} chars, {len(remaining_lines)} lines)") del model if DEVICE == "cuda": torch.cuda.empty_cache() print(f"\n\n{'=' * 90}") print("GENERATION QUALITY TABLE (averaged across seeds)") print(f"{'=' * 90}") print(f"{'Checkpoint':<12} {'Temp':>5} {'Print%':>7} {'Divers%':>8} {'Repeat%':>8} {'English%':>9} {'Shakesp%':>9}") print(f"{'-'*12} {'-'*5} {'-'*7} {'-'*8} {'-'*8} {'-'*9} {'-'*9}") for ckpt_label, _ in CHECKPOINTS: for temp in TEMPS: matching = [r for r in all_results.values() if r["ckpt"] == ckpt_label and r["temp"] == temp] if not matching: continue avg_prn = sum(r["prn"] for r in matching) / len(matching) avg_div = sum(r["div"] for r in matching) / len(matching) avg_rep = sum(r["rep"] for r in matching) / len(matching) avg_eng = sum(r["eng"] for r in matching) / len(matching) avg_shk = sum(r["shk"] for r in matching) / len(matching) print(f"{ckpt_label:<12} {temp:>5.1f} {avg_prn:>7.1%} {avg_div:>8.1%} {avg_rep:>8.1%} {avg_eng:>9.1%} {avg_shk:>9.1%}") if __name__ == "__main__": main()