| import os |
| import sys |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" |
|
|
| import torch |
| import torch.nn.functional as F |
| import urllib.request |
| import sys |
| import math |
| from collections import Counter |
|
|
| from trigram import ( |
| VOCAB, EMBEDDING_DIM, HIDDEN_DIM, FFN_HIDDEN, CTX, THRESHOLD, |
| SPECIAL_VOCAB, MORPHTernaryModel, StickyZoneSTE, |
| ) |
|
|
| CKPT_DIR = os.path.join(os.path.dirname(__file__) or ".", "runs", "ternary-v1") |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
| def load_model_from(path): |
| model = MORPHTernaryModel().to(DEVICE) |
| if path is None: |
| return model |
| ckpt = torch.load(path, map_location=DEVICE, weights_only=False) |
| model.load_state_dict(ckpt["model_state_dict"]) |
| return model |
|
|
|
|
| @torch.no_grad() |
| def generate(model, seed_bytes, max_new_tokens=300, temperature=0.8, top_k=40): |
| model.eval() |
| idx = torch.tensor([seed_bytes], dtype=torch.long, device=DEVICE) |
| for _ in range(max_new_tokens): |
| idx_cond = idx[:, -CTX:] |
| if idx_cond.shape[1] < 3: |
| break |
| with torch.autocast("cuda", dtype=torch.bfloat16): |
| logits, _ = model(idx_cond) |
| last_logits = logits[:, -1, :] / temperature |
| if top_k is not None: |
| v, _ = torch.topk(last_logits, top_k) |
| last_logits[last_logits < v[:, [-1]]] = float("-inf") |
| probs = F.softmax(last_logits, dim=-1) |
| idx_next = torch.multinomial(probs, num_samples=1) |
| idx = torch.cat([idx, idx_next], dim=1) |
| return idx[0].cpu().tolist() |
|
|
|
|
| def bytes_to_text(byte_list): |
| readable = [] |
| for b in byte_list: |
| if 32 <= b < 127: |
| readable.append(chr(b)) |
| elif b == 10: |
| readable.append("\n") |
| elif b == 13: |
| readable.append("") |
| elif b == 9: |
| readable.append("\t") |
| elif b >= 256: |
| readable.append(f"<{b}>") |
| else: |
| readable.append(f"\\x{b:02x}") |
| return "".join(readable) |
|
|
|
|
| def byte_repetition_rate(byte_list): |
| if len(byte_list) < 2: |
| return 0.0 |
| bigrams = [(byte_list[i], byte_list[i+1]) for i in range(len(byte_list)-1)] |
| return 1.0 - len(set(bigrams)) / len(bigrams) |
|
|
|
|
| def byte_diversity(byte_list): |
| unique = len(set(b for b in byte_list if b < 256)) |
| return unique / 256.0 |
|
|
|
|
| def english_word_fraction(byte_list): |
| text = bytes_to_text(byte_list).lower() |
| words = text.split() |
| if not words: |
| return 0.0 |
| common = { |
| "the","and","that","have","for","not","with","you","this","but", |
| "his","they","her","she","will","would","there","their","what","which", |
| "out","all","were","your","when","who","him","been","has","more", |
| "my","than","its","can","no","do","is","it","me","so","as","if", |
| "am","be","of","at","by","an","or","in","to","a","i","on","we", |
| "our","us","from","them","he","was","are","had","did","shall", |
| "king","lord","sir","come","good","love","make","thee","thou", |
| "now","here","then","where","how","why","what","let","go","must", |
| "enter","exit","exeunt","act","scene", |
| } |
| recognized = sum(1 for w in words if w.strip(".,:;!?\"'()") in common) |
| return recognized / len(words) |
|
|
|
|
| def shakespeare_character_ratio(byte_list): |
| text = bytes_to_text(byte_list) |
| lines = text.split("\n") |
| char_lines = 0 |
| total_lines = 0 |
| for line in lines: |
| stripped = line.strip() |
| if not stripped: |
| continue |
| total_lines += 1 |
| if ":" in stripped and stripped.split(":")[0].strip().isupper(): |
| char_lines += 1 |
| return char_lines / max(total_lines, 1) |
|
|
|
|
| def printable_fraction(byte_list): |
| printable = sum(1 for b in byte_list if (32 <= b < 127) or b in (10, 13, 9)) |
| return printable / max(len(byte_list), 1) |
|
|
|
|
| SEEDS = { |
| "romeo": list(b"ROMEO:\nWhat light through yonder window breaks?\n"), |
| "king": list(b"KING RICHARD III:\nNow is the winter of our discontent\n"), |
| "hamlet": list(b"HAMLET:\nTo be, or not to be, that is the question:\n"), |
| "macbeth": list(b"MACBETH:\nTomorrow, and tomorrow, and tomorrow\n"), |
| "blank": list(b"\n"), |
| } |
|
|
| CHECKPOINTS = [ |
| ("init", None), |
| ("step5K", os.path.join(CKPT_DIR, "trigram-morph-step5000.pt")), |
| ("best", os.path.join(CKPT_DIR, "trigram-morph-best.pt")), |
| ("step13K", os.path.join(CKPT_DIR, "trigram-morph-step13000.pt")), |
| ("step25K", os.path.join(CKPT_DIR, "trigram-morph-step25000.pt")), |
| ] |
|
|
| TEMPS = [0.5, 0.8, 1.2] |
|
|
|
|
| def main(): |
| print(f"Device: {DEVICE}") |
| print("=" * 90) |
|
|
| n_gen = 400 |
|
|
| all_results = {} |
|
|
| for ckpt_label, ckpt_path in CHECKPOINTS: |
| model = load_model_from(ckpt_path) |
| print(f"\n{'=' * 90}") |
| print(f"CHECKPOINT: {ckpt_label}") |
| print(f"{'=' * 90}") |
|
|
| for seed_name, seed_bytes in SEEDS.items(): |
| for temp in TEMPS: |
| tag = f"{ckpt_label}/{seed_name}/t{temp}" |
| tokens = generate(model, seed_bytes, max_new_tokens=n_gen, temperature=temp, top_k=40) |
| text = bytes_to_text(tokens) |
|
|
| rep = byte_repetition_rate(tokens) |
| div = byte_diversity(tokens) |
| eng = english_word_fraction(tokens) |
| shk = shakespeare_character_ratio(tokens) |
| prn = printable_fraction(tokens) |
|
|
| all_results[tag] = { |
| "ckpt": ckpt_label, "seed": seed_name, "temp": temp, |
| "rep": rep, "div": div, "eng": eng, "shk": shk, "prn": prn, |
| "text": text, |
| } |
|
|
| print(f"\n--- {seed_name} seed, temp={temp} ---") |
| print(f" printable={prn:.2%} diversity={div:.2%} repetition={rep:.2%} english={eng:.2%} shakespeare_fmt={shk:.2%}") |
| for line in text.split("\n")[:6]: |
| print(f" | {line}") |
| remaining_lines = text.split("\n") |
| if len(remaining_lines) > 6: |
| print(f" | ... ({len(text)} chars, {len(remaining_lines)} lines)") |
|
|
| del model |
| if DEVICE == "cuda": |
| torch.cuda.empty_cache() |
|
|
| print(f"\n\n{'=' * 90}") |
| print("GENERATION QUALITY TABLE (averaged across seeds)") |
| print(f"{'=' * 90}") |
| print(f"{'Checkpoint':<12} {'Temp':>5} {'Print%':>7} {'Divers%':>8} {'Repeat%':>8} {'English%':>9} {'Shakesp%':>9}") |
| print(f"{'-'*12} {'-'*5} {'-'*7} {'-'*8} {'-'*8} {'-'*9} {'-'*9}") |
|
|
| for ckpt_label, _ in CHECKPOINTS: |
| for temp in TEMPS: |
| matching = [r for r in all_results.values() if r["ckpt"] == ckpt_label and r["temp"] == temp] |
| if not matching: |
| continue |
| avg_prn = sum(r["prn"] for r in matching) / len(matching) |
| avg_div = sum(r["div"] for r in matching) / len(matching) |
| avg_rep = sum(r["rep"] for r in matching) / len(matching) |
| avg_eng = sum(r["eng"] for r in matching) / len(matching) |
| avg_shk = sum(r["shk"] for r in matching) / len(matching) |
| print(f"{ckpt_label:<12} {temp:>5.1f} {avg_prn:>7.1%} {avg_div:>8.1%} {avg_rep:>8.1%} {avg_eng:>9.1%} {avg_shk:>9.1%}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|