File size: 1,035 Bytes
1d0a8b6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 | import streamlit as st
import torch, tiktoken, os
from gpt_model import build_model
CKPT = "ckpt_best.pt"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
N_LAYER, N_HEAD, N_EMBD, BLOCK = 6, 6, 384, 256
@st.cache_resource
def load_model():
enc = tiktoken.get_encoding("gpt2")
model = build_model(enc.n_vocab, N_LAYER, N_HEAD, N_EMBD, BLOCK).to(DEVICE).eval()
state = torch.load(CKPT, map_location=DEVICE)
model.load_state_dict(state)
return model, enc
st.title("LittleGPT Space ✨")
model, enc = load_model()
prompt = st.text_area("Enter your prompt", "Once upon a time...")
max_new = st.slider("Max new tokens", 16, 256, 100, 8)
temp = st.slider("Temperature", 0.1, 1.5, 0.9, 0.1)
top_k = st.slider("Top-k (0 = off)", 0, 200, 50, 10)
if st.button("Generate"):
x = torch.tensor([enc.encode(prompt)], dtype=torch.long, device=DEVICE)
with torch.no_grad():
y = model.generate(x, max_new_tokens=max_new, temperature=temp, top_k=top_k or None)
st.write(enc.decode(y[0].tolist()))
|