import streamlit as st import torch, tiktoken, os from gpt_model import build_model CKPT = "ckpt_best.pt" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" N_LAYER, N_HEAD, N_EMBD, BLOCK = 6, 6, 384, 256 @st.cache_resource def load_model(): enc = tiktoken.get_encoding("gpt2") model = build_model(enc.n_vocab, N_LAYER, N_HEAD, N_EMBD, BLOCK).to(DEVICE).eval() state = torch.load(CKPT, map_location=DEVICE) model.load_state_dict(state) return model, enc st.title("LittleGPT Space ✨") model, enc = load_model() prompt = st.text_area("Enter your prompt", "Once upon a time...") max_new = st.slider("Max new tokens", 16, 256, 100, 8) temp = st.slider("Temperature", 0.1, 1.5, 0.9, 0.1) top_k = st.slider("Top-k (0 = off)", 0, 200, 50, 10) if st.button("Generate"): x = torch.tensor([enc.encode(prompt)], dtype=torch.long, device=DEVICE) with torch.no_grad(): y = model.generate(x, max_new_tokens=max_new, temperature=temp, top_k=top_k or None) st.write(enc.decode(y[0].tolist()))