Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| import sys, os | |
| # ── Inline all model code so the Space is self-contained ────────────────────── | |
| import torch.nn as nn | |
| import math | |
| # ── RMSNorm ─────────────────────────────────────────────────────────────────── | |
| class RMSNorm(nn.Module): | |
| def __init__(self, cfg, eps=1e-8): | |
| super().__init__() | |
| self.eps = eps | |
| self.gamma = nn.Parameter(torch.ones(cfg["emb_dim"])) | |
| def forward(self, x): | |
| rms = x.pow(2).mean(dim=-1, keepdim=True).sqrt() | |
| return (x / (rms + self.eps)) * self.gamma | |
| # ── SwiGLU FFN ──────────────────────────────────────────────────────────────── | |
| class SwiGLUFFN(nn.Module): | |
| def __init__(self, cfg, hidden_dim): | |
| super().__init__() | |
| self.w1 = nn.Linear(cfg["emb_dim"], hidden_dim) | |
| self.w2 = nn.Linear(cfg["emb_dim"], hidden_dim) | |
| self.w3 = nn.Linear(hidden_dim, cfg["emb_dim"]) | |
| def forward(self, x): | |
| return self.w3(F.silu(self.w1(x)) * self.w2(x)) | |
| # ── RoPE ────────────────────────────────────────────────────────────────────── | |
| def precompute_rope_freqs(dim, max_seq_len, theta=10000.0): | |
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) | |
| positions = torch.arange(max_seq_len).float() | |
| angles = torch.outer(positions, freqs) | |
| return angles.cos(), angles.sin() | |
| def apply_rope(x, cos, sin): | |
| x_even = x[..., 0::2] | |
| x_odd = x[..., 1::2] | |
| cos = cos[:x.shape[2]].unsqueeze(0).unsqueeze(0) | |
| sin = sin[:x.shape[2]].unsqueeze(0).unsqueeze(0) | |
| out = torch.stack([x_even * cos - x_odd * sin, | |
| x_even * sin + x_odd * cos], dim=-1).flatten(-2) | |
| return out | |
| # ── Grouped Query Attention ─────────────────────────────────────────────────── | |
| class GroupedQueryAttention(nn.Module): | |
| def __init__(self, d_in, d_out, num_heads, n_kv_heads, context_length, dropout): | |
| super().__init__() | |
| assert d_out % num_heads == 0 | |
| assert num_heads % n_kv_heads == 0 | |
| self.d_out = d_out | |
| self.num_heads = num_heads | |
| self.n_kv_heads = n_kv_heads | |
| self.dim_head = d_out // num_heads | |
| self.n_rep = num_heads // n_kv_heads | |
| kv_dim = n_kv_heads * self.dim_head | |
| self.W_Query = nn.Linear(d_in, d_out, bias=False) | |
| self.W_Key = nn.Linear(d_in, kv_dim, bias=False) | |
| self.W_Value = nn.Linear(d_in, kv_dim, bias=False) | |
| self.dropout = nn.Dropout(dropout) | |
| self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1)) | |
| cos, sin = precompute_rope_freqs(self.dim_head, context_length) | |
| self.register_buffer("rope_cos", cos) | |
| self.register_buffer("rope_sin", sin) | |
| def forward(self, x): | |
| b, t, _ = x.shape | |
| Q = self.W_Query(x).view(b, t, self.num_heads, self.dim_head).transpose(1, 2) | |
| K = self.W_Key(x).view(b, t, self.n_kv_heads, self.dim_head).transpose(1, 2) | |
| V = self.W_Value(x).view(b, t, self.n_kv_heads, self.dim_head).transpose(1, 2) | |
| Q = apply_rope(Q, self.rope_cos, self.rope_sin) | |
| K = apply_rope(K, self.rope_cos, self.rope_sin) | |
| K = K.repeat_interleave(self.n_rep, dim=1) | |
| V = V.repeat_interleave(self.n_rep, dim=1) | |
| scores = Q @ K.transpose(-2, -1) | |
| mask = self.mask[:t, :t].unsqueeze(0).unsqueeze(0) | |
| scores = scores.masked_fill(mask.bool(), -torch.inf) | |
| w = self.dropout(torch.softmax(scores / self.dim_head ** 0.5, dim=-1)) | |
| out = (w @ V).transpose(1, 2).contiguous().view(b, t, self.d_out) | |
| return out | |
| # ── Transformer Block ───────────────────────────────────────────────────────── | |
| class TransformerBlock(nn.Module): | |
| def __init__(self, cfg): | |
| super().__init__() | |
| self.norm1 = RMSNorm(cfg) | |
| self.GQA = GroupedQueryAttention(cfg["emb_dim"], cfg["emb_dim"], | |
| cfg["n_heads"], cfg["n_kv_heads"], | |
| cfg["context_length"], cfg["dropout"]) | |
| self.norm2 = RMSNorm(cfg) | |
| self.FF = SwiGLUFFN(cfg, cfg["ffn_hidden"]) | |
| self.drop = nn.Dropout(cfg["dropout"]) | |
| def forward(self, x): | |
| x = x + self.drop(self.GQA(self.norm1(x))) | |
| x = x + self.drop(self.FF(self.norm2(x))) | |
| return x | |
| # ── GPT ─────────────────────────────────────────────────────────────────────── | |
| class GPT(nn.Module): | |
| def __init__(self, cfg): | |
| super().__init__() | |
| self.token_embedding = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"]) | |
| self.dropout = nn.Dropout(cfg["dropout"]) | |
| self.trf_blocks = nn.Sequential(*[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]) | |
| self.final_norm = RMSNorm(cfg) | |
| self.output_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False) | |
| self.output_head.weight = self.token_embedding.weight | |
| def forward(self, x): | |
| x = self.dropout(self.token_embedding(x)) | |
| x = self.trf_blocks(x) | |
| return self.output_head(self.final_norm(x)) | |
| # ── Config ──────────────────────────────────────────────────────────────────── | |
| MODEL_CONFIG = { | |
| "vocab_size": 16384, | |
| "context_length": 512, | |
| "emb_dim": 512, | |
| "n_heads": 8, | |
| "n_kv_heads": 4, | |
| "n_layers": 8, | |
| "ffn_hidden": 1376, | |
| "dropout": 0.0, | |
| "norm_eps": 1e-6, | |
| } | |
| # ── Load model & tokenizer once at startup ──────────────────────────────────── | |
| from huggingface_hub import hf_hub_download | |
| from tokenizers import Tokenizer | |
| REPO_ID = "ziadkassem/StoryGPT" | |
| print("Downloading model weights...") | |
| weights_path = hf_hub_download(repo_id=REPO_ID, filename="best_model.pt") | |
| tok_path = hf_hub_download(repo_id=REPO_ID, filename="storygpt_tokenizer.json") | |
| tokenizer = Tokenizer.from_file(tok_path) | |
| device = torch.device("cpu") # Spaces free tier is CPU only | |
| model = GPT(MODEL_CONFIG) | |
| weights = torch.load(weights_path, map_location=device) | |
| if list(weights.keys())[0].startswith("module."): | |
| weights = {k.replace("module.", ""): v for k, v in weights.items()} | |
| model.load_state_dict(weights) | |
| model.eval() | |
| print("Model loaded!") | |
| # ── Generation logic ────────────────────────────────────────────────────────── | |
| def generate_story(prompt: str, max_tokens: int, temperature: float, top_k: int) -> str: | |
| bos_id = tokenizer.token_to_id("<bos>") | |
| eos_id = tokenizer.token_to_id("<eos>") | |
| ids = [bos_id] + tokenizer.encode(prompt).ids | |
| idx = torch.tensor(ids, dtype=torch.long).unsqueeze(0) | |
| with torch.no_grad(): | |
| for _ in range(max_tokens): | |
| idx_cond = idx[:, -MODEL_CONFIG["context_length"]:] | |
| logits = model(idx_cond)[:, -1, :] | |
| logits = logits / temperature | |
| v, _ = torch.topk(logits, top_k) | |
| logits[logits < v[:, [-1]]] = -float("Inf") | |
| probs = torch.softmax(logits, dim=-1) | |
| next_id = torch.multinomial(probs, num_samples=1) | |
| if next_id.item() == eos_id: | |
| break | |
| idx = torch.cat([idx, next_id], dim=1) | |
| return tokenizer.decode(idx.squeeze(0).tolist()) | |
| # ── Gradio UI ───────────────────────────────────────────────────────────────── | |
| with gr.Blocks(title="StoryGPT", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # 📖 StoryGPT | |
| A **50M parameter** LLaMA-style language model pre-trained from scratch on TinyStories. | |
| Architecture: **GQA · RoPE · RMSNorm · SwiGLU** | Trained with **AMP + Cosine LR** | Perplexity: **4.09** | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| prompt_box = gr.Textbox( | |
| label="Story Prompt", | |
| placeholder="Once upon a time,", | |
| lines=2, | |
| ) | |
| with gr.Row(): | |
| max_tokens = gr.Slider(50, 400, value=200, step=10, label="Max Tokens") | |
| temperature = gr.Slider(0.5, 1.5, value=0.8, step=0.05, label="Temperature") | |
| top_k = gr.Slider(5, 100, value=40, step=5, label="Top-K") | |
| btn = gr.Button("✨ Generate Story", variant="primary") | |
| with gr.Column(scale=3): | |
| output_box = gr.Textbox(label="Generated Story", lines=12, interactive=False) | |
| btn.click( | |
| fn=generate_story, | |
| inputs=[prompt_box, max_tokens, temperature, top_k], | |
| outputs=output_box, | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["Once upon a time, there was a little girl named Lily who"], | |
| ["One sunny day, a puppy named Max found a"], | |
| ["The brave knight looked at the dragon and said"], | |
| ], | |
| inputs=prompt_box, | |
| ) | |
| demo.launch() | |