import torch import torch.nn as nn import torch.nn.functional as F import gradio as gr from transformers import AutoTokenizer from datasets import Dataset import os class NanoGPT(nn.Module): def __init__(self, vocab_size=30522, n_embd=96, n_head=4, n_layer=3, block_size=96): super().__init__() self.block_size = block_size self.tok_emb = nn.Embedding(vocab_size, n_embd) self.pos_emb = nn.Parameter(torch.zeros(1, block_size, n_embd)) self.drop = nn.Dropout(0.1) self.layers = nn.ModuleList([ nn.TransformerDecoderLayer( d_model=n_embd, nhead=n_head, dim_feedforward=n_embd*4, dropout=0.1, activation="gelu", batch_first=True ) for _ in range(n_layer) ]) self.ln_f = nn.LayerNorm(n_embd) self.head = nn.Linear(n_embd, vocab_size, bias=False) self.tok_emb.weight = self.head.weight # weight tying self.n_embd = n_embd def forward(self, idx, targets=None): B, T = idx.shape tok_emb = self.tok_emb(idx) pos_emb = self.pos_emb[:, :T, :] x = self.drop(tok_emb + pos_emb) for layer in self.layers: x = layer(x, None) # causal self-attention x = self.ln_f(x) logits = self.head(x) if targets is None: return logits, None B, T, C = logits.shape logits = logits.view(B*T, C) targets = targets.view(B*T) loss = F.cross_entropy(logits, targets) return logits, loss @torch.no_grad() def generate(self, idx, max_new_tokens=80, temperature=0.95): for _ in range(max_new_tokens): idx_cond = idx[:, -self.block_size:] logits, _ = self(idx_cond) logits = logits[:, -1, :] / temperature probs = F.softmax(logits, dim=-1) next_idx = torch.multinomial(probs, num_samples=1) idx = torch.cat((idx, next_idx), dim=1) return idx # Globals tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") vocab_size = tokenizer.vocab_size block_size = 96 model = NanoGPT(vocab_size=vocab_size, n_embd=96, n_head=4, n_layer=3, block_size=block_size) model_path = "/data/nanogpt_yap.pt" # /data is persistent on Spaces if os.path.exists(model_path): model.load_state_dict(torch.load(model_path, map_location="cpu")) print("Loaded saved model") # Tiny dataset (repeat for more tokens) life_texts = [ "Life is what happens when you're busy making other plans.", "The meaning of life is to find your gift. The purpose is to give it away.", "You only live once, but if you do it right, once is enough.", "Hey human, existence is weird. Coffee helps.", "I think therefore I am... but mostly I just scroll.", "Why do we exist? Probably for the memes and Java code.", # add more if you want ] def create_dataset(): text = " ".join(life_texts * 50) # ~few k tokens encodings = tokenizer(text, return_tensors="pt") input_ids = encodings.input_ids[0] seqs = [] for i in range(0, len(input_ids) - block_size - 1, block_size // 2): chunk = input_ids[i:i + block_size + 1] if len(chunk) == block_size + 1: seqs.append(chunk) if not seqs: return None data = {"input_ids": [s[:-1].tolist() for s in seqs], "labels": [s[1:].tolist() for s in seqs]} return Dataset.from_dict(data) def train_once(): dataset = create_dataset() if dataset is None: return "Dataset too small!" def collator(features): batch = tokenizer.pad(features, padding=True, return_tensors="pt") batch["labels"] = batch["input_ids"].clone() return batch from transformers import Trainer, TrainingArguments args = TrainingArguments( output_dir="/data/results", num_train_epochs=5, per_device_train_batch_size=4, save_strategy="no", logging_steps=20, report_to="none", optim="adamw_torch", learning_rate=5e-4, ) trainer = Trainer( model=model, args=args, train_dataset=dataset, data_collator=collator, ) trainer.train() torch.save(model.state_dict(), model_path) return "Training finished! Model saved to /data. Chat now!" def chat_with_nano(message, history): if not message.strip(): return history + [["", "Say something existential... or about Java?"]] prompt = f"Human: {message}\nAI: " inputs = tokenizer(prompt, return_tensors="pt").input_ids with torch.no_grad(): generated = model.generate(inputs, max_new_tokens=80, temperature=0.95) response = tokenizer.decode(generated[0][len(inputs[0]):], skip_special_tokens=True).strip() history.append([message, response]) return history with gr.Blocks() as demo: gr.Markdown("# Nano Java/Life Yap AI") gr.Markdown("Tiny ~1M param transformer. Train once, then chat!") chatbot = gr.Chatbot(height=400) msg = gr.Textbox(placeholder="Ask about life, existence, or Java...") clear = gr.Button("Clear") train_btn = gr.Button("Train Nano Model (10-60 min on CPU – do once!)") status = gr.Textbox(label="Status") train_btn.click(train_once, outputs=status) def respond(message, chat_history): updated_history = chat_with_nano(message, chat_history) return "", updated_history msg.submit(respond, [msg, chatbot], [msg, chatbot]) clear.click(lambda: None, None, chatbot, queue=False) demo.launch(server_name="0.0.0.0", server_port=7860)