Pclanglais commited on
Commit
de97637
·
verified ·
1 Parent(s): 2da630c

simplify model.py: drop unused configs / multi-scale n-gram path

Browse files
Files changed (1) hide show
  1. model.py +75 -210
model.py CHANGED
@@ -1,73 +1,47 @@
1
  """
2
- ByteHybrid v2: Byte-level document classifier with optional n-gram hash embeddings.
3
 
4
- Changes from v1:
5
- - Added ByteNgramEmbed: rolling hash of byte trigrams into fixed-size embedding table
6
- - New config "base_ngram" with ngram_buckets=4096, ngram_dim=64 (~262k extra params)
7
- - Backward compatible: existing configs work unchanged (ngram_buckets=0)
8
- """
 
 
9
 
10
- import math
 
 
11
  import torch
12
  import torch.nn as nn
13
  import torch.nn.functional as F
14
 
15
 
16
- # ── Byte N-gram Hash Embedding ───────────────────────────────────────────
17
-
18
-
19
  class ByteNgramEmbed(nn.Module):
20
- """Rolling hash of byte n-grams into fixed-size embedding table.
21
-
22
- Supports single n-gram size or multi-scale (e.g., trigrams + 5-grams).
23
- Uses polynomial hash. Collisions act as regularization.
24
  """
25
-
26
  def __init__(self, num_buckets=4096, embed_dim=64, n=3):
27
  super().__init__()
28
  self.n = n
29
  self.num_buckets = num_buckets
30
  self.embed = nn.Embedding(num_buckets, embed_dim)
31
-
32
- def _hash(self, byte_ids, n):
33
  B, T = byte_ids.shape
34
  clamped = byte_ids.clamp(max=255)
35
- padded = F.pad(clamped, (0, n - 1), value=0)
36
  h = torch.zeros(B, T, dtype=torch.long, device=byte_ids.device)
37
- for i in range(n):
38
- h = h * 257 + padded[:, i:i+T]
39
- h = h % self.num_buckets
40
- return h
41
-
42
- def forward(self, byte_ids):
43
- return self.embed(self._hash(byte_ids, self.n))
44
-
45
-
46
- class MultiScaleNgramEmbed(nn.Module):
47
- """Multi-scale n-gram hash embeddings (e.g., 3-gram + 5-gram).
48
-
49
- Each scale gets its own hash table and embedding. Outputs are summed.
50
- """
51
-
52
- def __init__(self, num_buckets=4096, embed_dim=64, scales=(3, 5)):
53
- super().__init__()
54
- self.scales = scales
55
- self.ngrams = nn.ModuleList([
56
- ByteNgramEmbed(num_buckets, embed_dim, n=n) for n in scales
57
- ])
58
-
59
- def forward(self, byte_ids):
60
- out = self.ngrams[0](byte_ids)
61
- for ng in self.ngrams[1:]:
62
- out = out + ng(byte_ids)
63
- return out
64
-
65
-
66
- # ── Causal Conv1d Block ──────────────────────────────────────────────────
67
 
68
 
69
  class ByteConvBlock(nn.Module):
70
- """Causal conv1d + gated FFN. Captures local byte patterns."""
71
 
72
  def __init__(self, d_model, kernel_size=15, expand=2):
73
  super().__init__()
@@ -75,160 +49,108 @@ class ByteConvBlock(nn.Module):
75
  self.pad = kernel_size - 1
76
  self.conv = nn.Conv1d(d_model, d_model, kernel_size, groups=d_model)
77
  self.norm2 = nn.LayerNorm(d_model)
78
- ffn_dim = d_model * expand
79
- self.ffn_gate = nn.Linear(d_model, ffn_dim, bias=False)
80
- self.ffn_up = nn.Linear(d_model, ffn_dim, bias=False)
81
- self.ffn_down = nn.Linear(ffn_dim, d_model, bias=False)
82
 
83
  def forward(self, x):
84
  residual = x
85
- x = self.norm1(x)
86
- x = x.transpose(1, 2)
87
  x = F.pad(x, (self.pad, 0))
88
- x = F.silu(self.conv(x))
89
- x = x.transpose(1, 2)
90
  x = residual + x
91
 
92
  residual = x
93
  x = self.norm2(x)
94
  x = self.ffn_down(F.silu(self.ffn_gate(x)) * self.ffn_up(x))
95
- x = residual + x
96
- return x
97
 
 
 
 
 
 
 
 
 
98
 
99
- # ── Attention Block ──────────────────────────────────────────────────────
 
 
 
 
100
 
