File size: 5,592 Bytes
46fdff5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import torch
import torch.nn as nn
import torch.nn.functional as F
import gradio as gr
from transformers import AutoTokenizer
from datasets import Dataset
import os

class NanoGPT(nn.Module):
    def __init__(self, vocab_size=30522, n_embd=96, n_head=4, n_layer=3, block_size=96):
        super().__init__()
        self.block_size = block_size
        self.tok_emb = nn.Embedding(vocab_size, n_embd)
        self.pos_emb = nn.Parameter(torch.zeros(1, block_size, n_embd))
        self.drop = nn.Dropout(0.1)

        self.layers = nn.ModuleList([
            nn.TransformerDecoderLayer(
                d_model=n_embd, nhead=n_head, dim_feedforward=n_embd*4,
                dropout=0.1, activation="gelu", batch_first=True
            ) for _ in range(n_layer)
        ])

        self.ln_f = nn.LayerNorm(n_embd)
        self.head = nn.Linear(n_embd, vocab_size, bias=False)
        self.tok_emb.weight = self.head.weight  # weight tying
        self.n_embd = n_embd

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.tok_emb(idx)
        pos_emb = self.pos_emb[:, :T, :]
        x = self.drop(tok_emb + pos_emb)

        for layer in self.layers:
            x = layer(x, None)  # causal self-attention

        x = self.ln_f(x)
        logits = self.head(x)

        if targets is None:
            return logits, None
        B, T, C = logits.shape
        logits = logits.view(B*T, C)
        targets = targets.view(B*T)
        loss = F.cross_entropy(logits, targets)
        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens=80, temperature=0.95):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)
            next_idx = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, next_idx), dim=1)
        return idx

# Globals
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
vocab_size = tokenizer.vocab_size
block_size = 96
model = NanoGPT(vocab_size=vocab_size, n_embd=96, n_head=4, n_layer=3, block_size=block_size)

model_path = "/data/nanogpt_yap.pt"  # /data is persistent on Spaces

if os.path.exists(model_path):
    model.load_state_dict(torch.load(model_path, map_location="cpu"))
    print("Loaded saved model")

# Tiny dataset (repeat for more tokens)
life_texts = [
    "Life is what happens when you're busy making other plans.",
    "The meaning of life is to find your gift. The purpose is to give it away.",
    "You only live once, but if you do it right, once is enough.",
    "Hey human, existence is weird. Coffee helps.",
    "I think therefore I am... but mostly I just scroll.",
    "Why do we exist? Probably for the memes and Java code.",
    # add more if you want
]

def create_dataset():
    text = " ".join(life_texts * 50)  # ~few k tokens
    encodings = tokenizer(text, return_tensors="pt")
    input_ids = encodings.input_ids[0]

    seqs = []
    for i in range(0, len(input_ids) - block_size - 1, block_size // 2):
        chunk = input_ids[i:i + block_size + 1]
        if len(chunk) == block_size + 1:
            seqs.append(chunk)

    if not seqs:
        return None
    data = {"input_ids": [s[:-1].tolist() for s in seqs], "labels": [s[1:].tolist() for s in seqs]}
    return Dataset.from_dict(data)

def train_once():
    dataset = create_dataset()
    if dataset is None:
        return "Dataset too small!"

    def collator(features):
        batch = tokenizer.pad(features, padding=True, return_tensors="pt")
        batch["labels"] = batch["input_ids"].clone()
        return batch

    from transformers import Trainer, TrainingArguments
    args = TrainingArguments(
        output_dir="/data/results",
        num_train_epochs=5,
        per_device_train_batch_size=4,
        save_strategy="no",
        logging_steps=20,
        report_to="none",
        optim="adamw_torch",
        learning_rate=5e-4,
    )

    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=dataset,
        data_collator=collator,
    )

    trainer.train()
    torch.save(model.state_dict(), model_path)
    return "Training finished! Model saved to /data. Chat now!"

def chat_with_nano(message, history):
    if not message.strip():
        return history + [["", "Say something existential... or about Java?"]]

    prompt = f"Human: {message}\nAI: "
    inputs = tokenizer(prompt, return_tensors="pt").input_ids

    with torch.no_grad():
        generated = model.generate(inputs, max_new_tokens=80, temperature=0.95)
    response = tokenizer.decode(generated[0][len(inputs[0]):], skip_special_tokens=True).strip()

    history.append([message, response])
    return history

with gr.Blocks() as demo:
    gr.Markdown("# Nano Java/Life Yap AI")
    gr.Markdown("Tiny ~1M param transformer. Train once, then chat!")

    chatbot = gr.Chatbot(height=400)
    msg = gr.Textbox(placeholder="Ask about life, existence, or Java...")
    clear = gr.Button("Clear")

    train_btn = gr.Button("Train Nano Model (10-60 min on CPU – do once!)")
    status = gr.Textbox(label="Status")

    train_btn.click(train_once, outputs=status)

    def respond(message, chat_history):
        updated_history = chat_with_nano(message, chat_history)
        return "", updated_history

    msg.submit(respond, [msg, chatbot], [msg, chatbot])
    clear.click(lambda: None, None, chatbot, queue=False)

demo.launch(server_name="0.0.0.0", server_port=7860)