# app.py import os import torch from flask import Flask, request, jsonify, render_template from huggingface_hub import hf_hub_download # Make sure mingpt is accessible from mingpt.model import GPT from mingpt.char_tokenizer import CharTokenizer # ========================================================== # CONFIGURATION # ========================================================== MODEL_REPO = os.environ.get("HF_MODEL_REPO", "morty649/positive-tinystories-model") PRETRAINED_MODEL_FILENAME = "story_gpt_pretrained.pt" RL_MODEL_FILENAME = "story_gpt_rl.pt" TOKENIZER_FILENAME = "tokenizer.pt" BLOCK_SIZE = int(os.environ.get("BLOCK_SIZE", 64)) DEVICE = torch.device(os.environ.get("TORCH_DEVICE", "cpu")) print(f"Using device: {DEVICE}") print(f"Using model repo: {MODEL_REPO}") # ========================================================== # UTIL: Download from HF Hub with Local Fallback # ========================================================== def download_from_hub_or_local(filename: str) -> str: """ Attempts to download a file from Hugging Face Hub. Falls back to local file if download fails. """ try: print(f"Attempting to download {filename} from {MODEL_REPO}...") path = hf_hub_download(repo_id=MODEL_REPO, filename=filename) print(f"Downloaded {filename} -> {path}") return path except (Exception) as e: print(f"Hub download failed for {filename}: {e}") if os.path.exists(filename): print(f"Using local fallback: {filename}") return filename raise FileNotFoundError(f"{filename} not found on Hub or locally.") # ========================================================== # LOAD TOKENIZER # ========================================================== print("Loading tokenizer...") try: tokenizer_path = download_from_hub_or_local(TOKENIZER_FILENAME) torch.serialization.add_safe_globals([CharTokenizer]) tokenizer = torch.load( tokenizer_path, map_location="cpu", weights_only=True ) vocab_size = getattr(tokenizer, "vocab_size", None) if vocab_size is None: raise ValueError("Tokenizer missing `vocab_size` attribute.") print("Tokenizer loaded successfully.") except Exception as e: raise SystemExit(f"Tokenizer failed to load: {e}") # ========================================================== # MODEL LOADER # ========================================================== def build_and_load_model(model_path: str): """ Builds GPT model with correct config and loads weights. """ try: config = GPT.get_default_config() config.model_type = "gpt-micro" config.vocab_size = vocab_size config.block_size = BLOCK_SIZE model = GPT(config) state_dict = torch.load(model_path, map_location=DEVICE) model.load_state_dict(state_dict) model.to(DEVICE) model.eval() print(f"Loaded model from {model_path}") return model except Exception as e: print(f"Failed to load model {model_path}: {e}") return None # ========================================================== # LOAD MODELS # ========================================================== print("Loading models...") try: pretrained_path = download_from_hub_or_local(PRETRAINED_MODEL_FILENAME) except FileNotFoundError: pretrained_path = None try: rl_path = download_from_hub_or_local(RL_MODEL_FILENAME) except FileNotFoundError: rl_path = None pretrained_model = build_and_load_model(pretrained_path) if pretrained_path else None rl_model = build_and_load_model(rl_path) if rl_path else None print( "Model status:", "Pretrained OK" if pretrained_model else "Pretrained MISSING", "|", "RL OK" if rl_model else "RL MISSING" ) # ========================================================== # FLASK APP # ========================================================== app = Flask(__name__, template_folder="templates", static_folder="static") @app.route("/") def home(): return render_template("index.html") @app.route("/generate", methods=["POST"]) def generate(): if pretrained_model is None or rl_model is None: return jsonify({"error": "One or more models failed to load."}), 500 payload = request.get_json(silent=True) or {} prompt_text = payload.get("prompt", "Once upon a time") max_tokens = int(payload.get("max_tokens", 100)) temperature = float(payload.get("temperature", 0.8)) top_k = int(payload.get("top_k", 30)) # Prepare prompt eot = "⏎" full_prompt = eot + prompt_text tokens = tokenizer(full_prompt) # Guard against overflow if len(tokens) > BLOCK_SIZE: tokens = tokens[-BLOCK_SIZE:] context = torch.tensor(tokens)[None, ...].to(DEVICE) try: with torch.no_grad(): pre_tokens = pretrained_model.generate( context, max_new_tokens=max_tokens, do_sample=True, temperature=temperature, top_k=top_k, )[0] rl_tokens = rl_model.generate( context, max_new_tokens=max_tokens, do_sample=True, temperature=temperature, top_k=top_k, )[0] pre_text = tokenizer.decode(pre_tokens) rl_text = tokenizer.decode(rl_tokens) # Remove prompt prefix if present if pre_text.startswith(full_prompt): pre_text = pre_text[len(full_prompt):] if rl_text.startswith(full_prompt): rl_text = rl_text[len(full_prompt):] return jsonify({ "pretrained_output": pre_text.strip(), "rl_output": rl_text.strip() }) except Exception as e: return jsonify({"error": f"Generation failed: {str(e)}"}), 500 # ========================================================== # ENTRY POINT # ========================================================== if __name__ == "__main__": port = int(os.environ.get("PORT", 7860)) app.run(host="0.0.0.0", port=port)