Pclanglais commited on
Commit
d988f49
·
verified ·
1 Parent(s): d588e39

v7.2.1-ngram release: fp32 + bf16 weights, model.py, predict.py, model card

Browse files
Files changed (8) hide show
  1. README.md +160 -0
  2. __pycache__/model.cpython-312.pyc +0 -0
  3. config.json +7 -0
  4. lang2idx.json +336 -0
  5. model.bf16.pt +3 -0
  6. model.pt +3 -0
  7. model.py +318 -0
  8. predict.py +96 -0
README.md ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - multilingual
5
+ tags:
6
+ - language-identification
7
+ - lid
8
+ - byte-level
9
+ - corpus-curation
10
+ - african-languages
11
+ library_name: pytorch
12
+ pipeline_tag: text-classification
13
+ metrics:
14
+ - f1
15
+ - accuracy
16
+ ---
17
+
18
+ # CommonLingua
19
+
20
+ **Byte-level language identification for 334 languages — 2.35 M parameters, 9 MB on disk, runs on CPU.**
21
+
22
+ CommonLingua sorts raw web, PDF, and digitised text into 334 ISO 639-3 language buckets so it can feed downstream training pipelines. It was built and trained at [PleIAs](https://pleias.fr) for curating [Common Corpus](https://huggingface.co/datasets/PleIAs/common_corpus), with a particular focus on the long tail — **61 African languages** are supported, including languages with no fastText / OpenLID coverage.
23
+
24
+ | | |
25
+ |---|---|
26
+ | **Languages** | 334 (61 African) |
27
+ | **Parameters** | 2,347,854 |
28
+ | **Disk (fp32 / bf16)** | 9.4 MB / 4.7 MB |
29
+ | **Max input** | 512 bytes (~paragraph) |
30
+ | **CommonLID macro F1** | **0.7879** |
31
+ | **CommonLID strict acc**| 77.63% |
32
+ | **License** | Apache-2.0 |
33
+
34
+ ## Quick start
35
+
36
+ ```bash
37
+ pip install "git+https://github.com/PleIAs/bytehybrid-lid#egg=commonlingua[hub]"
38
+ ```
39
+
40
+ ```python
41
+ from commonlingua import LID
42
+
43
+ lid = LID.from_pretrained("PleIAs/CommonLingua") # auto-downloads
44
+ # Use the bf16 build for 2× speedup on GPU at no measurable quality cost:
45
+ # lid = LID.from_pretrained("PleIAs/CommonLingua", dtype="bf16")
46
+
47
+ text = (
48
+ "Wikipédia est une encyclopédie universelle, multilingue, créée par Jimmy "
49
+ "Wales et Larry Sanger le 15 janvier 2001 et fonctionnant selon le principe "
50
+ "du wiki."
51
+ )
52
+ r = lid.predict(text)
53
+ print(r.lang, r.confidence) # fra 0.99
54
+ ```
55
+
56
+ The intended workload is **paragraph-level corpus curation**. For batch annotation of large parquet shards, see `predict_parquet` in the package README.
57
+
58
+ ## Architecture
59
+
60
+ ```
61
+ raw bytes → [trigram hash embed (4096 × 64)]
62
+ ↓ ↘
63
+ + ────────→ 3× depthwise Conv1D (k=15) → 1× attention (RoPE, 4 heads)
64
+
65
+ masked mean-pool → 334 logits
66
+ ```
67
+
68
+ - **No tokenizer.** The model operates directly on raw UTF-8 bytes, padded to 512. This makes it inherently script-agnostic — Latin, Arabic, Ethiopic, N'Ko, Tifinagh, Devanagari, CJK, all handled by the same byte stream.
69
+ - **Trigram hash embedding** (added in v7.2.1): a polynomial rolling hash of byte 3-grams indexes a 4096-bucket embedding table. Hash collisions act as regularisation. Adds ~262 k parameters and ≤ 2% inference overhead, but improves macro F1 by +1.2 points and African F1 by +1.5 points over the no-n-gram baseline.
70
+ - **Causal Conv1D × 3** captures local byte patterns (script ranges, common digraphs, morpheme boundaries).
71
+ - **Bidirectional attention × 1 with RoPE** captures global structure across the paragraph.
72
+
73
+ ## Evaluation
74
+
75
+ Evaluated on **CommonLID** (Ortiz Suárez et al. 2026): 376 k held-out paragraphs, 200+ languages. All baselines re-evaluated through the same pipeline (`iso639-lang` normalisation, equivalence-class collapsing applied identically) for an apples-to-apples comparison.
76
+
77
+ ### Headline
78
+
79
+ | Model | Params | Labels | Strict acc | Equiv acc | Macro F1 |
80
+ |----------------------|-------:|-------:|----------:|----------:|-----------:|
81
+ | OpenLID v2 | ~600 M | 200 | 55.77 % | 70.19 % | 0.6390 |
82
+ | fastText-218 (NLLB) | ~600 M | 218 | 59.53 % | 71.64 % | 0.6590 |
83
+ | GlotLID v3 | ~600 M | 2 102 | 57.69 % | 71.26 % | 0.6729 |
84
+ | **CommonLingua v7.2.1** | **2.35 M** | **334** | **77.63 %** | **82.92 %** | **0.7879** |
85
+
86
+ CommonLingua reaches **+11.5 macro F1** over the next best baseline with **~250× fewer parameters**. The full per-language F1 breakdown ships in `eval_per_language.json`.
87
+
88
+ ### African subset
89
+
90
+ CommonLID's African subset (17 languages with ≥ 100 gold samples — the regime where OpenLID/GlotLID/fastText reportedly underperform):
91
+
92
+ | Model | African macro F1 |
93
+ |---|---:|
94
+ | OpenLID v2 | 0.5xx |
95
+ | GlotLID v3 | 0.725 |
96
+ | **CommonLingua v7.2.1** | **0.7222** |
97
+
98
+ Notably, CommonLingua reaches **F1 = 0.975** on Amharic — a language Lingua does not support.
99
+
100
+ ### fp32 vs bf16
101
+
102
+ The bf16 build is half the disk size and ~2.4× faster on H100, with **no measurable quality drop**: 0 of the 72 evaluated languages drift by more than 0.01 F1.
103
+
104
+ | Build | Disk | Strict acc | Equiv acc | Macro F1 | African F1 | Lingua F1 |
105
+ |---|---:|---:|---:|---:|---:|---:|
106
+ | **fp32** (default) | 9.4 MB | 0.7763 | 0.8292 | **0.7879** | 0.7222 | 0.8806 |
107
+ | **bf16** | 4.7 MB | 0.7763 | 0.8292 | **0.7879** | 0.7221 | 0.8804 |
108
+ | Δ | −50 % | 0 | 0 | 0 | −0.0002 | −0.0003 |
109
+
110
+ ### Throughput
111
+
112
+ Texts/sec (one paragraph = one text, ≤ 512 bytes input, padded). Real CommonLingua weights and the production code path:
113
+
114
+ | Device | Setting | fp32 | bf16 | bf16 vs fp32 |
115
+ |---|---|---:|---:|---:|
116
+ | H100 80GB (bs=4096) | best | 10,962 | **26,236** | 2.4× |
117
+ | H100 80GB (bs=1024) | | 10,892 | 26,130 | 2.4× |
118
+ | H100 80GB (bs=256) | | 10,677 | 25,241 | 2.4× |
119
+ | H100 80GB (bs=64) | low-latency| 10,025 | 22,625 | 2.3× |
120
+ | Sapphire Rapids CPU (8 threads) | bs=32 | _PENDING_ | _PENDING_ | _PENDING_ |
121
+ | Sapphire Rapids CPU (1 thread) | bs=32 | _PENDING_ | _PENDING_ | _PENDING_ |
122
+
123
+ The press release that previously circulated cited "20 t/s on CPU, 3 000 t/s on GPU" — the actual GPU figure is **~9× higher** in fp32 and **~22× higher** in bf16. The bf16 build is recommended whenever the host supports it (essentially: anything Ampere or newer).
124
+
125
+ ## Training data
126
+
127
+ Trained on **2,482,568 paragraphs across 334 languages**, drawn entirely from open-licensed and public-domain sources. Wikipedia provides the bulk (~93 %); the long tail is filled by Pralekha (Indic), VOA Africa, Cultural Heritage, OpenAlex (Indo-Malay journal data + African academic), Common Corpus adversarial pulls, Perseus / OpenPecha / eBible / Sefaria / Ben-Yehuda / Krike-Krake (ancient and minority-script corpora).
128
+
129
+ Per-source contributions, license attribution, and full schema are documented in [PleIAs/CommonLingua-Train](https://huggingface.co/datasets/PleIAs/CommonLingua-Train).
130
+
131
+ ## Known limitations
132
+
133
+ 1. **Arabic dialect cluster** — Modern Standard (`arb`) is robust (F1 ≈ 0.95), but Moroccan (`ary`, F1 ≈ 0.47) and Egyptian (`arz`, F1 ≈ 0.25) Arabic are structurally hard: the dialects share large stretches of vocabulary with MSA and with each other. Not a data-volume problem; needs targeted corpus.
134
+ 2. **Indonesian / Malay (`ind` / `msa`)** — ~48 % msa error rate. Adding 20 k journal-provenance rows for each gave only marginal improvement; this pair will likely need supervised disambiguation features beyond byte-level signal.
135
+ 3. **Estonian (`est`) attractor** — accumulates ~750 false positives from unrelated languages. Model quirk to investigate; in practice a confidence threshold of 0.7 mostly removes spurious `est` predictions.
136
+ 4. **Lingala (`lin`)** — CommonLID's gold has the labels for `lin` reversed with Tiv/Yoruba (paper acknowledged a related labelling issue). Our F1 = 0 on this class is a benchmark artefact, not a model failure. Real-world `lin` predictions are correct on internal held-out data.
137
+ 5. **Short text (<50 chars)** — confidence drops sharply. The model is **not** intended for short-query LID; use CLD3 or a query-tuned model for that regime.
138
+
139
+ ## Files
140
+
141
+ | File | Description |
142
+ |---|---|
143
+ | `model.pt` | fp32 checkpoint (9.4 MB) |
144
+ | `model.bf16.pt` | bf16 checkpoint (4.7 MB) |
145
+ | `model.py` | ByteHybrid v2 architecture |
146
+ | `predict.py` | Standalone CLI (no `commonlingua` package required) |
147
+ | `config.json` | model config |
148
+ | `lang2idx.json` | 334-class label map |
149
+ | `eval_per_language.json` | full per-language F1 on CommonLID |
150
+
151
+ ## Citation
152
+
153
+ ```bibtex
154
+ @misc{commonlingua,
155
+ author = {{PleIAs}},
156
+ title = {CommonLingua: Byte-level Language Identification for 334 Languages},
157
+ year = {2026},
158
+ url = {https://huggingface.co/PleIAs/CommonLingua}
159
+ }
160
+ ```
__pycache__/model.cpython-312.pyc ADDED
Binary file (15.8 kB). View file
 
config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "config": "base_ngram",
3
+ "num_classes": 334,
4
+ "max_len": 512,
5
+ "epoch": 3,
6
+ "val_acc": 0.9615777881207785
7
+ }
lang2idx.json ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "abk": 0,
3
+ "ace": 1,
4
+ "ach": 2,
5
+ "ady": 3,
6
+ "afr": 4,
7
+ "aka": 5,
8
+ "akk": 6,
9
+ "alt": 7,
10
+ "amh": 8,
11
+ "ami": 9,
12
+ "ang": 10,
13
+ "anp": 11,
14
+ "arb": 12,
15
+ "arc": 13,
16
+ "arg": 14,
17
+ "arq": 15,
18
+ "ary": 16,
19
+ "arz": 17,
20
+ "asm": 18,
21
+ "ast": 19,
22
+ "atj": 20,
23
+ "ava": 21,
24
+ "avk": 22,
25
+ "awa": 23,
26
+ "aym": 24,
27
+ "azb": 25,
28
+ "azj": 26,
29
+ "bak": 27,
30
+ "bam": 28,
31
+ "ban": 29,
32
+ "bar": 30,
33
+ "bbc": 31,
34
+ "bel": 32,
35
+ "ben": 33,
36
+ "bho": 34,
37
+ "bik": 35,
38
+ "bis": 36,
39
+ "bjn": 37,
40
+ "blk": 38,
41
+ "bod": 39,
42
+ "bos": 40,
43
+ "bpy": 41,
44
+ "bre": 42,
45
+ "bug": 43,
46
+ "bul": 44,
47
+ "bxr": 45,
48
+ "cat": 46,
49
+ "cbk": 47,
50
+ "cdo": 48,
51
+ "ceb": 49,
52
+ "ces": 50,
53
+ "che": 51,
54
+ "chr": 52,
55
+ "chu": 53,
56
+ "chv": 54,
57
+ "chy": 55,
58
+ "ckb": 56,
59
+ "cor": 57,
60
+ "cos": 58,
61
+ "crh": 59,
62
+ "csb": 60,
63
+ "cym": 61,
64
+ "dag": 62,
65
+ "dan": 63,
66
+ "deu": 64,
67
+ "din": 65,
68
+ "diq": 66,
69
+ "div": 67,
70
+ "dsb": 68,
71
+ "dty": 69,
72
+ "dzo": 70,
73
+ "egx": 71,
74
+ "egy": 72,
75
+ "ell": 73,
76
+ "eml": 74,
77
+ "eng": 75,
78
+ "epo": 76,
79
+ "est": 77,
80
+ "ett": 78,
81
+ "eus": 79,
82
+ "ewe": 80,
83
+ "ext": 81,
84
+ "fao": 82,
85
+ "fas": 83,
86
+ "fij": 84,
87
+ "fin": 85,
88
+ "fon": 86,
89
+ "fra": 87,
90
+ "fro": 88,
91
+ "frp": 89,
92
+ "frr": 90,
93
+ "fry": 91,
94
+ "ful": 92,
95
+ "fur": 93,
96
+ "gag": 94,
97
+ "gan": 95,
98
+ "gaz": 96,
99
+ "gcr": 97,
100
+ "gez": 98,
101
+ "gla": 99,
102
+ "gle": 100,
103
+ "glg": 101,
104
+ "glk": 102,
105
+ "glv": 103,
106
+ "gom": 104,
107
+ "gor": 105,
108
+ "got": 106,
109
+ "gpe": 107,
110
+ "grc": 108,
111
+ "gsw": 109,
112
+ "guc": 110,
113
+ "gug": 111,
114
+ "guj": 112,
115
+ "gur": 113,
116
+ "guw": 114,
117
+ "hak": 115,
118
+ "hat": 116,
119
+ "hau": 117,
120
+ "haw": 118,
121
+ "hbo": 119,
122
+ "hbs": 120,
123
+ "heb": 121,
124
+ "hif": 122,
125
+ "hin": 123,
126
+ "hrv": 124,
127
+ "hsb": 125,
128
+ "hun": 126,
129
+ "hye": 127,
130
+ "hyw": 128,
131
+ "ibo": 129,
132
+ "ido": 130,
133
+ "iku": 131,
134
+ "ile": 132,
135
+ "ilo": 133,
136
+ "ina": 134,
137
+ "ind": 135,
138
+ "inh": 136,
139
+ "ipk": 137,
140
+ "isl": 138,
141
+ "ita": 139,
142
+ "jam": 140,
143
+ "jav": 141,
144
+ "jbo": 142,
145
+ "jpn": 143,
146
+ "kaa": 144,
147
+ "kab": 145,
148
+ "kal": 146,
149
+ "kan": 147,
150
+ "kas": 148,
151
+ "kat": 149,
152
+ "kaz": 150,
153
+ "kbd": 151,
154
+ "kbp": 152,
155
+ "kcg": 153,
156
+ "khk": 154,
157
+ "khm": 155,
158
+ "kik": 156,
159
+ "kin": 157,
160
+ "kir": 158,
161
+ "kmr": 159,
162
+ "koi": 160,
163
+ "kom": 161,
164
+ "kon": 162,
165
+ "kor": 163,
166
+ "krc": 164,
167
+ "ksh": 165,
168
+ "lad": 166,
169
+ "lao": 167,
170
+ "lat": 168,
171
+ "latex": 169,
172
+ "lav": 170,
173
+ "lbe": 171,
174
+ "lez": 172,
175
+ "lfn": 173,
176
+ "lij": 174,
177
+ "lim": 175,
178
+ "lin": 176,
179
+ "lit": 177,
180
+ "lld": 178,
181
+ "lmo": 179,
182
+ "ltg": 180,
183
+ "ltz": 181,
184
+ "lug": 182,
185
+ "luo": 183,
186
+ "lzh": 184,
187
+ "mad": 185,
188
+ "mai": 186,
189
+ "mal": 187,
190
+ "mar": 188,
191
+ "mdf": 189,
192
+ "mhr": 190,
193
+ "min": 191,
194
+ "mkd": 192,
195
+ "mlg": 193,
196
+ "mlt": 194,
197
+ "mni": 195,
198
+ "mnw": 196,
199
+ "mri": 197,
200
+ "mrj": 198,
201
+ "msa": 199,
202
+ "mwl": 200,
203
+ "mya": 201,
204
+ "myv": 202,
205
+ "mzn": 203,
206
+ "nah": 204,
207
+ "nan": 205,
208
+ "nap": 206,
209
+ "nav": 207,
210
+ "nds": 208,
211
+ "nep": 209,
212
+ "new": 210,
213
+ "nia": 211,
214
+ "nld": 212,
215
+ "nno": 213,
216
+ "nor": 214,
217
+ "nov": 215,
218
+ "nqo": 216,
219
+ "nrf": 217,
220
+ "nso": 218,
221
+ "nya": 219,
222
+ "nyn": 220,
223
+ "oci": 221,
224
+ "olo": 222,
225
+ "orm": 223,
226
+ "ory": 224,
227
+ "oss": 225,
228
+ "pag": 226,
229
+ "pam": 227,
230
+ "pan": 228,
231
+ "pap": 229,
232
+ "pcd": 230,
233
+ "pcm": 231,
234
+ "pdc": 232,
235
+ "peo": 233,
236
+ "pfl": 234,
237
+ "pms": 235,
238
+ "pnb": 236,
239
+ "pnt": 237,
240
+ "pol": 238,
241
+ "por": 239,
242
+ "pus": 240,
243
+ "pwn": 241,
244
+ "quy": 242,
245
+ "rcf": 243,
246
+ "rmy": 244,
247
+ "roh": 245,
248
+ "ron": 246,
249
+ "rue": 247,
250
+ "run": 248,
251
+ "rup": 249,
252
+ "rus": 250,
253
+ "sah": 251,
254
+ "san": 252,
255
+ "sat": 253,
256
+ "scn": 254,
257
+ "sgs": 255,
258
+ "shi": 256,
259
+ "shn": 257,
260
+ "sin": 258,
261
+ "skr": 259,
262
+ "slk": 260,
263
+ "slv": 261,
264
+ "sme": 262,
265
+ "smn": 263,
266
+ "smo": 264,
267
+ "sna": 265,
268
+ "snd": 266,
269
+ "som": 267,
270
+ "sot": 268,
271
+ "spa": 269,
272
+ "sqi": 270,
273
+ "srd": 271,
274
+ "srn": 272,
275
+ "srp": 273,
276
+ "ssw": 274,
277
+ "stq": 275,
278
+ "sun": 276,
279
+ "sux": 277,
280
+ "swe": 278,
281
+ "swh": 279,
282
+ "szl": 280,
283
+ "szy": 281,
284
+ "tah": 282,
285
+ "tam": 283,
286
+ "tat": 284,
287
+ "tay": 285,
288
+ "tcy": 286,
289
+ "tel": 287,
290
+ "tet": 288,
291
+ "tgk": 289,
292
+ "tgl": 290,
293
+ "tha": 291,
294
+ "tir": 292,
295
+ "tly": 293,
296
+ "ton": 294,
297
+ "tpi": 295,
298
+ "trv": 296,
299
+ "tsn": 297,
300
+ "tso": 298,
301
+ "tuk": 299,
302
+ "tum": 300,
303
+ "tur": 301,
304
+ "txb": 302,
305
+ "tyv": 303,
306
+ "udm": 304,
307
+ "uig": 305,
308
+ "ukr": 306,
309
+ "urd": 307,
310
+ "uzb": 308,
311
+ "vec": 309,
312
+ "ven": 310,
313
+ "vep": 311,
314
+ "vie": 312,
315
+ "vls": 313,
316
+ "vol": 314,
317
+ "vro": 315,
318
+ "war": 316,
319
+ "wln": 317,
320
+ "wol": 318,
321
+ "wuu": 319,
322
+ "xal": 320,
323
+ "xcl": 321,
324
+ "xho": 322,
325
+ "xmf": 323,
326
+ "xog": 324,
327
+ "xto": 325,
328
+ "ydd": 326,
329
+ "yor": 327,
330
+ "yue": 328,
331
+ "zea": 329,
332
+ "zgh": 330,
333
+ "zha": 331,
334
+ "zho": 332,
335
+ "zul": 333
336
+ }
model.bf16.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a730271a5646cd7f547b4c631ddba181eaa434d113a7d08751625f39a0962bb
3
+ size 4715426
model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:63a2232378faefbde97c8269490793f0154109380488b4d6ae21f3839565f2f0
3
+ size 9409133
model.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ByteHybrid v2: Byte-level document classifier with optional n-gram hash embeddings.
3
+
4
+ Changes from v1:
5
+ - Added ByteNgramEmbed: rolling hash of byte trigrams into fixed-size embedding table
6
+ - New config "base_ngram" with ngram_buckets=4096, ngram_dim=64 (~262k extra params)
7
+ - Backward compatible: existing configs work unchanged (ngram_buckets=0)
8
+ """
9
+
10
+ import math
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+
16
+ # ── Byte N-gram Hash Embedding ───────────────────────────────────────────
17
+
18
+
19
+ class ByteNgramEmbed(nn.Module):
20
+ """Rolling hash of byte n-grams into fixed-size embedding table.
21
+
22
+ Supports single n-gram size or multi-scale (e.g., trigrams + 5-grams).
23
+ Uses polynomial hash. Collisions act as regularization.
24
+ """
25
+
26
+ def __init__(self, num_buckets=4096, embed_dim=64, n=3):
27
+ super().__init__()
28
+ self.n = n
29
+ self.num_buckets = num_buckets
30
+ self.embed = nn.Embedding(num_buckets, embed_dim)
31
+
32
+ def _hash(self, byte_ids, n):
33
+ B, T = byte_ids.shape
34
+ clamped = byte_ids.clamp(max=255)
35
+ padded = F.pad(clamped, (0, n - 1), value=0)
36
+ h = torch.zeros(B, T, dtype=torch.long, device=byte_ids.device)
37
+ for i in range(n):
38
+ h = h * 257 + padded[:, i:i+T]
39
+ h = h % self.num_buckets
40
+ return h
41
+
42
+ def forward(self, byte_ids):
43
+ return self.embed(self._hash(byte_ids, self.n))
44
+
45
+
46
+ class MultiScaleNgramEmbed(nn.Module):
47
+ """Multi-scale n-gram hash embeddings (e.g., 3-gram + 5-gram).
48
+
49
+ Each scale gets its own hash table and embedding. Outputs are summed.
50
+ """
51
+
52
+ def __init__(self, num_buckets=4096, embed_dim=64, scales=(3, 5)):
53
+ super().__init__()
54
+ self.scales = scales
55
+ self.ngrams = nn.ModuleList([
56
+ ByteNgramEmbed(num_buckets, embed_dim, n=n) for n in scales
57
+ ])
58
+
59
+ def forward(self, byte_ids):
60
+ out = self.ngrams[0](byte_ids)
61
+ for ng in self.ngrams[1:]:
62
+ out = out + ng(byte_ids)
63
+ return out
64
+
65
+
66
+ # ── Causal Conv1d Block ──────────────────────────────────────────────────
67
+
68
+
69
+ class ByteConvBlock(nn.Module):
70
+ """Causal conv1d + gated FFN. Captures local byte patterns."""
71
+
72
+ def __init__(self, d_model, kernel_size=15, expand=2):
73
+ super().__init__()
74
+ self.norm1 = nn.LayerNorm(d_model)
75
+ self.pad = kernel_size - 1
76
+ self.conv = nn.Conv1d(d_model, d_model, kernel_size, groups=d_model)
77
+ self.norm2 = nn.LayerNorm(d_model)
78
+ ffn_dim = d_model * expand
79
+ self.ffn_gate = nn.Linear(d_model, ffn_dim, bias=False)
80
+ self.ffn_up = nn.Linear(d_model, ffn_dim, bias=False)
81
+ self.ffn_down = nn.Linear(ffn_dim, d_model, bias=False)
82
+
83
+ def forward(self, x):
84
+ residual = x
85
+ x = self.norm1(x)
86
+ x = x.transpose(1, 2)
87
+ x = F.pad(x, (self.pad, 0))
88
+ x = F.silu(self.conv(x))
89
+ x = x.transpose(1, 2)
90
+ x = residual + x
91
+
92
+ residual = x
93
+ x = self.norm2(x)
94
+ x = self.ffn_down(F.silu(self.ffn_gate(x)) * self.ffn_up(x))
95
+ x = residual + x
96
+ return x
97
+
98
+
99
+ # ── Attention Block ──────────────────────────────────────────────────────
100
+
101
+
102
+ class ByteAttnBlock(nn.Module):
103
+ """Standard bidirectional attention + SwiGLU FFN with RoPE."""
104
+
105
+ def __init__(self, d_model, n_heads=4, expand=2):
106
+ super().__init__()
107
+ self.n_heads = n_heads
108
+ self.head_dim = d_model // n_heads
109
+
110
+ self.norm1 = nn.LayerNorm(d_model)
111
+ self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
112
+ self.out_proj = nn.Linear(d_model, d_model, bias=False)
113
+
114
+ self.norm2 = nn.LayerNorm(d_model)
115
+ ffn_dim = d_model * expand
116
+ self.ffn_gate = nn.Linear(d_model, ffn_dim, bias=False)
117
+ self.ffn_up = nn.Linear(d_model, ffn_dim, bias=False)
118
+ self.ffn_down = nn.Linear(ffn_dim, d_model, bias=False)
119
+
120
+ def forward(self, x):
121
+ B, T, D = x.shape
122
+ residual = x
123
+
124
+ x = self.norm1(x)
125
+ qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.head_dim)
126
+ q, k, v = qkv.unbind(dim=2)
127
+ q = q.transpose(1, 2)
128
+ k = k.transpose(1, 2)
129
+ v = v.transpose(1, 2)
130
+
131
+ q, k = apply_rope(q, k, T, self.head_dim, x.device)
132
+
133
+ attn = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
134
+ attn = attn.softmax(dim=-1)
135
+ out = (attn @ v).transpose(1, 2).contiguous().view(B, T, D)
136
+ out = self.out_proj(out)
137
+ x = residual + out
138
+
139
+ residual = x
140
+ x = self.norm2(x)
141
+ x = self.ffn_down(F.silu(self.ffn_gate(x)) * self.ffn_up(x))
142
+ x = residual + x
143
+ return x
144
+
145
+
146
+ # ── Rotary Position Embedding ────────────────────────────────────────────
147
+
148
+
149
+ def precompute_freqs(dim, max_len=4096, theta=10000.0):
150
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
151
+ t = torch.arange(max_len)
152
+ freqs = torch.outer(t, freqs)
153
+ return torch.cos(freqs), torch.sin(freqs)
154
+
155
+
156
+ def apply_rope(q, k, seq_len, head_dim, device):
157
+ cos, sin = precompute_freqs(head_dim, seq_len)
158
+ cos = cos[:seq_len].to(device=device, dtype=q.dtype)
159
+ sin = sin[:seq_len].to(device=device, dtype=q.dtype)
160
+
161
+ def rotate(x):
162
+ x1, x2 = x[..., : head_dim // 2], x[..., head_dim // 2 :]
163
+ return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)
164
+
165
+ return rotate(q), rotate(k)
166
+
167
+
168
+ # ── Full Model ───────────────────────────────────────────────────────────
169
+
170
+
171
+ class ByteHybrid(nn.Module):
172
+ """Byte-level classifier with optional n-gram hash embeddings.
173
+
174
+ Args:
175
+ num_classes: number of output classes
176
+ d_model: hidden dimension
177
+ n_conv: number of conv1d blocks
178
+ n_attn: number of attention blocks
179
+ n_heads: attention heads
180
+ max_len: maximum byte sequence length
181
+ conv_kernel: conv1d kernel size
182
+ ngram_buckets: hash table size for n-gram embeddings (0 = disabled)
183
+ ngram_dim: embedding dimension for n-gram hashes
184
+ """
185
+
186
+ def __init__(
187
+ self,
188
+ num_classes=13,
189
+ d_model=256,
190
+ n_conv=3,
191
+ n_attn=1,
192
+ n_heads=4,
193
+ ffn_expand=2,
194
+ max_len=2048,
195
+ conv_kernel=15,
196
+ ngram_buckets=0,
197
+ ngram_dim=64,
198
+ ngram_scales=None,
199
+ ):
200
+ super().__init__()
201
+ self.max_len = max_len
202
+
203
+ # Byte embedding: 256 possible byte values + 1 padding
204
+ self.embed = nn.Embedding(257, d_model, padding_idx=256)
205
+
206
+ # Optional n-gram hash embedding
207
+ # ngram_scales: tuple of n-gram sizes, e.g. (3,) or (3, 5)
208
+ self.ngram_embed = None
209
+ if ngram_buckets > 0:
210
+ scales = ngram_scales if ngram_scales else (3,)
211
+ if len(scales) == 1:
212
+ self.ngram_embed = ByteNgramEmbed(ngram_buckets, ngram_dim, n=scales[0])
213
+ else:
214
+ self.ngram_embed = MultiScaleNgramEmbed(ngram_buckets, ngram_dim, scales=scales)
215
+ self.ngram_proj = nn.Linear(ngram_dim, d_model, bias=False)
216
+
217
+ # Conv blocks
218
+ self.conv_layers = nn.ModuleList([
219
+ ByteConvBlock(d_model, kernel_size=conv_kernel, expand=ffn_expand)
220
+ for _ in range(n_conv)
221
+ ])
222
+
223
+ # Attention blocks
224
+ self.attn_layers = nn.ModuleList([
225
+ ByteAttnBlock(d_model, n_heads, ffn_expand)
226
+ for _ in range(n_attn)
227
+ ])
228
+
229
+ self.final_norm = nn.LayerNorm(d_model)
230
+
231
+ # Classification head
232
+ self.head = nn.Sequential(
233
+ nn.Linear(d_model, d_model),
234
+ nn.GELU(),
235
+ nn.Dropout(0.1),
236
+ nn.Linear(d_model, num_classes),
237
+ )
238
+
239
+ def forward(self, byte_ids):
240
+ """
241
+ Args:
242
+ byte_ids: (B, T) long tensor of byte values [0-255], padded with 256
243
+ Returns:
244
+ logits: (B, num_classes)
245
+ """
246
+ pad_mask = byte_ids != 256
247
+
248
+ x = self.embed(byte_ids)
249
+
250
+ # Add n-gram features if enabled
251
+ if self.ngram_embed is not None:
252
+ ng = self.ngram_embed(byte_ids)
253
+ x = x + self.ngram_proj(ng)
254
+
255
+ for layer in self.conv_layers:
256
+ x = layer(x)
257
+
258
+ for layer in self.attn_layers:
259
+ x = layer(x)
260
+
261
+ x = self.final_norm(x)
262
+
263
+ mask = pad_mask.unsqueeze(-1).to(x.dtype)
264
+ x = (x * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
265
+
266
+ return self.head(x)
267
+
268
+ @staticmethod
269
+ def encode_text(text, max_len=2048):
270
+ """Convert text string to byte tensor, padded to max_len."""
271
+ raw = text.encode("utf-8", errors="replace")[:max_len]
272
+ byte_ids = list(raw) + [256] * (max_len - len(byte_ids))
273
+ return torch.tensor(byte_ids, dtype=torch.long)
274
+
275
+
276
+ # ── Configurations ───────────────────────────────────────────────────────
277
+
278
+ CONFIGS = {
279
+ # ~2M params: 3 conv + 1 attn, d=256 (original)
280
+ "base": dict(d_model=256, n_conv=3, n_attn=1, n_heads=4, conv_kernel=15),
281
+ # ~2.3M params: base + trigram hash embeddings (4k buckets × 64 dim)
282
+ "base_ngram": dict(d_model=256, n_conv=3, n_attn=1, n_heads=4, conv_kernel=15,
283
+ ngram_buckets=4096, ngram_dim=64),
284
+ # ~2.5M params: larger hash table
285
+ "base_ngram_large": dict(d_model=256, n_conv=3, n_attn=1, n_heads=4, conv_kernel=15,
286
+ ngram_buckets=8192, ngram_dim=64),
287
+ # ~3.5M params: 3 conv + 2 attn, d=256
288
+ "large": dict(d_model=256, n_conv=3, n_attn=2, n_heads=4, conv_kernel=15),
289
+ # ~2M params: deeper conv, no attn
290
+ "conv_only": dict(d_model=256, n_conv=5, n_attn=0, n_heads=4, conv_kernel=15),
291
+ # ~2M params: wider kernel conv
292
+ "wide_conv": dict(d_model=256, n_conv=3, n_attn=1, n_heads=4, conv_kernel=31),
293
+ # Scaled-up configs
294
+ "d384": dict(d_model=384, n_conv=3, n_attn=1, n_heads=4, conv_kernel=15),
295
+ "d384_2attn": dict(d_model=384, n_conv=3, n_attn=2, n_heads=4, conv_kernel=15),
296
+ "d512": dict(d_model=512, n_conv=3, n_attn=1, n_heads=8, conv_kernel=15),
297
+ # 4-gram variant
298
+ "base_4gram": dict(d_model=256, n_conv=3, n_attn=1, n_heads=4, conv_kernel=15,
299
+ ngram_buckets=4096, ngram_dim=64, ngram_scales=(4,)),
300
+ # Multi-scale: 3-gram + 5-gram (two hash tables, summed)
301
+ "base_multiscale": dict(d_model=256, n_conv=3, n_attn=1, n_heads=4, conv_kernel=15,
302
+ ngram_buckets=4096, ngram_dim=64, ngram_scales=(3, 5)),
303
+ # Multi-scale: 3-gram + 4-gram + 5-gram
304
+ "base_multiscale3": dict(d_model=256, n_conv=3, n_attn=1, n_heads=4, conv_kernel=15,
305
+ ngram_buckets=4096, ngram_dim=64, ngram_scales=(3, 4, 5)),
306
+ }
307
+
308
+
309
+ def count_params(model):
310
+ return sum(p.numel() for p in model.parameters())
311
+
312
+
313
+ if __name__ == "__main__":
314
+ for name, cfg in CONFIGS.items():
315
+ model = ByteHybrid(num_classes=334, max_len=512, **cfg)
316
+ byte_ids = torch.randint(0, 256, (4, 512))
317
+ logits = model(byte_ids)
318
+ print(f"{name:<20s} {count_params(model):>10,} params output={logits.shape}")
predict.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Standalone CommonLingua predict — single file, no `commonlingua` package required.
2
+
3
+ Drop this next to `model.py` and the checkpoint, then:
4
+
5
+ python predict.py "Wikipedia is a free online encyclopedia, ..."
6
+ python predict.py --file input.txt
7
+ cat texts.tsv | python predict.py --stdin
8
+
9
+ For the full Python API and parquet batch mode, install the package:
10
+
11
+ pip install "git+https://github.com/PleIAs/bytehybrid-lid#egg=commonlingua[hub]"
12
+ """
13
+ import argparse
14
+ import sys
15
+ from pathlib import Path
16
+
17
+ import numpy as np
18
+ import torch
19
+
20
+ sys.path.insert(0, str(Path(__file__).parent))
21
+ from model import ByteHybrid, CONFIGS # noqa: E402
22
+
23
+
24
+ def load(checkpoint, dtype="fp32", device=None):
25
+ if device is None:
26
+ device = "cuda" if torch.cuda.is_available() else "cpu"
27
+ device = torch.device(device)
28
+ ckpt = torch.load(checkpoint, map_location="cpu", weights_only=False)
29
+ cfg = CONFIGS[ckpt["config"]]
30
+ model = ByteHybrid(num_classes=ckpt["num_classes"], max_len=ckpt["max_len"], **cfg)
31
+ model.load_state_dict(ckpt["model_state_dict"])
32
+ if dtype == "bf16":
33
+ model = model.to(torch.bfloat16)
34
+ model.eval().to(device)
35
+ idx2lang = {v: k for k, v in ckpt["lang2idx"].items()}
36
+ return model, idx2lang, ckpt["max_len"], device
37
+
38
+
39
+ def encode(texts, max_len):
40
+ out = np.full((len(texts), max_len), 256, dtype=np.int64)
41
+ for i, t in enumerate(texts):
42
+ if not isinstance(t, str):
43
+ t = "" if t is None else str(t)
44
+ raw = t.encode("utf-8", errors="replace")[:max_len]
45
+ if raw:
46
+ out[i, :len(raw)] = np.frombuffer(raw, dtype=np.uint8)
47
+ return torch.from_numpy(out)
48
+
49
+
50
+ @torch.no_grad()
51
+ def predict(model, texts, idx2lang, max_len, device, top_k=3, batch_size=256):
52
+ out = []
53
+ for i in range(0, len(texts), batch_size):
54
+ chunk = texts[i:i + batch_size]
55
+ b = encode(chunk, max_len).to(device)
56
+ probs = torch.softmax(model(b).float(), dim=-1)
57
+ top_p, top_idx = probs.topk(top_k, dim=-1)
58
+ for p_row, idx_row in zip(top_p.cpu().tolist(), top_idx.cpu().tolist()):
59
+ out.append([(idx2lang[j], float(p)) for p, j in zip(p_row, idx_row)])
60
+ return out
61
+
62
+
63
+ def main():
64
+ p = argparse.ArgumentParser()
65
+ p.add_argument("text", nargs="*", help="Texts to classify (one per arg).")
66
+ p.add_argument("--file", help="Read a single text from FILE.")
67
+ p.add_argument("--stdin", action="store_true", help="One text per line from stdin.")
68
+ p.add_argument("--checkpoint", default=str(Path(__file__).parent / "model.pt"))
69
+ p.add_argument("--dtype", choices=["fp32", "bf16"], default="fp32")
70
+ p.add_argument("--device", default=None)
71
+ p.add_argument("--top-k", type=int, default=3)
72
+ p.add_argument("--batch-size", type=int, default=256)
73
+ args = p.parse_args()
74
+
75
+ if args.stdin:
76
+ texts = [line.rstrip("\n") for line in sys.stdin if line.strip()]
77
+ elif args.file:
78
+ texts = [Path(args.file).read_text(encoding="utf-8")]
79
+ elif args.text:
80
+ texts = args.text
81
+ else:
82
+ p.print_help()
83
+ return
84
+
85
+ model, idx2lang, max_len, device = load(args.checkpoint, dtype=args.dtype, device=args.device)
86
+ print(f"# {len(idx2lang)} languages, max_len={max_len}, dtype={args.dtype}, device={device}",
87
+ file=sys.stderr)
88
+ results = predict(model, texts, idx2lang, max_len, device, args.top_k, args.batch_size)
89
+ for text, top in zip(texts, results):
90
+ preview = text[:80].replace("\n", " ")
91
+ others = " ".join(f"{lg}={p:.3f}" for lg, p in top[1:])
92
+ print(f"{top[0][0]}\t{top[0][1]:.4f}\t{others}\t{preview}")
93
+
94
+
95
+ if __name__ == "__main__":
96
+ main()