""" LUNA 100M — Text Generation / Interactive Chat Usage: python generate.py # interactive REPL python generate.py --prompt "The future of AI is" # single prompt python generate.py --ckpt Base/out/luna_100m/latest.pt # custom checkpoint python generate.py --max_new 200 --temp 0.8 --top_p 0.9 # tune generation """ import sys import math import argparse import torch import torch.nn as nn import torch.nn.functional as F from pathlib import Path # ─── Model (must match train.py exactly) ────────────────────────────────────── class RotaryEmbedding(nn.Module): def __init__(self, dim, max_seq_len=1024): super().__init__() inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) t = torch.arange(max_seq_len).float() freqs = torch.einsum("i,j->ij", t, inv_freq) emb = torch.cat([freqs, freqs], dim=-1) self.register_buffer("cos_cached", emb.cos()) self.register_buffer("sin_cached", emb.sin()) def forward(self, seq_len): return self.cos_cached[:seq_len], self.sin_cached[:seq_len] def rotate_half(x): x1, x2 = x.chunk(2, dim=-1) return torch.cat([-x2, x1], dim=-1) def apply_rotary(x, cos, sin): c = cos.unsqueeze(0).unsqueeze(0) s = sin.unsqueeze(0).unsqueeze(0) return x * c + rotate_half(x) * s class CausalSelfAttention(nn.Module): def __init__(self, n_embd, n_head, block_size, rotary_pct=0.25): super().__init__() self.n_head = n_head self.head_dim = n_embd // n_head self.rot_dim = int(self.head_dim * rotary_pct) self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=True) self.c_proj = nn.Linear(n_embd, n_embd, bias=True) self.rotary = RotaryEmbedding(self.rot_dim, block_size) def forward(self, x): B, T, C = x.size() qkv = self.c_attn(x).reshape(B, T, 3, self.n_head, self.head_dim).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) cos, sin = self.rotary(T) q = torch.cat([apply_rotary(q[..., :self.rot_dim], cos, sin), q[..., self.rot_dim:]], dim=-1) k = torch.cat([apply_rotary(k[..., :self.rot_dim], cos, sin), k[..., self.rot_dim:]], dim=-1) y = F.scaled_dot_product_attention(q, k, v, is_causal=True) return self.c_proj(y.transpose(1, 2).contiguous().view(B, T, C)) class MLP(nn.Module): def __init__(self, n_embd): super().__init__() self.fc = nn.Linear(n_embd, 4 * n_embd, bias=True) self.gelu = nn.GELU() self.proj = nn.Linear(4 * n_embd, n_embd, bias=True) def forward(self, x): return self.proj(self.gelu(self.fc(x))) class Block(nn.Module): def __init__(self, n_embd, n_head, block_size): super().__init__() self.ln1 = nn.LayerNorm(n_embd) self.attn = CausalSelfAttention(n_embd, n_head, block_size) self.ln2 = nn.LayerNorm(n_embd) self.mlp = MLP(n_embd) def forward(self, x): x = x + self.attn(self.ln1(x)) x = x + self.mlp(self.ln2(x)) return x class LUNAModel(nn.Module): def __init__(self, vocab_size=50304, block_size=1024, n_layer=10, n_embd=768, n_head=12): super().__init__() self.block_size = block_size self.wte = nn.Embedding(vocab_size, n_embd) self.blocks = nn.ModuleList([Block(n_embd, n_head, block_size) for _ in range(n_layer)]) self.ln_f = nn.LayerNorm(n_embd) self.lm_head = nn.Linear(n_embd, vocab_size, bias=False) self.lm_head.weight = self.wte.weight # tied def forward(self, idx): x = self.wte(idx) for block in self.blocks: x = block(x) return self.lm_head(self.ln_f(x)) # ─── Generation ─────────────────────────────────────────────────────────────── @torch.no_grad() def generate(model, input_ids, max_new=200, temperature=0.8, top_p=0.9, top_k=50, repetition_penalty=1.1, device="cpu"): model.eval() ids = input_ids.clone().to(device) generated = [] for _ in range(max_new): # Crop to block_size ctx = ids[:, -model.block_size:] logits = model(ctx) # (1, T, V) logits = logits[:, -1, :] # last token # Repetition penalty if repetition_penalty != 1.0: for token_id in set(ids[0].tolist()): logits[0, token_id] /= repetition_penalty logits = logits / max(temperature, 1e-8) # Top-k if top_k > 0: vals, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < vals[:, -1:]] = -float("inf") # Top-p (nucleus) probs = torch.softmax(logits, dim=-1) if top_p < 1.0: sorted_probs, sorted_idx = torch.sort(probs, descending=True) cum = torch.cumsum(sorted_probs, dim=-1) mask = cum - sorted_probs > top_p sorted_probs[mask] = 0.0 sorted_probs /= sorted_probs.sum() next_token = sorted_idx[0, torch.multinomial(sorted_probs[0], 1)] else: next_token = torch.multinomial(probs[0], 1) ids = torch.cat([ids, next_token.view(1, 1)], dim=1) generated.append(next_token.item()) # Stop at EOS if next_token.item() == 50276: break return generated # ─── Load ───────────────────────────────────────────────────────────────────── def load_model(ckpt_path: str, device: str): print(f"Loading checkpoint: {ckpt_path}") ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True) # Handle both raw state_dict and {'model': ...} wrappers state = ckpt["model"] if "model" in ckpt else ckpt step = ckpt.get("step", "?") tokens = ckpt.get("tokens_seen", 0) print(f" Step: {step} | Tokens seen: {tokens:,}") model = LUNAModel() model.load_state_dict(state, strict=True) model = model.to(device) model.eval() print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}") return model def load_tokenizer(tok_dir: str): try: from transformers import AutoTokenizer tok = AutoTokenizer.from_pretrained(tok_dir) print(f" Tokenizer: {tok_dir} (vocab {tok.vocab_size})") return tok except Exception as e: print(f" ERROR loading tokenizer: {e}") print(" Install: pip install transformers") sys.exit(1) # ─── Entry ──────────────────────────────────────────────────────────────────── def parse_args(): p = argparse.ArgumentParser(description="LUNA 100M - Text Generation") p.add_argument("--ckpt", default="Base/out/luna_100m/latest.pt") p.add_argument("--tok_dir", default="Base/checkpoints/EleutherAI/pythia-160m") p.add_argument("--prompt", default=None, help="Single prompt (else interactive)") p.add_argument("--max_new", type=int, default=200) p.add_argument("--temp", type=float, default=0.8) p.add_argument("--top_p", type=float, default=0.9) p.add_argument("--top_k", type=int, default=50) p.add_argument("--rep_pen", type=float, default=1.1, help="Repetition penalty") p.add_argument("--device", default="auto") return p.parse_args() def run_prompt(model, tokenizer, prompt, args, device): ids = tokenizer.encode(prompt, return_tensors="pt") print(f"\n{'='*60}") print(f"PROMPT: {prompt}") print(f"{'='*60}") print(prompt, end="", flush=True) new_ids = generate( model, ids, max_new=args.max_new, temperature=args.temp, top_p=args.top_p, top_k=args.top_k, repetition_penalty=args.rep_pen, device=device, ) output = tokenizer.decode(new_ids, skip_special_tokens=True) print(output) print(f"{'='*60}") print(f"Generated {len(new_ids)} tokens") def main(): args = parse_args() if args.device == "auto": device = "cuda" if torch.cuda.is_available() else "cpu" else: device = args.device print(f"\nDevice: {device}") model = load_model(args.ckpt, device) tokenizer = load_tokenizer(args.tok_dir) if args.prompt: run_prompt(model, tokenizer, args.prompt, args, device) return # Interactive REPL print(f"\n{'='*60}") print(" LUNA 100M - Interactive Generation") print(f" Checkpoint: {args.ckpt}") print(f" max_new={args.max_new} temp={args.temp} top_p={args.top_p} top_k={args.top_k}") print(" Type your prompt and press Enter. Ctrl+C to exit.") print(f"{'='*60}\n") while True: try: prompt = input(">>> ").strip() if not prompt: continue run_prompt(model, tokenizer, prompt, args, device) except KeyboardInterrupt: print("\nBye!") break except EOFError: break if __name__ == "__main__": main()