101
 
102
  class ByteAttnBlock(nn.Module):
103
- """Standard bidirectional attention + SwiGLU FFN with RoPE."""
104
 
105
  def __init__(self, d_model, n_heads=4, expand=2):
106
  super().__init__()
107
  self.n_heads = n_heads
108
  self.head_dim = d_model // n_heads
109
-
110
  self.norm1 = nn.LayerNorm(d_model)
111
  self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
112
  self.out_proj = nn.Linear(d_model, d_model, bias=False)
113
-
114
  self.norm2 = nn.LayerNorm(d_model)
115
- ffn_dim = d_model * expand
116
- self.ffn_gate = nn.Linear(d_model, ffn_dim, bias=False)
117
- self.ffn_up = nn.Linear(d_model, ffn_dim, bias=False)
118
- self.ffn_down = nn.Linear(ffn_dim, d_model, bias=False)
119
 
120
  def forward(self, x):
121
  B, T, D = x.shape
122
  residual = x
123
-
124
- x = self.norm1(x)
125
- qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.head_dim)
126
- q, k, v = qkv.unbind(dim=2)
127
- q = q.transpose(1, 2)
128
- k = k.transpose(1, 2)
129
- v = v.transpose(1, 2)
130
-
131
- q, k = apply_rope(q, k, T, self.head_dim, x.device)
132
-
133
  attn = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
134
  attn = attn.softmax(dim=-1)
135
  out = (attn @ v).transpose(1, 2).contiguous().view(B, T, D)
136
- out = self.out_proj(out)
137
- x = residual + out
138
 
139
  residual = x
140
- x = self.norm2(x)
141
- x = self.ffn_down(F.silu(self.ffn_gate(x)) * self.ffn_up(x))
142
- x = residual + x
143
- return x
144
-
145
-
146
- # ── Rotary Position Embedding ────────────────────────────────────────────
147
-
148
-
149
- def precompute_freqs(dim, max_len=4096, theta=10000.0):
150
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
151
- t = torch.arange(max_len)
152
- freqs = torch.outer(t, freqs)
153
- return torch.cos(freqs), torch.sin(freqs)
154
-
155
-
156
- def apply_rope(q, k, seq_len, head_dim, device):
157
- cos, sin = precompute_freqs(head_dim, seq_len)
158
- cos = cos[:seq_len].to(device=device, dtype=q.dtype)
159
- sin = sin[:seq_len].to(device=device, dtype=q.dtype)
160
-
161
- def rotate(x):
162
- x1, x2 = x[..., : head_dim // 2], x[..., head_dim // 2 :]
163
- return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)
164
-
165
- return rotate(q), rotate(k)
166
-
167
-
168
- # ── Full Model ───────────────────────────────────────────────────────────
169
 
170
 
171
  class ByteHybrid(nn.Module):
172
- """Byte-level classifier with optional n-gram hash embeddings.
173
-
174
- Args:
175
- num_classes: number of output classes
176
- d_model: hidden dimension
177
- n_conv: number of conv1d blocks
178
- n_attn: number of attention blocks
179
- n_heads: attention heads
180
- max_len: maximum byte sequence length
181
- conv_kernel: conv1d kernel size
182
- ngram_buckets: hash table size for n-gram embeddings (0 = disabled)
183
- ngram_dim: embedding dimension for n-gram hashes
184
- """
185
 
186
  def __init__(
187
  self,
188
- num_classes=13,
189
  d_model=256,
190
  n_conv=3,
191
  n_attn=1,
192
  n_heads=4,
193
  ffn_expand=2,
194
- max_len=2048,
195
  conv_kernel=15,
196
  ngram_buckets=0,
197
  ngram_dim=64,
198
- ngram_scales=None,
199
  ):
200
  super().__init__()
201
  self.max_len = max_len
202
 
203
- # Byte embedding: 256 possible byte values + 1 padding
204
  self.embed = nn.Embedding(257, d_model, padding_idx=256)
205
 
206
- # Optional n-gram hash embedding
207
- # ngram_scales: tuple of n-gram sizes, e.g. (3,) or (3, 5)
208
  self.ngram_embed = None
209
  if ngram_buckets > 0:
210
- scales = ngram_scales if ngram_scales else (3,)
211
- if len(scales) == 1:
212
- self.ngram_embed = ByteNgramEmbed(ngram_buckets, ngram_dim, n=scales[0])
213
- else:
214
- self.ngram_embed = MultiScaleNgramEmbed(ngram_buckets, ngram_dim, scales=scales)
215
  self.ngram_proj = nn.Linear(ngram_dim, d_model, bias=False)
