Spaces:
Sleeping
Sleeping
| # app.py | |
| import os | |
| import torch | |
| from flask import Flask, request, jsonify, render_template | |
| from huggingface_hub import hf_hub_download | |
| # Make sure mingpt is accessible | |
| from mingpt.model import GPT | |
| from mingpt.char_tokenizer import CharTokenizer | |
| # ========================================================== | |
| # CONFIGURATION | |
| # ========================================================== | |
| MODEL_REPO = os.environ.get("HF_MODEL_REPO", "morty649/positive-tinystories-model") | |
| PRETRAINED_MODEL_FILENAME = "story_gpt_pretrained.pt" | |
| RL_MODEL_FILENAME = "story_gpt_rl.pt" | |
| TOKENIZER_FILENAME = "tokenizer.pt" | |
| BLOCK_SIZE = int(os.environ.get("BLOCK_SIZE", 64)) | |
| DEVICE = torch.device(os.environ.get("TORCH_DEVICE", "cpu")) | |
| print(f"Using device: {DEVICE}") | |
| print(f"Using model repo: {MODEL_REPO}") | |
| # ========================================================== | |
| # UTIL: Download from HF Hub with Local Fallback | |
| # ========================================================== | |
| def download_from_hub_or_local(filename: str) -> str: | |
| """ | |
| Attempts to download a file from Hugging Face Hub. | |
| Falls back to local file if download fails. | |
| """ | |
| try: | |
| print(f"Attempting to download {filename} from {MODEL_REPO}...") | |
| path = hf_hub_download(repo_id=MODEL_REPO, filename=filename) | |
| print(f"Downloaded {filename} -> {path}") | |
| return path | |
| except (Exception) as e: | |
| print(f"Hub download failed for {filename}: {e}") | |
| if os.path.exists(filename): | |
| print(f"Using local fallback: {filename}") | |
| return filename | |
| raise FileNotFoundError(f"{filename} not found on Hub or locally.") | |
| # ========================================================== | |
| # LOAD TOKENIZER | |
| # ========================================================== | |
| print("Loading tokenizer...") | |
| try: | |
| tokenizer_path = download_from_hub_or_local(TOKENIZER_FILENAME) | |
| torch.serialization.add_safe_globals([CharTokenizer]) | |
| tokenizer = torch.load( | |
| tokenizer_path, | |
| map_location="cpu", | |
| weights_only=True | |
| ) | |
| vocab_size = getattr(tokenizer, "vocab_size", None) | |
| if vocab_size is None: | |
| raise ValueError("Tokenizer missing `vocab_size` attribute.") | |
| print("Tokenizer loaded successfully.") | |
| except Exception as e: | |
| raise SystemExit(f"Tokenizer failed to load: {e}") | |
| # ========================================================== | |
| # MODEL LOADER | |
| # ========================================================== | |
| def build_and_load_model(model_path: str): | |
| """ | |
| Builds GPT model with correct config and loads weights. | |
| """ | |
| try: | |
| config = GPT.get_default_config() | |
| config.model_type = "gpt-micro" | |
| config.vocab_size = vocab_size | |
| config.block_size = BLOCK_SIZE | |
| model = GPT(config) | |
| state_dict = torch.load(model_path, map_location=DEVICE) | |
| model.load_state_dict(state_dict) | |
| model.to(DEVICE) | |
| model.eval() | |
| print(f"Loaded model from {model_path}") | |
| return model | |
| except Exception as e: | |
| print(f"Failed to load model {model_path}: {e}") | |
| return None | |
| # ========================================================== | |
| # LOAD MODELS | |
| # ========================================================== | |
| print("Loading models...") | |
| try: | |
| pretrained_path = download_from_hub_or_local(PRETRAINED_MODEL_FILENAME) | |
| except FileNotFoundError: | |
| pretrained_path = None | |
| try: | |
| rl_path = download_from_hub_or_local(RL_MODEL_FILENAME) | |
| except FileNotFoundError: | |
| rl_path = None | |
| pretrained_model = build_and_load_model(pretrained_path) if pretrained_path else None | |
| rl_model = build_and_load_model(rl_path) if rl_path else None | |
| print( | |
| "Model status:", | |
| "Pretrained OK" if pretrained_model else "Pretrained MISSING", | |
| "|", | |
| "RL OK" if rl_model else "RL MISSING" | |
| ) | |
| # ========================================================== | |
| # FLASK APP | |
| # ========================================================== | |
| app = Flask(__name__, template_folder="templates", static_folder="static") | |
| def home(): | |
| return render_template("index.html") | |
| def generate(): | |
| if pretrained_model is None or rl_model is None: | |
| return jsonify({"error": "One or more models failed to load."}), 500 | |
| payload = request.get_json(silent=True) or {} | |
| prompt_text = payload.get("prompt", "Once upon a time") | |
| max_tokens = int(payload.get("max_tokens", 100)) | |
| temperature = float(payload.get("temperature", 0.8)) | |
| top_k = int(payload.get("top_k", 30)) | |
| # Prepare prompt | |
| eot = "⏎" | |
| full_prompt = eot + prompt_text | |
| tokens = tokenizer(full_prompt) | |
| # Guard against overflow | |
| if len(tokens) > BLOCK_SIZE: | |
| tokens = tokens[-BLOCK_SIZE:] | |
| context = torch.tensor(tokens)[None, ...].to(DEVICE) | |
| try: | |
| with torch.no_grad(): | |
| pre_tokens = pretrained_model.generate( | |
| context, | |
| max_new_tokens=max_tokens, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_k=top_k, | |
| )[0] | |
| rl_tokens = rl_model.generate( | |
| context, | |
| max_new_tokens=max_tokens, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_k=top_k, | |
| )[0] | |
| pre_text = tokenizer.decode(pre_tokens) | |
| rl_text = tokenizer.decode(rl_tokens) | |
| # Remove prompt prefix if present | |
| if pre_text.startswith(full_prompt): | |
| pre_text = pre_text[len(full_prompt):] | |
| if rl_text.startswith(full_prompt): | |
| rl_text = rl_text[len(full_prompt):] | |
| return jsonify({ | |
| "pretrained_output": pre_text.strip(), | |
| "rl_output": rl_text.strip() | |
| }) | |
| except Exception as e: | |
| return jsonify({"error": f"Generation failed: {str(e)}"}), 500 | |
| # ========================================================== | |
| # ENTRY POINT | |
| # ========================================================== | |
| if __name__ == "__main__": | |
| port = int(os.environ.get("PORT", 7860)) | |
| app.run(host="0.0.0.0", port=port) |