OpenTransformer commited on
Commit
44d9388
Β·
verified Β·
1 Parent(s): 4cfe745

Upload stream_trainer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. stream_trainer.py +194 -0
stream_trainer.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ WIRE-SPEED TRANSFORMER - Learns directly from network stream
4
+ No batching. No epochs. Just continuous absorption.
5
+
6
+ Receives tokenized data via stdin from Rust feeder.
7
+ Updates weights after every micro-batch (configurable, default 32 tokens).
8
+ """
9
+
10
+ import sys
11
+ import math
12
+ import time
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from collections import deque
17
+
18
+ # ─────────────────── Config ───────────────────
19
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ torch.backends.cuda.matmul.allow_tf32 = True
21
+
22
+ # Tiny model for wire-speed updates
23
+ CONFIG = {
24
+ "d": 256, # embedding dim
25
+ "layers": 4, # transformer layers
26
+ "heads": 8, # attention heads
27
+ "rank": 32, # attention rank (from n.py's tuneable attention)
28
+ "vocab": 128256, # DeepSeek V3.2 vocab
29
+ "ctx": 512, # context window
30
+ }
31
+
32
+ LR = 1e-4
33
+ UPDATE_EVERY = 32 # tokens between weight updates (micro-batch)
34
+ PRINT_EVERY = 10000 # tokens between stats
35
+
36
+ # ─────────────────── Model (simplified from n.py) ───────────────────
37
+ class TuneableAttention(nn.Module):
38
+ def __init__(self, d, h, r):
39
+ super().__init__()
40
+ self.h, self.dk, self.r = h, d // h, r
41
+ self.qkv = nn.Linear(d, 3 * d, bias=False)
42
+ self.U = nn.Parameter(torch.randn(self.dk, r) * 0.02)
43
+ self.proj = nn.Linear(d, d, bias=False)
44
+
45
+ def forward(self, x, mask=None):
46
+ B, N, D = x.shape
47
+ qkv = self.qkv(x).view(B, N, 3, self.h, self.dk)
48
+ q, k, v = qkv.unbind(2) # B, N, h, dk
49
+
50
+ # Project Q and K through U for tuneable rank
51
+ q = (q @ self.U) # B, N, h, r
52
+ k = (k @ self.U) # B, N, h, r
53
+
54
+ # Attention
55
+ q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
56
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(self.r)
57
+ if mask is not None:
58
+ att = att + mask
59
+ att = F.softmax(att, dim=-1)
60
+ out = (att @ v).transpose(1, 2).reshape(B, N, D)
61
+ return self.proj(out)
62
+
63
+ class Block(nn.Module):
64
+ def __init__(self, d, h, r):
65
+ super().__init__()
66
+ self.ln1 = nn.LayerNorm(d)
67
+ self.attn = TuneableAttention(d, h, r)
68
+ self.ln2 = nn.LayerNorm(d)
69
+ self.ff = nn.Sequential(
70
+ nn.Linear(d, 4 * d),
71
+ nn.GELU(),
72
+ nn.Linear(4 * d, d)
73
+ )
74
+
75
+ def forward(self, x, mask):
76
+ x = x + self.attn(self.ln1(x), mask)
77
+ x = x + self.ff(self.ln2(x))
78
+ return x
79
+
80
+ class StreamingTransformer(nn.Module):
81
+ def __init__(self, cfg):
82
+ super().__init__()
83
+ d, L, h, r, V = cfg["d"], cfg["layers"], cfg["heads"], cfg["rank"], cfg["vocab"]
84
+ self.emb = nn.Embedding(V, d)
85
+ self.blocks = nn.ModuleList([Block(d, h, r) for _ in range(L)])
86
+ self.ln = nn.LayerNorm(d)
87
+ self.head = nn.Linear(d, V, bias=False)
88
+ # Weight tying
89
+ self.head.weight = self.emb.weight
90
+
91
+ def forward(self, x):
92
+ B, N = x.shape
93
+ # Causal mask
94
+ mask = torch.triu(torch.ones(N, N, device=x.device), 1) * -1e9
95
+
96
+ h = self.emb(x)
97
+ for block in self.blocks:
98
+ h = block(h, mask)
99
+ return self.head(self.ln(h))
100
+
101
+ def count_params(self):
102
+ return sum(p.numel() for p in self.parameters())
103
+
104
+ # ─────────────────── Online Trainer ───────────────────
105
+ class WireSpeedTrainer:
106
+ def __init__(self, model, lr=LR):
107
+ self.model = model.to(DEVICE)
108
+ self.opt = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.9, 0.95))
109
+ self.ctx_size = CONFIG["ctx"]
110
+
111
+ # Rolling buffer for context
112
+ self.buffer = deque(maxlen=self.ctx_size + 1)
113
+
114
+ # Stats
115
+ self.tokens_seen = 0
116
+ self.total_loss = 0.0
117
+ self.updates = 0
118
+ self.start_time = time.time()
119
+
120
+ def ingest_token(self, token_id):
121
+ """Absorb a single token. Update weights when buffer fills."""
122
+ self.buffer.append(token_id)
123
+ self.tokens_seen += 1
124
+
125
+ # Update every N tokens when we have enough context
126
+ if len(self.buffer) >= UPDATE_EVERY + 1 and self.tokens_seen % UPDATE_EVERY == 0:
127
+ self._update()
128
+
129
+ # Print stats
130
+ if self.tokens_seen % PRINT_EVERY == 0:
131
+ self._print_stats()
132
+
133
+ def _update(self):
134
+ """Single gradient step on current buffer."""
135
+ # Convert buffer to tensor
136
+ tokens = list(self.buffer)
137
+ x = torch.tensor(tokens[:-1], device=DEVICE).unsqueeze(0) # input
138
+ y = torch.tensor(tokens[1:], device=DEVICE).unsqueeze(0) # target
139
+
140
+ # Forward
141
+ self.model.train()
142
+ logits = self.model(x)
143
+
144
+ # Loss on last UPDATE_EVERY positions only (most recent)
145
+ loss = F.cross_entropy(
146
+ logits[:, -UPDATE_EVERY:].reshape(-1, CONFIG["vocab"]),
147
+ y[:, -UPDATE_EVERY:].reshape(-1)
148
+ )
149
+
150
+ # Backward
151
+ self.opt.zero_grad()
152
+ loss.backward()
153
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
154
+ self.opt.step()
155
+
156
+ self.total_loss += loss.item()
157
+ self.updates += 1
158
+
159
+ def _print_stats(self):
160
+ elapsed = time.time() - self.start_time
161
+ tok_per_sec = self.tokens_seen / elapsed if elapsed > 0 else 0
162
+ avg_loss = self.total_loss / max(1, self.updates)
163
+
164
+ print(f"[{elapsed:.0f}s] {self.tokens_seen:,} tok | {tok_per_sec:.0f} tok/s | "
165
+ f"loss={avg_loss:.4f} | updates={self.updates}", flush=True)
166
+
167
+ # ─────────────────── Main ───────────────────
168
+ def main():
169
+ print(f"Wire-Speed Transformer", flush=True)
170
+ print(f"Config: {CONFIG}", flush=True)
171
+ print(f"Device: {DEVICE}", flush=True)
172
+
173
+ model = StreamingTransformer(CONFIG)
174
+ params = model.count_params()
175
+ print(f"Parameters: {params:,} ({params/1e6:.1f}M)", flush=True)
176
+
177
+ trainer = WireSpeedTrainer(model)
178
+
179
+ print(f"Listening for tokens on stdin...", flush=True)
180
+ print(f"Update every {UPDATE_EVERY} tokens, print every {PRINT_EVERY}", flush=True)
181
+
182
+ # Read token IDs from stdin (one per line from Rust feeder)
183
+ for line in sys.stdin:
184
+ try:
185
+ token_id = int(line.strip())
186
+ if 0 <= token_id < CONFIG["vocab"]:
187
+ trainer.ingest_token(token_id)
188
+ except ValueError:
189
+ continue # Skip malformed lines
190
+
191
+ print(f"Stream ended. Total tokens: {trainer.tokens_seen:,}", flush=True)
192
+
193
+ if __name__ == "__main__":
194
+ main()