216
 
217
- # Conv blocks
218
- self.conv_layers = nn.ModuleList([
219
- ByteConvBlock(d_model, kernel_size=conv_kernel, expand=ffn_expand)
220
- for _ in range(n_conv)
221
- ])
222
-
223
- # Attention blocks
224
- self.attn_layers = nn.ModuleList([
225
- ByteAttnBlock(d_model, n_heads, ffn_expand)
226
- for _ in range(n_attn)
227
- ])
228
-
229
  self.final_norm = nn.LayerNorm(d_model)
230
-
231
- # Classification head
232
  self.head = nn.Sequential(
233
  nn.Linear(d_model, d_model),
234
  nn.GELU(),
@@ -237,82 +159,25 @@ class ByteHybrid(nn.Module):
237
  )
238
 
239
  def forward(self, byte_ids):
240
- """
241
- Args:
242
- byte_ids: (B, T) long tensor of byte values [0-255], padded with 256
243
- Returns:
244
- logits: (B, num_classes)
245
- """
246
  pad_mask = byte_ids != 256
247
-
248
  x = self.embed(byte_ids)
249
-
250
- # Add n-gram features if enabled
251
  if self.ngram_embed is not None:
252
- ng = self.ngram_embed(byte_ids)
253
- x = x + self.ngram_proj(ng)
254
-
255
  for layer in self.conv_layers:
256
  x = layer(x)
257
-
258
  for layer in self.attn_layers:
259
  x = layer(x)
260
-
261
  x = self.final_norm(x)
262
-
263
  mask = pad_mask.unsqueeze(-1).to(x.dtype)
