Eeppa commited on
Commit
d79be6f
·
verified ·
1 Parent(s): 04ffdf2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -48
app.py CHANGED
@@ -6,6 +6,7 @@ from transformers import AutoTokenizer
6
  from datasets import Dataset
7
  import os
8
 
 
9
  class NanoGPT(nn.Module):
10
  def __init__(self, vocab_size=30522, n_embd=96, n_head=4, n_layer=3, block_size=96):
11
  super().__init__()
@@ -15,16 +16,15 @@ class NanoGPT(nn.Module):
15
  self.drop = nn.Dropout(0.1)
16
 
17
  self.layers = nn.ModuleList([
18
- nn.TransformerDecoderLayer(
19
- d_model=n_embd, nhead=n_head, dim_feedforward=n_embd*4,
20
- dropout=0.1, activation="gelu", batch_first=True
21
- ) for _ in range(n_layer)
22
  ])
23
 
24
  self.ln_f = nn.LayerNorm(n_embd)
25
  self.head = nn.Linear(n_embd, vocab_size, bias=False)
26
- self.tok_emb.weight = self.head.weight # weight tying
27
- self.n_embd = n_embd
28
 
29
  def forward(self, idx, targets=None):
30
  B, T = idx.shape
@@ -33,16 +33,17 @@ class NanoGPT(nn.Module):
33
  x = self.drop(tok_emb + pos_emb)
34
 
35
  for layer in self.layers:
36
- x = layer(x, None) # causal self-attention
37
 
38
  x = self.ln_f(x)
39
  logits = self.head(x)
40
 
41
  if targets is None:
42
  return logits, None
 
43
  B, T, C = logits.shape
44
- logits = logits.view(B*T, C)
45
- targets = targets.view(B*T)
46
  loss = F.cross_entropy(logits, targets)
47
  return logits, loss
48
 
@@ -57,58 +58,62 @@ class NanoGPT(nn.Module):
57
  idx = torch.cat((idx, next_idx), dim=1)
58
  return idx
59
 
60
- # Globals
61
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
62
- vocab_size = tokenizer.vocab_size
63
- block_size = 96
64
- model = NanoGPT(vocab_size=vocab_size, n_embd=96, n_head=4, n_layer=3, block_size=block_size)
65
-
66
- model_path = "/data/nanogpt_yap.pt" # /data is persistent on Spaces
67
 
68
  if os.path.exists(model_path):
69
  model.load_state_dict(torch.load(model_path, map_location="cpu"))
70
- print("Loaded saved model")
71
 
72
- # Tiny dataset (repeat for more tokens)
73
  life_texts = [
74
  "Life is what happens when you're busy making other plans.",
75
  "The meaning of life is to find your gift. The purpose is to give it away.",
76
  "You only live once, but if you do it right, once is enough.",
77
- "Hey human, existence is weird. Coffee helps.",
78
- "I think therefore I am... but mostly I just scroll.",
79
- "Why do we exist? Probably for the memes and Java code.",
80
- # add more if you want
81
  ]
82
 
83
  def create_dataset():
84
- text = " ".join(life_texts * 50) # ~few k tokens
85
  encodings = tokenizer(text, return_tensors="pt")
86
  input_ids = encodings.input_ids[0]
87
 
88
  seqs = []
