LH-Tech-AI commited on
Commit
160f4be
Β·
verified Β·
1 Parent(s): 366d265

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +124 -0
README.md CHANGED
@@ -50,5 +50,129 @@ You can find the full code in this repo as `train.py` and inference.py. Have fun
50
  ## Usage
51
  Use this to run the model:
52
  ```python3
 
 
 
 
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  ```
 
50
  ## Usage
51
  Use this to run the model:
52
  ```python3
53
+ """
54
+ StorySupra-10M β€” Interactive Story Generator
55
+ Loads model weights directly from HuggingFace: SupraLabs/StorySupra-10M
56
+ """
57
 
58
+ import torch
59
+ from transformers import LlamaForCausalLM, PreTrainedTokenizerFast
60
+
61
+ # ──────────────────────────────────────────────
62
+ # Configuration
63
+ # ──────────────────────────────────────────────
64
+ MODEL_ID = "SupraLabs/StorySupra-10M"
65
+
66
+ GENERATION_DEFAULTS = {
67
+ "max_new_tokens": 100,
68
+ "temperature": 0.55,
69
+ "top_k": 25,
70
+ "top_p": 0.85,
71
+ "repetition_penalty": 1.1,
72
+ "do_sample": True,
73
+ }
74
+
75
+ EXIT_COMMANDS = {"exit", "quit", "leave"}
76
+
77
+ # ──────────────────────────────────────────────
78
+ # Model loading
79
+ # ──────────────────────────────────────────────
80
+
81
+ def load_model(model_id: str):
82
+ """Download and return the tokenizer and model from HuggingFace Hub."""
83
+ print(f"Downloading model from HuggingFace: {model_id}")
84
+ print("(This may take a moment on first run β€” weights will be cached locally.)\n")
85
+
86
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id)
87
+ model = LlamaForCausalLM.from_pretrained(model_id)
88
+
89
+ device = "cuda" if torch.cuda.is_available() else "cpu"
90
+ print(f"Using device: {device}\n")
91
+
92
+ model.to(device)
93
+ model.eval()
94
+
95
+ return tokenizer, model, device
96
+
97
+
98
+ # ──────────────────────────────────────────────
99
+ # Text generation
100
+ # ──────────────────────────────────────────────
101
+
102
+ def generate_text(
103
+ prompt: str,
104
+ tokenizer,
105
+ model,
106
+ device: str,
107
+ max_new_tokens: int = GENERATION_DEFAULTS["max_new_tokens"],
108
+ temperature: float = GENERATION_DEFAULTS["temperature"],
109
+ top_k: int = GENERATION_DEFAULTS["top_k"],
110
+ top_p: float = GENERATION_DEFAULTS["top_p"],
111
+ repetition_penalty: float = GENERATION_DEFAULTS["repetition_penalty"],
112
+ ) -> str:
113
+ """Generate a story continuation from the given prompt."""
114
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
115
+
116
+ with torch.no_grad():
117
+ output_tokens = model.generate(
118
+ **inputs,
119
+ max_new_tokens=max_new_tokens,
120
+ do_sample=True,
121
+ temperature=temperature,
122
+ top_k=top_k,
123
+ top_p=top_p,
124
+ repetition_penalty=repetition_penalty,
125
+ pad_token_id=tokenizer.pad_token_id,
126
+ eos_token_id=tokenizer.eos_token_id,
127
+ )
128
+
129
+ return tokenizer.decode(output_tokens[0], skip_special_tokens=True)
130
+
131
+
132
+ # ──────────────────────────────────────────────
133
+ # Interactive loop
134
+ # ──────────────────────────────────────────────
135
+
136
+ def run():
137
+ print("=" * 50)
138
+ print(" StorySupra-10M β€” Interactive Story Generator")
139
+ print("=" * 50)
140
+
141
+ tokenizer, model, device = load_model(MODEL_ID)
142
+
143
+ print("-" * 50)
144
+ print("Model ready! Type a prompt to generate a story.")
145
+ print(f"Type {' / '.join(EXIT_COMMANDS)} to quit.")
146
+ print("-" * 50)
147
+
148
+ while True:
149
+ try:
150
+ user_prompt = input("\nYour prompt: ").strip()
151
+ except (EOFError, KeyboardInterrupt):
152
+ print("\nExiting. Goodbye!")
153
+ break
154
+
155
+ if not user_prompt:
156
+ print("Please enter a prompt.")
157
+ continue
158
+
159
+ if user_prompt.lower() in EXIT_COMMANDS:
160
+ print("Goodbye!")
161
+ break
162
+
163
+ print("\nGenerating...\n")
164
+ story = generate_text(user_prompt, tokenizer, model, device)
165
+
166
+ print("Generated story:")
167
+ print("-" * 20)
168
+ print(story)
169
+ print("-" * 20)
170
+
171
+
172
+ # ──────────────────────────────────────────────
173
+ # Entry point
174
+ # ──────────────────────────────────────────────
175
+
176
+ if __name__ == "__main__":
177
+ run()
178
  ```