Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import gradio as gr | |
| from transformers import AutoTokenizer | |
| from datasets import Dataset | |
| import os | |
| class NanoGPT(nn.Module): | |
| def __init__(self, vocab_size=30522, n_embd=96, n_head=4, n_layer=3, block_size=96): | |
| super().__init__() | |
| self.block_size = block_size | |
| self.tok_emb = nn.Embedding(vocab_size, n_embd) | |
| self.pos_emb = nn.Parameter(torch.zeros(1, block_size, n_embd)) | |
| self.drop = nn.Dropout(0.1) | |
| self.layers = nn.ModuleList([ | |
| nn.TransformerDecoderLayer( | |
| d_model=n_embd, nhead=n_head, dim_feedforward=n_embd*4, | |
| dropout=0.1, activation="gelu", batch_first=True | |
| ) for _ in range(n_layer) | |
| ]) | |
| self.ln_f = nn.LayerNorm(n_embd) | |
| self.head = nn.Linear(n_embd, vocab_size, bias=False) | |
| self.tok_emb.weight = self.head.weight # weight tying | |
| self.n_embd = n_embd | |
| def forward(self, idx, targets=None): | |
| B, T = idx.shape | |
| tok_emb = self.tok_emb(idx) | |
| pos_emb = self.pos_emb[:, :T, :] | |
| x = self.drop(tok_emb + pos_emb) | |
| for layer in self.layers: | |
| x = layer(x, None) # causal self-attention | |
| x = self.ln_f(x) | |
| logits = self.head(x) | |
| if targets is None: | |
| return logits, None | |
| B, T, C = logits.shape | |
| logits = logits.view(B*T, C) | |
| targets = targets.view(B*T) | |
| loss = F.cross_entropy(logits, targets) | |
| return logits, loss | |
| def generate(self, idx, max_new_tokens=80, temperature=0.95): | |
| for _ in range(max_new_tokens): | |
| idx_cond = idx[:, -self.block_size:] | |
| logits, _ = self(idx_cond) | |
| logits = logits[:, -1, :] / temperature | |
| probs = F.softmax(logits, dim=-1) | |
| next_idx = torch.multinomial(probs, num_samples=1) | |
| idx = torch.cat((idx, next_idx), dim=1) | |
| return idx | |
| # Globals | |
| tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") | |
| vocab_size = tokenizer.vocab_size | |
| block_size = 96 | |
| model = NanoGPT(vocab_size=vocab_size, n_embd=96, n_head=4, n_layer=3, block_size=block_size) | |
| model_path = "/data/nanogpt_yap.pt" # /data is persistent on Spaces | |
| if os.path.exists(model_path): | |
| model.load_state_dict(torch.load(model_path, map_location="cpu")) | |
| print("Loaded saved model") | |
| # Tiny dataset (repeat for more tokens) | |
| life_texts = [ | |
| "Life is what happens when you're busy making other plans.", | |
| "The meaning of life is to find your gift. The purpose is to give it away.", | |
| "You only live once, but if you do it right, once is enough.", | |
| "Hey human, existence is weird. Coffee helps.", | |
| "I think therefore I am... but mostly I just scroll.", | |
| "Why do we exist? Probably for the memes and Java code.", | |
| # add more if you want | |
| ] | |
| def create_dataset(): | |
| text = " ".join(life_texts * 50) # ~few k tokens | |
| encodings = tokenizer(text, return_tensors="pt") | |
| input_ids = encodings.input_ids[0] | |
| seqs = [] | |
| for i in range(0, len(input_ids) - block_size - 1, block_size // 2): | |
| chunk = input_ids[i:i + block_size + 1] | |
| if len(chunk) == block_size + 1: | |
| seqs.append(chunk) | |
| if not seqs: | |
| return None | |
| data = {"input_ids": [s[:-1].tolist() for s in seqs], "labels": [s[1:].tolist() for s in seqs]} | |
| return Dataset.from_dict(data) | |
| def train_once(): | |
| dataset = create_dataset() | |
| if dataset is None: | |
| return "Dataset too small!" | |
| def collator(features): | |
| batch = tokenizer.pad(features, padding=True, return_tensors="pt") | |
| batch["labels"] = batch["input_ids"].clone() | |
| return batch | |
| from transformers import Trainer, TrainingArguments | |
| args = TrainingArguments( | |
| output_dir="/data/results", | |
| num_train_epochs=5, | |
| per_device_train_batch_size=4, | |
| save_strategy="no", | |
| logging_steps=20, | |
| report_to="none", | |
| optim="adamw_torch", | |
| learning_rate=5e-4, | |
| ) | |
| trainer = Trainer( | |
| model=model, | |
| args=args, | |
| train_dataset=dataset, | |
| data_collator=collator, | |
| ) | |
| trainer.train() | |
| torch.save(model.state_dict(), model_path) | |
| return "Training finished! Model saved to /data. Chat now!" | |
| def chat_with_nano(message, history): | |
| if not message.strip(): | |
| return history + [["", "Say something existential... or about Java?"]] | |
| prompt = f"Human: {message}\nAI: " | |
| inputs = tokenizer(prompt, return_tensors="pt").input_ids | |
| with torch.no_grad(): | |
| generated = model.generate(inputs, max_new_tokens=80, temperature=0.95) | |
| response = tokenizer.decode(generated[0][len(inputs[0]):], skip_special_tokens=True).strip() | |
| history.append([message, response]) | |
| return history | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Nano Java/Life Yap AI") | |
| gr.Markdown("Tiny ~1M param transformer. Train once, then chat!") | |
| chatbot = gr.Chatbot(height=400) | |
| msg = gr.Textbox(placeholder="Ask about life, existence, or Java...") | |
| clear = gr.Button("Clear") | |
| train_btn = gr.Button("Train Nano Model (10-60 min on CPU – do once!)") | |
| status = gr.Textbox(label="Status") | |
| train_btn.click(train_once, outputs=status) | |
| def respond(message, chat_history): | |
| updated_history = chat_with_nano(message, chat_history) | |
| return "", updated_history | |
| msg.submit(respond, [msg, chatbot], [msg, chatbot]) | |
| clear.click(lambda: None, None, chatbot, queue=False) | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |