""" StorySupra-10M — Interactive Story Generator Loads model weights directly from HuggingFace: SupraLabs/StorySupra-10M """ import torch from transformers import LlamaForCausalLM, PreTrainedTokenizerFast # ────────────────────────────────────────────── # Configuration # ────────────────────────────────────────────── MODEL_ID = "SupraLabs/StorySupra-10M" GENERATION_DEFAULTS = { "max_new_tokens": 100, "temperature": 0.55, "top_k": 25, "top_p": 0.85, "repetition_penalty": 1.1, "do_sample": True, } EXIT_COMMANDS = {"exit", "quit", "leave"} # ────────────────────────────────────────────── # Model loading # ────────────────────────────────────────────── def load_model(model_id: str): """Download and return the tokenizer and model from HuggingFace Hub.""" print(f"Downloading model from HuggingFace: {model_id}") print("(This may take a moment on first run — weights will be cached locally.)\n") tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id) model = LlamaForCausalLM.from_pretrained(model_id) device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}\n") model.to(device) model.eval() return tokenizer, model, device # ────────────────────────────────────────────── # Text generation # ────────────────────────────────────────────── def generate_text( prompt: str, tokenizer, model, device: str, max_new_tokens: int = GENERATION_DEFAULTS["max_new_tokens"], temperature: float = GENERATION_DEFAULTS["temperature"], top_k: int = GENERATION_DEFAULTS["top_k"], top_p: float = GENERATION_DEFAULTS["top_p"], repetition_penalty: float = GENERATION_DEFAULTS["repetition_penalty"], ) -> str: """Generate a story continuation from the given prompt.""" inputs = tokenizer(prompt, return_tensors="pt").to(device) with torch.no_grad(): output_tokens = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, ) return tokenizer.decode(output_tokens[0], skip_special_tokens=True) # ────────────────────────────────────────────── # Interactive loop # ────────────────────────────────────────────── def run(): print("=" * 50) print(" StorySupra-10M — Interactive Story Generator") print("=" * 50) tokenizer, model, device = load_model(MODEL_ID) print("-" * 50) print("Model ready! Type a prompt to generate a story.") print(f"Type {' / '.join(EXIT_COMMANDS)} to quit.") print("-" * 50) while True: try: user_prompt = input("\nYour prompt: ").strip() except (EOFError, KeyboardInterrupt): print("\nExiting. Goodbye!") break if not user_prompt: print("Please enter a prompt.") continue if user_prompt.lower() in EXIT_COMMANDS: print("Goodbye!") break print("\nGenerating...\n") story = generate_text(user_prompt, tokenizer, model, device) print("Generated story:") print("-" * 20) print(story) print("-" * 20) # ────────────────────────────────────────────── # Entry point # ────────────────────────────────────────────── if __name__ == "__main__": run()