File size: 1,035 Bytes
1d0a8b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import streamlit as st
import torch, tiktoken, os
from gpt_model import build_model

CKPT = "ckpt_best.pt"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

N_LAYER, N_HEAD, N_EMBD, BLOCK = 6, 6, 384, 256

@st.cache_resource
def load_model():
    enc = tiktoken.get_encoding("gpt2")
    model = build_model(enc.n_vocab, N_LAYER, N_HEAD, N_EMBD, BLOCK).to(DEVICE).eval()
    state = torch.load(CKPT, map_location=DEVICE)
    model.load_state_dict(state)
    return model, enc

st.title("LittleGPT Space ✨")
model, enc = load_model()

prompt = st.text_area("Enter your prompt", "Once upon a time...")
max_new = st.slider("Max new tokens", 16, 256, 100, 8)
temp    = st.slider("Temperature", 0.1, 1.5, 0.9, 0.1)
top_k   = st.slider("Top-k (0 = off)", 0, 200, 50, 10)

if st.button("Generate"):
    x = torch.tensor([enc.encode(prompt)], dtype=torch.long, device=DEVICE)
    with torch.no_grad():
        y = model.generate(x, max_new_tokens=max_new, temperature=temp, top_k=top_k or None)
    st.write(enc.decode(y[0].tolist()))