Java-AI-Nano-V1 / app.py
Eeppa's picture
Update app.py
d79be6f verified
raw
history blame
5.85 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import gradio as gr
from transformers import AutoTokenizer
from datasets import Dataset
import os
# Tiny NanoGPT class
class NanoGPT(nn.Module):
def __init__(self, vocab_size=30522, n_embd=96, n_head=4, n_layer=3, block_size=96):
super().__init__()
self.block_size = block_size
self.tok_emb = nn.Embedding(vocab_size, n_embd)
self.pos_emb = nn.Parameter(torch.zeros(1, block_size, n_embd))
self.drop = nn.Dropout(0.1)
self.layers = nn.ModuleList([
nn.TransformerDecoderLayer(d_model=n_embd, nhead=n_head,
dim_feedforward=n_embd*4, dropout=0.1,
activation="gelu", batch_first=True)
for _ in range(n_layer)
])
self.ln_f = nn.LayerNorm(n_embd)
self.head = nn.Linear(n_embd, vocab_size, bias=False)
self.tok_emb.weight = self.head.weight # tie weights
def forward(self, idx, targets=None):
B, T = idx.shape
tok_emb = self.tok_emb(idx)
pos_emb = self.pos_emb[:, :T, :]
x = self.drop(tok_emb + pos_emb)
for layer in self.layers:
x = layer(x, None) # self-attn only, causal
x = self.ln_f(x)
logits = self.head(x)
if targets is None:
return logits, None
B, T, C = logits.shape
logits = logits.view(B * T, C)
targets = targets.view(B * T)
loss = F.cross_entropy(logits, targets)
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens=80, temperature=0.95):
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.block_size:]
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / temperature
probs = F.softmax(logits, dim=-1)
next_idx = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, next_idx), dim=1)
return idx
# Setup
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = NanoGPT()
model_path = "nanogpt_yap.pt" # saved in current dir (non-persistent on restart, but ok for test)
if os.path.exists(model_path):
model.load_state_dict(torch.load(model_path, map_location="cpu"))
print("Loaded existing model weights")
# Small dataset for yapping about life (repeat for more data)
life_texts = [
"Life is what happens when you're busy making other plans.",
"The meaning of life is to find your gift. The purpose is to give it away.",
"You only live once, but if you do it right, once is enough.",
"Existence is weird. Coffee helps sometimes.",
"I think therefore I am... mostly scrolling though.",
"Why do we exist? Probably for memes and Java bugs.",
"Another day, another existential crisis. Pass the tea.",
]
def create_dataset():
text = " ".join(life_texts * 50) # small but repeated
encodings = tokenizer(text, return_tensors="pt")
input_ids = encodings.input_ids[0]
seqs = []
block_size = 96
step = block_size // 2
for i in range(0, len(input_ids) - block_size - 1, step):
chunk = input_ids[i : i + block_size + 1]
if len(chunk) == block_size + 1:
seqs.append(chunk)
if not seqs:
return None
data = {
"input_ids": [s[:-1].tolist() for s in seqs],
"labels": [s[1:].tolist() for s in seqs],
}
return Dataset.from_dict(data)
def train_model():
dataset = create_dataset()
if dataset is None or len(dataset) == 0:
return "Dataset creation failed - too small!"
def data_collator(features):
batch = tokenizer.pad(features, padding=True, return_tensors="pt")
batch["labels"] = batch["input_ids"].clone()
return batch
from transformers import Trainer, TrainingArguments
args = TrainingArguments(
output_dir="./results",
num_train_epochs=5,
per_device_train_batch_size=4,
save_strategy="no",
logging_steps=20,
report_to="none",
optim="adamw_torch",
learning_rate=5e-4,
fp16=False, # CPU
)
trainer = Trainer(
model=model,
args=args,
train_dataset=dataset,
data_collator=data_collator,
)
trainer.train()
torch.save(model.state_dict(), model_path)
return "Training done! Model saved. You can chat now (responses may be silly)."
def generate_response(message, history):
if not message:
return history + [["", "Ask me something deep... or weird."]]
prompt = f"Human: {message}\nAI: "
inputs = tokenizer(prompt, return_tensors="pt").input_ids
with torch.no_grad():
generated = model.generate(inputs, max_new_tokens=80, temperature=0.95)
full_text = tokenizer.decode(generated[0])
response = full_text[len(prompt):].strip() # trim prompt part
history.append([message, response])
return history
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# Nano AI Yap Test")
gr.Markdown("Tiny from-scratch model (~1M params). Train first, then chat!")
chatbot = gr.Chatbot(height=400)
textbox = gr.Textbox(placeholder="Talk to me about life, existence, or anything...")
clear_btn = gr.Button("Clear Chat")
train_button = gr.Button("Start Training (takes 10–60 min on free CPU – run once)")
status_box = gr.Textbox(label="Training Status", interactive=False)
train_button.click(train_model, outputs=status_box)
def submit_chat(msg, hist):
updated_hist = generate_response(msg, hist)
return "", updated_hist
textbox.submit(submit_chat, [textbox, chatbot], [textbox, chatbot])
clear_btn.click(lambda: None, None, chatbot, queue=False)
demo.launch(server_name="0.0.0.0", server_port=7860)