StoryGPT / README.md
ziadkassem's picture
Upload README.md with huggingface_hub
0d74e60 verified
metadata
language: en
license: mit
tags:
  - pytorch
  - language-model
  - causal-lm
  - llama-style
  - gqa
  - rope
  - swiglu
  - rmsnorm
  - pretrained-from-scratch
datasets:
  - roneneldan/TinyStories
metrics:
  - perplexity

StoryGPT

A 50M parameter LLaMA-style decoder-only transformer pre-trained from scratch on the TinyStories dataset.

Built as an end-to-end CV showcase demonstrating a production-grade LLM pre-training pipeline.

Model Description

Component Implementation
Attention Grouped Query Attention (GQA) — same as LLaMA 2/3
Position Encoding Rotary Embeddings (RoPE)
Normalization RMSNorm
Activation SwiGLU FFN
Weight Tying Embedding weight = Output head weight
Tokenizer Custom BPE trained from scratch (16,384 vocab)

Config:

vocab_size    : 16,384
context_length: 512
emb_dim       : 512
n_heads       : 8
n_kv_heads    : 4   (GQA)
n_layers      : 8
ffn_hidden    : 1,376
Parameters    : ~50M

Training

  • Dataset: TinyStories (150k stories, ~40M tokens)
  • Steps: 20,000
  • Optimizer: AdamW (β=(0.9, 0.95), weight_decay=0.1)
  • LR Schedule: Cosine decay with linear warmup (500 steps), peak 3e-4 → min 3e-5
  • Gradient Clipping: 1.0
  • Mixed Precision: torch.cuda.amp (AMP float16)
  • Hardware: 2× NVIDIA T4 (DataParallel) on Kaggle

Results

Metric Value
Train Loss 1.36
Val Loss 1.41
Perplexity 4.09

Usage

import torch
from huggingface_hub import hf_hub_download
from tokenizers import Tokenizer

# Download model and tokenizer
weights_path = hf_hub_download(repo_id="YOUR_HF_USERNAME/StoryGPT", filename="best_model.pt")
tok_path     = hf_hub_download(repo_id="YOUR_HF_USERNAME/StoryGPT", filename="storygpt_tokenizer.json")

tokenizer = Tokenizer.from_file(tok_path)

# Load model (copy model source files locally first)
from StoryGPT.model.gpt import GPT
from StoryGPT.config import MODEL_CONFIG

model = GPT(MODEL_CONFIG)
weights = torch.load(weights_path, map_location="cpu")
if list(weights.keys())[0].startswith("module."):
    weights = {k.replace("module.", ""): v for k, v in weights.items()}
model.load_state_dict(weights)
model.eval()

Sample Output

Once upon a time, there was a little boy named Timmy. Timmy loved to play with his toys and go on adventures. One day, he decided to explore the forest near his house...

License

MIT