| import streamlit as st |
| import torch, tiktoken, os |
| from gpt_model import build_model |
|
|
| CKPT = "ckpt_best.pt" |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| N_LAYER, N_HEAD, N_EMBD, BLOCK = 6, 6, 384, 256 |
|
|
| @st.cache_resource |
| def load_model(): |
| enc = tiktoken.get_encoding("gpt2") |
| model = build_model(enc.n_vocab, N_LAYER, N_HEAD, N_EMBD, BLOCK).to(DEVICE).eval() |
| state = torch.load(CKPT, map_location=DEVICE) |
| model.load_state_dict(state) |
| return model, enc |
|
|
| st.title("LittleGPT Space ✨") |
| model, enc = load_model() |
|
|
| prompt = st.text_area("Enter your prompt", "Once upon a time...") |
| max_new = st.slider("Max new tokens", 16, 256, 100, 8) |
| temp = st.slider("Temperature", 0.1, 1.5, 0.9, 0.1) |
| top_k = st.slider("Top-k (0 = off)", 0, 200, 50, 10) |
|
|
| if st.button("Generate"): |
| x = torch.tensor([enc.encode(prompt)], dtype=torch.long, device=DEVICE) |
| with torch.no_grad(): |
| y = model.generate(x, max_new_tokens=max_new, temperature=temp, top_k=top_k or None) |
| st.write(enc.decode(y[0].tolist())) |
|
|