264
  x = (x * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
265
-
266
  return self.head(x)
267
 
268
- @staticmethod
269
- def encode_text(text, max_len=2048):
270
- """Convert text string to byte tensor, padded to max_len."""
271
- raw = text.encode("utf-8", errors="replace")[:max_len]
272
- byte_ids = list(raw) + [256] * (max_len - len(byte_ids))
273
- return torch.tensor(byte_ids, dtype=torch.long)
274
-
275
-
276
- # ── Configurations ───────────────────────────────────────────────────────
277
 
 
 
278
  CONFIGS = {
279
- # ~2M params: 3 conv + 1 attn, d=256 (original)
280
- "base": dict(d_model=256, n_conv=3, n_attn=1, n_heads=4, conv_kernel=15),
281
- # ~2.3M params: base + trigram hash embeddings (4k buckets × 64 dim)
282
- "base_ngram": dict(d_model=256, n_conv=3, n_attn=1, n_heads=4, conv_kernel=15,
283
- ngram_buckets=4096, ngram_dim=64),
284
- # ~2.5M params: larger hash table
285
- "base_ngram_large": dict(d_model=256, n_conv=3, n_attn=1, n_heads=4, conv_kernel=15,
286
- ngram_buckets=8192, ngram_dim=64),
287
- # ~3.5M params: 3 conv + 2 attn, d=256
288
- "large": dict(d_model=256, n_conv=3, n_attn=2, n_heads=4, conv_kernel=15),
289
- # ~2M params: deeper conv, no attn
290
- "conv_only": dict(d_model=256, n_conv=5, n_attn=0, n_heads=4, conv_kernel=15),
291
- # ~2M params: wider kernel conv
292
- "wide_conv": dict(d_model=256, n_conv=3, n_attn=1, n_heads=4, conv_kernel=31),
293
- # Scaled-up configs
294
- "d384": dict(d_model=384, n_conv=3, n_attn=1, n_heads=4, conv_kernel=15),
295
- "d384_2attn": dict(d_model=384, n_conv=3, n_attn=2, n_heads=4, conv_kernel=15),
296
- "d512": dict(d_model=512, n_conv=3, n_attn=1, n_heads=8, conv_kernel=15),
297
- # 4-gram variant
298
- "base_4gram": dict(d_model=256, n_conv=3, n_attn=1, n_heads=4, conv_kernel=15,
299
- ngram_buckets=4096, ngram_dim=64, ngram_scales=(4,)),
300
- # Multi-scale: 3-gram + 5-gram (two hash tables, summed)
301
- "base_multiscale": dict(d_model=256, n_conv=3, n_attn=1, n_heads=4, conv_kernel=15,
302
- ngram_buckets=4096, ngram_dim=64, ngram_scales=(3, 5)),
303
- # Multi-scale: 3-gram + 4-gram + 5-gram
304
- "base_multiscale3": dict(d_model=256, n_conv=3, n_attn=1, n_heads=4, conv_kernel=15,
305
- ngram_buckets=4096, ngram_dim=64, ngram_scales=(3, 4, 5)),
306
  }
307
-
308
-
309
- def count_params(model):
310
- return sum(p.numel() for p in model.parameters())
311
-
312
-
313
- if __name__ == "__main__":
314
- for name, cfg in CONFIGS.items():
315
- model = ByteHybrid(num_classes=334, max_len=512, **cfg)
316
- byte_ids = torch.randint(0, 256, (4, 512))
317
- logits = model(byte_ids)
318
- print(f"{name:<20s} {count_params(model):>10,} params output={logits.shape}")
 
1
  """
2
+ ByteHybrid: byte-level language identification (CommonLingua v7.2.1).
3
 
4
+ Operates directly on raw UTF-8 bytes — no tokenizer required:
5
+
6
+ raw bytes byte-embed + trigram-hash-embed (summed)
7
+ 3 × depthwise Conv1D (k=15)
8
+ → 1 × bidirectional attention (RoPE, 4 heads)
9
+ → masked mean-pool
10
+ → classification head (334 logits)
11
 
12
+ The shipped checkpoint uses the `base_ngram` config: d_model=256, 4096 trigram
13
+ hash buckets × 64 dim, max_len=512 bytes. Total parameters ≈ 2.35 M.
14
+ """
15
  import torch
16
  import torch.nn as nn
17
  import torch.nn.functional as F
18
 
19
 
 
 
 
20
  class ByteNgramEmbed(nn.Module):
21
+ """Rolling polynomial hash of byte trigrams into a fixed-size table.
22
+
23
+ Hash collisions act as regularisation; the small table (4096 × 64)
24
+ keeps parameter count bounded under arbitrary input distributions.
25
  """
26
+
27
  def __init__(self, num_buckets=4096, embed_dim=64, n=3):
28
  super().__init__()
29
  self.n = n
30
  self.num_buckets = num_buckets
31
  self.embed = nn.Embedding(num_buckets, embed_dim)
32
+
33
+ def forward(self, byte_ids):
34
  B, T = byte_ids.shape
35
  clamped = byte_ids.clamp(max=255)
36
+ padded = F.pad(clamped, (0, self.n - 1), value=0)
37
  h = torch.zeros(B, T, dtype=torch.long, device=byte_ids.device)
38
+ for i in range(self.n):
39
+ h = h * 257 + padded[:, i:i + T]
40
+ return self.embed(h % self.num_buckets)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
 
43
  class ByteConvBlock(nn.Module):
44
+ """Causal depthwise Conv1D + SwiGLU FFN, with residual + layernorm."""
45
 
46
  def __init__(self, d_model, kernel_size=15, expand=2):
47
  super().__init__()
 
49
  self.pad = kernel_size - 1
50
  self.conv = nn.Conv1d(d_model, d_model, kernel_size, groups=d_model)
51
  self.norm2 = nn.LayerNorm(d_model)
52
+ ffn = d_model * expand
53
+ self.ffn_gate = nn.Linear(d_model, ffn, bias=False)
54
+ self.ffn_up = nn.Linear(d_model, ffn, bias=False)
55
+ self.ffn_down = nn.Linear(ffn, d_model, bias=False)
56
 
57
  def forward(self, x):
58
  residual = x
59
+ x = self.norm1(x).transpose(1, 2)
 
60
  x = F.pad(x, (self.pad, 0))
61
+ x = F.silu(self.conv(x)).transpose(1, 2)
 
62
  x = residual + x
63
 
64
  residual = x
65
  x = self.norm2(x)
66
  x = self.ffn_down(F.silu(self.ffn_gate(x)) * self.ffn_up(x))
67
+ return residual + x
68
+
69
 
70
+ def _rope(q, k):
71
+ head_dim = q.shape[-1]
72
+ seq_len = q.shape[-2]
73
+ freqs = 1.0 / (10000.0 ** (torch.arange(0, head_dim, 2, device=q.device).float() / head_dim))
74
+ t = torch.arange(seq_len, device=q.device)
75
+ a = torch.outer(t, freqs)
76
+ cos = a.cos().to(q.dtype)
77
+ sin = a.sin().to(q.dtype)
78
 
79
+ def rot(x):
80
+ x1, x2 = x[..., : head_dim // 2], x[..., head_dim // 2:]
81
+ return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)
82
+
83
+ return rot(q), rot(k)
84
 
85
 
86
  class ByteAttnBlock(nn.Module):
87
+ """Bidirectional self-attention with RoPE + SwiGLU FFN."""
88
 
89
  def __init__(self, d_model, n_heads=4, expand=2):
90
  super().__init__()
91
  self.n_heads = n_heads
92
  self.head_dim = d_model // n_heads
 
93
  self.norm1 = nn.LayerNorm(d_model)
94
  self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
95
  self.out_proj = nn.Linear(d_model, d_model, bias=False)
 
96
  self.norm2 = nn.LayerNorm(d_model)
97
+ ffn = d_model * expand
98
+ self.ffn_gate = nn.Linear(d_model, ffn, bias=False)
99
+ self.ffn_up = nn.Linear(d_model, ffn, bias=False)
100
+ self.ffn_down = nn.Linear(ffn, d_model, bias=False)
101
 
102
  def forward(self, x):
103
  B, T, D = x.shape
104
  residual = x
105
+ h = self.norm1(x)
106
+ qkv = self.qkv(h).reshape(B, T, 3, self.n_heads, self.head_dim)
107
+ q, k, v = (t.transpose(1, 2) for t in qkv.unbind(dim=2))
108
+ q, k = _rope(q, k)
 
 
 
 
 
 
109
  attn = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
110
  attn = attn.softmax(dim=-1)
111
  out = (attn @ v).transpose(1, 2).contiguous().view(B, T, D)
112
+ x = residual + self.out_proj(out)
 
113
 
114
  residual = x
115
+ h = self.norm2(x)
116
+ h = self.ffn_down(F.silu(self.ffn_gate(h)) * self.ffn_up(h))
117
+ return residual + h
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
 
120
  class ByteHybrid(nn.Module):
121
+ """Byte-level classifier with optional trigram-hash augmentation."""
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
  def __init__(
124
  self,
125
+ num_classes,
126
  d_model=256,
127
  n_conv=3,
128
  n_attn=1,
129
  n_heads=4,
130
  ffn_expand=2,
131
+ max_len=512,
132
  conv_kernel=15,
133
  ngram_buckets=0,
134
  ngram_dim=64,
 
135
  ):
136
  super().__init__()
137
  self.max_len = max_len
138
 
139
+ # Byte values 0–255 plus index 256 = padding token
140
  self.embed = nn.Embedding(257, d_model, padding_idx=256)
141
 
 
 
142
  self.ngram_embed = None
143
  if ngram_buckets > 0:
144
+ self.ngram_embed = ByteNgramEmbed(ngram_buckets, ngram_dim, n=3)
 
 
 
 
145
  self.ngram_proj = nn.Linear(ngram_dim, d_model, bias=False)
146
 
147
+ self.conv_layers = nn.ModuleList(
148
+ [ByteConvBlock(d_model, conv_kernel, ffn_expand) for _ in range(n_conv)]
149
+ )
150
+ self.attn_layers = nn.ModuleList(
151
+ [ByteAttnBlock(d_model, n_heads, ffn_expand) for _ in range(n_attn)]
152
+ )
 
 
 
 
 
 
153
  self.final_norm = nn.LayerNorm(d_model)
 
 
154
  self.head = nn.Sequential(
155
  nn.Linear(d_model, d_model),
156
  nn.GELU(),
 
159
  )
160
 
161
  def forward(self, byte_ids):
 
 
 
 
 
 
162
  pad_mask = byte_ids != 256
 
163
  x = self.embed(byte_ids)
 
 
164
  if self.ngram_embed is not None:
165
+ x = x + self.ngram_proj(self.ngram_embed(byte_ids))
 
 
166
  for layer in self.conv_layers:
167
  x = layer(x)
 
168
  for layer in self.attn_layers:
169
  x = layer(x)
 
170
  x = self.final_norm(x)
 
171
  mask = pad_mask.unsqueeze(-1).to(x.dtype)
172
  x = (x * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
 
173
  return self.head(x)
174
 
 
 
 
 
 
 
 
 
 
175
 
176
+ # Single shipped configuration. The checkpoint encodes which config it was
177
+ # trained with under the "config" key.
178
  CONFIGS = {
179
+ "base_ngram": dict(
180
+ d_model=256, n_conv=3, n_attn=1, n_heads=4, conv_kernel=15,
181
+ ngram_buckets=4096, ngram_dim=64,
182
+ ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  }