print("Loading...") import torch from transformers import LlamaForCausalLM, PreTrainedTokenizerFast def run_inference(): model_path = "./StorySupra-10M" device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") tokenizer = PreTrainedTokenizerFast.from_pretrained(model_path) model = LlamaForCausalLM.from_pretrained(model_path) model.to(device) model.eval() def generate_text(prompt, max_new_tokens=100, temperature=0.55, top_k=25, top_p=0.85, repetition_penalty=1.1): inputs = tokenizer(prompt, return_tensors="pt").to(device) with torch.no_grad(): output_tokens = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id ) return tokenizer.decode(output_tokens[0], skip_special_tokens=True) print("-" * 30) print("StorySupra Story Generator loaded!") print("Enter a prompt (or type 'exit' to quit):") while True: user_prompt = input("\nYour prompt: ") if user_prompt.lower() in ["exit", "quit", "leave"]: break story = generate_text(user_prompt) print(f"\nGenerated story:\n{story}") print("-" * 20) if __name__ == "__main__": run_inference()