AlekseyCalvin commited on
Commit
01a520e
·
verified ·
1 Parent(s): 37da9e7

Upload 11 files

Browse files
README.md CHANGED
@@ -1,3 +1,178 @@
1
  ---
2
  license: apache-2.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
+ datasets:
4
+ - roneneldan/TinyStories
5
+ language:
6
+ - en
7
+ pipeline_tag: text-generation
8
+ library_name: transformers
9
+ tags:
10
+ - small
11
+ - tiny
12
+ - story
13
+ - tinystories
14
+ - roneneldan
15
+ - cpu
16
+ - free
17
+ - open-source
18
  ---
19
+
20
+ # 📖 StorySupra 10M
21
+
22
+ ## Config
23
+ - Parameters: 12,587,264 (~10M)
24
+ - Hidden Size: 256
25
+ - Intermediate Size: 1024
26
+ - Hidden Layers: 8
27
+ - Attention Heads: 8
28
+ - Max Position Embeddings: 256
29
+ - Vocab Size: 8192
30
+
31
+ ## Samples
32
+ Once upon a time , a small bird was flying in the sky . It saw a big tree and wanted to rest under it . But the tree was too high for the bird to reach . The bird tried to fly up , but it could not . Then , a wise old owl flew by and saw the bird struggling . The owl said , " Don ' t worry little bird , I can help you ." The owl used its strong beak to climb the tree and get the bird down . The bird was
33
+ <br><br>
34
+ Once upon a time , there was a little boy named Timmy . He loved to play with his toys and run around outside . One day , he found a shiny penny on the ground . It was so pretty that he picked it up and showed it to his mom . " Look , Mommy ! I found a penny !" he said . His mom smiled and said , " That ' s great , Timmy . But be careful , it ' s very special ." Timmy didn ' t understand what " valuable " meant , but he knew it meant something important . So
35
+ <br><br>
36
+ Once upon a time , there was a lovely princess . She had long , blonde hair and a sparkly crown . One day , she wanted to go for a walk in the forest . She put on her dress and started walking . As she walked , she saw something strange . It was a big , scary bear ! The princess was scared , but she didn ' t want to get away . So she just kept walking until she reached the forest . When she got there , she saw a little rabbit . He was wearing a bright red bow and he looked very friendly .
37
+
38
+ ## Training
39
+ - GPU: single RTX 5060 Ti 16GB
40
+ - Time: ~20 minutes
41
+ - Epochs: 3
42
+ - Samples of the dataset: 200k
43
+
44
+ ## Dataset
45
+ 200k samples of roneneldan/TinyStories
46
+
47
+ ## Code
48
+ You can find the full code in this repo as `train.py` and inference.py. Have fun :-)
49
+
50
+ ## Usage
51
+ Use this to run the model:
52
+ ```python3
53
+ """
54
+ StorySupra-10M — Interactive Story Generator
55
+ Loads model weights directly from HuggingFace: SupraLabs/StorySupra-10M
56
+ """
57
+
58
+ import torch
59
+ from transformers import LlamaForCausalLM, PreTrainedTokenizerFast
60
+
61
+ # ──────────────────────────────────────────────
62
+ # Configuration
63
+ # ──────────────────────────────────────────────
64
+ MODEL_ID = "SupraLabs/StorySupra-10M"
65
+
66
+ GENERATION_DEFAULTS = {
67
+ "max_new_tokens": 100,
68
+ "temperature": 0.55,
69
+ "top_k": 25,
70
+ "top_p": 0.85,
71
+ "repetition_penalty": 1.1,
72
+ "do_sample": True,
73
+ }
74
+
75
+ EXIT_COMMANDS = {"exit", "quit", "leave"}
76
+
77
+ # ──────────────────────────────────────────────
78
+ # Model loading
79
+ # ──────────────────────────────────────────────
80
+
81
+ def load_model(model_id: str):
82
+ """Download and return the tokenizer and model from HuggingFace Hub."""
83
+ print(f"Downloading model from HuggingFace: {model_id}")
84
+ print("(This may take a moment on first run — weights will be cached locally.)\n")
85
+
86
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id)
87
+ model = LlamaForCausalLM.from_pretrained(model_id)
88
+
89
+ device = "cuda" if torch.cuda.is_available() else "cpu"
90
+ print(f"Using device: {device}\n")
91
+
92
+ model.to(device)
93
+ model.eval()
94
+
95
+ return tokenizer, model, device
96
+
97
+
98
+ # ──────────────────────────────────────────────
99
+ # Text generation
100
+ # ──────────────────────────────────────────────
101
+
102
+ def generate_text(
103
+ prompt: str,
104
+ tokenizer,
105
+ model,
106
+ device: str,
107
+ max_new_tokens: int = GENERATION_DEFAULTS["max_new_tokens"],
108
+ temperature: float = GENERATION_DEFAULTS["temperature"],
109
+ top_k: int = GENERATION_DEFAULTS["top_k"],
110
+ top_p: float = GENERATION_DEFAULTS["top_p"],
111
+ repetition_penalty: float = GENERATION_DEFAULTS["repetition_penalty"],
112
+ ) -> str:
113
+ """Generate a story continuation from the given prompt."""
114
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
115
+
116
+ with torch.no_grad():
117
+ output_tokens = model.generate(
118
+ **inputs,
119
+ max_new_tokens=max_new_tokens,
120
+ do_sample=True,
121
+ temperature=temperature,
122
+ top_k=top_k,
123
+ top_p=top_p,
124
+ repetition_penalty=repetition_penalty,
125
+ pad_token_id=tokenizer.pad_token_id,
126
+ eos_token_id=tokenizer.eos_token_id,
127
+ )
128
+
129
+ return tokenizer.decode(output_tokens[0], skip_special_tokens=True)
130
+
131
+
132
+ # ──────────────────────────────────────────────
133
+ # Interactive loop
134
+ # ──────────────────────────────────────────────
135
+
136
+ def run():
137
+ print("=" * 50)
138
+ print(" StorySupra-10M — Interactive Story Generator")
139
+ print("=" * 50)
140
+
141
+ tokenizer, model, device = load_model(MODEL_ID)
142
+
143
+ print("-" * 50)
144
+ print("Model ready! Type a prompt to generate a story.")
145
+ print(f"Type {' / '.join(EXIT_COMMANDS)} to quit.")
146
+ print("-" * 50)
147
+
148
+ while True:
149
+ try:
150
+ user_prompt = input("\nYour prompt: ").strip()
151
+ except (EOFError, KeyboardInterrupt):
152
+ print("\nExiting. Goodbye!")
153
+ break
154
+
155
+ if not user_prompt:
156
+ print("Please enter a prompt.")
157
+ continue
158
+
159
+ if user_prompt.lower() in EXIT_COMMANDS:
160
+ print("Goodbye!")
161
+ break
162
+
163
+ print("\nGenerating...\n")
164
+ story = generate_text(user_prompt, tokenizer, model, device)
165
+
166
+ print("Generated story:")
167
+ print("-" * 20)
168
+ print(story)
169
+ print("-" * 20)
170
+
171
+
172
+ # ──────────────────────────────────────────────
173
+ # Entry point
174
+ # ──────────────────────────────────────────────
175
+
176
+ if __name__ == "__main__":
177
+ run()
178
+ ```
config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "LlamaForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 0,
8
+ "dtype": "float32",
9
+ "eos_token_id": 2,
10
+ "head_dim": 32,
11
+ "hidden_act": "silu",
12
+ "hidden_size": 256,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 1024,
15
+ "max_position_embeddings": 256,
16
+ "mlp_bias": false,
17
+ "model_type": "llama",
18
+ "num_attention_heads": 8,
19
+ "num_hidden_layers": 8,
20
+ "num_key_value_heads": 8,
21
+ "pad_token_id": 1,
22
+ "pretraining_tp": 1,
23
+ "rms_norm_eps": 1e-06,
24
+ "rope_parameters": {
25
+ "rope_theta": 10000.0,
26
+ "rope_type": "default"
27
+ },
28
+ "tie_word_embeddings": false,
29
+ "transformers_version": "5.8.1",
30
+ "use_cache": false,
31
+ "vocab_size": 8192
32
+ }
generation_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 0,
4
+ "eos_token_id": 2,
5
+ "output_attentions": false,
6
+ "output_hidden_states": false,
7
+ "pad_token_id": 1,
8
+ "transformers_version": "5.8.1",
9
+ "use_cache": true
10
+ }
gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
inference.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ print("Loading...")
2
+
3
+ import torch
4
+ from transformers import LlamaForCausalLM, PreTrainedTokenizerFast
5
+
6
+ def run_inference():
7
+ model_path = "./StorySupra-10M"
8
+
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ print(f"Using device: {device}")
11
+
12
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(model_path)
13
+
14
+ model = LlamaForCausalLM.from_pretrained(model_path)
15
+ model.to(device)
16
+ model.eval()
17
+
18
+ def generate_text(prompt, max_new_tokens=100, temperature=0.55, top_k=25, top_p=0.85, repetition_penalty=1.1):
19
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
20
+
21
+ with torch.no_grad():
22
+ output_tokens = model.generate(
23
+ **inputs,
24
+ max_new_tokens=max_new_tokens,
25
+ do_sample=True,
26
+ temperature=temperature,
27
+ top_k=top_k,
28
+ top_p=top_p,
29
+ repetition_penalty=repetition_penalty,
30
+ pad_token_id=tokenizer.pad_token_id,
31
+ eos_token_id=tokenizer.eos_token_id
32
+ )
33
+
34
+ return tokenizer.decode(output_tokens[0], skip_special_tokens=True)
35
+
36
+ print("-" * 30)
37
+ print("StorySupra Story Generator loaded!")
38
+ print("Enter a prompt (or type 'exit' to quit):")
39
+
40
+ while True:
41
+ user_prompt = input("\nYour prompt: ")
42
+ if user_prompt.lower() in ["exit", "quit", "leave"]:
43
+ break
44
+
45
+ story = generate_text(user_prompt)
46
+ print(f"\nGenerated story:\n{story}")
47
+ print("-" * 20)
48
+
49
+ if __name__ == "__main__":
50
+ run_inference()
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e9c874a48b24de2df0d12ec4a8a7e3e9c310d41aeaddff0e79d03803383dbf42
3
+ size 50357208
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "backend": "tokenizers",
3
+ "bos_token": "<s>",
4
+ "eos_token": "</s>",
5
+ "model_max_length": 1000000000000000019884624838656,
6
+ "pad_token": "<pad>",
7
+ "tokenizer_class": "LlamaTokenizer",
8
+ "unk_token": "<unk>"
9
+ }
train.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from datasets import load_dataset
3
+ from tokenizers import Tokenizer, models, trainers, pre_tokenizers
4
+ from transformers import LlamaConfig, LlamaForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling
5
+
6
+ dataset = load_dataset("roneneldan/TinyStories", split="train[:200000]")
7
+
8
+ def train_tokenizer(dataset):
9
+ tokenizer = Tokenizer(models.BPE(unk_token="<unk>"))
10
+ tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()
11
+
12
+ trainer = trainers.BpeTrainer(
13
+ vocab_size=8192,
14
+ special_tokens=["<s>", "<pad>", "</s>", "<unk>", "<mask>"]
15
+ )
16
+
17
+ def batch_iterator():
18
+ for i in range(0, len(dataset), 1000):
19
+ yield dataset[i : i + 1000]["text"]
20
+
21
+ tokenizer.train_from_iterator(batch_iterator(), trainer=trainer)
22
+
23
+ from transformers import PreTrainedTokenizerFast
24
+ return PreTrainedTokenizerFast(
25
+ tokenizer_object=tokenizer,
26
+ bos_token="<s>",
27
+ eos_token="</s>",
28
+ unk_token="<unk>",
29
+ pad_token="<pad>"
30
+ )
31
+
32
+ tokenizer = train_tokenizer(dataset)
33
+
34
+ def tokenize_function(examples):
35
+ return tokenizer(examples["text"], truncation=True, max_length=256)
36
+
37
+ tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
38
+
39
+ config = LlamaConfig(
40
+ vocab_size=8192,
41
+ hidden_size=256,
42
+ intermediate_size=1024,
43
+ num_hidden_layers=8,
44
+ num_attention_heads=8,
45
+ max_position_embeddings=256,
46
+ pad_token_id=tokenizer.pad_token_id,
47
+ bos_token_id=tokenizer.bos_token_id,
48
+ eos_token_id=tokenizer.eos_token_id,
49
+ )
50
+
51
+ model = LlamaForCausalLM(config)
52
+ print(f"Model parameters: {model.num_parameters():,}")
53
+
54
+ training_args = TrainingArguments(
55
+ output_dir="./StorySupra-10M",
56
+ per_device_train_batch_size=32,
57
+ num_train_epochs=3,
58
+ save_steps=500,
59
+ logging_steps=100,
60
+ learning_rate=5e-4,
61
+ weight_decay=0.01,
62
+ fp16=True,
63
+ push_to_hub=False,
64
+ report_to="none",
65
+ lr_scheduler_type="cosine"
66
+ )
67
+
68
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
69
+
70
+ trainer = Trainer(
71
+ model=model,
72
+ args=training_args,
73
+ train_dataset=tokenized_dataset,
74
+ data_collator=data_collator,
75
+ )
76
+
77
+ trainer.train()
78
+
79
+ def generate_story(prompt):
80
+ inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
81
+ model.to("cuda")
82
+ outputs = model.generate(**inputs, max_length=100, do_sample=True, temperature=0.55, top_k=25, top_p=0.85, repetition_penalty=1.1)
83
+ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
84
+
85
+ generate_story("Once upon a time, a small bird")
86
+
87
+ trainer.save_model("./StorySupra-10M")
88
+ tokenizer.save_pretrained("./StorySupra-10M")
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b0cca96b3100c2a57b8e16275daef5d68c6a103f14efbbb8dd80db4ca8f2738
3
+ size 5265
use-from-hf.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ StorySupra-10M — Interactive Story Generator
3
+ Loads model weights directly from HuggingFace: SupraLabs/StorySupra-10M
4
+ """
5
+
6
+ import torch
7
+ from transformers import LlamaForCausalLM, PreTrainedTokenizerFast
8
+
9
+ # ──────────────────────────────────────────────
10
+ # Configuration
11
+ # ──────────────────────────────────────────────
12
+ MODEL_ID = "SupraLabs/StorySupra-10M"
13
+
14
+ GENERATION_DEFAULTS = {
15
+ "max_new_tokens": 100,
16
+ "temperature": 0.55,
17
+ "top_k": 25,
18
+ "top_p": 0.85,
19
+ "repetition_penalty": 1.1,
20
+ "do_sample": True,
21
+ }
22
+
23
+ EXIT_COMMANDS = {"exit", "quit", "leave"}
24
+
25
+ # ──────────────────────────────────────────────
26
+ # Model loading
27
+ # ──────────────────────────────────────────────
28
+
29
+ def load_model(model_id: str):
30
+ """Download and return the tokenizer and model from HuggingFace Hub."""
31
+ print(f"Downloading model from HuggingFace: {model_id}")
32
+ print("(This may take a moment on first run — weights will be cached locally.)\n")
33
+
34
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id)
35
+ model = LlamaForCausalLM.from_pretrained(model_id)
36
+
37
+ device = "cuda" if torch.cuda.is_available() else "cpu"
38
+ print(f"Using device: {device}\n")
39
+
40
+ model.to(device)
41
+ model.eval()
42
+
43
+ return tokenizer, model, device
44
+
45
+
46
+ # ──────────────────────────────────────────────
47
+ # Text generation
48
+ # ──────────────────────────────────────────────
49
+
50
+ def generate_text(
51
+ prompt: str,
52
+ tokenizer,
53
+ model,
54
+ device: str,
55
+ max_new_tokens: int = GENERATION_DEFAULTS["max_new_tokens"],
56
+ temperature: float = GENERATION_DEFAULTS["temperature"],
57
+ top_k: int = GENERATION_DEFAULTS["top_k"],
58
+ top_p: float = GENERATION_DEFAULTS["top_p"],
59
+ repetition_penalty: float = GENERATION_DEFAULTS["repetition_penalty"],
60
+ ) -> str:
61
+ """Generate a story continuation from the given prompt."""
62
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
63
+
64
+ with torch.no_grad():
65
+ output_tokens = model.generate(
66
+ **inputs,
67
+ max_new_tokens=max_new_tokens,
68
+ do_sample=True,
69
+ temperature=temperature,
70
+ top_k=top_k,
71
+ top_p=top_p,
72
+ repetition_penalty=repetition_penalty,
73
+ pad_token_id=tokenizer.pad_token_id,
74
+ eos_token_id=tokenizer.eos_token_id,
75
+ )
76
+
77
+ return tokenizer.decode(output_tokens[0], skip_special_tokens=True)
78
+
79
+
80
+ # ──────────────────────────────────────────────
81
+ # Interactive loop
82
+ # ──────────────────────────────────────────────
83
+
84
+ def run():
85
+ print("=" * 50)
86
+ print(" StorySupra-10M — Interactive Story Generator")
87
+ print("=" * 50)
88
+
89
+ tokenizer, model, device = load_model(MODEL_ID)
90
+
91
+ print("-" * 50)
92
+ print("Model ready! Type a prompt to generate a story.")
93
+ print(f"Type {' / '.join(EXIT_COMMANDS)} to quit.")
94
+ print("-" * 50)
95
+
96
+ while True:
97
+ try:
98
+ user_prompt = input("\nYour prompt: ").strip()
99
+ except (EOFError, KeyboardInterrupt):
100
+ print("\nExiting. Goodbye!")
101
+ break
102
+
103
+ if not user_prompt:
104
+ print("Please enter a prompt.")
105
+ continue
106
+
107
+ if user_prompt.lower() in EXIT_COMMANDS:
108
+ print("Goodbye!")
109
+ break
110
+
111
+ print("\nGenerating...\n")
112
+ story = generate_text(user_prompt, tokenizer, model, device)
113
+
114
+ print("Generated story:")
115
+ print("-" * 20)
116
+ print(story)
117
+ print("-" * 20)
118
+
119
+
120
+ # ──────────────────────────────────────────────
121
+ # Entry point
122
+ # ──────────────────────────────────────────────
123
+
124
+ if __name__ == "__main__":
125
+ run()