morty649's picture
Fix tokenizer load for torch 2.6
83c11ad
# app.py
import os
import torch
from flask import Flask, request, jsonify, render_template
from huggingface_hub import hf_hub_download
# Make sure mingpt is accessible
from mingpt.model import GPT
from mingpt.char_tokenizer import CharTokenizer
# ==========================================================
# CONFIGURATION
# ==========================================================
MODEL_REPO = os.environ.get("HF_MODEL_REPO", "morty649/positive-tinystories-model")
PRETRAINED_MODEL_FILENAME = "story_gpt_pretrained.pt"
RL_MODEL_FILENAME = "story_gpt_rl.pt"
TOKENIZER_FILENAME = "tokenizer.pt"
BLOCK_SIZE = int(os.environ.get("BLOCK_SIZE", 64))
DEVICE = torch.device(os.environ.get("TORCH_DEVICE", "cpu"))
print(f"Using device: {DEVICE}")
print(f"Using model repo: {MODEL_REPO}")
# ==========================================================
# UTIL: Download from HF Hub with Local Fallback
# ==========================================================
def download_from_hub_or_local(filename: str) -> str:
"""
Attempts to download a file from Hugging Face Hub.
Falls back to local file if download fails.
"""
try:
print(f"Attempting to download {filename} from {MODEL_REPO}...")
path = hf_hub_download(repo_id=MODEL_REPO, filename=filename)
print(f"Downloaded {filename} -> {path}")
return path
except (Exception) as e:
print(f"Hub download failed for {filename}: {e}")
if os.path.exists(filename):
print(f"Using local fallback: {filename}")
return filename
raise FileNotFoundError(f"{filename} not found on Hub or locally.")
# ==========================================================
# LOAD TOKENIZER
# ==========================================================
print("Loading tokenizer...")
try:
tokenizer_path = download_from_hub_or_local(TOKENIZER_FILENAME)
torch.serialization.add_safe_globals([CharTokenizer])
tokenizer = torch.load(
tokenizer_path,
map_location="cpu",
weights_only=True
)
vocab_size = getattr(tokenizer, "vocab_size", None)
if vocab_size is None:
raise ValueError("Tokenizer missing `vocab_size` attribute.")
print("Tokenizer loaded successfully.")
except Exception as e:
raise SystemExit(f"Tokenizer failed to load: {e}")
# ==========================================================
# MODEL LOADER
# ==========================================================
def build_and_load_model(model_path: str):
"""
Builds GPT model with correct config and loads weights.
"""
try:
config = GPT.get_default_config()
config.model_type = "gpt-micro"
config.vocab_size = vocab_size
config.block_size = BLOCK_SIZE
model = GPT(config)
state_dict = torch.load(model_path, map_location=DEVICE)
model.load_state_dict(state_dict)
model.to(DEVICE)
model.eval()
print(f"Loaded model from {model_path}")
return model
except Exception as e:
print(f"Failed to load model {model_path}: {e}")
return None
# ==========================================================
# LOAD MODELS
# ==========================================================
print("Loading models...")
try:
pretrained_path = download_from_hub_or_local(PRETRAINED_MODEL_FILENAME)
except FileNotFoundError:
pretrained_path = None
try:
rl_path = download_from_hub_or_local(RL_MODEL_FILENAME)
except FileNotFoundError:
rl_path = None
pretrained_model = build_and_load_model(pretrained_path) if pretrained_path else None
rl_model = build_and_load_model(rl_path) if rl_path else None
print(
"Model status:",
"Pretrained OK" if pretrained_model else "Pretrained MISSING",
"|",
"RL OK" if rl_model else "RL MISSING"
)
# ==========================================================
# FLASK APP
# ==========================================================
app = Flask(__name__, template_folder="templates", static_folder="static")
@app.route("/")
def home():
return render_template("index.html")
@app.route("/generate", methods=["POST"])
def generate():
if pretrained_model is None or rl_model is None:
return jsonify({"error": "One or more models failed to load."}), 500
payload = request.get_json(silent=True) or {}
prompt_text = payload.get("prompt", "Once upon a time")
max_tokens = int(payload.get("max_tokens", 100))
temperature = float(payload.get("temperature", 0.8))
top_k = int(payload.get("top_k", 30))
# Prepare prompt
eot = "⏎"
full_prompt = eot + prompt_text
tokens = tokenizer(full_prompt)
# Guard against overflow
if len(tokens) > BLOCK_SIZE:
tokens = tokens[-BLOCK_SIZE:]
context = torch.tensor(tokens)[None, ...].to(DEVICE)
try:
with torch.no_grad():
pre_tokens = pretrained_model.generate(
context,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
top_k=top_k,
)[0]
rl_tokens = rl_model.generate(
context,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
top_k=top_k,
)[0]
pre_text = tokenizer.decode(pre_tokens)
rl_text = tokenizer.decode(rl_tokens)
# Remove prompt prefix if present
if pre_text.startswith(full_prompt):
pre_text = pre_text[len(full_prompt):]
if rl_text.startswith(full_prompt):
rl_text = rl_text[len(full_prompt):]
return jsonify({
"pretrained_output": pre_text.strip(),
"rl_output": rl_text.strip()
})
except Exception as e:
return jsonify({"error": f"Generation failed: {str(e)}"}), 500
# ==========================================================
# ENTRY POINT
# ==========================================================
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
app.run(host="0.0.0.0", port=port)