StoryGPT / app.py
ziadkassem's picture
Upload app.py with huggingface_hub
270380c verified
import gradio as gr
import torch
import torch.nn.functional as F
import sys, os
# ── Inline all model code so the Space is self-contained ──────────────────────
import torch.nn as nn
import math
# ── RMSNorm ───────────────────────────────────────────────────────────────────
class RMSNorm(nn.Module):
def __init__(self, cfg, eps=1e-8):
super().__init__()
self.eps = eps
self.gamma = nn.Parameter(torch.ones(cfg["emb_dim"]))
def forward(self, x):
rms = x.pow(2).mean(dim=-1, keepdim=True).sqrt()
return (x / (rms + self.eps)) * self.gamma
# ── SwiGLU FFN ────────────────────────────────────────────────────────────────
class SwiGLUFFN(nn.Module):
def __init__(self, cfg, hidden_dim):
super().__init__()
self.w1 = nn.Linear(cfg["emb_dim"], hidden_dim)
self.w2 = nn.Linear(cfg["emb_dim"], hidden_dim)
self.w3 = nn.Linear(hidden_dim, cfg["emb_dim"])
def forward(self, x):
return self.w3(F.silu(self.w1(x)) * self.w2(x))
# ── RoPE ──────────────────────────────────────────────────────────────────────
def precompute_rope_freqs(dim, max_seq_len, theta=10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
positions = torch.arange(max_seq_len).float()
angles = torch.outer(positions, freqs)
return angles.cos(), angles.sin()
def apply_rope(x, cos, sin):
x_even = x[..., 0::2]
x_odd = x[..., 1::2]
cos = cos[:x.shape[2]].unsqueeze(0).unsqueeze(0)
sin = sin[:x.shape[2]].unsqueeze(0).unsqueeze(0)
out = torch.stack([x_even * cos - x_odd * sin,
x_even * sin + x_odd * cos], dim=-1).flatten(-2)
return out
# ── Grouped Query Attention ───────────────────────────────────────────────────
class GroupedQueryAttention(nn.Module):
def __init__(self, d_in, d_out, num_heads, n_kv_heads, context_length, dropout):
super().__init__()
assert d_out % num_heads == 0
assert num_heads % n_kv_heads == 0
self.d_out = d_out
self.num_heads = num_heads
self.n_kv_heads = n_kv_heads
self.dim_head = d_out // num_heads
self.n_rep = num_heads // n_kv_heads
kv_dim = n_kv_heads * self.dim_head
self.W_Query = nn.Linear(d_in, d_out, bias=False)
self.W_Key = nn.Linear(d_in, kv_dim, bias=False)
self.W_Value = nn.Linear(d_in, kv_dim, bias=False)
self.dropout = nn.Dropout(dropout)
self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))
cos, sin = precompute_rope_freqs(self.dim_head, context_length)
self.register_buffer("rope_cos", cos)
self.register_buffer("rope_sin", sin)
def forward(self, x):
b, t, _ = x.shape
Q = self.W_Query(x).view(b, t, self.num_heads, self.dim_head).transpose(1, 2)
K = self.W_Key(x).view(b, t, self.n_kv_heads, self.dim_head).transpose(1, 2)
V = self.W_Value(x).view(b, t, self.n_kv_heads, self.dim_head).transpose(1, 2)
Q = apply_rope(Q, self.rope_cos, self.rope_sin)
K = apply_rope(K, self.rope_cos, self.rope_sin)
K = K.repeat_interleave(self.n_rep, dim=1)
V = V.repeat_interleave(self.n_rep, dim=1)
scores = Q @ K.transpose(-2, -1)
mask = self.mask[:t, :t].unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask.bool(), -torch.inf)
w = self.dropout(torch.softmax(scores / self.dim_head ** 0.5, dim=-1))
out = (w @ V).transpose(1, 2).contiguous().view(b, t, self.d_out)
return out
# ── Transformer Block ─────────────────────────────────────────────────────────
class TransformerBlock(nn.Module):
def __init__(self, cfg):
super().__init__()
self.norm1 = RMSNorm(cfg)
self.GQA = GroupedQueryAttention(cfg["emb_dim"], cfg["emb_dim"],
cfg["n_heads"], cfg["n_kv_heads"],
cfg["context_length"], cfg["dropout"])
self.norm2 = RMSNorm(cfg)
self.FF = SwiGLUFFN(cfg, cfg["ffn_hidden"])
self.drop = nn.Dropout(cfg["dropout"])
def forward(self, x):
x = x + self.drop(self.GQA(self.norm1(x)))
x = x + self.drop(self.FF(self.norm2(x)))
return x
# ── GPT ───────────────────────────────────────────────────────────────────────
class GPT(nn.Module):
def __init__(self, cfg):
super().__init__()
self.token_embedding = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
self.dropout = nn.Dropout(cfg["dropout"])
self.trf_blocks = nn.Sequential(*[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
self.final_norm = RMSNorm(cfg)
self.output_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
self.output_head.weight = self.token_embedding.weight
def forward(self, x):
x = self.dropout(self.token_embedding(x))
x = self.trf_blocks(x)
return self.output_head(self.final_norm(x))
# ── Config ────────────────────────────────────────────────────────────────────
MODEL_CONFIG = {
"vocab_size": 16384,
"context_length": 512,
"emb_dim": 512,
"n_heads": 8,
"n_kv_heads": 4,
"n_layers": 8,
"ffn_hidden": 1376,
"dropout": 0.0,
"norm_eps": 1e-6,
}
# ── Load model & tokenizer once at startup ────────────────────────────────────
from huggingface_hub import hf_hub_download
from tokenizers import Tokenizer
REPO_ID = "ziadkassem/StoryGPT"
print("Downloading model weights...")
weights_path = hf_hub_download(repo_id=REPO_ID, filename="best_model.pt")
tok_path = hf_hub_download(repo_id=REPO_ID, filename="storygpt_tokenizer.json")
tokenizer = Tokenizer.from_file(tok_path)
device = torch.device("cpu") # Spaces free tier is CPU only
model = GPT(MODEL_CONFIG)
weights = torch.load(weights_path, map_location=device)
if list(weights.keys())[0].startswith("module."):
weights = {k.replace("module.", ""): v for k, v in weights.items()}
model.load_state_dict(weights)
model.eval()
print("Model loaded!")
# ── Generation logic ──────────────────────────────────────────────────────────
def generate_story(prompt: str, max_tokens: int, temperature: float, top_k: int) -> str:
bos_id = tokenizer.token_to_id("<bos>")
eos_id = tokenizer.token_to_id("<eos>")
ids = [bos_id] + tokenizer.encode(prompt).ids
idx = torch.tensor(ids, dtype=torch.long).unsqueeze(0)
with torch.no_grad():
for _ in range(max_tokens):
idx_cond = idx[:, -MODEL_CONFIG["context_length"]:]
logits = model(idx_cond)[:, -1, :]
logits = logits / temperature
v, _ = torch.topk(logits, top_k)
logits[logits < v[:, [-1]]] = -float("Inf")
probs = torch.softmax(logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1)
if next_id.item() == eos_id:
break
idx = torch.cat([idx, next_id], dim=1)
return tokenizer.decode(idx.squeeze(0).tolist())
# ── Gradio UI ─────────────────────────────────────────────────────────────────
with gr.Blocks(title="StoryGPT", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 📖 StoryGPT
A **50M parameter** LLaMA-style language model pre-trained from scratch on TinyStories.
Architecture: **GQA · RoPE · RMSNorm · SwiGLU** | Trained with **AMP + Cosine LR** | Perplexity: **4.09**
""")
with gr.Row():
with gr.Column(scale=2):
prompt_box = gr.Textbox(
label="Story Prompt",
placeholder="Once upon a time,",
lines=2,
)
with gr.Row():
max_tokens = gr.Slider(50, 400, value=200, step=10, label="Max Tokens")
temperature = gr.Slider(0.5, 1.5, value=0.8, step=0.05, label="Temperature")
top_k = gr.Slider(5, 100, value=40, step=5, label="Top-K")
btn = gr.Button("✨ Generate Story", variant="primary")
with gr.Column(scale=3):
output_box = gr.Textbox(label="Generated Story", lines=12, interactive=False)
btn.click(
fn=generate_story,
inputs=[prompt_box, max_tokens, temperature, top_k],
outputs=output_box,
)
gr.Examples(
examples=[
["Once upon a time, there was a little girl named Lily who"],
["One sunny day, a puppy named Max found a"],
["The brave knight looked at the dragon and said"],
],
inputs=prompt_box,
)
demo.launch()