Eeppa commited on
Commit
46fdff5
·
verified ·
1 Parent(s): 482e073

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +166 -0
app.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import gradio as gr
5
+ from transformers import AutoTokenizer
6
+ from datasets import Dataset
7
+ import os
8
+
9
+ class NanoGPT(nn.Module):
10
+ def __init__(self, vocab_size=30522, n_embd=96, n_head=4, n_layer=3, block_size=96):
11
+ super().__init__()
12
+ self.block_size = block_size
13
+ self.tok_emb = nn.Embedding(vocab_size, n_embd)
14
+ self.pos_emb = nn.Parameter(torch.zeros(1, block_size, n_embd))
15
+ self.drop = nn.Dropout(0.1)
16
+
17
+ self.layers = nn.ModuleList([
18
+ nn.TransformerDecoderLayer(
19
+ d_model=n_embd, nhead=n_head, dim_feedforward=n_embd*4,
20
+ dropout=0.1, activation="gelu", batch_first=True
21
+ ) for _ in range(n_layer)
22
+ ])
23
+
24
+ self.ln_f = nn.LayerNorm(n_embd)
25
+ self.head = nn.Linear(n_embd, vocab_size, bias=False)
26
+ self.tok_emb.weight = self.head.weight # weight tying
27
+ self.n_embd = n_embd
28
+
29
+ def forward(self, idx, targets=None):
30
+ B, T = idx.shape
31
+ tok_emb = self.tok_emb(idx)
32
+ pos_emb = self.pos_emb[:, :T, :]
33
+ x = self.drop(tok_emb + pos_emb)
34
+
35
+ for layer in self.layers:
36
+ x = layer(x, None) # causal self-attention
37
+
38
+ x = self.ln_f(x)
39
+ logits = self.head(x)
40
+
41
+ if targets is None:
42
+ return logits, None
43
+ B, T, C = logits.shape
44
+ logits = logits.view(B*T, C)
45
+ targets = targets.view(B*T)
46
+ loss = F.cross_entropy(logits, targets)
47
+ return logits, loss
48
+
49
+ @torch.no_grad()
50
+ def generate(self, idx, max_new_tokens=80, temperature=0.95):
51
+ for _ in range(max_new_tokens):
52
+ idx_cond = idx[:, -self.block_size:]
53
+ logits, _ = self(idx_cond)
54
+ logits = logits[:, -1, :] / temperature
55
+ probs = F.softmax(logits, dim=-1)
56
+ next_idx = torch.multinomial(probs, num_samples=1)
57
+ idx = torch.cat((idx, next_idx), dim=1)
58
+ return idx
59
+
60
+ # Globals
61
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
62
+ vocab_size = tokenizer.vocab_size
63
+ block_size = 96
64
+ model = NanoGPT(vocab_size=vocab_size, n_embd=96, n_head=4, n_layer=3, block_size=block_size)
65
+
66
+ model_path = "/data/nanogpt_yap.pt" # /data is persistent on Spaces
67
+
68
+ if os.path.exists(model_path):
69
+ model.load_state_dict(torch.load(model_path, map_location="cpu"))
70
+ print("Loaded saved model")
71
+
72
+ # Tiny dataset (repeat for more tokens)
73
+ life_texts = [
74
+ "Life is what happens when you're busy making other plans.",
75
+ "The meaning of life is to find your gift. The purpose is to give it away.",
76
+ "You only live once, but if you do it right, once is enough.",
77
+ "Hey human, existence is weird. Coffee helps.",
78
+ "I think therefore I am... but mostly I just scroll.",
79
+ "Why do we exist? Probably for the memes and Java code.",
80
+ # add more if you want
81
+ ]
82
+
83
+ def create_dataset():
84
+ text = " ".join(life_texts * 50) # ~few k tokens
85
+ encodings = tokenizer(text, return_tensors="pt")
86
+ input_ids = encodings.input_ids[0]
87
+
88
+ seqs = []
89
+ for i in range(0, len(input_ids) - block_size - 1, block_size // 2):
90
+ chunk = input_ids[i:i + block_size + 1]
91
+ if len(chunk) == block_size + 1:
92
+ seqs.append(chunk)
93
+
94
+ if not seqs:
95
+ return None
96
+ data = {"input_ids": [s[:-1].tolist() for s in seqs], "labels": [s[1:].tolist() for s in seqs]}
97
+ return Dataset.from_dict(data)
98
+
99
+ def train_once():
100
+ dataset = create_dataset()
101
+ if dataset is None:
102
+ return "Dataset too small!"
103
+
104
+ def collator(features):
105
+ batch = tokenizer.pad(features, padding=True, return_tensors="pt")
106
+ batch["labels"] = batch["input_ids"].clone()
107
+ return batch
108
+
109
+ from transformers import Trainer, TrainingArguments
110
+ args = TrainingArguments(
111
+ output_dir="/data/results",
112
+ num_train_epochs=5,
113
+ per_device_train_batch_size=4,
114
+ save_strategy="no",
115
+ logging_steps=20,
116
+ report_to="none",
117
+ optim="adamw_torch",
118
+ learning_rate=5e-4,
119
+ )
120
+
121
+ trainer = Trainer(
122
+ model=model,
123
+ args=args,
124
+ train_dataset=dataset,
125
+ data_collator=collator,
126
+ )
127
+
128
+ trainer.train()
129
+ torch.save(model.state_dict(), model_path)
130
+ return "Training finished! Model saved to /data. Chat now!"
131
+
132
+ def chat_with_nano(message, history):
133
+ if not message.strip():
134
+ return history + [["", "Say something existential... or about Java?"]]
135
+
136
+ prompt = f"Human: {message}\nAI: "
137
+ inputs = tokenizer(prompt, return_tensors="pt").input_ids
138
+
139
+ with torch.no_grad():
140
+ generated = model.generate(inputs, max_new_tokens=80, temperature=0.95)
141
+ response = tokenizer.decode(generated[0][len(inputs[0]):], skip_special_tokens=True).strip()
142
+
143
+ history.append([message, response])
144
+ return history
145
+
146
+ with gr.Blocks() as demo:
147
+ gr.Markdown("# Nano Java/Life Yap AI")
148
+ gr.Markdown("Tiny ~1M param transformer. Train once, then chat!")
149
+
150
+ chatbot = gr.Chatbot(height=400)
151
+ msg = gr.Textbox(placeholder="Ask about life, existence, or Java...")
152
+ clear = gr.Button("Clear")
153
+
154
+ train_btn = gr.Button("Train Nano Model (10-60 min on CPU – do once!)")
155
+ status = gr.Textbox(label="Status")
156
+
157
+ train_btn.click(train_once, outputs=status)
158
+
159
+ def respond(message, chat_history):
160
+ updated_history = chat_with_nano(message, chat_history)
161
+ return "", updated_history
162
+
163
+ msg.submit(respond, [msg, chatbot], [msg, chatbot])
164
+ clear.click(lambda: None, None, chatbot, queue=False)
165
+
166
+ demo.launch(server_name="0.0.0.0", server_port=7860)