89
- for i in range(0, len(input_ids) - block_size - 1, block_size // 2):
90
- chunk = input_ids[i:i + block_size + 1]
 
 
91
  if len(chunk) == block_size + 1:
92
  seqs.append(chunk)
93
 
94
  if not seqs:
95
  return None
96
- data = {"input_ids": [s[:-1].tolist() for s in seqs], "labels": [s[1:].tolist() for s in seqs]}
 
 
 
 
97
  return Dataset.from_dict(data)
98
 
99
- def train_once():
100
  dataset = create_dataset()
101
- if dataset is None:
102
- return "Dataset too small!"
103
 
104
- def collator(features):
105
  batch = tokenizer.pad(features, padding=True, return_tensors="pt")
106
  batch["labels"] = batch["input_ids"].clone()
107
  return batch
108
 
109
  from transformers import Trainer, TrainingArguments
 
110
  args = TrainingArguments(
111
- output_dir="/data/results",
112
  num_train_epochs=5,
113
  per_device_train_batch_size=4,
114
  save_strategy="no",
@@ -116,51 +121,55 @@ def train_once():
116
  report_to="none",
117
  optim="adamw_torch",
118
  learning_rate=5e-4,
 
119
  )
120
 
121
  trainer = Trainer(
122
  model=model,
123
  args=args,
124
  train_dataset=dataset,
125
- data_collator=collator,
126
  )
127
 
128
  trainer.train()
129
  torch.save(model.state_dict(), model_path)
130
- return "Training finished! Model saved to /data. Chat now!"
131
 
132
- def chat_with_nano(message, history):
133
- if not message.strip():
134
- return history + [["", "Say something existential... or about Java?"]]
135
 
136
  prompt = f"Human: {message}\nAI: "
137
  inputs = tokenizer(prompt, return_tensors="pt").input_ids
138
 
139
  with torch.no_grad():
140
  generated = model.generate(inputs, max_new_tokens=80, temperature=0.95)
141
- response = tokenizer.decode(generated[0][len(inputs[0]):], skip_special_tokens=True).strip()
 
 
142
 
143
  history.append([message, response])
144
  return history
145
 
 
146
  with gr.Blocks() as demo:
147
- gr.Markdown("# Nano Java/Life Yap AI")
148
- gr.Markdown("Tiny ~1M param transformer. Train once, then chat!")
149
 
150
  chatbot = gr.Chatbot(height=400)
151
- msg = gr.Textbox(placeholder="Ask about life, existence, or Java...")
152
- clear = gr.Button("Clear")
153
 
154
- train_btn = gr.Button("Train Nano Model (10-60 min on CPU – do once!)")
155
- status = gr.Textbox(label="Status")
156
 
157
- train_btn.click(train_once, outputs=status)
158
 
159
- def respond(message, chat_history):
160
- updated_history = chat_with_nano(message, chat_history)
161
- return "", updated_history
162
 
163
- msg.submit(respond, [msg, chatbot], [msg, chatbot])
164
- clear.click(lambda: None, None, chatbot, queue=False)
165
 
166
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
6
  from datasets import Dataset
7
  import os
8
 
9
+ # Tiny NanoGPT class
10
  class NanoGPT(nn.Module):
11
  def __init__(self, vocab_size=30522, n_embd=96, n_head=4, n_layer=3, block_size=96):
12
  super().__init__()
 
16
  self.drop = nn.Dropout(0.1)
17
 
18
  self.layers = nn.ModuleList([
19
+ nn.TransformerDecoderLayer(d_model=n_embd, nhead=n_head,
20
+ dim_feedforward=n_embd*4, dropout=0.1,
21
+ activation="gelu", batch_first=True)
22
+ for _ in range(n_layer)
23
  ])
24
 
25
  self.ln_f = nn.LayerNorm(n_embd)
26
  self.head = nn.Linear(n_embd, vocab_size, bias=False)
27
+ self.tok_emb.weight = self.head.weight # tie weights
 
28
 
29
  def forward(self, idx, targets=None):
30
  B, T = idx.shape
 
33
  x = self.drop(tok_emb + pos_emb)
34
 
35
  for layer in self.layers:
36
+ x = layer(x, None) # self-attn only, causal
37
 
38
  x = self.ln_f(x)
39
  logits = self.head(x)
40
 
41
  if targets is None:
42
  return logits, None
43
+
44
  B, T, C = logits.shape
45
+ logits = logits.view(B * T, C)
46
+ targets = targets.view(B * T)
47
  loss = F.cross_entropy(logits, targets)
48
  return logits, loss
49
 
 
58
  idx = torch.cat((idx, next_idx), dim=1)
59
  return idx
60
 
61
+ # Setup
62
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
63
+ model = NanoGPT()
64
+ model_path = "nanogpt_yap.pt" # saved in current dir (non-persistent on restart, but ok for test)
 
 
 
65
 
66
  if os.path.exists(model_path):
67
  model.load_state_dict(torch.load(model_path, map_location="cpu"))
68
+ print("Loaded existing model weights")
69
 
70
+ # Small dataset for yapping about life (repeat for more data)
71
  life_texts = [
72
  "Life is what happens when you're busy making other plans.",
73
  "The meaning of life is to find your gift. The purpose is to give it away.",
74
  "You only live once, but if you do it right, once is enough.",
75
+ "Existence is weird. Coffee helps sometimes.",
76
+ "I think therefore I am... mostly scrolling though.",
77
+ "Why do we exist? Probably for memes and Java bugs.",
78
+ "Another day, another existential crisis. Pass the tea.",
79
  ]
80
 
81
  def create_dataset():
82
+ text = " ".join(life_texts * 50) # small but repeated
83
  encodings = tokenizer(text, return_tensors="pt")
84
  input_ids = encodings.input_ids[0]
85
 
86
  seqs = []
87
+ block_size = 96
88
+ step = block_size // 2
89
+ for i in range(0, len(input_ids) - block_size - 1, step):
90
+ chunk = input_ids[i : i + block_size + 1]
91
  if len(chunk) == block_size + 1:
92
  seqs.append(chunk)
93
 
94
  if not seqs:
95
  return None
96
+
97
+ data = {
98
+ "input_ids": [s[:-1].tolist() for s in seqs],
99
+ "labels": [s[1:].tolist() for s in seqs],
100
+ }
101
  return Dataset.from_dict(data)
102
 
103
+ def train_model():
104
  dataset = create_dataset()
105
+ if dataset is None or len(dataset) == 0:
106
+ return "Dataset creation failed - too small!"
107
 
108
+ def data_collator(features):
109
  batch = tokenizer.pad(features, padding=True, return_tensors="pt")
110
  batch["labels"] = batch["input_ids"].clone()
111
  return batch
112
 
113
  from transformers import Trainer, TrainingArguments
114
+
115
  args = TrainingArguments(
116
+ output_dir="./results",
117
  num_train_epochs=5,
118
  per_device_train_batch_size=4,
119
  save_strategy="no",
 
121
  report_to="none",
122
  optim="adamw_torch",
123
  learning_rate=5e-4,
124
+ fp16=False, # CPU
125
  )
126
 
127
  trainer = Trainer(
128
  model=model,
129
  args=args,
130
  train_dataset=dataset,
131
+ data_collator=data_collator,
132
  )
133
 
134
  trainer.train()
135
  torch.save(model.state_dict(), model_path)
136
+ return "Training done! Model saved. You can chat now (responses may be silly)."
137
 
138
+ def generate_response(message, history):
139
+ if not message:
140
+ return history + [["", "Ask me something deep... or weird."]]
141
 
142
  prompt = f"Human: {message}\nAI: "
143
  inputs = tokenizer(prompt, return_tensors="pt").input_ids
144
 
145
  with torch.no_grad():
146
  generated = model.generate(inputs, max_new_tokens=80, temperature=0.95)
147
+
148
+ full_text = tokenizer.decode(generated[0])
149
+ response = full_text[len(prompt):].strip() # trim prompt part
150
 
151
  history.append([message, response])
152
  return history
153
 
154
+ # Gradio UI
155
  with gr.Blocks() as demo:
156
+ gr.Markdown("# Nano AI Yap Test")
157
+ gr.Markdown("Tiny from-scratch model (~1M params). Train first, then chat!")
158
 
159
  chatbot = gr.Chatbot(height=400)
160
+ textbox = gr.Textbox(placeholder="Talk to me about life, existence, or anything...")
161
+ clear_btn = gr.Button("Clear Chat")
162
 
163
+ train_button = gr.Button("Start Training (takes 1060 min on free CPU – run once)")
164
+ status_box = gr.Textbox(label="Training Status", interactive=False)
165
 
166
+ train_button.click(train_model, outputs=status_box)
167
 
168
+ def submit_chat(msg, hist):
169
+ updated_hist = generate_response(msg, hist)
170
+ return "", updated_hist
171
 
172
+ textbox.submit(submit_chat, [textbox, chatbot], [textbox, chatbot])
173
+ clear_btn.click(lambda: None, None, chatbot, queue=False)
174
 
175
  demo.launch(server_name="0.0.0.0", server_port=7860)