StoryGPT / README.md
ziadkassem's picture
Upload README.md with huggingface_hub
0d74e60 verified
---
language: en
license: mit
tags:
- pytorch
- language-model
- causal-lm
- llama-style
- gqa
- rope
- swiglu
- rmsnorm
- pretrained-from-scratch
datasets:
- roneneldan/TinyStories
metrics:
- perplexity
---
# StoryGPT
A **50M parameter** LLaMA-style decoder-only transformer pre-trained from scratch on the [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories) dataset.
Built as an end-to-end CV showcase demonstrating a production-grade LLM pre-training pipeline.
## Model Description
| Component | Implementation |
|---|---|
| Attention | Grouped Query Attention (GQA) — same as LLaMA 2/3 |
| Position Encoding | Rotary Embeddings (RoPE) |
| Normalization | RMSNorm |
| Activation | SwiGLU FFN |
| Weight Tying | Embedding weight = Output head weight |
| Tokenizer | Custom BPE trained from scratch (16,384 vocab) |
**Config:**
```
vocab_size : 16,384
context_length: 512
emb_dim : 512
n_heads : 8
n_kv_heads : 4 (GQA)
n_layers : 8
ffn_hidden : 1,376
Parameters : ~50M
```
## Training
- **Dataset:** TinyStories (150k stories, ~40M tokens)
- **Steps:** 20,000
- **Optimizer:** AdamW (β=(0.9, 0.95), weight_decay=0.1)
- **LR Schedule:** Cosine decay with linear warmup (500 steps), peak 3e-4 → min 3e-5
- **Gradient Clipping:** 1.0
- **Mixed Precision:** torch.cuda.amp (AMP float16)
- **Hardware:** 2× NVIDIA T4 (DataParallel) on Kaggle
## Results
| Metric | Value |
|---|---|
| Train Loss | 1.36 |
| Val Loss | 1.41 |
| **Perplexity** | **4.09** |
## Usage
```python
import torch
from huggingface_hub import hf_hub_download
from tokenizers import Tokenizer
# Download model and tokenizer
weights_path = hf_hub_download(repo_id="YOUR_HF_USERNAME/StoryGPT", filename="best_model.pt")
tok_path = hf_hub_download(repo_id="YOUR_HF_USERNAME/StoryGPT", filename="storygpt_tokenizer.json")
tokenizer = Tokenizer.from_file(tok_path)
# Load model (copy model source files locally first)
from StoryGPT.model.gpt import GPT
from StoryGPT.config import MODEL_CONFIG
model = GPT(MODEL_CONFIG)
weights = torch.load(weights_path, map_location="cpu")
if list(weights.keys())[0].startswith("module."):
weights = {k.replace("module.", ""): v for k, v in weights.items()}
model.load_state_dict(weights)
model.eval()
```
## Sample Output
> *Once upon a time, there was a little boy named Timmy. Timmy loved to play with his toys and go on adventures. One day, he decided to explore the forest near his house...*
## License
MIT