Java-AI-Nano-V1 / app.py
Eeppa's picture
Create app.py
46fdff5 verified
raw
history blame
5.59 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import gradio as gr
from transformers import AutoTokenizer
from datasets import Dataset
import os
class NanoGPT(nn.Module):
def __init__(self, vocab_size=30522, n_embd=96, n_head=4, n_layer=3, block_size=96):
super().__init__()
self.block_size = block_size
self.tok_emb = nn.Embedding(vocab_size, n_embd)
self.pos_emb = nn.Parameter(torch.zeros(1, block_size, n_embd))
self.drop = nn.Dropout(0.1)
self.layers = nn.ModuleList([
nn.TransformerDecoderLayer(
d_model=n_embd, nhead=n_head, dim_feedforward=n_embd*4,
dropout=0.1, activation="gelu", batch_first=True
) for _ in range(n_layer)
])
self.ln_f = nn.LayerNorm(n_embd)
self.head = nn.Linear(n_embd, vocab_size, bias=False)
self.tok_emb.weight = self.head.weight # weight tying
self.n_embd = n_embd
def forward(self, idx, targets=None):
B, T = idx.shape
tok_emb = self.tok_emb(idx)
pos_emb = self.pos_emb[:, :T, :]
x = self.drop(tok_emb + pos_emb)
for layer in self.layers:
x = layer(x, None) # causal self-attention
x = self.ln_f(x)
logits = self.head(x)
if targets is None:
return logits, None
B, T, C = logits.shape
logits = logits.view(B*T, C)
targets = targets.view(B*T)
loss = F.cross_entropy(logits, targets)
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens=80, temperature=0.95):
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.block_size:]
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / temperature
probs = F.softmax(logits, dim=-1)
next_idx = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, next_idx), dim=1)
return idx
# Globals
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
vocab_size = tokenizer.vocab_size
block_size = 96
model = NanoGPT(vocab_size=vocab_size, n_embd=96, n_head=4, n_layer=3, block_size=block_size)
model_path = "/data/nanogpt_yap.pt" # /data is persistent on Spaces
if os.path.exists(model_path):
model.load_state_dict(torch.load(model_path, map_location="cpu"))
print("Loaded saved model")
# Tiny dataset (repeat for more tokens)
life_texts = [
"Life is what happens when you're busy making other plans.",
"The meaning of life is to find your gift. The purpose is to give it away.",
"You only live once, but if you do it right, once is enough.",
"Hey human, existence is weird. Coffee helps.",
"I think therefore I am... but mostly I just scroll.",
"Why do we exist? Probably for the memes and Java code.",
# add more if you want
]
def create_dataset():
text = " ".join(life_texts * 50) # ~few k tokens
encodings = tokenizer(text, return_tensors="pt")
input_ids = encodings.input_ids[0]
seqs = []
for i in range(0, len(input_ids) - block_size - 1, block_size // 2):
chunk = input_ids[i:i + block_size + 1]
if len(chunk) == block_size + 1:
seqs.append(chunk)
if not seqs:
return None
data = {"input_ids": [s[:-1].tolist() for s in seqs], "labels": [s[1:].tolist() for s in seqs]}
return Dataset.from_dict(data)
def train_once():
dataset = create_dataset()
if dataset is None:
return "Dataset too small!"
def collator(features):
batch = tokenizer.pad(features, padding=True, return_tensors="pt")
batch["labels"] = batch["input_ids"].clone()
return batch
from transformers import Trainer, TrainingArguments
args = TrainingArguments(
output_dir="/data/results",
num_train_epochs=5,
per_device_train_batch_size=4,
save_strategy="no",
logging_steps=20,
report_to="none",
optim="adamw_torch",
learning_rate=5e-4,
)
trainer = Trainer(
model=model,
args=args,
train_dataset=dataset,
data_collator=collator,
)
trainer.train()
torch.save(model.state_dict(), model_path)
return "Training finished! Model saved to /data. Chat now!"
def chat_with_nano(message, history):
if not message.strip():
return history + [["", "Say something existential... or about Java?"]]
prompt = f"Human: {message}\nAI: "
inputs = tokenizer(prompt, return_tensors="pt").input_ids
with torch.no_grad():
generated = model.generate(inputs, max_new_tokens=80, temperature=0.95)
response = tokenizer.decode(generated[0][len(inputs[0]):], skip_special_tokens=True).strip()
history.append([message, response])
return history
with gr.Blocks() as demo:
gr.Markdown("# Nano Java/Life Yap AI")
gr.Markdown("Tiny ~1M param transformer. Train once, then chat!")
chatbot = gr.Chatbot(height=400)
msg = gr.Textbox(placeholder="Ask about life, existence, or Java...")
clear = gr.Button("Clear")
train_btn = gr.Button("Train Nano Model (10-60 min on CPU – do once!)")
status = gr.Textbox(label="Status")
train_btn.click(train_once, outputs=status)
def respond(message, chat_history):
updated_history = chat_with_nano(message, chat_history)
return "", updated_history
msg.submit(respond, [msg, chatbot], [msg, chatbot])
clear.click(lambda: None, None, chatbot, queue=False)
demo.launch(server_name="0.0.0.0", server_port=7860)