File size: 5,854 Bytes
46fdff5
 
 
 
 
 
 
 
d79be6f
46fdff5
 
 
 
 
 
 
 
 
d79be6f
 
 
 
46fdff5
 
 
 
d79be6f
46fdff5
 
 
 
 
 
 
 
d79be6f
46fdff5
 
 
 
 
 
d79be6f
46fdff5
d79be6f
 
46fdff5
 
 
 
 
 
 
 
 
 
 
 
 
 
d79be6f
46fdff5
d79be6f
 
46fdff5
 
 
d79be6f
46fdff5
d79be6f
46fdff5
 
 
 
d79be6f
 
 
 
46fdff5
 
 
d79be6f
46fdff5
 
 
 
d79be6f
 
 
 
46fdff5
 
 
 
 
d79be6f
 
 
 
 
46fdff5
 
d79be6f
46fdff5
d79be6f
 
46fdff5
d79be6f
46fdff5
 
 
 
 
d79be6f
46fdff5
d79be6f
46fdff5
 
 
 
 
 
 
d79be6f
46fdff5
 
 
 
 
 
d79be6f
46fdff5
 
 
 
d79be6f
46fdff5
d79be6f
 
 
46fdff5
 
 
 
 
 
d79be6f
 
 
46fdff5
 
 
 
d79be6f
46fdff5
d79be6f
 
46fdff5
 
d79be6f
 
46fdff5
d79be6f
 
46fdff5
d79be6f
46fdff5
d79be6f
 
 
46fdff5
d79be6f
 
46fdff5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import torch
import torch.nn as nn
import torch.nn.functional as F
import gradio as gr
from transformers import AutoTokenizer
from datasets import Dataset
import os

# Tiny NanoGPT class
class NanoGPT(nn.Module):
    def __init__(self, vocab_size=30522, n_embd=96, n_head=4, n_layer=3, block_size=96):
        super().__init__()
        self.block_size = block_size
        self.tok_emb = nn.Embedding(vocab_size, n_embd)
        self.pos_emb = nn.Parameter(torch.zeros(1, block_size, n_embd))
        self.drop = nn.Dropout(0.1)

        self.layers = nn.ModuleList([
            nn.TransformerDecoderLayer(d_model=n_embd, nhead=n_head,
                                       dim_feedforward=n_embd*4, dropout=0.1,
                                       activation="gelu", batch_first=True)
            for _ in range(n_layer)
        ])

        self.ln_f = nn.LayerNorm(n_embd)
        self.head = nn.Linear(n_embd, vocab_size, bias=False)
        self.tok_emb.weight = self.head.weight  # tie weights

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.tok_emb(idx)
        pos_emb = self.pos_emb[:, :T, :]
        x = self.drop(tok_emb + pos_emb)

        for layer in self.layers:
            x = layer(x, None)  # self-attn only, causal

        x = self.ln_f(x)
        logits = self.head(x)

        if targets is None:
            return logits, None

        B, T, C = logits.shape
        logits = logits.view(B * T, C)
        targets = targets.view(B * T)
        loss = F.cross_entropy(logits, targets)
        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens=80, temperature=0.95):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)
            next_idx = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, next_idx), dim=1)
        return idx

# Setup
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = NanoGPT()
model_path = "nanogpt_yap.pt"  # saved in current dir (non-persistent on restart, but ok for test)

if os.path.exists(model_path):
    model.load_state_dict(torch.load(model_path, map_location="cpu"))
    print("Loaded existing model weights")

# Small dataset for yapping about life (repeat for more data)
life_texts = [
    "Life is what happens when you're busy making other plans.",
    "The meaning of life is to find your gift. The purpose is to give it away.",
    "You only live once, but if you do it right, once is enough.",
    "Existence is weird. Coffee helps sometimes.",
    "I think therefore I am... mostly scrolling though.",
    "Why do we exist? Probably for memes and Java bugs.",
    "Another day, another existential crisis. Pass the tea.",
]

def create_dataset():
    text = " ".join(life_texts * 50)  # small but repeated
    encodings = tokenizer(text, return_tensors="pt")
    input_ids = encodings.input_ids[0]

    seqs = []
    block_size = 96
    step = block_size // 2
    for i in range(0, len(input_ids) - block_size - 1, step):
        chunk = input_ids[i : i + block_size + 1]
        if len(chunk) == block_size + 1:
            seqs.append(chunk)

    if not seqs:
        return None

    data = {
        "input_ids": [s[:-1].tolist() for s in seqs],
        "labels": [s[1:].tolist() for s in seqs],
    }
    return Dataset.from_dict(data)

def train_model():
    dataset = create_dataset()
    if dataset is None or len(dataset) == 0:
        return "Dataset creation failed - too small!"

    def data_collator(features):
        batch = tokenizer.pad(features, padding=True, return_tensors="pt")
        batch["labels"] = batch["input_ids"].clone()
        return batch

    from transformers import Trainer, TrainingArguments

    args = TrainingArguments(
        output_dir="./results",
        num_train_epochs=5,
        per_device_train_batch_size=4,
        save_strategy="no",
        logging_steps=20,
        report_to="none",
        optim="adamw_torch",
        learning_rate=5e-4,
        fp16=False,  # CPU
    )

    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=dataset,
        data_collator=data_collator,
    )

    trainer.train()
    torch.save(model.state_dict(), model_path)
    return "Training done! Model saved. You can chat now (responses may be silly)."

def generate_response(message, history):
    if not message:
        return history + [["", "Ask me something deep... or weird."]]

    prompt = f"Human: {message}\nAI: "
    inputs = tokenizer(prompt, return_tensors="pt").input_ids

    with torch.no_grad():
        generated = model.generate(inputs, max_new_tokens=80, temperature=0.95)

    full_text = tokenizer.decode(generated[0])
    response = full_text[len(prompt):].strip()  # trim prompt part

    history.append([message, response])
    return history

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("# Nano AI Yap Test")
    gr.Markdown("Tiny from-scratch model (~1M params). Train first, then chat!")

    chatbot = gr.Chatbot(height=400)
    textbox = gr.Textbox(placeholder="Talk to me about life, existence, or anything...")
    clear_btn = gr.Button("Clear Chat")

    train_button = gr.Button("Start Training (takes 10–60 min on free CPU – run once)")
    status_box = gr.Textbox(label="Training Status", interactive=False)

    train_button.click(train_model, outputs=status_box)

    def submit_chat(msg, hist):
        updated_hist = generate_response(msg, hist)
        return "", updated_hist

    textbox.submit(submit_chat, [textbox, chatbot], [textbox, chatbot])
    clear_btn.click(lambda: None, None, chatbot, queue=False)

demo.launch(server_name="0.0.0.0", server_port=7860)