import gradio as gr import torch import torch.nn.functional as F import sys, os # ── Inline all model code so the Space is self-contained ────────────────────── import torch.nn as nn import math # ── RMSNorm ─────────────────────────────────────────────────────────────────── class RMSNorm(nn.Module): def __init__(self, cfg, eps=1e-8): super().__init__() self.eps = eps self.gamma = nn.Parameter(torch.ones(cfg["emb_dim"])) def forward(self, x): rms = x.pow(2).mean(dim=-1, keepdim=True).sqrt() return (x / (rms + self.eps)) * self.gamma # ── SwiGLU FFN ──────────────────────────────────────────────────────────────── class SwiGLUFFN(nn.Module): def __init__(self, cfg, hidden_dim): super().__init__() self.w1 = nn.Linear(cfg["emb_dim"], hidden_dim) self.w2 = nn.Linear(cfg["emb_dim"], hidden_dim) self.w3 = nn.Linear(hidden_dim, cfg["emb_dim"]) def forward(self, x): return self.w3(F.silu(self.w1(x)) * self.w2(x)) # ── RoPE ────────────────────────────────────────────────────────────────────── def precompute_rope_freqs(dim, max_seq_len, theta=10000.0): freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) positions = torch.arange(max_seq_len).float() angles = torch.outer(positions, freqs) return angles.cos(), angles.sin() def apply_rope(x, cos, sin): x_even = x[..., 0::2] x_odd = x[..., 1::2] cos = cos[:x.shape[2]].unsqueeze(0).unsqueeze(0) sin = sin[:x.shape[2]].unsqueeze(0).unsqueeze(0) out = torch.stack([x_even * cos - x_odd * sin, x_even * sin + x_odd * cos], dim=-1).flatten(-2) return out # ── Grouped Query Attention ─────────────────────────────────────────────────── class GroupedQueryAttention(nn.Module): def __init__(self, d_in, d_out, num_heads, n_kv_heads, context_length, dropout): super().__init__() assert d_out % num_heads == 0 assert num_heads % n_kv_heads == 0 self.d_out = d_out self.num_heads = num_heads self.n_kv_heads = n_kv_heads self.dim_head = d_out // num_heads self.n_rep = num_heads // n_kv_heads kv_dim = n_kv_heads * self.dim_head self.W_Query = nn.Linear(d_in, d_out, bias=False) self.W_Key = nn.Linear(d_in, kv_dim, bias=False) self.W_Value = nn.Linear(d_in, kv_dim, bias=False) self.dropout = nn.Dropout(dropout) self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1)) cos, sin = precompute_rope_freqs(self.dim_head, context_length) self.register_buffer("rope_cos", cos) self.register_buffer("rope_sin", sin) def forward(self, x): b, t, _ = x.shape Q = self.W_Query(x).view(b, t, self.num_heads, self.dim_head).transpose(1, 2) K = self.W_Key(x).view(b, t, self.n_kv_heads, self.dim_head).transpose(1, 2) V = self.W_Value(x).view(b, t, self.n_kv_heads, self.dim_head).transpose(1, 2) Q = apply_rope(Q, self.rope_cos, self.rope_sin) K = apply_rope(K, self.rope_cos, self.rope_sin) K = K.repeat_interleave(self.n_rep, dim=1) V = V.repeat_interleave(self.n_rep, dim=1) scores = Q @ K.transpose(-2, -1) mask = self.mask[:t, :t].unsqueeze(0).unsqueeze(0) scores = scores.masked_fill(mask.bool(), -torch.inf) w = self.dropout(torch.softmax(scores / self.dim_head ** 0.5, dim=-1)) out = (w @ V).transpose(1, 2).contiguous().view(b, t, self.d_out) return out # ── Transformer Block ───────────────────────────────────────────────────────── class TransformerBlock(nn.Module): def __init__(self, cfg): super().__init__() self.norm1 = RMSNorm(cfg) self.GQA = GroupedQueryAttention(cfg["emb_dim"], cfg["emb_dim"], cfg["n_heads"], cfg["n_kv_heads"], cfg["context_length"], cfg["dropout"]) self.norm2 = RMSNorm(cfg) self.FF = SwiGLUFFN(cfg, cfg["ffn_hidden"]) self.drop = nn.Dropout(cfg["dropout"]) def forward(self, x): x = x + self.drop(self.GQA(self.norm1(x))) x = x + self.drop(self.FF(self.norm2(x))) return x # ── GPT ─────────────────────────────────────────────────────────────────────── class GPT(nn.Module): def __init__(self, cfg): super().__init__() self.token_embedding = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"]) self.dropout = nn.Dropout(cfg["dropout"]) self.trf_blocks = nn.Sequential(*[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]) self.final_norm = RMSNorm(cfg) self.output_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False) self.output_head.weight = self.token_embedding.weight def forward(self, x): x = self.dropout(self.token_embedding(x)) x = self.trf_blocks(x) return self.output_head(self.final_norm(x)) # ── Config ──────────────────────────────────────────────────────────────────── MODEL_CONFIG = { "vocab_size": 16384, "context_length": 512, "emb_dim": 512, "n_heads": 8, "n_kv_heads": 4, "n_layers": 8, "ffn_hidden": 1376, "dropout": 0.0, "norm_eps": 1e-6, } # ── Load model & tokenizer once at startup ──────────────────────────────────── from huggingface_hub import hf_hub_download from tokenizers import Tokenizer REPO_ID = "ziadkassem/StoryGPT" print("Downloading model weights...") weights_path = hf_hub_download(repo_id=REPO_ID, filename="best_model.pt") tok_path = hf_hub_download(repo_id=REPO_ID, filename="storygpt_tokenizer.json") tokenizer = Tokenizer.from_file(tok_path) device = torch.device("cpu") # Spaces free tier is CPU only model = GPT(MODEL_CONFIG) weights = torch.load(weights_path, map_location=device) if list(weights.keys())[0].startswith("module."): weights = {k.replace("module.", ""): v for k, v in weights.items()} model.load_state_dict(weights) model.eval() print("Model loaded!") # ── Generation logic ────────────────────────────────────────────────────────── def generate_story(prompt: str, max_tokens: int, temperature: float, top_k: int) -> str: bos_id = tokenizer.token_to_id("") eos_id = tokenizer.token_to_id("") ids = [bos_id] + tokenizer.encode(prompt).ids idx = torch.tensor(ids, dtype=torch.long).unsqueeze(0) with torch.no_grad(): for _ in range(max_tokens): idx_cond = idx[:, -MODEL_CONFIG["context_length"]:] logits = model(idx_cond)[:, -1, :] logits = logits / temperature v, _ = torch.topk(logits, top_k) logits[logits < v[:, [-1]]] = -float("Inf") probs = torch.softmax(logits, dim=-1) next_id = torch.multinomial(probs, num_samples=1) if next_id.item() == eos_id: break idx = torch.cat([idx, next_id], dim=1) return tokenizer.decode(idx.squeeze(0).tolist()) # ── Gradio UI ───────────────────────────────────────────────────────────────── with gr.Blocks(title="StoryGPT", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 📖 StoryGPT A **50M parameter** LLaMA-style language model pre-trained from scratch on TinyStories. Architecture: **GQA · RoPE · RMSNorm · SwiGLU** | Trained with **AMP + Cosine LR** | Perplexity: **4.09** """) with gr.Row(): with gr.Column(scale=2): prompt_box = gr.Textbox( label="Story Prompt", placeholder="Once upon a time,", lines=2, ) with gr.Row(): max_tokens = gr.Slider(50, 400, value=200, step=10, label="Max Tokens") temperature = gr.Slider(0.5, 1.5, value=0.8, step=0.05, label="Temperature") top_k = gr.Slider(5, 100, value=40, step=5, label="Top-K") btn = gr.Button("✨ Generate Story", variant="primary") with gr.Column(scale=3): output_box = gr.Textbox(label="Generated Story", lines=12, interactive=False) btn.click( fn=generate_story, inputs=[prompt_box, max_tokens, temperature, top_k], outputs=output_box, ) gr.Examples( examples=[ ["Once upon a time, there was a little girl named Lily who"], ["One sunny day, a puppy named Max found a"], ["The brave knight looked at the dragon and said"], ], inputs=prompt_box, ) demo.launch()