CLIWorks commited on
Commit
d8bc908
Β·
verified Β·
1 Parent(s): 10d05e0

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. REVIEW.md +224 -0
  2. arbitor.egg-info/PKG-INFO +18 -0
  3. arbitor.egg-info/SOURCES.txt +104 -0
  4. arbitor.egg-info/dependency_links.txt +1 -0
  5. arbitor.egg-info/requires.txt +16 -0
  6. arbitor.egg-info/top_level.txt +6 -0
  7. arbitor/__init__.py +35 -0
  8. arbitor/attention/__init__.py +15 -0
  9. arbitor/attention/context_attention.py +109 -0
  10. arbitor/attention/frame_buffer.py +78 -0
  11. arbitor/attention/kq_cache.py +30 -0
  12. arbitor/attention/kv_ledger.py +57 -0
  13. arbitor/attention/mla.py +176 -0
  14. arbitor/attention/ring_buffer.py +49 -0
  15. arbitor/components.py +1218 -0
  16. arbitor/config.py +125 -0
  17. arbitor/converters/convert_to_ternary2.py +81 -0
  18. arbitor/converters/convert_to_ternary54.py +120 -0
  19. arbitor/converters/convert_to_ternary64.py +111 -0
  20. arbitor/converters/convert_to_ternary8.py +101 -0
  21. arbitor/decoders.py +231 -0
  22. arbitor/encoders/__init__.py +11 -0
  23. arbitor/encoders/audio.py +83 -0
  24. arbitor/encoders/mel_frontend.py +70 -0
  25. arbitor/encoders/models/__init__.py +86 -0
  26. arbitor/encoders/models/download.py +132 -0
  27. arbitor/encoders/models/opensora-vae/config.json +35 -0
  28. arbitor/encoders/models/opensora-vae/model.safetensors +3 -0
  29. arbitor/encoders/models/pig-vae/model.safetensors +3 -0
  30. arbitor/encoders/opensora_vae.py +145 -0
  31. arbitor/encoders/opensora_vae_modules/autoencoder_2d.py +339 -0
  32. arbitor/encoders/opensora_vae_modules/autoencoder_kl_causal_3d.py +638 -0
  33. arbitor/encoders/opensora_vae_modules/registry.py +41 -0
  34. arbitor/encoders/opensora_vae_modules/unet_causal_3d_blocks.py +476 -0
  35. arbitor/encoders/opensora_vae_modules/vae.py +340 -0
  36. arbitor/encoders/pig_vae.py +148 -0
  37. arbitor/encoders/vae2d.py +56 -0
  38. arbitor/kernel/flash_vq.py +510 -0
  39. arbitor/kernel/ternary_audit.py +192 -0
  40. arbitor/kernel/ternary_scale.py +1811 -0
  41. arbitor/kernel/triton_video.py +75 -0
  42. arbitor/main.py +585 -0
  43. arbitor/optim/__init__.py +0 -0
  44. arbitor/optim/sign_sgd.py +45 -0
  45. arbitor/profiling.py +196 -0
  46. arbitor/sequencers.py +218 -0
  47. arbitor/vq.py +89 -0
  48. docs/ARB-RENAME-NOTE.md +62 -0
  49. docs/arbs-tts/README.md +90 -0
  50. docs/benchmarks/BENCHMARK.md +151 -0
REVIEW.md ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ARBS Code Audit: Dead Imports, Dead Code, and Triton Kernel Analysis
2
+
3
+ **Reviewed:** 2026-05-20T00:00:00Z
4
+ **Depth:** standard
5
+ **Files Reviewed:** 10
6
+
7
+ ## Summary
8
+
9
+ The ARBS codebase has **3 BLOCKER bugs** that will cause runtime crashes, **8 unused class/function definitions** (dead code), **7 dead Triton kernels in components.py** that should be moved to `arbitor/kernel/`, and **21+ unused imports** across files. Two missing function definitions (`_graph_gather_add`, `_moe_dense_combine`) exist in dead code paths but would crash if those paths were ever activated.
10
+
11
+ ---
12
+
13
+ ## BLOCKER Issues
14
+
15
+ ### CR-01: `_TernaryLinearFn.forward` references undefined `x_2d` (NameError at runtime)
16
+
17
+ **File:** `arbitor/kernel/ternary_scale.py:206-208`
18
+ **Issue:** The TileLang `_TernaryLinearFn.forward()` method references `x_2d` on lines 206-208, but `x_2d` is never defined in the method's scope. This will cause a `NameError` at runtime if the TileLang code path is taken in `TernaryScaleTensor.forward` (line 1069). The Triton variant `_TritonTernaryLinearFn` (line 878) correctly defines `x_2d = x.reshape(-1, k_in).contiguous()` before use, so this was likely an omission when the TileLang function was written.
19
+
20
+ ```python
21
+ # Line 206 β€” NameError: name 'x_2d' is not defined
22
+ M = x_2d.shape[0]
23
+ output = torch.empty(M, N, device=x.device, dtype=torch.float32)
24
+ fwd_kernel(x_2d.half(), T_packed, E, output)
25
+ ```
26
+
27
+ **Fix:** Add `x_2d = x.reshape(-1, K).contiguous()` before line 206:
28
+ ```python
29
+ with torch.no_grad():
30
+ N, K = shape
31
+ x_2d = x.reshape(-1, K).contiguous() # missing definition
32
+ M = x_2d.shape[0]
33
+ output = torch.empty(M, N, device=x.device, dtype=torch.float32)
34
+ fwd_kernel(x_2d.half(), T_packed, E, output)
35
+ ```
36
+
37
+ ---
38
+
39
+ ### CR-02: `_check_tilelang_finite` called but never defined (NameError at runtime)
40
+
41
+ **File:** `arbitor/kernel/ternary_scale.py:1072`
42
+ **Issue:** `_check_tilelang_finite()` is called in `TernaryScaleTensor.forward()` but is never defined anywhere in the codebase. This will cause a `NameError` at runtime when the TileLang path is active and the kernel produces valid output (the check is specifically gated by `_HAS_TILELANG` being True).
43
+
44
+ **Fix:** Either define the function (if the check is intentional) or remove the call:
45
+ ```python
46
+ # Replace line 1072 with a direct finiteness check or remove
47
+ if not torch.isfinite(y).all():
48
+ raise FloatingPointError("TileLang ternary kernel produced non-finite activations")
49
+ ```
50
+
51
+ ---
52
+
53
+ ### CR-03: `self.modality_gate` used but never assigned (AttributeError at runtime)
54
+
55
+ **File:** `arbitor/main.py:129-130`
56
+ **Issue:** `ARBModel.forward()` references `self.modality_gate` but it is never assigned in `ARBModel.__init__()`. While `ModalityGate` is imported at line 19, it is never instantiated and stored as `self.modality_gate`. This will cause an `AttributeError` on any forward pass where `self.modality_gate is not None` is evaluated.
57
+
58
+ The code at lines 129-132:
59
+ ```python
60
+ if self.modality_gate is not None:
61
+ gate_weights, active_count, hops = self.modality_gate(active_mods)
62
+ else:
63
+ gate_weights, active_count, hops = {}, len(active_mods), 1
64
+ ```
65
+
66
+ **Fix:** Add `self.modality_gate = ModalityGate()` in `ARBModel.__init__()` (or assign `self.modality_gate = None` if the gate should be optional):
67
+ ```python
68
+ # In ARBModel.__init__, after line 78:
69
+ self.modality_gate = ModalityGate()
70
+ ```
71
+
72
+ ---
73
+
74
+ ## WARNING: Undefined Functions in Dead Code
75
+
76
+ ### WR-01: `_graph_gather_add` called but never defined
77
+
78
+ **File:** `arbitor/components.py:739`
79
+ **Issue:** `TernaryGraph.forward()` calls `_graph_gather_add(vq_output, node_features, vq_indices)` but this function is never defined anywhere in the codebase. `TernaryGraph` is dead code (never imported or used), so this does not crash currently, but it blocks any future use of `TernaryGraph`.
80
+
81
+ **Fix:** Define `_graph_gather_add` or remove the dead class.
82
+
83
+ ---
84
+
85
+ ### WR-02: `_moe_dense_combine` called but never defined
86
+
87
+ **File:** `arbitor/components.py:941`
88
+ **Issue:** `SharedProjectionMoE.forward()` calls `_moe_dense_combine(torch.stack(...), topk_idx, topk_weights)` but this function is never defined. `SharedProjectionMoE` is dead code, but the missing function is a latent bug.
89
+
90
+ **Fix:** Define `_moe_dense_combine` or remove the dead class.
91
+
92
+ ---
93
+
94
+ ## WARNING: Unused Class/Function Definitions (Dead Code)
95
+
96
+ ### WR-03: `TernaryLSTMCell` class β€” defined but never used
97
+
98
+ **File:** `arbitor/components.py:189-207`
99
+ **Issue:** `TernaryLSTMCell` is defined and re-exported from `__init__.py` (line 23) but is never instantiated anywhere in the codebase. The model uses `MoEGraph` with attention (MLA) instead of LSTM-based processing.
100
+
101
+ ---
102
+
103
+ ### WR-04: `TernaryGraph` class β€” defined but never used
104
+
105
+ **File:** `arbitor/components.py:665-802`
106
+ **Issue:** `TernaryGraph` is defined in `components.py` but never imported or instantiated. It was replaced by `MoEGraph` (line 1342). The only reference is in a comment (line 1348).
107
+ **Also:** `TernaryGraph` references the undefined function `_graph_gather_add` (see WR-01), so it cannot function even if someone tried to use it.
108
+
109
+ ---
110
+
111
+ ### WR-05: `SharedProjectionMoE` class β€” defined but never used
112
+
113
+ **File:** `arbitor/components.py:806-999`
114
+ **Issue:** `SharedProjectionMoE` is defined in `components.py` but never imported or instantiated. It was replaced by `MoEGraph._run_expert()` (line 1429). The only reference is in a comment (line 1348).
115
+ **Also:** References the undefined function `_moe_dense_combine` (see WR-02).
116
+
117
+ ---
118
+
119
+ ### WR-06: 7 dead Triton kernel functions in `components.py`
120
+
121
+ **File:** `arbitor/components.py:266-386`
122
+ **Issue:** These Triton kernel functions are defined inside the `if _HAS_TRITON:` block but are only referenced by their forward/backward wrapper functions which are themselves part of dead code (`TernaryGraph` and `SharedProjectionMoE`):
123
+
124
+ | Line | Function | Used By |
125
+ |------|----------|---------|
126
+ | 268 | `_triton_graph_aggregate_fwd_kernel` | dead (TernaryGraph) |
127
+ | 292 | `_triton_graph_aggregate_bwd_kernel` | dead (TernaryGraph) |
128
+ | 316 | `_triton_graph_gather_add_fwd_kernel` | dead (TernaryGraph) |
129
+ | 329 | `_triton_graph_gather_add_bwd_kernel` | dead (TernaryGraph) |
130
+ | 342 | `_triton_moe_dense_combine_fwd_kernel` | dead (SharedProjectionMoE) |
131
+ | 359 | `_triton_moe_dense_combine_bwd_expert_kernel` | dead (SharedProjectionMoE) |
132
+ | 374 | `_triton_moe_dense_combine_bwd_weight_kernel` | dead (SharedProjectionMoE) |
133
+
134
+ The live Triton kernels (`_triton_video_denoise_fwd_kernel` line 389, `_triton_video_denoise_bwd_kernel` line 402) are still in `components.py` and should also be moved to `arbitor/kernel/`.
135
+
136
+ ---
137
+
138
+ ### WR-07: `_triton_flash_vq_quantize_kernel` β€” dead Triton kernel
139
+
140
+ **File:** `arbitor/kernel/flash_vq.py:370-402`
141
+ **Issue:** This Triton kernel is defined but never called. The `_TritonFlashVQFn.forward()` method uses PyTorch's `embed[indices]` for the gather operation (line 468) instead of this kernel.
142
+
143
+ ---
144
+
145
+ ### WR-08: `TILE_SIZE = 384` β€” unused constant
146
+
147
+ **File:** `arbitor/kernel/ternary_scale.py:949`
148
+ **Issue:** `TILE_SIZE` is defined as a module-level constant but never referenced anywhere in the codebase.
149
+
150
+ ---
151
+
152
+ ## WARNING: Unused Imports
153
+
154
+ ### WR-09: Unused imports in `arbitor/main.py` (line 10)
155
+
156
+ | Symbol | Used In File? |
157
+ |--------|--------------|
158
+ | `EMBEDDING_DIM` | No β€” not referenced in body |
159
+ | `FFN_HIDDEN` | No β€” not referenced in body |
160
+ | `CODEBOOK_DIM` | No β€” not referenced in body |
161
+ | `ATTENTION_STRIDE` | No β€” not referenced in body |
162
+ | `MG_N_EXPERTS` | No β€” MoEGraph uses default, not passed |
163
+ | `MG_CORE_RANK` | No β€” MoEGraph uses default |
164
+ | `MG_SHARED_INTER` | No β€” MoEGraph uses default |
165
+ | `MG_ACT_ITERS` | No β€” MoEGraph uses default |
166
+
167
+ ---
168
+
169
+ ### WR-10: Unused imports in `arbitor/components.py` (line 21)
170
+
171
+ | Symbol | Used In Live Code? | Note |
172
+ |--------|-------------------|------|
173
+ | `FFN_HIDDEN` | No | Not referenced in file body |
174
+ | `CTX` | No | Not referenced in file body |
175
+ | `THRESHOLD` | No | Only used in dead `TernaryGraph`. Live `MoEGraph` hardcodes `threshold=0.05` |
176
+ | `KG_EMA_ALPHA` | No | Only used in dead `TernaryGraph`. Live `MoEGraph` hardcodes `0.99` |
177
+ | `KG_REQUANT_EVERY` | No | Only used in dead `TernaryGraph`. Live `MoEGraph` hardcodes `50` |
178
+ | `KG_TERNARY_THRESHOLD` | No | Only used in dead `TernaryGraph`. Live `MoEGraph` hardcodes `0.3` |
179
+
180
+ ---
181
+
182
+ ### WR-11: Unused imports in `arbitor/profiling.py` (line 17)
183
+
184
+ | Symbol | Used In File? |
185
+ |--------|--------------|
186
+ | `VOCAB` | No β€” not referenced in body |
187
+ | `math` (line 11) | No β€” not referenced in body |
188
+
189
+ ---
190
+
191
+ ## INFO: Triton Kernel Code in `components.py` Should Be Moved to `arbitor/kernel/`
192
+
193
+ ### IN-01: Live Triton kernels reside in `components.py` instead of `arbitor/kernel/`
194
+
195
+ **File:** `arbitor/components.py:389-445`
196
+ **Issue:** The codebase convention places Triton kernels in `arbitor/kernel/` (e.g., `ternary_scale.py`, `flash_vq.py`, `ternary_audit.py`). Two live Triton kernels remain in `components.py`:
197
+
198
+ - `_triton_video_denoise_fwd_kernel` (line 389)
199
+ - `_triton_video_denoise_bwd_kernel` (line 402)
200
+ - `_TritonVideoDenoiseFn` (line 415)
201
+ - `_video_denoise_step` (line 448)
202
+
203
+ These should be extracted into `arbitor/kernel/video_denoise.py` and imported from there, following the pattern established by `ternary_scale.py` and `flash_vq.py`.
204
+
205
+ ---
206
+
207
+ ## INFO: Additional Dead Code
208
+
209
+ ### IN-02: Hardcoded MoEGraph config values bypass config constants
210
+
211
+ **File:** `arbitor/components.py:1381-1383`
212
+ **Issue:** `MoEGraph` uses hardcoded values (`50`, `0.3`, `0.99`) instead of the imported config constants (`KG_REQUANT_EVERY`, `KG_TERNARY_THRESHOLD`, `KG_EMA_ALPHA`). The values happen to match the config, but any future config changes will silently be ignored.
213
+
214
+ ---
215
+
216
+ ### IN-03: `AUDIO_VOCAB` not used meaningfully in `config.py`
217
+
218
+ **File:** `arbitor/config.py:2`
219
+ **Issue:** `AUDIO_VOCAB=288` is imported and used in `TalkerHead` and `TinyNeuralCodec`, but the `SPECIAL_VOCAB` map (line 65) defines tokens up to 287. `AUDIO_VOCAB` = `VOCAB` = 288, meaning the audio head has the same vocabulary as the text head. This may be intentional for the current prototype but is worth flagging given `AUDIO_VOCAB` vs `VOCAB` are separate constants.
220
+
221
+ ---
222
+
223
+ _Reviewed: 2026-05-20T00:00:00Z_
224
+ _Reviewer: gsd-code-reviewer (deep analysis)_
arbitor.egg-info/PKG-INFO ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: arbitor
3
+ Version: 0.2.0
4
+ Summary: ARB (Any Relational Bit) β€” ternary-weighted neural network system
5
+ License: MIT
6
+ Requires-Python: >=3.12
7
+ Requires-Dist: torch>=2.5
8
+ Requires-Dist: einops
9
+ Requires-Dist: tqdm
10
+ Provides-Extra: dev
11
+ Requires-Dist: pytest; extra == "dev"
12
+ Provides-Extra: cuda
13
+ Requires-Dist: torch>=2.5; extra == "cuda"
14
+ Requires-Dist: triton>=3.0; extra == "cuda"
15
+ Provides-Extra: triton
16
+ Requires-Dist: triton>=3.0; extra == "triton"
17
+ Provides-Extra: tilelang
18
+ Requires-Dist: tilelang; extra == "tilelang"
arbitor.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pyproject.toml
2
+ arbitor/__init__.py
3
+ arbitor/components.py
4
+ arbitor/config.py
5
+ arbitor/decoders.py
6
+ arbitor/main.py
7
+ arbitor/profiling.py
8
+ arbitor/sequencers.py
9
+ arbitor/vq.py
10
+ arbitor.egg-info/PKG-INFO
11
+ arbitor.egg-info/SOURCES.txt
12
+ arbitor.egg-info/dependency_links.txt
13
+ arbitor.egg-info/requires.txt
14
+ arbitor.egg-info/top_level.txt
15
+ arbitor/attention/__init__.py
16
+ arbitor/attention/context_attention.py
17
+ arbitor/attention/frame_buffer.py
18
+ arbitor/attention/kq_cache.py
19
+ arbitor/attention/kv_ledger.py
20
+ arbitor/attention/mla.py
21
+ arbitor/attention/ring_buffer.py
22
+ arbitor/converters/convert_to_ternary2.py
23
+ arbitor/converters/convert_to_ternary54.py
24
+ arbitor/converters/convert_to_ternary64.py
25
+ arbitor/converters/convert_to_ternary8.py
26
+ arbitor/encoders/__init__.py
27
+ arbitor/encoders/audio.py
28
+ arbitor/encoders/mel_frontend.py
29
+ arbitor/encoders/opensora_vae.py
30
+ arbitor/encoders/pig_vae.py
31
+ arbitor/encoders/vae2d.py
32
+ arbitor/encoders/models/__init__.py
33
+ arbitor/encoders/models/download.py
34
+ arbitor/encoders/opensora_vae_modules/autoencoder_2d.py
35
+ arbitor/encoders/opensora_vae_modules/autoencoder_kl_causal_3d.py
36
+ arbitor/encoders/opensora_vae_modules/registry.py
37
+ arbitor/encoders/opensora_vae_modules/unet_causal_3d_blocks.py
38
+ arbitor/encoders/opensora_vae_modules/vae.py
39
+ arbitor/kernel/flash_vq.py
40
+ arbitor/kernel/ternary_audit.py
41
+ arbitor/kernel/ternary_scale.py
42
+ arbitor/kernel/triton_video.py
43
+ arbitor/optim/__init__.py
44
+ arbitor/optim/sign_sgd.py
45
+ testing/bigcalc.py
46
+ testing/scaled_optum.py
47
+ testing/sign_gsd.py
48
+ testing/test_200_step_smoke.py
49
+ testing/test_bigint_ternary.py
50
+ testing/test_gradient_capture.py
51
+ testing/test_polarity_validation.py
52
+ testing/test_tilelang_training.py
53
+ testing/test_tscale.py
54
+ testing/tscale_mini.py
55
+ testing/attention/__init__.py
56
+ testing/attention/test_kq_cache.py
57
+ testing/attention/test_kv_cache.py
58
+ testing/attention/test_lstm_removal.py
59
+ testing/attention/test_lstm_removal_clean.py
60
+ testing/attention/test_mla.py
61
+ testing/attention/test_ring_buffer.py
62
+ testing/benchmarks/benchmark.py
63
+ testing/benchmarks/benchmark_phase2.py
64
+ testing/benchmarks/benchmark_true_ternary.py
65
+ testing/eval/eval_checkpoints.py
66
+ testing/eval/eval_generation.py
67
+ testing/eval/eval_metrics.py
68
+ testing/eval/test_eval.py
69
+ testing/kg/test_composite_head.py
70
+ testing/kg/test_kg_edges.py
71
+ testing/kg/test_kv_integration.py
72
+ testing/model/audio-comprehension.py
73
+ testing/model/health.py
74
+ testing/model/image-comprehension.py
75
+ testing/model/test-stp.py
76
+ testing/model/test_arb.py
77
+ testing/model/test_flash.py
78
+ testing/model/test_tscale.py
79
+ testing/model/text-comprehension.py
80
+ testing/model/video-comprehension.py
81
+ testing/vae/test_opensora_vae.py
82
+ tests/test_cross_modal.py
83
+ tests/test_lti.py
84
+ tests/test_moegraph_topk.py
85
+ tests/test_vae2d.py
86
+ tests/test_vae2d_sequencer.py
87
+ training/audio.py
88
+ training/diffusion.py
89
+ training/pretrain.py
90
+ training/text.py
91
+ training/vision.py
92
+ training/data/__init__.py
93
+ training/data/prepare_cc12m.py
94
+ training/data/prepare_fineweb.py
95
+ training/data/prepare_librispeech.py
96
+ training/data/prepare_starcoder.py
97
+ training/data/prepare_webvid.py
98
+ training/data/tokenize_from_hf.py
99
+ training/finetuning/__init__.py
100
+ training/finetuning/audio.py
101
+ training/finetuning/diffusion.py
102
+ training/finetuning/lora.py
103
+ training/finetuning/text.py
104
+ training/finetuning/vision.py
arbitor.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
arbitor.egg-info/requires.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.5
2
+ einops
3
+ tqdm
4
+
5
+ [cuda]
6
+ torch>=2.5
7
+ triton>=3.0
8
+
9
+ [dev]
10
+ pytest
11
+
12
+ [tilelang]
13
+ tilelang
14
+
15
+ [triton]
16
+ triton>=3.0
arbitor.egg-info/top_level.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ arbitor
2
+ docs
3
+ models
4
+ testing
5
+ tests
6
+ training
arbitor/__init__.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ARBitor β€” Any Relational Bit System.
2
+
3
+ Core package for the ARB ternary-weighted neural network.
4
+ Quick import: from arbitor import ARBModel, VOCAB
5
+ """
6
+ from .config import VOCAB, AUDIO_VOCAB, AUDIO_SR, AUDIO_FRAME_RATE, \
7
+ EMBEDDING_DIM, HIDDEN_DIM, CTX, SPECIAL_VOCAB, \
8
+ CODEBOOK_DIM, SHARED_VQ_SIZE, \
9
+ MG_N_EXPERTS, MG_CORE_RANK, MG_SHARED_INTER, MG_ACT_ITERS, \
10
+ MEMGRAM_STRUCT_PRIMES, MEMGRAM_CONV_PRIMES, MEMGRAM_EMBED_DIM, MEMGRAM_KEY_DIM
11
+
12
+ from .kernel.ternary_scale import (
13
+ TernaryScaleTensor, TernaryRMSNorm, TScaleType, GROUP_SIZES,
14
+ _HAS_TRITON, _HAS_TILELANG,
15
+ )
16
+ from .kernel.flash_vq import FlashVQCodebook
17
+ from .kernel.ternary_audit import audit_model, format_audit, freeze_float_parameters, trainable_parameters
18
+ from .converters.convert_to_ternary8 import pack_ternary, unpack_ternary
19
+
20
+ from .sequencers import ByteEmbedding, Sequencer, TextSequencer, VAE2DSequencer, VAEAudioSequencer, MultimodalSequencer
21
+ from .vq import SharedVQ
22
+ from .components import (
23
+ TernaryEmbeddingTable, TernaryVQCodebook,
24
+ GNNLoRAAdapter, HaltingUnit,
25
+ MemGram, MoEGraph,
26
+ ByteHead, OutputRouter,
27
+ LossComponents, LossWeights, StickyZoneSTE,
28
+ KGVQCodebook, CompositeProposalHead,
29
+ _BOUNDARY_TOKEN_MAP,
30
+ )
31
+ from .decoders import VideoHead, TalkerHead, MRFBlock, TinyNeuralCodec
32
+ from .main import ARBModel, _extract_boundary_from_input
33
+
34
+ # Re-export encoders
35
+ from .encoders import TinyNeuralCodec as Codec, AudioVQEncoder, load_vae, VAEWrapper
arbitor/attention/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ARB Attention β€” KV Ledger, MLA, Sliding Window Attention."""
2
+ from .ring_buffer import GPURingBuffer
3
+ from .kv_ledger import KVLedger
4
+ from .kq_cache import KQCache
5
+ from .mla import (MultiHeadLatentAttention, apply_rotary_emb,
6
+ precompute_freqs_cis)
7
+ from .context_attention import ContextAttentionScheduler
8
+ from .frame_buffer import TemporalFrameBuffer
9
+
10
+ __all__ = [
11
+ "GPURingBuffer", "KVLedger", "KQCache",
12
+ "MultiHeadLatentAttention", "apply_rotary_emb",
13
+ "precompute_freqs_cis", "ContextAttentionScheduler",
14
+ "TemporalFrameBuffer",
15
+ ]
arbitor/attention/context_attention.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Context Attention Scheduler β€” sliding window + full context orchestration.
2
+
3
+ Schedules 4 sliding window (d=64, CSA-compressed to d=16) and 4 full context
4
+ (d=32, HCA-compressed to d=8) MLA attention passes. Combines both via gating.
5
+
6
+ Pipeline: GNN output β†’ ContextAttentionScheduler β†’ MoE input
7
+ """
8
+ import torch
9
+ import torch.nn as nn
10
+ from ..config import HIDDEN_DIM, MLA_HCA_STRIDE
11
+ from ..kernel.ternary_scale import TernaryScaleTensor, TernaryRMSNorm, TScaleType
12
+ from .mla import (MultiHeadLatentAttention, precompute_freqs_cis,
13
+ MLA_N_LAYERS, MLA_N_HEADS, MLA_SLIDE_DIM, MLA_FULL_DIM,
14
+ MLA_QK_NOPE_HEAD_DIM, MLA_QK_ROPE_HEAD_DIM,
15
+ MLA_V_HEAD_DIM, MLA_ROPE_THETA,
16
+ MLA_CSA_DIM, MLA_HCA_DIM, MLA_HCA_STRIDE)
17
+
18
+ SLIDING_WINDOW_SIZE = 32768
19
+ KV_LEDGER_SIZE = 262144
20
+
21
+
22
+ class ContextAttentionScheduler(nn.Module):
23
+ def __init__(self, dim=HIDDEN_DIM):
24
+ super().__init__()
25
+ self.dim = dim
26
+
27
+ # Slide layers with CSA compression (d=64 β†’ d=16) β€” half of total layers
28
+ n_layers_per_pass = max(1, MLA_N_LAYERS // 2)
29
+ self.slide_layers = nn.ModuleList([
30
+ MultiHeadLatentAttention(
31
+ dim=dim, n_heads=MLA_N_HEADS, kv_lora_rank=MLA_SLIDE_DIM,
32
+ qk_nope_head_dim=MLA_QK_NOPE_HEAD_DIM,
33
+ qk_rope_head_dim=MLA_QK_ROPE_HEAD_DIM,
34
+ v_head_dim=MLA_V_HEAD_DIM,
35
+ csa_dim=MLA_CSA_DIM, hca_dim=None,
36
+ ) for _ in range(n_layers_per_pass)
37
+ ])
38
+ # CSA: embed motif IDs β†’ kv_lora_rank, then compress β†’ csa_dim
39
+ self.slide_embed = TernaryScaleTensor(1, MLA_SLIDE_DIM, tscale_type=TScaleType.T32)
40
+ self.slide_compress = TernaryScaleTensor(MLA_SLIDE_DIM, MLA_CSA_DIM, tscale_type=TScaleType.T32)
41
+
42
+ # Full context layers with HCA compression (d=32 β†’ d=8) β€” half of total layers
43
+ self.full_layers = nn.ModuleList([
44
+ MultiHeadLatentAttention(
45
+ dim=dim, n_heads=MLA_N_HEADS, kv_lora_rank=MLA_FULL_DIM,
46
+ qk_nope_head_dim=MLA_QK_NOPE_HEAD_DIM,
47
+ qk_rope_head_dim=MLA_QK_ROPE_HEAD_DIM,
48
+ v_head_dim=MLA_V_HEAD_DIM,
49
+ csa_dim=None, hca_dim=MLA_HCA_DIM,
50
+ ) for _ in range(n_layers_per_pass)
51
+ ])
52
+ # HCA: embed motif IDs β†’ kv_lora_rank, then compress β†’ hca_dim
53
+ self.full_embed = TernaryScaleTensor(1, MLA_FULL_DIM, tscale_type=TScaleType.T32)
54
+ self.full_compress = TernaryScaleTensor(MLA_FULL_DIM, MLA_HCA_DIM, tscale_type=TScaleType.T32)
55
+
56
+ self.gate = TernaryScaleTensor(dim, 1, tscale_type=TScaleType.T32)
57
+
58
+ self._freqs_cis = None
59
+ self._max_freq_len = 0
60
+
61
+ def _ensure_freqs(self, seq_len, device):
62
+ needed = max(seq_len, SLIDING_WINDOW_SIZE, KV_LEDGER_SIZE)
63
+ if self._freqs_cis is None or needed > self._max_freq_len:
64
+ self._max_freq_len = needed
65
+ self._freqs_cis = precompute_freqs_cis(
66
+ MLA_QK_ROPE_HEAD_DIM, needed, theta=MLA_ROPE_THETA
67
+ ).to(device)
68
+ return self._freqs_cis
69
+
70
+ def forward(self, x, kv_ledger, full_ledger=None, kq_cache=None):
71
+ bsz, seqlen, _ = x.shape
72
+ device = x.device
73
+ freqs_cis = self._ensure_freqs(seqlen, device)
74
+
75
+ full_ledger = full_ledger or kv_ledger
76
+
77
+ window_size = min(SLIDING_WINDOW_SIZE, kv_ledger.size) if kv_ledger.size > 0 else 0
78
+
79
+ out_slide = x
80
+ if window_size > 0:
81
+ start = max(0, kv_ledger.size - SLIDING_WINDOW_SIZE)
82
+ end = kv_ledger.size
83
+ slide_ids = kv_ledger.get_range(start, end).float().unsqueeze(-1)
84
+ # Embed to kv_lora_rank, then CSA compress to csa_dim
85
+ slide_latent = self.slide_embed(slide_ids)
86
+ csa_cache = self.slide_compress(slide_latent)
87
+ pe_cache = torch.zeros(csa_cache.shape[0], MLA_QK_ROPE_HEAD_DIM, device=device)
88
+
89
+ for layer in self.slide_layers:
90
+ out_slide = layer(out_slide, slide_latent, pe_cache,
91
+ start_pos=0, freqs_cis=freqs_cis, mask=None,
92
+ csa_cache=csa_cache)
93
+
94
+ out_full = x
95
+ if full_ledger.size > 0:
96
+ full = full_ledger.get_sparse(stride=MLA_HCA_STRIDE)
97
+ full_ids = full.float().unsqueeze(-1)
98
+ full_latent = self.full_embed(full_ids)
99
+ hca_cache = self.full_compress(full_latent)
100
+ pe_cache = torch.zeros(hca_cache.shape[0], MLA_QK_ROPE_HEAD_DIM, device=device)
101
+
102
+ for layer in self.full_layers:
103
+ out_full = layer(out_full, full_latent, pe_cache,
104
+ start_pos=0, freqs_cis=freqs_cis, mask=None,
105
+ hca_cache=hca_cache, hca_pe_cache=pe_cache)
106
+
107
+ gate = torch.sigmoid(self.gate(x.mean(dim=1, keepdim=True)))
108
+ out = gate * out_slide + (1 - gate) * out_full
109
+ return out
arbitor/attention/frame_buffer.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """TemporalFrameBuffer β€” ring buffer for video latents with HCA compression.
2
+
3
+ Stores the last N video latents (local) and maintains a compressed long-range
4
+ cache via TernaryScaleTensor projection. Used for conditioning video generation
5
+ on previous time steps.
6
+
7
+ Latent shape: [B, C, H', W'] where C=OPEN_SORA_LATENT_CHANNELS=4,
8
+ H'=VIDEO_HEIGHT=32, W'=VIDEO_WIDTH=32. Each "latent" is one 4-frame chunk.
9
+ """
10
+ import torch
11
+ import torch.nn as nn
12
+ from ..kernel.ternary_scale import TernaryScaleTensor, TScaleType
13
+ from .ring_buffer import GPURingBuffer
14
+ from ..config import FRAME_BUFFER_LOCAL_SIZE, FRAME_BUFFER_CACHE_STRIDE, \
15
+ OPEN_SORA_LATENT_CHANNELS, VIDEO_HEIGHT, VIDEO_WIDTH
16
+
17
+
18
+ class TemporalFrameBuffer(nn.Module):
19
+ def __init__(self, local_size=FRAME_BUFFER_LOCAL_SIZE,
20
+ cache_stride=FRAME_BUFFER_CACHE_STRIDE,
21
+ latent_channels=OPEN_SORA_LATENT_CHANNELS,
22
+ height=VIDEO_HEIGHT, width=VIDEO_WIDTH,
23
+ tscale_type=TScaleType.T32):
24
+ super().__init__()
25
+ self.latent_channels = latent_channels
26
+ self.spatial_dim = height * width
27
+ self.latent_flat_dim = latent_channels * self.spatial_dim
28
+
29
+ self.local = GPURingBuffer(
30
+ max_size=local_size,
31
+ dtype=torch.float32,
32
+ dim=self.latent_flat_dim,
33
+ )
34
+
35
+ self.compress_proj = TernaryScaleTensor(
36
+ self.latent_flat_dim,
37
+ self.latent_flat_dim // 4,
38
+ tscale_type=tscale_type,
39
+ )
40
+ self.compressed_cache = []
41
+ self.cache_stride = cache_stride
42
+ self._frames_since_compress = 0
43
+
44
+ def append(self, latent):
45
+ B = latent.shape[0]
46
+ flat = latent.reshape(B, -1)
47
+ self.local.append(flat)
48
+
49
+ self._frames_since_compress += 1
50
+ if self._frames_since_compress >= self.cache_stride:
51
+ compressed = self.compress_proj(flat)
52
+ self.compressed_cache.append(compressed.detach())
53
+ self._frames_since_compress = 0
54
+
55
+ def get_local(self, n=None):
56
+ n = n or self.local.max_size
57
+ result = self.local.get_last_n(n)
58
+ if result.dim() == 0 or result.shape[0] == 0:
59
+ return torch.zeros(0, 1, self.latent_flat_dim)
60
+ if result.dim() == 1:
61
+ result = result.unsqueeze(0)
62
+ return result
63
+
64
+ def get_compressed_cache(self):
65
+ if not self.compressed_cache:
66
+ return torch.zeros(0, 1, self.latent_flat_dim // 4)
67
+ return torch.stack(self.compressed_cache, dim=0)
68
+
69
+ def get_conditioning(self, n_local=None):
70
+ return {
71
+ "local": self.get_local(n_local),
72
+ "compressed": self.get_compressed_cache(),
73
+ }
74
+
75
+ def reset(self):
76
+ self.local.reset()
77
+ self.compressed_cache = []
78
+ self._frames_since_compress = 0
arbitor/attention/kq_cache.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """KQ Cache β€” small ring buffer of last 8K motif IDs for O(1) peek.
2
+
3
+ Per D-64: Small ring buffer holding last 8K motif IDs. No compression - just raw IDs.
4
+ O(1) peek for fast motif lookup without MemGram query.
5
+
6
+ Per D-65: Updated after each ByteHead output append to ledger.
7
+ """
8
+ import torch
9
+ import torch.nn as nn
10
+ from ..config import KQ_CACHE_SIZE
11
+ from .ring_buffer import GPURingBuffer
12
+
13
+
14
+ class KQCache(nn.Module):
15
+ def __init__(self, max_size=KQ_CACHE_SIZE):
16
+ super().__init__()
17
+ self.ring = GPURingBuffer(max_size=max_size, dtype=torch.int32, dim=1)
18
+
19
+ def append(self, motif_id: int):
20
+ self.ring.append(torch.tensor(motif_id, dtype=torch.int32, device=self.ring.buffer.device))
21
+
22
+ def peek(self, n=1):
23
+ return self.ring.get_last_n(n)
24
+
25
+ @property
26
+ def size(self):
27
+ return self.ring.size
28
+
29
+ def reset(self):
30
+ self.ring.reset()
arbitor/attention/kv_ledger.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """KV Ledger β€” append-only ring buffer of motif IDs (int32), max 256K entries.
2
+
3
+ Per D-57: Append-only ring buffer of motif IDs (int32), max 256K entries.
4
+ When full, oldest entries are overwritten. Stored as flat tensor on GPU.
5
+
6
+ Per D-59: The ledger stores only what the model outputs (motif IDs),
7
+ not input prompts. Prompts go through VQ -> GNN -> Motif pipeline first.
8
+
9
+ KV is consumed by the ContextAttentionScheduler. Its output is injected into
10
+ MoEGraph, which then conditions the router and output heads through the shared
11
+ processed relational state.
12
+ """
13
+ import torch
14
+ import torch.nn as nn
15
+ from ..config import KV_LEDGER_SIZE, SLIDING_WINDOW_SIZE
16
+ from .ring_buffer import GPURingBuffer
17
+
18
+
19
+ class KVLedger(nn.Module):
20
+ def __init__(self, max_size=KV_LEDGER_SIZE):
21
+ super().__init__()
22
+ self.ring = GPURingBuffer(max_size=max_size, dtype=torch.int32, dim=1)
23
+
24
+ def append(self, motif_id: int):
25
+ self.ring.append(torch.tensor(motif_id, dtype=torch.int32, device=self.ring.buffer.device))
26
+
27
+ def get_sliding_window(self, n=SLIDING_WINDOW_SIZE):
28
+ return self.ring.get_last_n(n)
29
+
30
+ def get_range(self, start, end):
31
+ n = end - start
32
+ if n <= 0 or start >= self.ring.size:
33
+ return torch.zeros(0, dtype=torch.int32, device=self.ring.buffer.device)
34
+ if start + n <= self.ring.max_size:
35
+ return self.ring.buffer[start:start + n].squeeze(-1)
36
+ first = self.ring.buffer[start:].squeeze(-1)
37
+ second = self.ring.buffer[:n - (self.ring.max_size - start)].squeeze(-1)
38
+ return torch.cat([first, second])
39
+
40
+ def get_sparse(self, stride=8):
41
+ size = self.ring.size
42
+ if size == 0:
43
+ return torch.zeros(0, dtype=torch.int32, device=self.ring.buffer.device)
44
+ all_vals = self.ring.get_all()
45
+ indices = torch.arange(0, size, stride, device=self.ring.buffer.device, dtype=torch.long)
46
+ indices = indices[indices < len(all_vals)]
47
+ return all_vals[indices]
48
+
49
+ @property
50
+ def size(self):
51
+ return self.ring.size
52
+
53
+ def __len__(self):
54
+ return self.ring.size
55
+
56
+ def reset(self):
57
+ self.ring.reset()
arbitor/attention/mla.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Multi-Head Latent Attention with CSA + HCA compression (DeepSeek V4 style).
2
+
3
+ Ternary-weighted. KV cache stores compressed latent at multiple levels:
4
+ - Base: MLA latent (d=kv_lora_rank, typically 64/32)
5
+ - CSA: Secondary compression (d_csa, e.g. 16) β€” 4x compression on cache
6
+ - HCA: Heavily compressed (d_hca, e.g. 8) β€” 8x compression, wider stride
7
+
8
+ Scores = q_nope_absorbed @ decompress(kv_cache) + q_pe @ pe_cache
9
+ """
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from ..config import HIDDEN_DIM, MLA_CSA_DIM, MLA_HCA_DIM, MLA_HCA_STRIDE, MLA_N_LAYERS
14
+ from ..kernel.ternary_scale import TernaryScaleTensor, TernaryRMSNorm, TScaleType
15
+
16
+ MLA_N_HEADS = 32
17
+ MLA_QK_NOPE_HEAD_DIM = 96
18
+ MLA_QK_ROPE_HEAD_DIM = 32
19
+ MLA_V_HEAD_DIM = 96
20
+ MLA_ROPE_THETA = 10000.0
21
+ MLA_SLIDE_DIM = 64
22
+ MLA_FULL_DIM = 32
23
+
24
+
25
+ def apply_rotary_emb(x, freqs_cis):
26
+ x_complex = torch.view_as_complex(
27
+ x.float().reshape(*x.shape[:-1], -1, 2)
28
+ )
29
+ freqs = freqs_cis.unsqueeze(1).unsqueeze(0)
30
+ return torch.view_as_real(x_complex * freqs).flatten(-2).to(x.dtype)
31
+
32
+
33
+ def precompute_freqs_cis(dim, end, theta=MLA_ROPE_THETA):
34
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
35
+ t = torch.arange(end, device=freqs.device)
36
+ freqs = torch.outer(t, freqs)
37
+ return torch.polar(torch.ones_like(freqs), freqs)
38
+
39
+
40
+ class MultiHeadLatentAttention(nn.Module):
41
+ def __init__(self, dim=HIDDEN_DIM, n_heads=MLA_N_HEADS, kv_lora_rank=MLA_SLIDE_DIM,
42
+ qk_nope_head_dim=MLA_QK_NOPE_HEAD_DIM, qk_rope_head_dim=MLA_QK_ROPE_HEAD_DIM,
43
+ v_head_dim=MLA_V_HEAD_DIM, max_seq_len=65536,
44
+ csa_dim=MLA_CSA_DIM, hca_dim=MLA_HCA_DIM,
45
+ tscale_type=TScaleType.T32):
46
+ super().__init__()
47
+ self.dim = dim
48
+ self.n_heads = n_heads
49
+ self.kv_lora_rank = kv_lora_rank
50
+ self.qk_nope_head_dim = qk_nope_head_dim
51
+ self.qk_rope_head_dim = qk_rope_head_dim
52
+ self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
53
+ self.v_head_dim = v_head_dim
54
+ self.softmax_scale = self.qk_head_dim ** -0.5
55
+ self.max_seq_len = max_seq_len
56
+ self.csa_dim = csa_dim
57
+ self.hca_dim = hca_dim
58
+
59
+ self.wq_norm = TernaryRMSNorm(dim, tscale_type=tscale_type)
60
+ self.wq = TernaryScaleTensor(dim, n_heads * self.qk_head_dim, tscale_type=tscale_type)
61
+
62
+ combined_out = n_heads * (qk_nope_head_dim + v_head_dim)
63
+ self.wkv_b = TernaryScaleTensor(kv_lora_rank, combined_out, tscale_type=tscale_type)
64
+ self.wo = TernaryScaleTensor(n_heads * v_head_dim, dim, tscale_type=tscale_type)
65
+
66
+ # CSA: secondary compression (kv_lora_rank -> csa_dim)
67
+ if csa_dim and csa_dim < kv_lora_rank:
68
+ self.csa_compress = TernaryScaleTensor(kv_lora_rank, csa_dim, tscale_type=tscale_type)
69
+ self.csa_decompress = TernaryScaleTensor(csa_dim, kv_lora_rank, tscale_type=tscale_type)
70
+ else:
71
+ self.csa_compress = None
72
+ self.csa_decompress = None
73
+
74
+ # HCA: heavily compressed (kv_lora_rank -> hca_dim)
75
+ if hca_dim and hca_dim < (csa_dim or kv_lora_rank):
76
+ self.hca_compress = TernaryScaleTensor(kv_lora_rank, hca_dim, tscale_type=tscale_type)
77
+ self.hca_decompress = TernaryScaleTensor(hca_dim, kv_lora_rank, tscale_type=tscale_type)
78
+ else:
79
+ self.hca_compress = None
80
+ self.hca_decompress = None
81
+
82
+ def _compress(self, kv_cache, compress_proj):
83
+ """Compress kv_cache from kv_lora_rank to smaller dim."""
84
+ return compress_proj(kv_cache)
85
+
86
+ def _decompress(self, cache, decompress_proj):
87
+ """Decompress cache back to kv_lora_rank."""
88
+ return decompress_proj(cache)
89
+
90
+ def _compute_scores(self, q_nope_absorbed, q_pe, kv_flat, pe_flat,
91
+ start_pos, seqlen, mask):
92
+ """Shared score computation for base, CSA, and HCA attention."""
93
+ n_keys = min(kv_flat.shape[0], pe_flat.shape[0])
94
+ kv_flat = kv_flat[:n_keys]
95
+ pe_flat = pe_flat[:n_keys]
96
+ if n_keys == 0:
97
+ return q_pe.new_zeros(q_pe.shape[0], seqlen, q_pe.shape[2], 0)
98
+ scores = (
99
+ torch.einsum("bshc,btc->bsht",
100
+ q_nope_absorbed, kv_flat.unsqueeze(0))
101
+ + torch.einsum("bshr,btr->bsht",
102
+ q_pe, pe_flat.unsqueeze(0))
103
+ ) * self.softmax_scale
104
+
105
+ if mask is not None:
106
+ scores = scores + mask.unsqueeze(0).unsqueeze(0)
107
+ if mask is None and seqlen > 1:
108
+ causal = torch.triu(
109
+ torch.full((seqlen, n_keys), float('-inf'), device=q_pe.device),
110
+ diagonal=1 + start_pos
111
+ )
112
+ scores = scores + causal.unsqueeze(0).unsqueeze(2)
113
+ return scores
114
+
115
+ def forward(self, x, kv_cache, pe_cache, start_pos=0, freqs_cis=None, mask=None,
116
+ csa_cache=None, hca_cache=None, hca_pe_cache=None):
117
+ bsz, seqlen, _ = x.size()
118
+
119
+ q = self.wq(self.wq_norm(x))
120
+ q = q.view(bsz, seqlen, self.n_heads, self.qk_head_dim)
121
+ q_nope, q_pe = torch.split(
122
+ q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
123
+
124
+ if freqs_cis is not None:
125
+ q_pe = apply_rotary_emb(q_pe, freqs_cis[start_pos:start_pos + seqlen])
126
+
127
+ wkv_b = self.wkv_b._get_T() * self.wkv_b._get_S()
128
+ wkv_b = wkv_b.view(self.n_heads, -1, self.kv_lora_rank)
129
+
130
+ q_nope_absorbed = torch.einsum(
131
+ "bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
132
+
133
+ n_cache = min(kv_cache.shape[0], pe_cache.shape[0])
134
+ kv_flat = kv_cache[:n_cache]
135
+ pe_flat = pe_cache[:n_cache]
136
+
137
+ # Decompress CSA cache if provided (replaces base kv_cache)
138
+ if csa_cache is not None and self.csa_decompress is not None:
139
+ n_csa = min(csa_cache.shape[0], pe_flat.shape[0])
140
+ kv_flat = self._decompress(csa_cache[:n_csa], self.csa_decompress)
141
+ pe_flat = pe_flat[:n_csa]
142
+
143
+ # Base attention (exact, CSA-compressed if applicable)
144
+ scores = self._compute_scores(
145
+ q_nope_absorbed, q_pe, kv_flat, pe_flat,
146
+ start_pos, seqlen, mask,
147
+ )
148
+ scores = scores.softmax(dim=-1, dtype=torch.float32)
149
+
150
+ attn_out = torch.einsum(
151
+ "bsht,btc->bshc", scores, kv_flat.unsqueeze(0))
152
+
153
+ # HCA long-range attention (heavily compressed, strided)
154
+ hca_out = None
155
+ if hca_cache is not None and self.hca_decompress is not None:
156
+ hca_kv = self._decompress(hca_cache, self.hca_decompress)
157
+ if hca_pe_cache is None:
158
+ hca_pe = pe_cache[::MLA_HCA_STRIDE]
159
+ else:
160
+ hca_pe = hca_pe_cache
161
+ n_hca = min(hca_kv.shape[0], hca_pe.shape[0])
162
+ hca_kv = hca_kv[:n_hca]
163
+ hca_pe = hca_pe[:n_hca]
164
+ hca_scores = self._compute_scores(
165
+ q_nope_absorbed, q_pe, hca_kv, hca_pe,
166
+ start_pos, seqlen, mask=None,
167
+ )
168
+ hca_scores = hca_scores.softmax(dim=-1, dtype=torch.float32)
169
+ hca_out = torch.einsum(
170
+ "bsht,btc->bshc", hca_scores, hca_kv.unsqueeze(0))
171
+ attn_out = attn_out + hca_out
172
+
173
+ attn_unproj = torch.einsum(
174
+ "bshc,hdc->bshd", attn_out, wkv_b[:, -self.v_head_dim:])
175
+
176
+ return self.wo(attn_unproj.flatten(2))
arbitor/attention/ring_buffer.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GPURingBuffer β€” generic GPU ring buffer utility.
2
+
3
+ O(1) append via circular pointer, chronological get_last_n with wrap handling.
4
+ All storage via register_buffer for device movement and state_dict serialization.
5
+ """
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ class GPURingBuffer(nn.Module):
11
+ def __init__(self, max_size: int, dtype: torch.dtype = torch.int32, dim: int = 1):
12
+ super().__init__()
13
+ self.max_size = max_size
14
+ self.ptr = 0
15
+ self.size = 0
16
+ buffer_shape = (max_size, dim if dim > 1 else 1)
17
+ self.register_buffer("buffer", torch.zeros(buffer_shape, dtype=dtype))
18
+
19
+ def append(self, x):
20
+ if not isinstance(x, torch.Tensor):
21
+ x = torch.tensor(x, dtype=self.buffer.dtype, device=self.buffer.device)
22
+ if self.buffer.dim() == 2 and x.dim() == 0:
23
+ x = x.view(1)
24
+ self.buffer[self.ptr] = x
25
+ self.ptr = (self.ptr + 1) % self.max_size
26
+ self.size = min(self.size + 1, self.max_size)
27
+
28
+ def get_last_n(self, n: int):
29
+ n = min(n, self.size)
30
+ if n == 0:
31
+ return torch.zeros(0, *self.buffer.shape[1:], dtype=self.buffer.dtype, device=self.buffer.device)
32
+ start = (self.ptr - n) % self.max_size
33
+ if start + n <= self.max_size:
34
+ result = self.buffer[start:start + n]
35
+ else:
36
+ first = self.buffer[start:]
37
+ second = self.buffer[:n - (self.max_size - start)]
38
+ result = torch.cat([first, second], dim=0)
39
+ if result.dim() > 1 and result.shape[1] == 1:
40
+ result = result.squeeze(-1)
41
+ return result
42
+
43
+ def get_all(self):
44
+ return self.get_last_n(self.size)
45
+
46
+ def reset(self):
47
+ self.buffer.zero_()
48
+ self.ptr = 0
49
+ self.size = 0
arbitor/components.py ADDED
@@ -0,0 +1,1218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Components β€” core neural network modules for the ARB system."""
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+ from .kernel.ternary_scale import TernaryScaleTensor, TScaleType, TernaryRMSNorm, GROUP_SIZES, _COMPONENT_CONTEXT, _HAS_TRITON
7
+ try:
8
+ from .kernel.ternary_scale import _TritonTernaryEmbedFn
9
+ except ImportError:
10
+ _TritonTernaryEmbedFn = None
11
+ from .converters.convert_to_ternary8 import pack_ternary, unpack_ternary
12
+ from dataclasses import dataclass, field, fields
13
+ from math import ceil as _ceil, log2 as _log2
14
+ from transformers import AutoModel, AutoFeatureExtractor
15
+ from .config import VOCAB, EMBEDDING_DIM, HIDDEN_DIM, AUDIO_VOCAB, AUDIO_SR, AUDIO_FRAME_RATE, SPECIAL_VOCAB, CODEBOOK_DIM, CODEBOOK_SIZE, FFN_HIDDEN, CTX, THRESHOLD, KG_EMA_ALPHA, KG_REQUANT_EVERY, KG_TERNARY_THRESHOLD, KGVQ_CODEBOOK_SIZE, KGVQ_CODEBOOK_DIM, KGVQ_DECAY, KGVQ_COMMITMENT_WEIGHT, KGVQ_DEAD_CODE_THRESHOLD, K_MAX_COMPOSITES, MG_N_EXPERTS, MG_CORE_RANK, MG_SHARED_INTER, MG_ACT_ITERS, MG_WORKSPACE_DIM, BYTEHEAD_ACT_MAX_ITERS, BYTEHEAD_ACT_HALT_CONSECUTIVE
16
+
17
+ _ceil_div = lambda a, b: _ceil(a / b) if b > 0 else 0
18
+
19
+ from .sequencers import ByteEmbedding
20
+
21
+
22
+ @dataclass
23
+ class LossWeights:
24
+ lm: float = 1.0
25
+ vq_commitment: float = 1.0
26
+ moe_aux: float = 1.0
27
+ graph_l1: float = 0.001
28
+ graph_ponder: float = 1.0
29
+ moe_ponder: float = 1.0
30
+ moegraph_ponder: float = 1.0
31
+ memgram_decay_reg: float = 0.01
32
+ composite_vq: float = 1.0
33
+
34
+
35
+ @dataclass
36
+ class LossComponents:
37
+ lm: torch.Tensor = None
38
+ vq_commitment: torch.Tensor = None
39
+ moe_aux: torch.Tensor = None
40
+ graph_l1: torch.Tensor = None
41
+ graph_ponder: torch.Tensor = None
42
+ moe_ponder: torch.Tensor = None
43
+ moegraph_ponder: torch.Tensor = None
44
+ memgram_decay_reg: torch.Tensor = None
45
+ composite_vq: torch.Tensor = None
46
+ weights: LossWeights = field(default_factory=LossWeights)
47
+
48
+ @property
49
+ def total(self) -> torch.Tensor:
50
+ w = self.weights
51
+ loss = None
52
+
53
+ def add_component(current, weight, component):
54
+ if component is None:
55
+ return current
56
+ weighted = weight * component
57
+ return weighted if current is None else current + weighted
58
+
59
+ loss = add_component(loss, w.lm, self.lm)
60
+ loss = add_component(loss, w.vq_commitment, self.vq_commitment)
61
+ loss = add_component(loss, w.moe_aux, self.moe_aux)
62
+ loss = add_component(loss, w.graph_l1, self.graph_l1)
63
+ loss = add_component(loss, w.graph_ponder, self.graph_ponder)
64
+ loss = add_component(loss, w.moe_ponder, self.moe_ponder)
65
+ loss = add_component(loss, w.moegraph_ponder, self.moegraph_ponder)
66
+ loss = add_component(loss, w.memgram_decay_reg, self.memgram_decay_reg)
67
+ loss = add_component(loss, w.composite_vq, self.composite_vq)
68
+ if loss is None:
69
+ raise ValueError("LossComponents.total requested with no active loss tensors")
70
+ return loss
71
+
72
+ @property
73
+ def active_fields(self) -> list[tuple[str, torch.Tensor, float]]:
74
+ result = []
75
+ for field in fields(self):
76
+ name = field.name
77
+ if name == 'weights':
78
+ continue
79
+ tensor = getattr(self, name)
80
+ if tensor is not None:
81
+ weight = getattr(self.weights, name)
82
+ result.append((name, tensor, weight))
83
+ return result
84
+
85
+ def log(self, writer, step, prefix="loss"):
86
+ writer.add_scalar(f"{prefix}/total", self.total.item(), step)
87
+ if self.lm is not None:
88
+ writer.add_scalar(f"{prefix}/lm", self.lm.item(), step)
89
+ if self.vq_commitment is not None:
90
+ writer.add_scalar(f"{prefix}/vq_commitment", self.vq_commitment.item(), step)
91
+ if self.moe_aux is not None:
92
+ writer.add_scalar(f"{prefix}/moe_aux", self.moe_aux.item(), step)
93
+ if self.graph_l1 is not None:
94
+ writer.add_scalar(f"{prefix}/graph_l1", self.graph_l1.item(), step)
95
+ if self.graph_ponder is not None:
96
+ writer.add_scalar(f"{prefix}/graph_ponder", self.graph_ponder.item(), step)
97
+ if self.moe_ponder is not None:
98
+ writer.add_scalar(f"{prefix}/moe_ponder", self.moe_ponder.item(), step)
99
+ if self.moegraph_ponder is not None:
100
+ writer.add_scalar(f"{prefix}/moegraph_ponder", self.moegraph_ponder.item(), step)
101
+ if self.memgram_decay_reg is not None:
102
+ writer.add_scalar(f"{prefix}/memgram_decay_reg", self.memgram_decay_reg.item(), step)
103
+ if self.composite_vq is not None:
104
+ writer.add_scalar(f"{prefix}/composite_vq", self.composite_vq.item(), step)
105
+
106
+ def backward(self, retain_graph=False):
107
+ self.total.backward(retain_graph=retain_graph)
108
+
109
+
110
+ class StickyZoneSTE(torch.autograd.Function):
111
+ @staticmethod
112
+ def forward(ctx, w, threshold):
113
+ ctx.save_for_backward(w, torch.tensor(threshold))
114
+ return w.sign() * (w.abs() > threshold).to(w.dtype)
115
+
116
+ @staticmethod
117
+ def backward(ctx, grad_output):
118
+ w, threshold_tensor = ctx.saved_tensors
119
+ threshold = threshold_tensor.item()
120
+ ratio = torch.clamp(w.abs() / threshold, 0.0, 1.0)
121
+ return grad_output * ratio, None
122
+
123
+
124
+ class TernaryEmbeddingTable(nn.Module):
125
+ def __init__(self, num_embeddings, embedding_dim, tscale_type=TScaleType.T32,
126
+ init_std=0.02, threshold=0.05, normalize=False):
127
+ super().__init__()
128
+ self.num_embeddings = num_embeddings
129
+ self.embedding_dim = embedding_dim
130
+ self.tscale_type = tscale_type
131
+ init_threshold = min(float(threshold), 0.5 * float(init_std)) if init_std > 0 else threshold
132
+ self.threshold = init_threshold
133
+ self.normalize = normalize
134
+ self.group_size = GROUP_SIZES.get(tscale_type, GROUP_SIZES[TScaleType.T64])
135
+ self.sparse_threshold = 65_536
136
+
137
+ if num_embeddings >= self.sparse_threshold:
138
+ n_trits = num_embeddings * embedding_dim
139
+ n_packed = _ceil_div(n_trits, 5)
140
+ packed_T = torch.randint(0, 243, (n_packed,), dtype=torch.uint8)
141
+ T_pad = n_packed * 5 - n_trits
142
+ gpr = _ceil_div(embedding_dim, self.group_size)
143
+ init_exp = int(round(_log2(max(init_std, 1e-8))))
144
+ self.register_buffer("T_packed", packed_T)
145
+ self.register_buffer("_T_shape", torch.tensor([num_embeddings, embedding_dim], dtype=torch.long))
146
+ self.register_buffer("_T_pad", torch.tensor(T_pad, dtype=torch.long))
147
+ self.register_buffer(
148
+ "E",
149
+ torch.full((num_embeddings * gpr,), init_exp, dtype=torch.int8),
150
+ )
151
+ self.register_buffer("E_accum", torch.zeros_like(self.E, dtype=torch.int8))
152
+ self.register_buffer("T_accum", torch.zeros(num_embeddings, embedding_dim, dtype=torch.int8))
153
+ self._ema_alpha: float = 0.1
154
+ self._loss_temp_scale: float = 1.0
155
+ return
156
+
157
+ w_init = torch.randn(num_embeddings, embedding_dim) * init_std
158
+ T_init = w_init.sign() * (w_init.abs() > init_threshold).to(w_init.dtype)
159
+ packed_T, _, T_pad = pack_ternary(T_init)
160
+ self.register_buffer("T_packed", packed_T)
161
+ self.register_buffer("_T_shape", torch.tensor([num_embeddings, embedding_dim], dtype=torch.long))
162
+ self.register_buffer("_T_pad", torch.tensor(T_pad, dtype=torch.long))
163
+
164
+ gpr = _ceil_div(embedding_dim, self.group_size)
165
+ total_in = gpr * self.group_size
166
+ padded = torch.zeros(num_embeddings, total_in)
167
+ padded[:, :embedding_dim] = w_init.abs()
168
+ grouped = padded.view(num_embeddings, gpr, self.group_size)
169
+ E_vals = torch.where(grouped.mean(dim=2) > 0, grouped.mean(dim=2), torch.ones(num_embeddings, gpr))
170
+ self.register_buffer("E", E_vals.flatten().log2().clamp(-128, 127).to(torch.int8))
171
+ self.register_buffer("E_accum", torch.zeros_like(self.E, dtype=torch.int8))
172
+ self.register_buffer("T_accum", torch.zeros(num_embeddings, embedding_dim, dtype=torch.int8))
173
+ self._ema_alpha: float = 0.1
174
+ self._loss_temp_scale: float = 1.0
175
+
176
+ def _get_T(self):
177
+ return unpack_ternary(self.T_packed, tuple(self._T_shape.tolist()), int(self._T_pad.item()))
178
+
179
+ def _get_T_rows(self, indices):
180
+ indices = indices.reshape(-1).to(device=self.T_packed.device, dtype=torch.long)
181
+ dim = self.embedding_dim
182
+ cols = torch.arange(dim, device=indices.device, dtype=torch.long)
183
+ lin = indices[:, None] * dim + cols[None, :]
184
+ pack_idx = lin // 5
185
+ trit_pos = lin - pack_idx * 5
186
+ packed = self.T_packed[pack_idx].to(torch.long)
187
+ divisors = torch.tensor([1, 3, 9, 27, 81], device=indices.device, dtype=torch.long)
188
+ code = (packed // divisors[trit_pos]) % 3
189
+ return (code.to(torch.int8) - 1)
190
+
191
+ def _expand_E_rows(self, indices):
192
+ indices = indices.reshape(-1).to(device=self.E.device, dtype=torch.long)
193
+ gpr = _ceil_div(self.embedding_dim, self.group_size)
194
+ E_rows = self.E.view(self.num_embeddings, gpr)[indices]
195
+ E_exp = E_rows.repeat_interleave(self.group_size, dim=1)
196
+ return E_exp[:, :self.embedding_dim]
197
+
198
+ @torch.no_grad()
199
+ def _set_T_rows(self, row_indices, rows):
200
+ row_indices = row_indices.reshape(-1).to(device=self.T_packed.device, dtype=torch.long)
201
+ rows = rows.to(device=self.T_packed.device, dtype=torch.int8).reshape(row_indices.numel(), self.embedding_dim)
202
+ divisors = [1, 3, 9, 27, 81]
203
+ for row_pos, row_idx in enumerate(row_indices.tolist()):
204
+ row = rows[row_pos]
205
+ for col in range(self.embedding_dim):
206
+ lin = row_idx * self.embedding_dim + col
207
+ pack_idx = lin // 5
208
+ trit_pos = lin - pack_idx * 5
209
+ divisor = divisors[trit_pos]
210
+ old = int(self.T_packed[pack_idx].item())
211
+ old_code = (old // divisor) % 3
212
+ new_code = int(row[col].item()) + 1
213
+ if old_code != new_code:
214
+ self.T_packed[pack_idx] = old - old_code * divisor + new_code * divisor
215
+
216
+ def _expand_E(self):
217
+ out_dim, in_dim = tuple(self._T_shape.tolist())
218
+ gpr = _ceil_div(in_dim, self.group_size)
219
+ E_2d = self.E.view(out_dim, gpr)
220
+ E_exp = E_2d.repeat_interleave(self.group_size, dim=1)
221
+ return E_exp[:, :in_dim]
222
+
223
+ def _ensure_E_accum(self):
224
+ if not hasattr(self, "E_accum"):
225
+ self.register_buffer("E_accum", torch.zeros_like(self.E, dtype=torch.int8))
226
+ elif self.E_accum.shape != self.E.shape or self.E_accum.device != self.E.device:
227
+ self.E_accum = torch.zeros_like(self.E, dtype=torch.int8)
228
+ return self.E_accum
229
+
230
+ def forward(self, indices):
231
+ use_sparse = self.num_embeddings >= self.sparse_threshold
232
+ if use_sparse:
233
+ idx_flat = indices.reshape(-1).to(device=self.T_packed.device, dtype=torch.long)
234
+ T_rows = self._get_T_rows(idx_flat)
235
+ E_exp = self._expand_E_rows(idx_flat)
236
+ w_eff = torch.exp2(E_exp.float()) * T_rows.float()
237
+ w_eff_grad = w_eff.detach().requires_grad_(torch.is_grad_enabled())
238
+ if torch.is_grad_enabled():
239
+ comp_name, _ = _COMPONENT_CONTEXT.get()
240
+ def capture_sparse_grad(grad):
241
+ suffix = f"_{comp_name}" if comp_name is not None else ""
242
+ setattr(self, f"_hook_sparse_indices{suffix}", idx_flat.detach())
243
+ setattr(self, f"_hook_sparse_grad_sign{suffix}", grad.reshape(-1, self.embedding_dim).sign().to(torch.int8).detach())
244
+ setattr(self, f"_hook_sparse_T{suffix}", T_rows.detach())
245
+ w_eff_grad.register_hook(capture_sparse_grad)
246
+ out = w_eff_grad.reshape(*indices.shape, self.embedding_dim)
247
+ return F.normalize(out, dim=-1) if self.normalize else out
248
+ if indices.is_cuda and _HAS_TRITON and _TritonTernaryEmbedFn is not None:
249
+ dummy = torch.zeros(1, device=indices.device, requires_grad=True)
250
+ out = _TritonTernaryEmbedFn.apply(indices, dummy, self)
251
+ else:
252
+ T = self._get_T()
253
+ w_eff = torch.exp2(self._expand_E().float()) * T.float()
254
+ w_eff_grad = w_eff.detach().requires_grad_(True)
255
+ self._hook_T = T
256
+ def capture_w_grad(grad_w):
257
+ self._hook_grad_T_sign = grad_w.sign().to(torch.int8)
258
+ w_eff_grad.register_hook(capture_w_grad)
259
+ out = F.embedding(indices, w_eff_grad)
260
+ return F.normalize(out, dim=-1) if self.normalize else out
261
+
262
+ def ternary_step(self, accum_threshold=3):
263
+ if hasattr(self, "_hook_sparse_indices") and hasattr(self, "_hook_sparse_grad_sign"):
264
+ return self._sparse_ternary_step(accum_threshold=accum_threshold)
265
+ if hasattr(self, "_hook_grad_T_sign"):
266
+ if hasattr(self, "_accumulate_corr_from_grad_sign"):
267
+ self._accumulate_corr_from_grad_sign(self._hook_grad_T_sign)
268
+ del self._hook_grad_T_sign
269
+
270
+ def update_E(self, loss_signal=None):
271
+ if hasattr(self, "_hook_sparse_indices") and hasattr(self, "_hook_sparse_grad_sign"):
272
+ return self._sparse_update_E(loss_signal=loss_signal)
273
+
274
+ @torch.no_grad()
275
+ def _sparse_ternary_step(self, accum_threshold=3):
276
+ indices = self._hook_sparse_indices.to(device=self.T_accum.device, dtype=torch.long)
277
+ grad_sign = self._hook_sparse_grad_sign.to(device=self.T_accum.device, dtype=torch.int16)
278
+ if indices.numel() == 0:
279
+ return
280
+ unique, inverse = torch.unique(indices, return_inverse=True)
281
+ grad_sum = torch.zeros(unique.numel(), self.embedding_dim, device=self.T_accum.device, dtype=torch.int16)
282
+ grad_sum.index_add_(0, inverse, grad_sign)
283
+ grad_step = grad_sum.sign().to(torch.int16) * int(getattr(self, "_t_accum_step", 1))
284
+ current = self.T_accum[unique].to(torch.int16)
285
+ updated = torch.clamp(current - grad_step, -128, 127).to(torch.int8)
286
+
287
+ pgt = getattr(self, "per_group_threshold", None)
288
+ if pgt is not None:
289
+ gpr = _ceil_div(self.embedding_dim, self.group_size)
290
+ threshold = pgt.view(self.num_embeddings, gpr)[unique]
291
+ threshold = threshold.unsqueeze(-1).expand(unique.numel(), gpr, self.group_size)
292
+ threshold = threshold.reshape(unique.numel(), gpr * self.group_size)[:, :self.embedding_dim]
293
+ threshold = threshold.to(updated.device)
294
+ flip_up = updated > threshold
295
+ flip_down = updated < -threshold
296
+ else:
297
+ flip_up = updated > accum_threshold
298
+ flip_down = updated < -accum_threshold
299
+ self._had_flip = bool((flip_up | flip_down).any().item())
300
+ if self._had_flip:
301
+ rows = self._get_T_rows(unique).to(updated.device)
302
+ rows = torch.where(flip_up, torch.ones_like(rows), torch.where(flip_down, -torch.ones_like(rows), rows))
303
+ self._set_T_rows(unique, rows)
304
+ updated = torch.where(flip_up | flip_down, torch.zeros_like(updated), updated)
305
+ self.T_accum[unique] = updated
306
+ del self._hook_sparse_indices
307
+ del self._hook_sparse_grad_sign
308
+ if hasattr(self, "_hook_sparse_T"):
309
+ del self._hook_sparse_T
310
+
311
+ @torch.no_grad()
312
+ def _sparse_update_E(self, loss_signal=None):
313
+ indices = self._hook_sparse_indices.to(device=self.E.device, dtype=torch.long)
314
+ grad_sign = self._hook_sparse_grad_sign.to(device=self.E.device, dtype=torch.int16)
315
+ T_rows = self._hook_sparse_T if hasattr(self, "_hook_sparse_T") else self._get_T_rows(indices)
316
+ T_rows = T_rows.to(device=self.E.device, dtype=torch.int16)
317
+ if indices.numel() == 0:
318
+ return
319
+ unique, inverse = torch.unique(indices, return_inverse=True)
320
+ gpr = _ceil_div(self.embedding_dim, self.group_size)
321
+ total_in = gpr * self.group_size
322
+ signed = grad_sign * T_rows
323
+ grouped = F.pad(signed, (0, total_in - self.embedding_dim)).view(indices.numel(), gpr, self.group_size)
324
+ score = grouped.sum(dim=2)
325
+ delta = torch.where(
326
+ score > 0,
327
+ torch.full_like(score, -1, dtype=torch.int16),
328
+ torch.where(score < 0, torch.ones_like(score, dtype=torch.int16), torch.zeros_like(score, dtype=torch.int16)),
329
+ )
330
+ delta_sum = torch.zeros(unique.numel(), gpr, device=self.E.device, dtype=torch.int16)
331
+ delta_sum.index_add_(0, inverse, delta)
332
+ delta_sign = delta_sum.sign()
333
+ e_idx = unique[:, None] * gpr + torch.arange(gpr, device=self.E.device, dtype=torch.long)[None, :]
334
+ accum = torch.clamp(self.E_accum[e_idx].to(torch.int16) + delta_sign, -128, 127)
335
+ threshold = int(getattr(self, "_e_accum_threshold", 4))
336
+ step = torch.where(
337
+ accum >= threshold,
338
+ torch.ones_like(accum, dtype=torch.int16),
339
+ torch.where(accum <= -threshold, torch.full_like(accum, -1, dtype=torch.int16), torch.zeros_like(accum, dtype=torch.int16)),
340
+ )
341
+ self.E[e_idx] = torch.clamp(self.E[e_idx].to(torch.int16) + step, -128, 127).to(torch.int8)
342
+ self.E_accum[e_idx] = (accum - step * threshold).to(torch.int8)
343
+
344
+
345
+
346
+ class TernaryVQCodebook(nn.Module):
347
+ def __init__(self, codebook_size, codebook_dim, commitment_weight=1.0,
348
+ tscale_type=TScaleType.T32, exact_lookup_max=16384,
349
+ candidate_count=256):
350
+ super().__init__()
351
+ self.codebook_size = codebook_size
352
+ self.codebook_dim = codebook_dim
353
+ self.commitment_weight = commitment_weight
354
+ self.exact_lookup_max = exact_lookup_max
355
+ self.candidate_count = candidate_count
356
+ self.threshold_ema_dead_code = 2
357
+ self.table = TernaryEmbeddingTable(codebook_size, codebook_dim, tscale_type=tscale_type, normalize=True)
358
+ self.register_buffer("cluster_size", torch.zeros(codebook_size, dtype=torch.int16))
359
+
360
+ @property
361
+ def embed(self):
362
+ idx = torch.arange(self.codebook_size, device=self.table.T_packed.device)
363
+ return self.table(idx)
364
+
365
+ def _candidate_ids(self, flat):
366
+ c = min(self.candidate_count, self.codebook_size)
367
+ take = min(flat.shape[1], 16)
368
+ primes = torch.tensor(
369
+ [1009, 9176, 6361, 5333, 4447, 3469, 2531, 1613,
370
+ 811, 421, 211, 109, 59, 31, 17, 7],
371
+ device=flat.device, dtype=torch.float32,
372
+ )[:take]
373
+ signed = torch.sign(flat[:, :take].float())
374
+ base = torch.abs(torch.round((signed * primes).sum(dim=1) * 104729)).to(torch.long)
375
+ offsets = torch.arange(c, device=flat.device, dtype=torch.long)
376
+ stride = 2_654_435_761
377
+ return (base[:, None] + offsets[None, :] * stride) % self.codebook_size
378
+
379
+ def _lookup(self, flat):
380
+ if self.codebook_size <= self.exact_lookup_max:
381
+ x_norm = F.normalize(flat.float(), dim=-1)
382
+ codebook = self.embed.to(device=flat.device)
383
+ sim = x_norm @ codebook.T
384
+ indices = sim.argmax(dim=-1)
385
+ quantized = codebook[indices]
386
+ return quantized, indices
387
+
388
+ candidate_ids = self._candidate_ids(flat)
389
+ x_norm = F.normalize(flat.float(), dim=-1)
390
+ n, c, d = flat.shape[0], candidate_ids.shape[1], flat.shape[1]
391
+ chunk = 64
392
+ quantized = torch.empty_like(flat)
393
+ indices = torch.empty(n, dtype=torch.long, device=flat.device)
394
+ for start in range(0, n, chunk):
395
+ end = min(start + chunk, n)
396
+ chunk_ids = candidate_ids[start:end]
397
+ chunk_vecs = self.table(chunk_ids).float()
398
+ chunk_norm = F.normalize(chunk_vecs, dim=-1)
399
+ chunk_sim = (chunk_norm * x_norm[start:end].unsqueeze(1)).sum(dim=-1)
400
+ chunk_best = chunk_sim.argmax(dim=-1)
401
+ indices[start:end] = candidate_ids[start:end].gather(1, chunk_best.unsqueeze(1)).squeeze(1)
402
+ quantized[start:end] = chunk_vecs[torch.arange(end - start, device=flat.device), chunk_best]
403
+ return quantized, indices
404
+
405
+ def forward(self, x):
406
+ orig_shape = x.shape
407
+ flat = x.reshape(-1, self.codebook_dim)
408
+ quantized, indices = self._lookup(flat)
409
+ commitment = self.commitment_weight * (
410
+ F.mse_loss(flat.float(), quantized.detach().float())
411
+ + 0.25 * F.mse_loss(quantized.float(), flat.detach().float())
412
+ )
413
+ quantized = flat + (quantized - flat).detach()
414
+ with torch.no_grad():
415
+ unique, counts = torch.unique(indices, return_counts=True)
416
+ current = self.cluster_size[unique].to(torch.int32)
417
+ updated = torch.clamp(current + counts.to(device=current.device, dtype=torch.int32), 0, 32767).to(torch.int16)
418
+ self.cluster_size[unique] = updated
419
+ return quantized.reshape(orig_shape), indices.reshape(orig_shape[:-1]), commitment
420
+
421
+
422
+ class GNNLoRAAdapter(nn.Module):
423
+ def __init__(self, dim, rank=32, max_hops=4):
424
+ super().__init__()
425
+ self.max_hops = max_hops
426
+ self.down = TernaryScaleTensor(dim, rank, tscale_type=TScaleType.T32)
427
+ self.up = TernaryScaleTensor(rank, dim, tscale_type=TScaleType.T32)
428
+ self.scale = TernaryEmbeddingTable(max_hops, rank, tscale_type=TScaleType.T32)
429
+
430
+ def forward(self, x, hop_t):
431
+ t_idx = min(hop_t, self.max_hops - 1)
432
+ s = self.scale(torch.tensor(t_idx, device=x.device))
433
+ return self.up(self.down(x) * s)
434
+
435
+
436
+ class HaltingUnit(nn.Module):
437
+ def __init__(self, dim, tscale_type=TScaleType.T32):
438
+ super().__init__()
439
+ self.proj = TernaryScaleTensor(dim, 1, tscale_type=tscale_type)
440
+ self.norm = TernaryRMSNorm(dim, tscale_type=tscale_type)
441
+
442
+ def forward(self, x):
443
+ return torch.sigmoid(self.proj(self.norm(x)))
444
+
445
+
446
+ class _NgramHashMapping:
447
+ """N-gram hash mapping with CPU offloading (Spider Engram style).
448
+
449
+ Hashes token sequences to fixed-size embedding indices. Hash computation
450
+ runs on CPU via numpy, O(1) per token via precomputed multipliers.
451
+ """
452
+
453
+ def __init__(self, max_ngram_size, num_heads, table_size_base, layer_seed=0):
454
+ self.max_ngram_size = max_ngram_size
455
+ self.num_heads = num_heads
456
+ self.num_ngram_orders = max_ngram_size - 1
457
+
458
+ import numpy as np
459
+ PRIME_1 = 10007
460
+ g = torch.Generator()
461
+ g.manual_seed(int(layer_seed + PRIME_1 * int(layer_seed)))
462
+ r = torch.randint(0, 1 << 30, (max_ngram_size,), generator=g, dtype=torch.int64)
463
+ self.multipliers = r.numpy() * 2 + 1
464
+
465
+ seen_primes = set()
466
+ self.prime_table_sizes = []
467
+ for _ in range(self.num_ngram_orders):
468
+ head_sizes = []
469
+ ps = table_size_base - 1
470
+ for _ in range(num_heads):
471
+ p = self._next_prime(ps, seen_primes)
472
+ seen_primes.add(p)
473
+ head_sizes.append(p)
474
+ ps = p
475
+ self.prime_table_sizes.append(head_sizes)
476
+
477
+ self.all_head_sizes = [s for sub in self.prime_table_sizes for s in sub]
478
+ offsets = [0]
479
+ for s in self.all_head_sizes[:-1]:
480
+ offsets.append(offsets[-1] + s)
481
+ self.offsets_arr = offsets
482
+ self.total_slots = sum(self.all_head_sizes)
483
+
484
+ @staticmethod
485
+ def _next_prime(n, seen):
486
+ while n in seen or not _is_prime(n):
487
+ n -= 1
488
+ return n
489
+
490
+ def compute_hashes(self, token_ids):
491
+ import numpy as np
492
+ x = token_ids.cpu().numpy().astype(np.int64)
493
+ B, T = x.shape
494
+
495
+ shifts = [x]
496
+ for k in range(1, self.max_ngram_size):
497
+ shifts.append(np.pad(x, ((0, 0), (k, 0)), constant_values=0)[:, :T])
498
+
499
+ all_hashes = []
500
+ for order_idx in range(self.num_ngram_orders):
501
+ n = order_idx + 2
502
+ mix = shifts[0] * self.multipliers[0]
503
+ for k in range(1, n):
504
+ mix = np.bitwise_xor(mix, shifts[k].astype(np.int64) * self.multipliers[k])
505
+ for j, ms in enumerate(self.prime_table_sizes[order_idx]):
506
+ all_hashes.append((mix % ms).astype(np.int64, copy=False))
507
+
508
+ result = np.stack(all_hashes, axis=2)
509
+ return torch.from_numpy(result).to(device=token_ids.device)
510
+
511
+
512
+ def _is_prime(n):
513
+ if n < 2:
514
+ return False
515
+ import math
516
+ for i in range(2, int(math.sqrt(n)) + 1):
517
+ if n % i == 0:
518
+ return False
519
+ return True
520
+
521
+
522
+ class MemGram(nn.Module):
523
+ """Engram-style associative memory with O(1) hashed lookup (CPU offloaded).
524
+
525
+ Features:
526
+ - O(1) hash -> index -> embedding lookup (no search, no decay for retrieval)
527
+ - CPU-offloaded hash computation (numpy)
528
+ - Single offset-stacked embedding table (not per-head tables)
529
+ - Gated retrieval: sigmoid(Q*K/sqrt(d)) gates the memory read
530
+ - Depthwise conv1d processes retrieved memory (Engram-style)
531
+ - No strength/decay buffers (decay is handled by GraphMoE usage frequency)
532
+ - MemGram lookups do NOT affect KG decaying (separate mechanisms)
533
+ """
534
+
535
+ def __init__(self, struct_primes=[64901, 64919, 64921, 64927, 64937, 64951, 64969, 64997,
536
+ 65003, 65011, 65027, 65029, 65033, 65053, 65063, 65071],
537
+ conv_primes=[8009, 8011, 8017, 8039],
538
+ embed_dim=64, hidden_dim=HIDDEN_DIM, key_dim=32,
539
+ max_ngram_size=3, num_hash_heads=4, layer_seed=0):
540
+ super().__init__()
541
+ self.embed_dim = embed_dim
542
+ self.key_dim = key_dim
543
+ self.hidden_dim = hidden_dim
544
+ self.n_struct_heads = len(struct_primes)
545
+ self.n_conv_heads = len(conv_primes)
546
+
547
+ self.struct_hash = _NgramHashMapping(
548
+ max_ngram_size=max_ngram_size, num_heads=num_hash_heads,
549
+ table_size_base=struct_primes[0], layer_seed=layer_seed,
550
+ )
551
+ self.conv_hash = _NgramHashMapping(
552
+ max_ngram_size=max_ngram_size, num_heads=num_hash_heads,
553
+ table_size_base=conv_primes[0], layer_seed=layer_seed + 1000,
554
+ )
555
+
556
+ total_heads = self.struct_hash.num_ngram_orders * num_hash_heads
557
+ self.total_mem_dim = total_heads * embed_dim
558
+
559
+ total_slots = self.struct_hash.total_slots + self.conv_hash.total_slots
560
+ self.mem_embed = nn.Embedding(total_slots, embed_dim)
561
+
562
+ self.k_proj = nn.Linear(self.total_mem_dim, key_dim, bias=False)
563
+ self.q_proj = nn.Linear(hidden_dim, key_dim, bias=False)
564
+ self.v_proj = nn.Linear(self.total_mem_dim, hidden_dim, bias=False)
565
+
566
+ with torch.no_grad():
567
+ self.v_proj.weight.zero_()
568
+
569
+ self.conv_norm = nn.RMSNorm(hidden_dim)
570
+ self.conv = nn.Conv1d(
571
+ hidden_dim, hidden_dim,
572
+ kernel_size=4, padding=9, dilation=3, groups=hidden_dim,
573
+ )
574
+ with torch.no_grad():
575
+ self.conv.weight.zero_()
576
+ if self.conv.bias is not None:
577
+ self.conv.bias.zero_()
578
+
579
+ def _retrieve(self, token_ids, hash_mapping):
580
+ hash_ids = hash_mapping.compute_hashes(token_ids)
581
+ B, T, H = hash_ids.shape
582
+ flat_ids = hash_ids.reshape(B * T, H)
583
+ offsets = torch.tensor(hash_mapping.offsets_arr, device=flat_ids.device, dtype=torch.long)
584
+ emb = self.mem_embed(flat_ids + offsets)
585
+ return emb.reshape(B, T, H * self.embed_dim)
586
+
587
+ def forward(self, vq_indices, hidden_state):
588
+ B, T, D = hidden_state.shape
589
+
590
+ struct_mem = self._retrieve(vq_indices[:, 1:], self.struct_hash)
591
+ conv_mem = self._retrieve(vq_indices[:, 1:], self.conv_hash)
592
+ mem = struct_mem + conv_mem
593
+
594
+ idx_end = mem.shape[1]
595
+ q_proj = self.q_proj(hidden_state[:, :idx_end])
596
+ k = self.k_proj(mem)
597
+ v = self.v_proj(mem)
598
+ gate = torch.sigmoid((q_proj * k).sum(dim=-1, keepdim=True) / (self.key_dim ** 0.5))
599
+ v_gated = gate * v
600
+
601
+ v_normed = self.conv_norm(v_gated)
602
+ v_t = v_normed.transpose(1, 2)
603
+ conv_out = self.conv(v_t)
604
+ conv_out = conv_out[:, :, :v_t.shape[-1]].transpose(1, 2)
605
+ output = hidden_state[:, :idx_end] + F.silu(conv_out) + v_gated
606
+
607
+ if idx_end < T:
608
+ output = F.pad(output, (0, 0, 0, T - idx_end))
609
+ return output
610
+
611
+ def retrieve_cb(self, vq_indices):
612
+ B, T = vq_indices.shape
613
+ struct_mem = self._retrieve(vq_indices[:, 1:], self.struct_hash)
614
+ conv_mem = self._retrieve(vq_indices[:, 1:], self.conv_hash)
615
+ mem = struct_mem + conv_mem
616
+ idx_end = mem.shape[1]
617
+ pad = torch.zeros(B, T - idx_end, mem.shape[2], device=mem.device)
618
+ mem = torch.cat([mem, pad], dim=1)
619
+ q = mem.mean(dim=-1, keepdim=True)
620
+ gate = torch.sigmoid(q)
621
+ return gate * mem
622
+
623
+
624
+ _BOUNDARY_TOKEN_MAP = {
625
+ SPECIAL_VOCAB['BOS']: 0,
626
+ SPECIAL_VOCAB['SYSTEM']: 1,
627
+ SPECIAL_VOCAB['USER']: 2,
628
+ SPECIAL_VOCAB['ASSISTANT']: 3,
629
+ }
630
+
631
+
632
+ class LTIInjection(nn.Module):
633
+ """LTI state injection: h = A*h + B*e + trans_out.
634
+
635
+ Spectral radius < 1 guaranteed by construction via ZOH discretization.
636
+ Prevents divergence in recurrent/ACT loops at high dimensions.
637
+ """
638
+ def __init__(self, dim: int):
639
+ super().__init__()
640
+ self.log_A = nn.Parameter(torch.zeros(dim))
641
+ self.log_dt = nn.Parameter(torch.zeros(1))
642
+ self.B = nn.Parameter(torch.ones(dim) * 0.1)
643
+ for p in (self.log_A, self.log_dt, self.B):
644
+ p.requires_grad_(False)
645
+
646
+ def get_A(self):
647
+ return torch.exp(-torch.exp((self.log_dt + self.log_A).clamp(-20, 20)))
648
+
649
+ def forward(self, h, e, trans_out):
650
+ return self.get_A() * h + self.B * e + trans_out
651
+
652
+
653
+ class ByteHead(nn.Module):
654
+ """Deep 3-layer MLP byte prediction head with ACT loop.
655
+
656
+ Architecture: 8192 β†’ 16384 β†’ 8192 β†’ 16384 β†’ 288
657
+ ACT: up to 3 iterations, halts when argmax stable for 2 consecutive steps.
658
+ """
659
+ def __init__(self, tscale_type=TScaleType.T32,
660
+ act_max_iters=BYTEHEAD_ACT_MAX_ITERS,
661
+ act_halt_consecutive=BYTEHEAD_ACT_HALT_CONSECUTIVE):
662
+ super().__init__()
663
+ H = HIDDEN_DIM
664
+ W = HIDDEN_DIM * 2
665
+ self.act_max_iters = act_max_iters
666
+ self.act_halt_consecutive = act_halt_consecutive
667
+ self._last_ponder = 0.0
668
+
669
+ self.norm = TernaryRMSNorm(H, tscale_type=tscale_type)
670
+ self.up = TernaryScaleTensor(H, W, tscale_type=tscale_type)
671
+ self.up_norm = TernaryRMSNorm(W, tscale_type=tscale_type)
672
+ self.hidden = TernaryScaleTensor(W, H, tscale_type=tscale_type)
673
+ self.hidden_norm = TernaryRMSNorm(H, tscale_type=tscale_type)
674
+ self.out = TernaryScaleTensor(H, W, tscale_type=tscale_type)
675
+ self.out_norm = TernaryRMSNorm(W, tscale_type=tscale_type)
676
+ self.head = TernaryScaleTensor(W, VOCAB, tscale_type=tscale_type)
677
+
678
+ if act_max_iters > 1:
679
+ self.act_residual = TernaryScaleTensor(VOCAB, H, tscale_type=tscale_type)
680
+ self.lti = LTIInjection(H)
681
+ else:
682
+ self.act_residual = None
683
+ self.lti = None
684
+
685
+ def forward(self, x):
686
+ if self.act_max_iters <= 1 or self.act_residual is None:
687
+ hn = F.silu(self.up(self.norm(x)))
688
+ hn = F.silu(self.hidden(self.up_norm(hn)))
689
+ hn = F.silu(self.out(self.hidden_norm(hn)))
690
+ return self.head(self.out_norm(hn))
691
+
692
+ h = x
693
+ x_initial = x
694
+ prev_argmax = None
695
+ stable_count = 0
696
+ total_iters = 0
697
+
698
+ for i in range(self.act_max_iters):
699
+ hn = F.silu(self.up(self.norm(h)))
700
+ hn = F.silu(self.hidden(self.up_norm(hn)))
701
+ hn = F.silu(self.out(self.hidden_norm(hn)))
702
+ logits = self.head(self.out_norm(hn))
703
+
704
+ curr_argmax = logits.argmax(dim=-1)
705
+ if prev_argmax is not None and (curr_argmax == prev_argmax).all():
706
+ stable_count += 1
707
+ else:
708
+ stable_count = 0
709
+
710
+ total_iters = i + 1
711
+ if stable_count >= self.act_halt_consecutive:
712
+ break
713
+
714
+ prev_argmax = curr_argmax
715
+ trans_out = self.act_residual(logits)
716
+ h = self.lti(h, x_initial, trans_out)
717
+
718
+ self._last_ponder = total_iters / max(self.act_max_iters, 1)
719
+ return logits
720
+
721
+
722
+ class OutputRouter(nn.Module):
723
+ """Routes HIDDEN_DIM relational tokens to ByteHead, VideoHead, or TalkerHead.
724
+
725
+ 3-layer MLP when depth=3, 2-layer when depth=2, single projection when depth=1.
726
+ Argmax at inference, soft weighted routing at training.
727
+ """
728
+ def __init__(self, tscale_type=TScaleType.T32, depth=3):
729
+ super().__init__()
730
+ if depth >= 3:
731
+ self.hidden1 = TernaryScaleTensor(HIDDEN_DIM, HIDDEN_DIM, tscale_type=tscale_type)
732
+ self.hidden1_norm = TernaryRMSNorm(HIDDEN_DIM, tscale_type=tscale_type)
733
+ self.hidden2 = TernaryScaleTensor(HIDDEN_DIM, HIDDEN_DIM // 4, tscale_type=tscale_type)
734
+ self.gate = TernaryScaleTensor(HIDDEN_DIM // 4, 4, tscale_type=tscale_type)
735
+ elif depth == 2:
736
+ self.hidden1 = None
737
+ self.hidden1_norm = None
738
+ self.hidden2 = TernaryScaleTensor(HIDDEN_DIM, HIDDEN_DIM // 4, tscale_type=tscale_type)
739
+ self.gate = TernaryScaleTensor(HIDDEN_DIM // 4, 4, tscale_type=tscale_type)
740
+ else:
741
+ self.hidden1 = None
742
+ self.hidden1_norm = None
743
+ self.hidden2 = None
744
+ self.gate = TernaryScaleTensor(HIDDEN_DIM, 4, tscale_type=tscale_type)
745
+ # 0 = Null (continue), 1 = ByteHead, 2 = VideoHead, 3 = TalkerHead
746
+
747
+ def forward(self, x, training=False):
748
+ h = x
749
+ if self.hidden1 is not None:
750
+ h = F.silu(self.hidden1_norm(self.hidden1(h)))
751
+ if self.hidden2 is not None:
752
+ h = self.hidden2(h)
753
+ logits = self.gate(h) # [B, T, 4]
754
+ logits = torch.nan_to_num(logits, nan=0.0, posinf=30.0, neginf=-30.0).clamp(-30.0, 30.0)
755
+ if training:
756
+ weights = F.softmax(logits, dim=-1)
757
+ return weights, logits
758
+ return logits.argmax(dim=-1)
759
+
760
+
761
+ class KGVQCodebook(TernaryVQCodebook):
762
+ """Compatibility wrapper for the KG/composite VQ.
763
+
764
+ The old implementation kept float32 `embed` and `embed_avg` buffers. The
765
+ production path now uses the same packed ternary/int8 backing table as the
766
+ shared VQ so default 5M-code KG construction cannot allocate hidden float
767
+ codebook state.
768
+ """
769
+ def __init__(self, codebook_size=KGVQ_CODEBOOK_SIZE, codebook_dim=KGVQ_CODEBOOK_DIM,
770
+ decay=KGVQ_DECAY, commitment_weight=KGVQ_COMMITMENT_WEIGHT,
771
+ threshold_ema_dead_code=KGVQ_DEAD_CODE_THRESHOLD):
772
+ super().__init__(
773
+ codebook_size=codebook_size,
774
+ codebook_dim=codebook_dim,
775
+ commitment_weight=commitment_weight,
776
+ )
777
+ self.decay = decay
778
+ self.threshold_ema_dead_code = threshold_ema_dead_code
779
+
780
+ @property
781
+ def embed(self):
782
+ if self.codebook_size > self.exact_lookup_max:
783
+ raise RuntimeError(
784
+ "Full KG VQ materialization is disabled for large ternary codebooks; "
785
+ "query rows through `table(indices)` instead."
786
+ )
787
+ return super().embed
788
+
789
+ def _ema_update(self, x_flat, indices):
790
+ unique, counts = torch.unique(indices, return_counts=True)
791
+ current = self.cluster_size[unique].to(torch.int32)
792
+ updated = torch.clamp(
793
+ current + counts.to(device=current.device, dtype=torch.int32),
794
+ 0,
795
+ 32767,
796
+ ).to(torch.int16)
797
+ self.cluster_size[unique] = updated
798
+
799
+ def _dead_code_reset(self, x_flat):
800
+ return None
801
+
802
+
803
+ class CompositeProposalHead(nn.Module):
804
+ """Multi-proposal head from pooled GNN output (Phase 17).
805
+
806
+ Projects GNN pool output (graph_pool_out [B, D]) to K_MAX composite motif
807
+ proposals, quantizes via KGVQ, and applies ACT-style halting.
808
+ """
809
+ def __init__(self, dim=HIDDEN_DIM, codebook_dim=KGVQ_CODEBOOK_DIM,
810
+ k_max=K_MAX_COMPOSITES, codebook_size=KGVQ_CODEBOOK_SIZE,
811
+ tscale_type=TScaleType.T32):
812
+ super().__init__()
813
+ self.dim = dim
814
+ self.k_max = k_max
815
+ self.codebook_dim = codebook_dim
816
+ self.proj = TernaryScaleTensor(dim, k_max * codebook_dim, tscale_type=tscale_type)
817
+ self.kgvq = TernaryVQCodebook(codebook_size=codebook_size, codebook_dim=codebook_dim,
818
+ tscale_type=tscale_type)
819
+ self.halt_gate = TernaryScaleTensor(dim, k_max, tscale_type=tscale_type)
820
+ self.diversity_weight = 0.1
821
+
822
+ def forward(self, pool_out):
823
+ B = pool_out.shape[0]
824
+ projections = self.proj(pool_out).view(B, self.k_max, self.codebook_dim)
825
+ quantized, composite_ids, vq_loss = self.kgvq(projections)
826
+
827
+ halt_logits = self.halt_gate(pool_out).clamp(-12.0, 12.0)
828
+ halt = torch.sigmoid(halt_logits) # [B, K_MAX]
829
+ composite_ids = composite_ids.masked_fill(halt < 0.5, -1)
830
+
831
+ normed = F.normalize(projections, dim=-1)
832
+ sim_matrix = normed @ normed.transpose(-1, -2)
833
+ triu = torch.triu(sim_matrix, diagonal=1)
834
+ n_pairs = self.k_max * (self.k_max - 1) / 2
835
+ diversity_loss = triu.sum(dim=(-1, -2)).mean() / max(n_pairs, 1)
836
+ diversity_loss = diversity_loss * self.diversity_weight
837
+
838
+ return composite_ids, vq_loss + diversity_loss, halt
839
+
840
+
841
+ class MoEGraph(nn.Module):
842
+ """Fused graph traversal + centroid-based MoE routing + ACT halting.
843
+
844
+ Each ACT iteration: traverse KG β†’ aggregate neighbor emb β†’ centroid route β†’
845
+ run expert β†’ halt check. All operations at MG_WORKSPACE_DIM (1024).
846
+
847
+ Replaces: TernaryGraph + GraphMoEGate + GraphACTCell + SharedProjectionMoE + MoEACTCell.
848
+ """
849
+ def __init__(self, cb_dim=MG_WORKSPACE_DIM, trigram_dim=HIDDEN_DIM,
850
+ codebook_dim=CODEBOOK_DIM,
851
+ num_experts=MG_N_EXPERTS, core_rank=MG_CORE_RANK,
852
+ shared_inter=MG_SHARED_INTER, max_iters=MG_ACT_ITERS,
853
+ halt_threshold=0.99, tscale_type=TScaleType.T32,
854
+ codebook_size=CODEBOOK_SIZE,
855
+ active_graph_max_nodes=4096,
856
+ top_k=1):
857
+ super().__init__()
858
+ self.cb_dim = cb_dim
859
+ self.trigram_dim = trigram_dim
860
+ self.codebook_dim = codebook_dim
861
+ self.num_experts = num_experts
862
+ self.core_rank = core_rank
863
+ self.shared_inter = shared_inter
864
+ self.max_iters = max_iters
865
+ self.halt_threshold = halt_threshold
866
+ self.codebook_size = codebook_size
867
+ self.active_graph_max_nodes = active_graph_max_nodes
868
+ self.top_k = top_k
869
+
870
+ self.down_proj = TernaryScaleTensor(trigram_dim, cb_dim, tscale_type=tscale_type)
871
+ self.down_norm = TernaryRMSNorm(trigram_dim, tscale_type=tscale_type)
872
+ self.up_proj = TernaryScaleTensor(cb_dim, trigram_dim, tscale_type=tscale_type)
873
+ self.up_norm = TernaryRMSNorm(cb_dim, tscale_type=tscale_type)
874
+ self.attn_down_proj = TernaryScaleTensor(trigram_dim, cb_dim, tscale_type=tscale_type)
875
+ self.codebook_up = TernaryScaleTensor(codebook_dim, cb_dim, tscale_type=tscale_type)
876
+
877
+ self.use_active_edge_store = self.codebook_size > self.active_graph_max_nodes
878
+ self.active_edge_capacity = max(int(self.active_graph_max_nodes) * 16, 65_536)
879
+ if self.use_active_edge_store:
880
+ self.register_buffer("edge_index", torch.zeros(2, 0, dtype=torch.int32))
881
+ self.register_buffer("edge_attr", torch.zeros(0, dtype=torch.int8))
882
+ self.register_buffer("edge_score", torch.zeros(0, dtype=torch.int8))
883
+ self.register_buffer("active_edge_src", torch.full((self.active_edge_capacity,), -1, dtype=torch.int32))
884
+ self.register_buffer("active_edge_dst", torch.full((self.active_edge_capacity,), -1, dtype=torch.int32))
885
+ self.register_buffer("active_edge_attr", torch.zeros(self.active_edge_capacity, dtype=torch.int8))
886
+ self.register_buffer("active_edge_score", torch.zeros(self.active_edge_capacity, dtype=torch.int8))
887
+ self.register_buffer("active_edge_ptr", torch.zeros((), dtype=torch.long))
888
+ else:
889
+ num_edges = self.codebook_size * 10
890
+ src = torch.arange(self.codebook_size, dtype=torch.int32).repeat_interleave(10)
891
+ dst = torch.randint(0, self.codebook_size, (num_edges,), dtype=torch.int32)
892
+ self.register_buffer("edge_index", torch.stack([src, dst], dim=0))
893
+ edge_init = torch.randint(-1, 2, (num_edges,), dtype=torch.int8)
894
+ self.register_buffer("edge_attr", edge_init)
895
+ self.register_buffer("edge_score", torch.zeros(num_edges, dtype=torch.int8))
896
+ self.register_buffer("_steps_since_requant", torch.tensor(0, dtype=torch.long))
897
+ self.requant_every = KG_REQUANT_EVERY
898
+ self.kg_ternary_threshold = KG_TERNARY_THRESHOLD
899
+ self.kg_ema_alpha = KG_EMA_ALPHA
900
+
901
+ self.centroids = TernaryEmbeddingTable(num_experts, cb_dim, tscale_type=tscale_type, normalize=True)
902
+
903
+ self.shared_up_norm = TernaryRMSNorm(cb_dim, tscale_type=tscale_type)
904
+ self.shared_up = TernaryScaleTensor(cb_dim, shared_inter, tscale_type=tscale_type)
905
+ self.shared_down_norm = TernaryRMSNorm(shared_inter, tscale_type=tscale_type)
906
+ self.shared_down = TernaryScaleTensor(shared_inter, cb_dim, tscale_type=tscale_type)
907
+
908
+ self.W_gate = nn.ModuleList([
909
+ TernaryScaleTensor(cb_dim, core_rank, tscale_type=tscale_type)
910
+ for _ in range(num_experts)
911
+ ])
912
+ self.W_gate_norms = nn.ModuleList([
913
+ TernaryRMSNorm(cb_dim, tscale_type=tscale_type)
914
+ for _ in range(num_experts)
915
+ ])
916
+ self.W_transform = nn.ModuleList([
917
+ TernaryScaleTensor(core_rank, shared_inter, tscale_type=tscale_type)
918
+ for _ in range(num_experts)
919
+ ])
920
+ self.W_transform_norms = nn.ModuleList([
921
+ TernaryRMSNorm(core_rank, tscale_type=tscale_type)
922
+ for _ in range(num_experts)
923
+ ])
924
+
925
+ self.hop_lora = GNNLoRAAdapter(dim=cb_dim, rank=32, max_hops=max_iters)
926
+ self.halting = HaltingUnit(dim=cb_dim, tscale_type=tscale_type)
927
+ self.lti = LTIInjection(cb_dim)
928
+
929
+ self._codebook_embed = None
930
+ self._codebook_table = None
931
+
932
+ def _codebook_tensor(self, device):
933
+ if self._codebook_table is not None:
934
+ idx = torch.arange(self.codebook_size, device=device)
935
+ codebook = self._codebook_table(idx)
936
+ if codebook.shape[-1] != self.cb_dim:
937
+ codebook = self.codebook_up(codebook)
938
+ return codebook
939
+ if self._codebook_embed is not None:
940
+ codebook = self._codebook_embed.to(device=device).squeeze(0)
941
+ if codebook.shape[-1] != self.cb_dim:
942
+ codebook = self.codebook_up(codebook)
943
+ return codebook
944
+ return torch.zeros(self.codebook_size, self.cb_dim, device=device)
945
+
946
+ def _active_codebook_features(self, vq_indices):
947
+ if self._codebook_table is not None:
948
+ safe_idx = vq_indices.clamp(min=0, max=self.codebook_size - 1)
949
+ active_code = self._codebook_table(safe_idx)
950
+ elif self._codebook_embed is not None:
951
+ codebook = self._codebook_embed.to(device=vq_indices.device).squeeze(0)
952
+ safe_idx = vq_indices.clamp(min=0, max=codebook.shape[0] - 1)
953
+ active_code = codebook[safe_idx]
954
+ else:
955
+ return torch.zeros(*vq_indices.shape, self.cb_dim, device=vq_indices.device)
956
+ if active_code.shape[-1] != self.cb_dim:
957
+ active_code = self.codebook_up(active_code)
958
+ return active_code
959
+
960
+ def _neighbor_aggregate(self, node_features, threshold):
961
+ N, D = node_features.shape
962
+ aggregated = torch.zeros(self.codebook_size, D, device=node_features.device, dtype=node_features.dtype)
963
+ edge_ternary = StickyZoneSTE.apply(self.edge_attr, threshold)
964
+ src_features = node_features[self.edge_index[0]]
965
+ messages = edge_ternary.unsqueeze(1).to(node_features.dtype) * src_features
966
+ dst_idx = self.edge_index[1].unsqueeze(1).expand(-1, D)
967
+ aggregated.scatter_add_(0, dst_idx, messages)
968
+ return aggregated
969
+
970
+ def _run_expert_batch(self, x, expert_idx):
971
+ B, T, D = x.shape
972
+ N = B * T
973
+ x_flat = rearrange(x, 'b t d -> (b t) d')
974
+ exp_flat = rearrange(expert_idx, 'b t -> (b t)')
975
+ shared_hidden = F.silu(self.shared_up(self.shared_up_norm(x_flat)))
976
+ sort_idx = exp_flat.argsort()
977
+ sorted_experts = exp_flat[sort_idx]
978
+ expert_counts = torch.bincount(sorted_experts, minlength=self.num_experts)
979
+ expert_boundaries = torch.cumsum(expert_counts, dim=0)
980
+ out_flat = torch.zeros(N, D, device=x.device, dtype=x.dtype)
981
+ for e in range(self.num_experts):
982
+ start = expert_boundaries[e] - expert_counts[e]
983
+ end = expert_boundaries[e]
984
+ if start == end:
985
+ continue
986
+ tok_idx = sort_idx[start:end]
987
+ inp = x_flat[tok_idx]
988
+ sh = shared_hidden[tok_idx]
989
+ gate = self.W_gate[e](self.W_gate_norms[e](inp))
990
+ core = self.W_transform[e](self.W_transform_norms[e](gate))
991
+ expert_out = self.shared_down(self.shared_down_norm(core * sh))
992
+ out_flat[tok_idx] = expert_out
993
+ return rearrange(out_flat, '(b t) d -> b t d', b=B, t=T)
994
+
995
+ def _run_expert(self, x, expert_idx):
996
+ return self._run_expert_batch(x, expert_idx)
997
+
998
+ def _active_node_add(self, vq_output, vq_indices):
999
+ return vq_output + self._active_codebook_features(vq_indices)
1000
+
1001
+ def forward(self, trigram_input, vq_indices, attention_output=None,
1002
+ memgram_cb_output=None, threshold=0.05):
1003
+ B, T, D = trigram_input.shape
1004
+ device = trigram_input.device
1005
+
1006
+ x = self.down_proj(self.down_norm(trigram_input))
1007
+
1008
+ attn_cb = None
1009
+ if attention_output is not None:
1010
+ attn_cb = self.attn_down_proj(self.down_norm(attention_output))
1011
+
1012
+ halted = torch.zeros(B, T, device=device, dtype=torch.bool)
1013
+ cumulative_p = torch.zeros(B, T, device=device)
1014
+ acc = torch.zeros_like(x)
1015
+ total_ponder = torch.zeros(B, T, device=device)
1016
+ last_x = x
1017
+ initial_x = x
1018
+
1019
+ use_active_graph = self.codebook_size > self.active_graph_max_nodes
1020
+ node_features = None if use_active_graph else self._codebook_tensor(device)
1021
+
1022
+ for iter_t in range(self.max_iters):
1023
+ if use_active_graph:
1024
+ traversal = self._active_node_add(x, vq_indices)
1025
+ else:
1026
+ node_aggregated = self._neighbor_aggregate(node_features, threshold)
1027
+ traversal = x + node_aggregated[vq_indices]
1028
+
1029
+ if attn_cb is not None:
1030
+ traversal = traversal + attn_cb
1031
+
1032
+ if iter_t in [1, 3] and memgram_cb_output is not None:
1033
+ memgram_raw = memgram_cb_output.to(device)
1034
+ if memgram_raw.shape[-1] != self.cb_dim:
1035
+ memgram_raw = memgram_raw.mean(dim=-1, keepdim=True).expand(-1, -1, self.cb_dim)
1036
+ traversal = traversal + memgram_raw
1037
+
1038
+ traversal = traversal + self.hop_lora(traversal, iter_t)
1039
+
1040
+ trav_norm = F.normalize(traversal, dim=-1, eps=1e-8)
1041
+ centroid_ids = torch.arange(self.num_experts, device=device)
1042
+ cent_norm = F.normalize(self.centroids(centroid_ids), dim=-1, eps=1e-8)
1043
+ scores = trav_norm @ cent_norm.T
1044
+ if self.top_k <= 1:
1045
+ _, expert_idx = scores.max(dim=-1)
1046
+ expert_out = self._run_expert(traversal, expert_idx)
1047
+ else:
1048
+ scores_topk, topk_idx = scores.topk(k=self.top_k, dim=-1)
1049
+ weights = F.softmax(scores_topk / 0.1, dim=-1)
1050
+ expert_out = 0
1051
+ for i in range(self.top_k):
1052
+ wi = weights[..., i:i+1]
1053
+ ei = topk_idx[..., i]
1054
+ expert_out = expert_out + wi * self._run_expert(traversal, ei)
1055
+ last_x = expert_out
1056
+
1057
+ p = self.halting(expert_out).squeeze(-1)
1058
+ still_running = ~halted
1059
+ remainder = (1.0 - cumulative_p).clamp(min=0)
1060
+ weight = torch.where(
1061
+ cumulative_p + p >= self.halt_threshold,
1062
+ remainder, p,
1063
+ )
1064
+ weight = weight * still_running.float()
1065
+ acc = acc + weight.unsqueeze(-1) * expert_out
1066
+ cumulative_p = cumulative_p + p * still_running.float()
1067
+ halted = halted | (cumulative_p >= self.halt_threshold)
1068
+ total_ponder = total_ponder + (1.0 - cumulative_p).clamp(min=0)
1069
+
1070
+ x = self.lti(x, initial_x, expert_out)
1071
+
1072
+ if halted.all():
1073
+ break
1074
+
1075
+ never_halted = (~halted).float().unsqueeze(-1)
1076
+ acc = acc + never_halted * last_x
1077
+
1078
+ output = self.up_proj(self.up_norm(acc))
1079
+ ponder_loss = total_ponder.mean() / self.max_iters
1080
+
1081
+ return output, ponder_loss
1082
+
1083
+ @torch.no_grad()
1084
+ def update_kg_edges(self, all_vq_indices):
1085
+ if self.use_active_edge_store:
1086
+ self._update_active_edges(all_vq_indices)
1087
+ return
1088
+
1089
+ unique_ids = torch.unique(all_vq_indices.to(device=self.edge_index.device, dtype=torch.int32))
1090
+ src_in_batch = torch.isin(self.edge_index[0], unique_ids)
1091
+
1092
+ if src_in_batch.any():
1093
+ dst_seen = torch.isin(self.edge_index[1][src_in_batch], unique_ids)
1094
+ delta = torch.where(
1095
+ dst_seen,
1096
+ torch.ones_like(self.edge_score[src_in_batch], dtype=torch.int16),
1097
+ torch.full_like(self.edge_score[src_in_batch], -1, dtype=torch.int16),
1098
+ )
1099
+ score = torch.clamp(self.edge_score[src_in_batch].to(torch.int16) + delta, -128, 127)
1100
+ self.edge_score[src_in_batch] = score.to(torch.int8)
1101
+
1102
+ self._requantize_dense_edges()
1103
+
1104
+ @torch.no_grad()
1105
+ def _update_active_edges(self, all_vq_indices):
1106
+ ids = all_vq_indices.to(device=self.active_edge_src.device, dtype=torch.int32)
1107
+ if ids.numel() < 2:
1108
+ self._steps_since_requant.add_(1)
1109
+ return
1110
+
1111
+ seq = ids.reshape(-1, ids.shape[-1]) if ids.dim() > 1 else ids.reshape(1, -1)
1112
+ src = seq[:, :-1].reshape(-1)
1113
+ dst = seq[:, 1:].reshape(-1)
1114
+ valid = (src >= 0) & (dst >= 0) & (src < self.codebook_size) & (dst < self.codebook_size) & (src != dst)
1115
+ src = src[valid]
1116
+ dst = dst[valid]
1117
+ if src.numel() == 0:
1118
+ self._steps_since_requant.add_(1)
1119
+ return
1120
+
1121
+ n_edges = min(src.numel(), self.active_edge_capacity)
1122
+ src = src[-n_edges:]
1123
+ dst = dst[-n_edges:]
1124
+ ptr = int(self.active_edge_ptr.item())
1125
+ slots = (torch.arange(n_edges, device=src.device, dtype=torch.long) + ptr) % self.active_edge_capacity
1126
+
1127
+ self.active_edge_src[slots] = src
1128
+ self.active_edge_dst[slots] = dst
1129
+ score = torch.clamp(self.active_edge_score[slots].to(torch.int16) + 1, -128, 127)
1130
+ self.active_edge_score[slots] = score.to(torch.int8)
1131
+ self.active_edge_attr[slots] = 1
1132
+ self.active_edge_ptr.fill_((ptr + n_edges) % self.active_edge_capacity)
1133
+ self._requantize_active_edges()
1134
+
1135
+ @torch.no_grad()
1136
+ def _requantize_dense_edges(self):
1137
+ if self._steps_since_requant.item() < self.requant_every:
1138
+ self._steps_since_requant.add_(1)
1139
+ return
1140
+ self.edge_attr = self._score_to_attr(self.edge_score)
1141
+ score = self.edge_score.to(torch.int16)
1142
+ score = torch.where(score > 0, score - 1, torch.where(score < 0, score + 1, score))
1143
+ self.edge_score = score.to(torch.int8)
1144
+ self._steps_since_requant.zero_()
1145
+
1146
+ @torch.no_grad()
1147
+ def _requantize_active_edges(self):
1148
+ if self._steps_since_requant.item() < self.requant_every:
1149
+ self._steps_since_requant.add_(1)
1150
+ return
1151
+ active = self.active_edge_src >= 0
1152
+ if active.any():
1153
+ self.active_edge_attr[active] = self._score_to_attr(self.active_edge_score[active])
1154
+ score = self.active_edge_score[active].to(torch.int16)
1155
+ score = torch.where(score > 0, score - 1, torch.where(score < 0, score + 1, score))
1156
+ self.active_edge_score[active] = score.to(torch.int8)
1157
+ self._steps_since_requant.zero_()
1158
+
1159
+ def _score_to_attr(self, score):
1160
+ threshold = max(1, int(round(float(self.kg_ternary_threshold) * 8)))
1161
+ score_i = score.to(torch.int16)
1162
+ return torch.where(
1163
+ score_i >= threshold,
1164
+ torch.ones_like(score, dtype=torch.int8),
1165
+ torch.where(
1166
+ score_i <= -threshold,
1167
+ torch.full_like(score, -1, dtype=torch.int8),
1168
+ torch.zeros_like(score, dtype=torch.int8),
1169
+ ),
1170
+ )
1171
+
1172
+ @torch.no_grad()
1173
+ def monitor_graph_health(self, threshold=0.05):
1174
+ if self.use_active_edge_store:
1175
+ active = self.active_edge_src >= 0
1176
+ if not active.any():
1177
+ return {
1178
+ "sparsity": 1.0, "isolated_nodes": self.codebook_size,
1179
+ "avg_polarity": 0.0, "dead_edges": 0,
1180
+ "score_mean": 0.0, "score_max": 0.0,
1181
+ "active_edges": 0,
1182
+ }
1183
+ edge_attr = self.active_edge_attr[active]
1184
+ edge_score = self.active_edge_score[active]
1185
+ nodes_with_edges = torch.unique(torch.cat([self.active_edge_src[active], self.active_edge_dst[active]]))
1186
+ else:
1187
+ edge_attr = self.edge_attr
1188
+ edge_score = self.edge_score
1189
+ nodes_with_edges = torch.unique(torch.cat([self.edge_index[0], self.edge_index[1]]))
1190
+
1191
+ ternary_edge = edge_attr.sign()
1192
+ sparsity = (ternary_edge == 0).float().mean().item() if ternary_edge.numel() else 1.0
1193
+ n_isolated = max(int(self.codebook_size) - int(nodes_with_edges.numel()), 0)
1194
+ n_pos = (ternary_edge > 0).sum().item()
1195
+ n_neg = (ternary_edge < 0).sum().item()
1196
+ n_nonzero = n_pos + n_neg
1197
+ avg_polarity = (n_pos - n_neg) / max(n_nonzero, 1)
1198
+ dead_edges = ((ternary_edge == 0) & (edge_score != 0)).sum().item()
1199
+ score_mean = edge_score.float().mean().item() if edge_score.numel() else 0.0
1200
+ score_max = edge_score.float().abs().max().item() if edge_score.numel() else 0.0
1201
+ return {
1202
+ "sparsity": sparsity, "isolated_nodes": n_isolated,
1203
+ "avg_polarity": avg_polarity, "dead_edges": dead_edges,
1204
+ "score_mean": score_mean, "score_max": score_max,
1205
+ "active_edges": int(ternary_edge.numel()),
1206
+ }
1207
+
1208
+ def set_adjacency(self, edge_index, edge_attr_init=None):
1209
+ self.use_active_edge_store = False
1210
+ device = self.edge_attr.device
1211
+ self.edge_index = edge_index.to(device=device, dtype=torch.int32)
1212
+ if edge_attr_init is not None:
1213
+ edge_attr = edge_attr_init.sign() * (edge_attr_init.abs() > 0).to(edge_attr_init.dtype)
1214
+ self.edge_attr = edge_attr.to(device=device, dtype=torch.int8)
1215
+ else:
1216
+ self.edge_attr = torch.randint(-1, 2, (edge_index.size(1),),
1217
+ device=device, dtype=torch.int8)
1218
+ self.edge_score = self.edge_attr.clone()
arbitor/config.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ VOCAB=288
2
+ AUDIO_VOCAB=288
3
+ AUDIO_SR=16000
4
+ AUDIO_FRAME_RATE=50
5
+ THRESHOLD=0.05
6
+
7
+ # -- 3B Target Dimensions --
8
+ EMBEDDING_DIM=1536
9
+ CODEBOOK_DIM=1024
10
+ CODEBOOK_SIZE=524288 # Base unit
11
+ # Shared multimodal VQ (256K entries Γ— 1024-dim)
12
+ SHARED_VQ_SIZE = 262144
13
+ HIDDEN_DIM=8192 # Main hidden dimension
14
+ FFN_HIDDEN=16384 # 2Γ— HIDDEN_DIM
15
+ CTX=256
16
+
17
+ # MoEGraph (256 experts, centroid routing, unified ACT)
18
+ MG_N_EXPERTS = 256
19
+ MG_CORE_RANK = 384
20
+ MG_SHARED_INTER = 1536
21
+ MG_ACT_ITERS = 4
22
+ MG_WORKSPACE_DIM = 768
23
+ MG_TOP_K = 2
24
+
25
+ # VQ
26
+ # MemGram (32 heads Γ— ~65K slots β‰ˆ 2M total associative slots)
27
+ MEMGRAM_STRUCT_PRIMES = [64901, 64919, 64921, 64927, 64937, 64951, 64969, 64997,
28
+ 65003, 65011, 65027, 65029, 65033, 65053, 65063, 65071,
29
+ 65101, 65119, 65123, 65129, 65141, 65147, 65167, 65171,
30
+ 65173, 65179, 65183, 65203, 65213, 65239, 65257, 65269]
31
+ MEMGRAM_CONV_PRIMES = [8009, 8011, 8017, 8039, 8081, 8087, 8089, 8093]
32
+ MEMGRAM_EMBED_DIM = 64
33
+ MEMGRAM_KEY_DIM = 32
34
+
35
+ # KV Ledger
36
+ KV_LEDGER_SIZE = 262144
37
+ SLIDING_WINDOW_SIZE = 32768
38
+ KQ_CACHE_SIZE = 8192
39
+
40
+ # MLA Attention dimensions
41
+ MLA_N_HEADS = 32
42
+ MLA_QK_NOPE_HEAD_DIM = 96
43
+ MLA_QK_ROPE_HEAD_DIM = 32
44
+ MLA_V_HEAD_DIM = 96
45
+ MLA_SLIDE_DIM = 64
46
+ MLA_FULL_DIM = 32
47
+ MLA_N_LAYERS = 24
48
+
49
+ # RoPE
50
+ MLA_ROPE_THETA = 10000.0
51
+
52
+ # Attention
53
+ ATTENTION_STRIDE = 8
54
+ KV_CONTEXT_LENGTH = 33554432
55
+
56
+ # CSA / HCA compression (DeepSeek V4 hybrid attention)
57
+ MLA_CSA_DIM = 16
58
+ MLA_HCA_DIM = 16
59
+ MLA_HCA_STRIDE = 32
60
+
61
+ # KG EMA β€” Phase 17
62
+ KG_EMA_ALPHA=0.99
63
+ KG_REQUANT_EVERY=50
64
+ KG_TERNARY_THRESHOLD=0.3
65
+
66
+ # Composite Motif VQ β€” Phase 17 (64K entries Γ— 1024-dim)
67
+ KGVQ_CODEBOOK_SIZE=65536
68
+ KGVQ_CODEBOOK_DIM=1024
69
+ KGVQ_DECAY=0.99
70
+ KGVQ_COMMITMENT_WEIGHT=1.0
71
+ KGVQ_DEAD_CODE_THRESHOLD=2
72
+ K_MAX_COMPOSITES=20
73
+
74
+ # VideoHead (Open-Sora VAE: 4 latent channels, 8Γ— spatial + 4Γ— temporal compression)
75
+ VIDEO_LATENT_CHANNELS = 4
76
+ VIDEO_MAX_STEPS = 8
77
+ VIDEO_HEIGHT = 64
78
+ VIDEO_WIDTH = 64
79
+
80
+ # -- Open-Sora 3D VAE (Phase 19) --
81
+ OPEN_SORA_VAE_PATH = "arbitor/encoders/models/opensora-vae"
82
+ OPEN_SORA_VAE_REPO = "hpcai-tech/OpenSora-VAE-v1.2"
83
+ OPEN_SORA_LATENT_CHANNELS = 4
84
+ OPEN_SORA_SCALE_FACTOR_SPATIAL = 8
85
+ OPEN_SORA_SCALE_FACTOR_TEMPORAL = 4
86
+
87
+ # -- ACT Loop Parameters (Phase 19) --
88
+ BYTEHEAD_ACT_MAX_ITERS = 3
89
+ BYTEHEAD_ACT_HALT_CONSECUTIVE = 2
90
+ BYTEHEAD_ACT_PONDER_LAMBDA = 0.01
91
+
92
+ VIDEOHEAD_ACT_MIN_FPS = 1
93
+ VIDEOHEAD_ACT_MAX_FPS = 60
94
+ VIDEOHEAD_ACT_FRAME_CHUNK = 8
95
+
96
+ TALKERHEAD_ACT_CHUNK_FRAMES = 500
97
+
98
+ # -- Timestamp Encoding (Phase 19) --
99
+ TIMESTAMP_MAX_PERIOD = 10000.0
100
+
101
+ # -- Temporal Frame Buffer (Phase 19) --
102
+ FRAME_BUFFER_LOCAL_SIZE = 3
103
+ FRAME_BUFFER_CACHE_STRIDE = 4
104
+
105
+ SPECIAL_VOCAB = {
106
+ # Control
107
+ 'PAD': 256, 'BOS': 257, 'EOS': 258, 'STOP': 259,
108
+ # Roles
109
+ 'SYSTEM': 260, 'USER': 261, 'ASSISTANT': 262,
110
+ # Reasoning
111
+ 'SCRATCHPAD': 263, 'PLAN': 264, 'REFLECTION': 265, 'SUMMARY': 266,
112
+ # Tool use
113
+ 'ACTION': 267, 'TOOL': 268, 'TOOL_RESULT': 269,
114
+ # Code
115
+ 'CODE': 270, 'CODE_BLOCK': 271, 'EXECUTION': 272,
116
+ # RAG
117
+ 'SEARCH': 273, 'CONTEXT': 274, 'CITATION': 275,
118
+ # Quality / format
119
+ 'ERROR': 276, 'FORMAT': 277,
120
+ # Multimodal
121
+ 'IMAGE': 278, 'TEXT': 279, 'AUDIO': 280,
122
+ 'VIDEO': 281, 'SPEAK': 282, 'IMG_GEN': 283,
123
+ # Future
124
+ 'RES1': 284, 'RES2': 285, 'RES3': 286, 'RESERVED': 287,
125
+ }
arbitor/converters/convert_to_ternary2.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+
5
+
6
+
7
+
8
+ def pack_ternary(w):
9
+ q = torch.empty_like(w, dtype=torch.uint8)
10
+ q[w < 0] = 0
11
+ q[w == 0] = 1
12
+ q[w > 0] = 2
13
+
14
+ flat = q.flatten()
15
+ pad = (-len(flat)) % 4
16
+ if pad:
17
+ flat = torch.cat([flat, torch.zeros(pad, dtype=torch.uint8, device=flat.device)])
18
+
19
+ flat = flat.view(-1, 4)
20
+
21
+ packed = (
22
+ flat[:, 0]
23
+ | (flat[:, 1] << 2)
24
+ | (flat[:, 2] << 4)
25
+ | (flat[:, 3] << 6)
26
+ )
27
+
28
+ return packed.cpu(), w.shape
29
+
30
+
31
+ def save_model(model, path="trigram-morph.pt"):
32
+ ternary_weights = {}
33
+ for name, param in model.named_parameters():
34
+ if "weight" in name and param.ndim >= 2 and "embed" not in name:
35
+ T = StickyZoneSTE.apply(param.data, THRESHOLD)
36
+ packed, shape = pack_ternary(T)
37
+ ternary_weights[name] = {"packed": packed, "shape": shape}
38
+
39
+ torch.save({
40
+ "model_state_dict": model.state_dict(),
41
+ "config": {
42
+ "vocab": VOCAB,
43
+ "embedding_dim": EMBEDDING_DIM,
44
+ "trigram_dim": HIDDEN_DIM,
45
+ "ffn_hidden": FFN_HIDDEN,
46
+ "ctx": CTX,
47
+ "threshold": THRESHOLD,
48
+ },
49
+ "ternary_packed": ternary_weights,
50
+ "format": "factorized_scaled_ternary",
51
+ "bpw": 1.58,
52
+ }, path)
53
+ total = sum(p.numel() for p in model.parameters())
54
+ print(f"Saved {total:,} params to {path}")
55
+
56
+
57
+ def load_model(path="trigram-morph.pt", device="cpu"):
58
+ checkpoint = torch.load(path, map_location=device, weights_only=False)
59
+ model = ARBModel()
60
+ model.load_state_dict(checkpoint["model_state_dict"], strict=False)
61
+ model.to(device)
62
+ model.eval()
63
+ return model
64
+
65
+
66
+ if __name__ == "__main__":
67
+ from ..trigram import ARBModel
68
+ model = ARBModel()
69
+ total = sum(p.numel() for p in model.parameters())
70
+ ternary = sum(
71
+ p.numel() for n, p in model.named_parameters()
72
+ if "weight" in n and p.ndim >= 2 and "embed" not in n
73
+ )
74
+ fp32 = sum(
75
+ p.numel() for n, p in model.named_parameters()
76
+ if not ("weight" in n and p.ndim >= 2 and "embed" not in n)
77
+ )
78
+ print(f"Total params: {total:,}")
79
+ print(f"Ternary params (1.58 BPW): {ternary:,}")
80
+ print(f"FP32 params: {fp32:,}")
81
+ save_model(model)
arbitor/converters/convert_to_ternary54.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+
5
+
6
+
7
+
8
+ def pack_ternary_34(w):
9
+ q = torch.empty_like(w, dtype=torch.uint8)
10
+
11
+ q[w < 0] = 0
12
+ q[w == 0] = 1
13
+ q[w > 0] = 2
14
+
15
+ flat = q.flatten()
16
+
17
+ # pad to multiple of 34
18
+ pad = (-len(flat)) % 34
19
+
20
+ if pad:
21
+ flat = torch.cat([
22
+ flat,
23
+ torch.ones(pad, dtype=torch.uint8, device=flat.device)
24
+ ])
25
+
26
+ flat = flat.view(-1, 34)
27
+
28
+ packed = torch.zeros(
29
+ flat.shape[0],
30
+ dtype=torch.uint64,
31
+ device=flat.device
32
+ )
33
+
34
+ multiplier = 1
35
+
36
+ for i in range(34):
37
+ packed += flat[:, i].to(torch.uint64) * multiplier
38
+ multiplier *= 3
39
+
40
+ return packed.cpu(), w.shape, pad
41
+
42
+
43
+ def unpack_ternary_34(packed, shape, pad=0):
44
+ packed = packed.to(torch.uint64)
45
+
46
+ out = []
47
+
48
+ for _ in range(34):
49
+ trit = packed % 3
50
+ packed //= 3
51
+ out.append(trit)
52
+
53
+ out = torch.stack(out, dim=1).flatten()
54
+
55
+ if pad:
56
+ out = out[:-pad]
57
+
58
+ out = out.view(shape)
59
+
60
+ # restore ternary values
61
+ out = out.to(torch.int8)
62
+
63
+ out[out == 0] = -1
64
+ out[out == 1] = 0
65
+ out[out == 2] = 1
66
+
67
+ return out
68
+
69
+
70
+ def save_model(model, path="trigram-morph.pt"):
71
+ ternary_weights = {}
72
+ for name, param in model.named_parameters():
73
+ if "weight" in name and param.ndim >= 2 and "embed" not in name:
74
+ T = StickyZoneSTE.apply(param.data, THRESHOLD)
75
+ packed, shape = pack_ternary(T)
76
+ ternary_weights[name] = {"packed": packed, "shape": shape}
77
+
78
+ torch.save({
79
+ "model_state_dict": model.state_dict(),
80
+ "config": {
81
+ "vocab": VOCAB,
82
+ "embedding_dim": EMBEDDING_DIM,
83
+ "trigram_dim": HIDDEN_DIM,
84
+ "ffn_hidden": FFN_HIDDEN,
85
+ "ctx": CTX,
86
+ "threshold": THRESHOLD,
87
+ },
88
+ "ternary_packed": ternary_weights,
89
+ "format": "factorized_scaled_ternary",
90
+ "bpw": 1.58,
91
+ }, path)
92
+ total = sum(p.numel() for p in model.parameters())
93
+ print(f"Saved {total:,} params to {path}")
94
+
95
+
96
+ def load_model(path="trigram-morph.pt", device="cpu"):
97
+ checkpoint = torch.load(path, map_location=device, weights_only=False)
98
+ model = ARBModel()
99
+ model.load_state_dict(checkpoint["model_state_dict"], strict=False)
100
+ model.to(device)
101
+ model.eval()
102
+ return model
103
+
104
+
105
+ if __name__ == "__main__":
106
+ from ..trigram import ARBModel
107
+ model = ARBModel()
108
+ total = sum(p.numel() for p in model.parameters())
109
+ ternary = sum(
110
+ p.numel() for n, p in model.named_parameters()
111
+ if "weight" in n and p.ndim >= 2 and "embed" not in n
112
+ )
113
+ fp32 = sum(
114
+ p.numel() for n, p in model.named_parameters()
115
+ if not ("weight" in n and p.ndim >= 2 and "embed" not in n)
116
+ )
117
+ print(f"Total params: {total:,}")
118
+ print(f"Ternary params (1.58 BPW): {ternary:,}")
119
+ print(f"FP32 params: {fp32:,}")
120
+ save_model(model)
arbitor/converters/convert_to_ternary64.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+
5
+
6
+
7
+
8
+ def pack_ternary(w):
9
+ q = torch.empty_like(w, dtype=torch.uint8)
10
+ q[w < 0] = 0
11
+ q[w == 0] = 1
12
+ q[w > 0] = 2
13
+
14
+ flat = q.flatten()
15
+ pad = (-len(flat)) % 40 # 40 trit -> 64 bit packing - Higher conversation than 1 trit -> 2 bit + uint64 performance
16
+ if pad:
17
+ flat = torch.cat([
18
+ flat,
19
+ torch.zeros(pad, dtype=torch.uint8, device=flat.device)
20
+ ])
21
+
22
+ flat = flat.view(-1, 40)
23
+
24
+ packed = torch.zeros(flat.shape[0], dtype=torch.uint64)
25
+
26
+ multiplier = 1
27
+
28
+ for i in range(40):
29
+ packed += flat[:, i].to(torch.uint64) * multiplier
30
+ multiplier *= 3
31
+
32
+ return packed.cpu(), w.shape, pad
33
+
34
+
35
+ def unpack_ternary_40(packed, shape, pad=0):
36
+ packed = packed.to(torch.uint64)
37
+
38
+ out = []
39
+
40
+ for _ in range(40):
41
+ trit = packed % 3
42
+ packed //= 3
43
+ out.append(trit)
44
+
45
+ out = torch.stack(out, dim=1).flatten()
46
+
47
+ if pad:
48
+ out = out[:-pad]
49
+
50
+ out = out.view(shape)
51
+
52
+ out = out.to(torch.int8)
53
+
54
+ out[out == 0] = -1
55
+ out[out == 1] = 0
56
+ out[out == 2] = 1
57
+
58
+ return out
59
+
60
+
61
+ def save_model(model, path="trigram-morph.pt"):
62
+ ternary_weights = {}
63
+ for name, param in model.named_parameters():
64
+ if "weight" in name and param.ndim >= 2 and "embed" not in name:
65
+ T = StickyZoneSTE.apply(param.data, THRESHOLD)
66
+ packed, shape = pack_ternary(T)
67
+ ternary_weights[name] = {"packed": packed, "shape": shape}
68
+
69
+ torch.save({
70
+ "model_state_dict": model.state_dict(),
71
+ "config": {
72
+ "vocab": VOCAB,
73
+ "embedding_dim": EMBEDDING_DIM,
74
+ "trigram_dim": HIDDEN_DIM,
75
+ "ffn_hidden": FFN_HIDDEN,
76
+ "ctx": CTX,
77
+ "threshold": THRESHOLD,
78
+ },
79
+ "ternary_packed": ternary_weights,
80
+ "format": "factorized_scaled_ternary",
81
+ "bpw": 1.58,
82
+ }, path)
83
+ total = sum(p.numel() for p in model.parameters())
84
+ print(f"Saved {total:,} params to {path}")
85
+
86
+
87
+ def load_model(path="trigram-morph.pt", device="cpu"):
88
+ checkpoint = torch.load(path, map_location=device, weights_only=False)
89
+ model = ARBModel()
90
+ model.load_state_dict(checkpoint["model_state_dict"], strict=False)
91
+ model.to(device)
92
+ model.eval()
93
+ return model
94
+
95
+
96
+ if __name__ == "__main__":
97
+ from ..trigram import ARBModel
98
+ model = ARBModel()
99
+ total = sum(p.numel() for p in model.parameters())
100
+ ternary = sum(
101
+ p.numel() for n, p in model.named_parameters()
102
+ if "weight" in n and p.ndim >= 2 and "embed" not in n
103
+ )
104
+ fp32 = sum(
105
+ p.numel() for n, p in model.named_parameters()
106
+ if not ("weight" in n and p.ndim >= 2 and "embed" not in n)
107
+ )
108
+ print(f"Total params: {total:,}")
109
+ print(f"Ternary params (1.58 BPW): {ternary:,}")
110
+ print(f"FP32 params: {fp32:,}")
111
+ save_model(model)
arbitor/converters/convert_to_ternary8.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+
5
+ # Lightweight imports used by pack_ternary/unpack_ternary (called by core system)
6
+ # No circular deps here β€” these are just type conversions
7
+
8
+ def pack_ternary(w):
9
+ q = torch.empty_like(w, dtype=torch.uint8)
10
+ q[w < 0] = 0
11
+ q[w == 0] = 1
12
+ q[w > 0] = 2
13
+
14
+ flat = q.flatten()
15
+ pad = (-len(flat)) % 5 # 5 trit -> 8 bit packing - Higher conversation than 1 trit -> 2 bit
16
+ if pad:
17
+ flat = torch.cat([
18
+ flat,
19
+ torch.zeros(pad, dtype=torch.uint8, device=flat.device)
20
+ ])
21
+
22
+ flat = flat.view(-1, 5)
23
+
24
+ packed = (
25
+ flat[:, 0]
26
+ + flat[:, 1] * 3
27
+ + flat[:, 2] * 9
28
+ + flat[:, 3] * 27
29
+ + flat[:, 4] * 81
30
+ ).to(torch.uint8)
31
+
32
+ return packed.cpu(), w.shape, pad
33
+
34
+
35
+ def unpack_ternary(packed, shape, pad=0):
36
+ packed = packed.to(torch.int16)
37
+
38
+ t0 = packed % 3
39
+ packed //= 3
40
+
41
+ t1 = packed % 3
42
+ packed //= 3
43
+
44
+ t2 = packed % 3
45
+ packed //= 3
46
+
47
+ t3 = packed % 3
48
+ packed //= 3
49
+
50
+ t4 = packed % 3
51
+
52
+ out = torch.stack([t0, t1, t2, t3, t4], dim=1).flatten()
53
+
54
+ if pad:
55
+ out = out[:-pad]
56
+
57
+ out = out.view(shape)
58
+
59
+ # map back
60
+ out = out.to(torch.int8)
61
+ out[out == 0] = -1
62
+ out[out == 1] = 0
63
+ out[out == 2] = 1
64
+
65
+ return out
66
+
67
+
68
+ def save_model(model, path="models/conversions/arb-model.pt"):
69
+ import os
70
+ os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
71
+ torch.save({"model_state_dict": model.state_dict()}, path)
72
+ total = sum(p.numel() for p in model.parameters())
73
+ print(f"Saved {total:,} params to {path}")
74
+
75
+
76
+ def load_model(path="models/conversions/arb-model.pt", device="cpu"):
77
+ from ..trigram import ARBModel
78
+ checkpoint = torch.load(path, map_location=device, weights_only=False)
79
+ model = ARBModel()
80
+ model.load_state_dict(checkpoint["model_state_dict"], strict=False)
81
+ model.to(device)
82
+ model.eval()
83
+ return model
84
+
85
+
86
+ if __name__ == "__main__":
87
+ from ..trigram import ARBModel
88
+ model = ARBModel()
89
+ total = sum(p.numel() for p in model.parameters())
90
+ ternary = sum(
91
+ p.numel() for n, p in model.named_parameters()
92
+ if "weight" in n and p.ndim >= 2 and "embed" not in n
93
+ )
94
+ fp32 = sum(
95
+ p.numel() for n, p in model.named_parameters()
96
+ if not ("weight" in n and p.ndim >= 2 and "embed" not in n)
97
+ )
98
+ print(f"Total params: {total:,}")
99
+ print(f"Ternary params (1.58 BPW): {ternary:,}")
100
+ print(f"FP32 params: {fp32:,}")
101
+ save_model(model)
arbitor/decoders.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Decoder modules β€” video diffusion, audio codec, speech generation.
2
+
3
+ These modules convert HIDDEN_DIM relational states into modality-specific outputs:
4
+ video (latent diffusion), audio (codec tokens), and speech (token striding + codec).
5
+ """
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from .kernel.ternary_scale import TernaryScaleTensor, TScaleType, TernaryRMSNorm
10
+ from .kernel.triton_video import video_denoise_step
11
+ from .config import HIDDEN_DIM, AUDIO_VOCAB, AUDIO_SR, AUDIO_FRAME_RATE, \
12
+ VIDEO_LATENT_CHANNELS, VIDEO_MAX_STEPS, VIDEO_HEIGHT, VIDEO_WIDTH, \
13
+ VIDEOHEAD_ACT_MIN_FPS, VIDEOHEAD_ACT_MAX_FPS, VIDEOHEAD_ACT_FRAME_CHUNK, \
14
+ TALKERHEAD_ACT_CHUNK_FRAMES
15
+ from .components import TernaryEmbeddingTable
16
+
17
+
18
+ class LTIInjection(nn.Module):
19
+ """LTI state injection: h = A*h + B*e + trans_out.
20
+ Spectral radius < 1 guaranteed by construction via ZOH discretization.
21
+ """
22
+ def __init__(self, dim: int):
23
+ super().__init__()
24
+ self.log_A = nn.Parameter(torch.zeros(dim))
25
+ self.log_dt = nn.Parameter(torch.zeros(1))
26
+ self.B = nn.Parameter(torch.ones(dim) * 0.1)
27
+ for p in (self.log_A, self.log_dt, self.B):
28
+ p.requires_grad_(False)
29
+
30
+ def get_A(self):
31
+ return torch.exp(-torch.exp((self.log_dt + self.log_A).clamp(-20, 20)))
32
+
33
+ def forward(self, h, e, trans_out):
34
+ return self.get_A() * h + self.B * e + trans_out
35
+
36
+
37
+ class VideoHead(nn.Module):
38
+ """Scaled latent diffusion with cross-attention conditioning, frame gate, and 4-frame latent.
39
+
40
+ Produces [B, ch, 4, H', W'] latents (4-frame temporal chunks) per D-102.
41
+ Frame gate controls adaptive fps in [MIN_FPS, MAX_FPS] range.
42
+ """
43
+ def __init__(self, tscale_type=TScaleType.T32, max_steps=VIDEO_MAX_STEPS,
44
+ latent_channels=VIDEO_LATENT_CHANNELS, height=VIDEO_HEIGHT, width=VIDEO_WIDTH,
45
+ min_fps=VIDEOHEAD_ACT_MIN_FPS, max_fps=VIDEOHEAD_ACT_MAX_FPS,
46
+ frame_chunk=VIDEOHEAD_ACT_FRAME_CHUNK):
47
+ super().__init__()
48
+ self.max_steps = max_steps
49
+ self.latent_channels = latent_channels
50
+ self.height = height
51
+ self.width = width
52
+ self.latent_dim = latent_channels * height * width
53
+ self.halt_threshold = 0.05
54
+ self.min_fps = min_fps
55
+ self.max_fps = max_fps
56
+ self.frame_chunk = frame_chunk
57
+
58
+ self.cross_attn_q = TernaryScaleTensor(self.latent_dim, HIDDEN_DIM, tscale_type=tscale_type)
59
+ self.cross_attn_kv = TernaryScaleTensor(HIDDEN_DIM, HIDDEN_DIM, tscale_type=tscale_type)
60
+ self.diffusion_step = TernaryScaleTensor(HIDDEN_DIM, self.latent_dim, tscale_type=tscale_type)
61
+ self.halt_unit = TernaryScaleTensor(HIDDEN_DIM, 1, tscale_type=tscale_type)
62
+ self.frame_gate = TernaryScaleTensor(HIDDEN_DIM, 1, tscale_type=tscale_type)
63
+ self.noise_embed = TernaryEmbeddingTable(max_steps, HIDDEN_DIM, tscale_type=tscale_type)
64
+ self.lti = LTIInjection(self.latent_dim)
65
+
66
+ @torch.no_grad()
67
+ def _compute_fps(self, cond):
68
+ frame_prob = torch.sigmoid(self.frame_gate(cond))
69
+ fps = self.min_fps + frame_prob * (self.max_fps - self.min_fps)
70
+ return fps.mean().item()
71
+
72
+ def forward(self, relational, max_steps=None, duration_seconds=1.0):
73
+ B, T, D = relational.shape
74
+ max_steps = max_steps or self.max_steps
75
+ cond = relational.mean(dim=1, keepdim=True)
76
+
77
+ fps = self._compute_fps(cond)
78
+ n_frames = max(1, int(fps * duration_seconds))
79
+ n_latents = min((n_frames + self.frame_chunk - 1) // self.frame_chunk, max_steps)
80
+
81
+ all_latents = []
82
+ for chunk_idx in range(n_latents):
83
+ latent = torch.randn(B, 1, self.latent_dim, device=relational.device,
84
+ requires_grad=torch.is_grad_enabled())
85
+ for step in range(max_steps):
86
+ q = self.cross_attn_q(latent)
87
+ kv = self.cross_attn_kv(cond.expand(-1, T, -1))
88
+ context = kv.mean(dim=1, keepdim=True)
89
+ step_embed = self.noise_embed(torch.tensor(step, device=relational.device))
90
+ step_embed = step_embed.expand(B, 1, -1)
91
+ step_input = q + context + step_embed
92
+ pred_noise = self.diffusion_step(step_input)
93
+ alpha = 0.9 ** step
94
+ trans_out = video_denoise_step(latent, pred_noise, alpha)
95
+ h = torch.zeros(B, 1, self.latent_dim, device=context.device)
96
+ h[:, :, :HIDDEN_DIM] = context
97
+ latent = self.lti(latent, h, trans_out)
98
+ halt = torch.sigmoid(self.halt_unit(context))
99
+ if halt.mean() > self.halt_threshold and step > 1:
100
+ break
101
+ all_latents.append(latent.view(B, self.latent_channels, 1, self.height, self.width))
102
+
103
+ return torch.cat(all_latents, dim=2)
104
+
105
+
106
+ class MRFBlock(nn.Module):
107
+ """Multi-Receptive Field Fusion block from HiFi-GAN."""
108
+ def __init__(self, channels, kernel_sizes=(3, 5, 7)):
109
+ super().__init__()
110
+ self.convs = nn.ModuleList([
111
+ nn.Sequential(
112
+ nn.LeakyReLU(0.1),
113
+ nn.Conv1d(channels, channels, k, padding=k//2, dilation=1),
114
+ )
115
+ for k in kernel_sizes
116
+ ])
117
+
118
+ def forward(self, x):
119
+ return sum(conv(x) for conv in self.convs) / len(self.convs)
120
+
121
+
122
+ class TinyNeuralCodec(nn.Module):
123
+ """Lightweight neural audio decoder (frozen float32 sidecar).
124
+
125
+ Maps byte token sequences to 16 kHz audio waveforms via transposed conv.
126
+ Token rate: 50 Hz β†’ output: [B, 1, T * 320] at 16 kHz.
127
+ """
128
+ def __init__(self, vocab=AUDIO_VOCAB, embed_dim=512, upsample_ratios=(5, 4, 4, 4)):
129
+ super().__init__()
130
+ self.embed = nn.Embedding(vocab, embed_dim)
131
+
132
+ in_ch = embed_dim
133
+ self.blocks = nn.ModuleList()
134
+ for i, ratio in enumerate(upsample_ratios):
135
+ out_ch = max(1, embed_dim // (2 ** (i + 1)))
136
+ k = ratio * 2
137
+ pad = (ratio + 1) // 2 if ratio % 2 else ratio // 2
138
+ op = max(0, ratio + 2 * pad - k)
139
+ block = nn.Sequential(
140
+ nn.ConvTranspose1d(in_ch, out_ch, k, stride=ratio, padding=pad, output_padding=op),
141
+ MRFBlock(out_ch),
142
+ )
143
+ self.blocks.append(block)
144
+ in_ch = out_ch
145
+
146
+ self.to_audio = nn.Conv1d(in_ch, 1, kernel_size=7, padding=3)
147
+
148
+ def forward(self, tokens):
149
+ x = self.embed(tokens)
150
+ x = x.permute(0, 2, 1)
151
+ for block in self.blocks:
152
+ x = block(x)
153
+ x = self.to_audio(x)
154
+ return torch.tanh(x)
155
+
156
+ def encode_audio(self, audio, frame_rate=AUDIO_FRAME_RATE, sr=AUDIO_SR):
157
+ B, C, T = audio.shape
158
+ frame_len = sr // frame_rate
159
+ pad = (frame_len - T % frame_len) % frame_len
160
+ if pad > 0:
161
+ audio = F.pad(audio, (0, pad))
162
+ frames = audio.unfold(2, frame_len, frame_len)
163
+ frames = frames.mean(dim=1)
164
+ emb = self.embed.weight
165
+ B, NF, FL = frames.shape
166
+ frames_flat = frames.reshape(-1, FL)
167
+ frame_energy = frames_flat.mean(dim=1)
168
+ tokens = torch.clamp(((frame_energy + 1.0) * 127.5).long(), 0, 255)
169
+ tokens = tokens.reshape(B, NF)
170
+ recon = self(tokens)
171
+ if pad > 0:
172
+ recon = recon[:, :, :T]
173
+ return tokens, recon
174
+
175
+
176
+ class TalkerHead(nn.Module):
177
+ """Audio generation head with temporal stride and chunked ACT generation.
178
+
179
+ 2-layer MLP: 8192 β†’ 8192 β†’ 288.
180
+ Generates byte token predictions at 50 Hz frame rate in 500-frame chunks.
181
+ TinyNeuralCodec decodes the predicted tokens to audio waveform.
182
+ """
183
+ def __init__(self, tscale_type=TScaleType.T32,
184
+ chunk_frames=TALKERHEAD_ACT_CHUNK_FRAMES):
185
+ super().__init__()
186
+ self.norm = TernaryRMSNorm(HIDDEN_DIM, tscale_type=tscale_type)
187
+ self.hidden = TernaryScaleTensor(HIDDEN_DIM, HIDDEN_DIM, tscale_type=tscale_type)
188
+ self.hidden_norm = TernaryRMSNorm(HIDDEN_DIM, tscale_type=tscale_type)
189
+ self.head = TernaryScaleTensor(HIDDEN_DIM, AUDIO_VOCAB, tscale_type=tscale_type)
190
+ self.codec = None
191
+ self.max_frames = chunk_frames
192
+ self.chunk_frames = chunk_frames
193
+
194
+ def load_codec(self, device='cuda'):
195
+ if self.codec is None:
196
+ self.codec = TinyNeuralCodec().to(device)
197
+ self.codec.eval()
198
+ return self.codec
199
+
200
+ def token_logits(self, x, max_frames=None):
201
+ max_frames = max_frames or self.max_frames
202
+ cond = self.norm(x)
203
+ cond = F.silu(self.hidden_norm(self.hidden(cond)))
204
+ stride = max(1, max_frames // max(1, cond.shape[1]))
205
+ logits = self.head(cond)
206
+ logits = logits.repeat_interleave(stride, dim=1)
207
+ if logits.shape[1] > max_frames:
208
+ logits = logits[:, :max_frames, :]
209
+ elif logits.shape[1] < max_frames:
210
+ pad = logits.new_zeros(logits.shape[0], max_frames - logits.shape[1], logits.shape[2])
211
+ logits = torch.cat([logits, pad], dim=1)
212
+ return logits
213
+
214
+ def forward(self, x, max_frames=None):
215
+ return self.token_logits(x, max_frames=max_frames).argmax(dim=-1)
216
+
217
+ def generate_audio(self, x, max_frames=None, return_all=True):
218
+ if max_frames is None:
219
+ max_frames = self.max_frames
220
+ all_tokens = []
221
+ remaining = max_frames
222
+ while remaining > 0:
223
+ chunk = min(remaining, self.chunk_frames)
224
+ tokens = self.forward(x, max_frames=chunk)
225
+ all_tokens.append(tokens)
226
+ remaining -= chunk
227
+ tokens = torch.cat(all_tokens, dim=1)
228
+ codec = self.load_codec(x.device if hasattr(x, 'device') else 'cuda')
229
+ with torch.no_grad():
230
+ waveform = codec(tokens)
231
+ return waveform, tokens
arbitor/encoders/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Encoder sidecar modules for the ARB system.
2
+
3
+ Each module exposes load(), encode(), decode() methods.
4
+ Loaded on-demand as frozen float/int8 sidecars.
5
+ """
6
+ from ..decoders import TinyNeuralCodec, MRFBlock
7
+ from .audio import AudioVQEncoder
8
+ from .pig_vae import load_vae, VAEWrapper
9
+ from .opensora_vae import load_opensora_vae, OpenSoraVAEWrapper
10
+ from .vae2d import VAE2DEncoder, load_vae2d
11
+ from .mel_frontend import MelSpectrogram3Band
arbitor/encoders/audio.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Audio training encoder β€” VQ encoder for TalkerHead target preparation.
2
+
3
+ Training-only component (~5M float params). Maps audio at 50 Hz to 289-class byte tokens.
4
+ TinyNeuralCodec (the decoder) is in arbitor.components β€” shared with TalkerHead.
5
+ """
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from ..components import TernaryEmbeddingTable
10
+ from ..kernel.ternary_scale import TernaryScaleTensor, TScaleType
11
+
12
+
13
+ class TernaryConv1d(nn.Module):
14
+ """Conv1d implemented as unfold + ternary linear projection."""
15
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
16
+ tscale_type=TScaleType.T32, bias=True):
17
+ super().__init__()
18
+ self.in_channels = in_channels
19
+ self.out_channels = out_channels
20
+ self.kernel_size = kernel_size
21
+ self.stride = stride
22
+ self.padding = padding
23
+ self.proj = TernaryScaleTensor(
24
+ in_channels * kernel_size,
25
+ out_channels,
26
+ tscale_type=tscale_type,
27
+ bias=bias,
28
+ )
29
+
30
+ def forward(self, x):
31
+ if self.padding:
32
+ x = F.pad(x, (self.padding, self.padding))
33
+ windows = x.unfold(2, self.kernel_size, self.stride)
34
+ windows = windows.permute(0, 2, 1, 3).reshape(x.size(0), -1, self.in_channels * self.kernel_size)
35
+ return self.proj(windows).permute(0, 2, 1)
36
+
37
+
38
+ class AudioVQEncoder(nn.Module):
39
+ """Encodes audio to discrete byte tokens at 50 Hz for TalkerHead training.
40
+
41
+ Input: [B, 1, T] audio waveform at 16 kHz
42
+ Output: [B, T/320, 288] logits over byte vocab (50 Hz frame rate)
43
+ """
44
+ def __init__(self, vocab=288, codebook_dim=64, downsample_ratios=(4, 4, 4, 5),
45
+ tscale_type=TScaleType.T32):
46
+ super().__init__()
47
+ in_ch = 1
48
+ self.down_blocks = nn.ModuleList()
49
+ for i, ratio in enumerate(downsample_ratios):
50
+ out_ch = min(128, 32 * (2 ** i))
51
+ block = nn.Sequential(
52
+ TernaryConv1d(in_ch, out_ch, kernel_size=ratio * 2, stride=ratio,
53
+ padding=ratio // 2, tscale_type=tscale_type),
54
+ nn.LeakyReLU(0.1),
55
+ TernaryConv1d(out_ch, out_ch, kernel_size=3, padding=1,
56
+ tscale_type=tscale_type),
57
+ nn.LeakyReLU(0.1),
58
+ )
59
+ self.down_blocks.append(block)
60
+ in_ch = out_ch
61
+ self.proj = TernaryScaleTensor(out_ch, codebook_dim, tscale_type=tscale_type, bias=True)
62
+ self.codebook = TernaryEmbeddingTable(vocab, codebook_dim, tscale_type=tscale_type)
63
+ self.out_proj = TernaryScaleTensor(codebook_dim, vocab, tscale_type=tscale_type, bias=True)
64
+
65
+ def forward(self, audio):
66
+ x = audio
67
+ for block in self.down_blocks:
68
+ x = block(x)
69
+ x = x.permute(0, 2, 1)
70
+ x = self.proj(x)
71
+ emb_idx = torch.arange(self.out_proj.out_dim, device=x.device)
72
+ emb = self.codebook(emb_idx).to(device=x.device, dtype=x.dtype)
73
+ dist = torch.cdist(x.float(), emb.unsqueeze(0).float())
74
+ indices = dist.argmin(dim=-1)
75
+ quantized = F.embedding(indices, emb)
76
+ quantized = x + (quantized - x).detach()
77
+ logits = self.out_proj(quantized)
78
+ return logits, indices
79
+
80
+ def encode(self, audio):
81
+ with torch.no_grad():
82
+ _, indices = self.forward(audio)
83
+ return indices
arbitor/encoders/mel_frontend.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Mel spectrogram frontend for audio-to-image conversion.
2
+
3
+ Converts [B, T] audio waveform to [B, 3, 64, T_mel] 3-channel mel
4
+ spectrogram (low/mid/high frequency bands β†’ RGB channels) suitable
5
+ for encoding through the 2D VAE encoder.
6
+
7
+ Band split:
8
+ - Channel 0 (low): 0-1000 Hz
9
+ - Channel 1 (mid): 1000-4000 Hz
10
+ - Channel 2 (high): 4000-8000 Hz
11
+ """
12
+ import torch
13
+ import torch.nn as nn
14
+ import torchaudio
15
+
16
+
17
+ class MelSpectrogram3Band(nn.Module):
18
+ """Audio β†’ 3-channel mel spectrogram (low/mid/high bands β†’ RGB).
19
+
20
+ Splits audio into 3 frequency bands and computes mel spectrograms
21
+ independently, stacked as RGB-like 3-channel image for VAE encoding.
22
+ """
23
+ def __init__(self, sample_rate=16000, n_fft=1024, hop_length=512,
24
+ n_mels=64, f_min=0, f_max=8000):
25
+ super().__init__()
26
+ self.sample_rate = sample_rate
27
+ self.n_fft = n_fft
28
+ self.hop_length = hop_length
29
+ self.n_mels = n_mels
30
+
31
+ self.mel_low = torchaudio.transforms.MelSpectrogram(
32
+ sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length,
33
+ n_mels=n_mels, f_min=f_min, f_max=1000,
34
+ )
35
+ self.mel_mid = torchaudio.transforms.MelSpectrogram(
36
+ sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length,
37
+ n_mels=n_mels, f_min=1000, f_max=4000,
38
+ )
39
+ self.mel_high = torchaudio.transforms.MelSpectrogram(
40
+ sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length,
41
+ n_mels=n_mels, f_min=4000, f_max=f_max,
42
+ )
43
+
44
+ def forward(self, waveform):
45
+ if waveform.dim() == 1:
46
+ waveform = waveform.unsqueeze(0)
47
+ elif waveform.dim() == 3:
48
+ if waveform.shape[1] == 1:
49
+ waveform = waveform.squeeze(1)
50
+ else:
51
+ waveform = waveform.mean(dim=1)
52
+
53
+ spec_low = torchaudio.functional.amplitude_to_DB(
54
+ self.mel_low(waveform), multiplier=10.0, amin=1e-10, db_multiplier=0.0, top_db=80.0
55
+ )
56
+ spec_mid = torchaudio.functional.amplitude_to_DB(
57
+ self.mel_mid(waveform), multiplier=10.0, amin=1e-10, db_multiplier=0.0, top_db=80.0
58
+ )
59
+ spec_high = torchaudio.functional.amplitude_to_DB(
60
+ self.mel_high(waveform), multiplier=10.0, amin=1e-10, db_multiplier=0.0, top_db=80.0
61
+ )
62
+
63
+ specs = []
64
+ for spec in [spec_low, spec_mid, spec_high]:
65
+ s_min = spec.amin(dim=(-2, -1), keepdim=True)
66
+ s_max = spec.amax(dim=(-2, -1), keepdim=True)
67
+ s_range = s_max - s_min + 1e-8
68
+ specs.append((spec - s_min) / s_range)
69
+
70
+ return torch.stack(specs, dim=1)
arbitor/encoders/models/__init__.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Local model loader β€” loads encoder models from local cache, falls back to HF.
2
+
3
+ Model directories (saved via model.save_pretrained()):
4
+ dinov2-small/ β€” facebook/dinov2-small (21M params, 384-dim) vision
5
+ vit-base/ β€” google/vit-base-patch16-224 (86M, 768-dim) vision fallback
6
+ moonshine-base/ β€” UsefulSensors/moonshine-base (62M, 416-dim) audio
7
+ pig-vae/ β€” Wan2.1 VAE checkpoint (84M params) video latent codec
8
+
9
+ Usage:
10
+ from arbitor.encoders.models import load_encoder, load_processor
11
+
12
+ model = load_encoder("dinov2-small")
13
+ processor = load_processor("dinov2-small", "image")
14
+
15
+ Download models:
16
+ python -m arbitor.encoders.models.download
17
+ """
18
+ import os
19
+
20
+ _MODELS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)))
21
+
22
+ # Map short names to (local_dir, hf_repo, type)
23
+ REGISTRY = {
24
+ "dinov2-small": {
25
+ "local": os.path.join(_MODELS_DIR, "dinov2-small"),
26
+ "hf": "facebook/dinov2-small",
27
+ "type": "auto",
28
+ },
29
+ "vit-base": {
30
+ "local": os.path.join(_MODELS_DIR, "vit-base"),
31
+ "hf": "google/vit-base-patch16-224",
32
+ "type": "auto",
33
+ },
34
+ "moonshine-base": {
35
+ "local": os.path.join(_MODELS_DIR, "moonshine-base"),
36
+ "hf": "UsefulSensors/moonshine-base",
37
+ "type": "auto",
38
+ },
39
+ }
40
+
41
+
42
+ def resolve_path(name: str) -> tuple[str, dict]:
43
+ """Return (local_path_or_hf_name, registry_entry)."""
44
+ entry = REGISTRY.get(name)
45
+ if entry is None:
46
+ raise ValueError(f"Unknown model: {name}. Options: {list(REGISTRY.keys())}")
47
+ if os.path.isdir(entry["local"]):
48
+ return entry["local"], entry
49
+ return entry["hf"], entry
50
+
51
+
52
+ def load_encoder(name: str, device=None, **kwargs):
53
+ """Load model from local cache, falling back to HuggingFace.
54
+
55
+ Args:
56
+ name: Short name ("dinov2-small", "vit-base", "moonshine-base")
57
+ device: Optional device to move model to (e.g. "cuda", "cpu")
58
+ Returns:
59
+ Loaded model in eval mode
60
+ """
61
+ from transformers import AutoModel
62
+
63
+ path, entry = resolve_path(name)
64
+ model = AutoModel.from_pretrained(path, low_cpu_mem_usage=True, **kwargs)
65
+ model.eval()
66
+ if device:
67
+ model = model.to(device)
68
+ return model
69
+
70
+
71
+ def load_processor(name: str, modality: str = "image"):
72
+ """Load processor (image processor or feature extractor) from local cache.
73
+
74
+ Args:
75
+ name: Short model name
76
+ modality: "image" for AutoImageProcessor, "audio" for AutoFeatureExtractor
77
+ Returns:
78
+ Processor instance
79
+ """
80
+ path, _ = resolve_path(name)
81
+ if modality == "audio":
82
+ from transformers import AutoFeatureExtractor
83
+ return AutoFeatureExtractor.from_pretrained(path)
84
+ else:
85
+ from transformers import AutoImageProcessor
86
+ return AutoImageProcessor.from_pretrained(path)
arbitor/encoders/models/download.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Download encoder models to local cache (arbitor/encoders/models/).
2
+
3
+ Usage:
4
+ python -m arbitor.encoders.models.download # Download all
5
+ python -m arbitor.encoders.models.download --model pig-vae --convert # Also convert GGUF→safetensors
6
+
7
+ Models are saved to arbitor/encoders/models/{name}/ and loaded from there
8
+ by sequencers and encoders β€” no HuggingFace download needed at runtime.
9
+ """
10
+ import os, sys, argparse, importlib
11
+
12
+ MODELS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)))
13
+
14
+ REGISTRY = {
15
+ "pig-vae": {
16
+ "type": "pth",
17
+ "hf_repo": "Wan-AI/Wan2.1-T2V-1.3B",
18
+ "hf_file": "Wan2.1_VAE.pth",
19
+ "desc": "Video VAE (16 latent channels, 84M params)",
20
+ "gguf_repo": "calcuis/pig-vae",
21
+ "gguf_file": "pig_wan_vae_fp32-f16.gguf",
22
+ },
23
+ "opensora-vae": {
24
+ "type": "pipeline",
25
+ "hf_repo": "hpcai-tech/OpenSora-VAE-v1.2",
26
+ "desc": "3D VAE (4 latent channels, 384M params, 8Γ— spatial + 4Γ— temporal compression)",
27
+ },
28
+ }
29
+
30
+
31
+ def convert_gguf_to_safetensors(name: str):
32
+ """Convert GGUF checkpoint to safetensors (pig-vae only)."""
33
+ dest = os.path.join(MODELS_DIR, name)
34
+ gguf_path = os.path.join(dest, f"{name.replace('-', '_')}_fp32-f16.gguf")
35
+ # Try alternate names
36
+ if not os.path.isfile(gguf_path):
37
+ alt = os.path.join(dest, "pig_wan_vae_fp32-f16.gguf")
38
+ if os.path.isfile(alt):
39
+ gguf_path = alt
40
+ if not os.path.isfile(gguf_path):
41
+ print(f" No GGUF file found in {dest}")
42
+ return False
43
+
44
+ print(f" Converting {gguf_path} to safetensors...", flush=True)
45
+ import gguf
46
+ import safetensors.torch
47
+
48
+ reader = gguf.GGUFReader(gguf_path)
49
+ state_dict = {t.name: __import__('torch').tensor(t.data) for t in reader.tensors}
50
+
51
+ safetensors_path = os.path.join(dest, "model.safetensors")
52
+ safetensors.torch.save_file(state_dict, safetensors_path)
53
+ size = os.path.getsize(safetensors_path)
54
+ print(f" βœ“ Saved {safetensors_path} ({size/1e6:.0f} MB, {len(state_dict)} tensors)", flush=True)
55
+ return True
56
+
57
+
58
+ def download_model(name: str, convert: bool = False):
59
+ """Download a single model to local cache."""
60
+ entry = REGISTRY.get(name)
61
+ if entry is None:
62
+ print(f"Unknown model: {name}. Options: {list(REGISTRY.keys())}")
63
+ return False
64
+
65
+ dest = os.path.join(MODELS_DIR, name)
66
+ os.makedirs(dest, exist_ok=True)
67
+
68
+ if entry["type"] == "auto":
69
+ from transformers import AutoModel
70
+ print(f"Downloading {name} ({entry['desc']})...", flush=True)
71
+ model = AutoModel.from_pretrained(entry["hf_repo"], low_cpu_mem_usage=True)
72
+ model.save_pretrained(dest)
73
+ print(f" βœ“ {name} saved to {dest}", flush=True)
74
+
75
+ if "dinov2" in name or "vit" in name:
76
+ from transformers import AutoImageProcessor
77
+ proc = AutoImageProcessor.from_pretrained(entry["hf_repo"])
78
+ proc.save_pretrained(dest)
79
+ elif "moonshine" in name:
80
+ from transformers import AutoFeatureExtractor
81
+ proc = AutoFeatureExtractor.from_pretrained(entry["hf_repo"])
82
+ proc.save_pretrained(dest)
83
+
84
+ elif entry["type"] == "pth":
85
+ from huggingface_hub import hf_hub_download
86
+ print(f"Downloading {name} ({entry['desc']})...", flush=True)
87
+ hf_hub_download(entry["hf_repo"], entry["hf_file"],
88
+ local_dir=dest, local_dir_use_symlinks=False)
89
+ print(f" βœ“ {name} .pth saved to {dest}", flush=True)
90
+
91
+ if convert:
92
+ convert_gguf_to_safetensors(name)
93
+
94
+ return True
95
+
96
+
97
+ def download_all(convert: bool = False):
98
+ success = 0
99
+ for name in REGISTRY:
100
+ if os.path.isdir(os.path.join(MODELS_DIR, name)):
101
+ existing = [f for f in os.listdir(os.path.join(MODELS_DIR, name))
102
+ if f.endswith(('.safetensors', '.pt', '.pth', '.gguf'))]
103
+ if existing:
104
+ print(f" β—‹ {name} already exists β€” skipping")
105
+ success += 1
106
+ continue
107
+ if download_model(name, convert=convert):
108
+ success += 1
109
+ print(f"\nDownloaded {success}/{len(REGISTRY)} models to {MODELS_DIR}")
110
+
111
+
112
+ if __name__ == "__main__":
113
+ parser = argparse.ArgumentParser(description="Download encoder models for ARB")
114
+ parser.add_argument("--model", type=str, default=None,
115
+ help=f"Model ({', '.join(REGISTRY.keys())})")
116
+ parser.add_argument("--convert", action="store_true",
117
+ help="Convert pig-vae GGUF→safetensors after download")
118
+ parser.add_argument("--list", action="store_true", help="List available models")
119
+ args = parser.parse_args()
120
+
121
+ if args.list:
122
+ for name, info in REGISTRY.items():
123
+ d = os.path.join(MODELS_DIR, name)
124
+ files = os.listdir(d) if os.path.isdir(d) else []
125
+ status = "βœ“" if any(f.endswith(('.safetensors', '.pt', '.pth')) for f in files) else "β—‹"
126
+ print(f" {status} {name:<20} {info['desc']}")
127
+ sys.exit(0)
128
+
129
+ if args.model:
130
+ download_model(args.model, convert=args.convert)
131
+ else:
132
+ download_all(convert=args.convert)
arbitor/encoders/models/opensora-vae/config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "VideoAutoencoderPipeline"
4
+ ],
5
+ "cal_loss": false,
6
+ "freeze_vae_2d": false,
7
+ "from_pretrained": null,
8
+ "micro_frame_size": 17,
9
+ "model_type": "VideoAutoencoderPipeline",
10
+ "scale": [
11
+ 3.85,
12
+ 2.32,
13
+ 2.33,
14
+ 3.06
15
+ ],
16
+ "shift": [
17
+ -0.1,
18
+ 0.34,
19
+ 0.27,
20
+ 0.98
21
+ ],
22
+ "torch_dtype": "float32",
23
+ "transformers_version": "4.36.2",
24
+ "vae_2d": {
25
+ "from_pretrained": "PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
26
+ "local_files_only": false,
27
+ "micro_batch_size": 4,
28
+ "subfolder": "vae",
29
+ "type": "VideoAutoencoderKL"
30
+ },
31
+ "vae_temporal": {
32
+ "from_pretrained": null,
33
+ "type": "VAE_Temporal_SD"
34
+ }
35
+ }
arbitor/encoders/models/opensora-vae/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:057f368f538ca04540c0728a6b8ef80ff529077c5a1c4ba810eb8ba017b8d7c9
3
+ size 1573430548
arbitor/encoders/models/pig-vae/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f5a9bb06d188abdf1585142159b84e2c6c8aaa3e64bb5c5792b7316ee2b44785
3
+ size 253879612
arbitor/encoders/opensora_vae.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Open-Sora 3D VAE v1.2 sidecar module.
2
+
3
+ Latent: [B, 4, T/4, H/8, W/8]
4
+ 8Γ— spatial compression, 4Γ— temporal compression.
5
+ Frozen float32 sidecar (no gradients).
6
+
7
+ Uses PixArt SDXL VAE (from diffusers) for spatial encoding/decoding.
8
+ Temporal VAE requires opensora package or custom module loading.
9
+ """
10
+ import os
11
+ import torch
12
+ import torch.nn as nn
13
+ from safetensors import safe_open
14
+
15
+ _LOCAL_VAE_DIR = os.path.join(os.path.dirname(__file__), "models", "opensora-vae")
16
+ _VAE_CONFIG = {
17
+ "scale": (3.85, 2.32, 2.33, 3.06),
18
+ "shift": (-0.10, 0.34, 0.27, 0.98),
19
+ "micro_frame_size": 17,
20
+ }
21
+ _QUANTO_CLASS_MARKERS = ("Q", "Quanto", "Quantized", "WeightQ")
22
+
23
+
24
+ def _mark_quantized_sidecar(module, quant_type, applied):
25
+ module._arb_quantize_requested = quant_type
26
+ module._arb_quantized_int8 = bool(applied and quant_type == "int8")
27
+ module._arb_quantized = bool(applied)
28
+ for p in module.parameters():
29
+ p.requires_grad = False
30
+ return module
31
+
32
+
33
+ def _has_quantized_modules(module):
34
+ return any(
35
+ any(marker in type(child).__name__ for marker in _QUANTO_CLASS_MARKERS)
36
+ for child in module.modules()
37
+ )
38
+
39
+
40
+ def _freeze_sidecar(model, quantize_requested=None, quantized=False):
41
+ _mark_quantized_sidecar(model, quantize_requested, quantized)
42
+ return model
43
+
44
+
45
+ def _quantize_int8_if_requested(model, quantize):
46
+ if quantize is None:
47
+ model = model.to(torch.bfloat16)
48
+ _mark_quantized_sidecar(model, quantize, False)
49
+ return model
50
+ try:
51
+ from optimum.quanto import quantize, freeze
52
+ qtype = {"int8": qint8}.get(quantize)
53
+ if qtype is None:
54
+ model = model.to(torch.bfloat16)
55
+ _mark_quantized_sidecar(model, quantize, False)
56
+ return model
57
+ quantize(model, weights=qtype)
58
+ freeze(model)
59
+ _mark_quantized_sidecar(model, quantize, _has_quantized_modules(model))
60
+ except ImportError:
61
+ model = model.to(torch.bfloat16)
62
+ _mark_quantized_sidecar(model, quantize, False)
63
+ return model
64
+
65
+
66
+ def load_opensora_vae(device="cuda", quantize=None):
67
+ """Load Open-Sora 3D VAE as frozen float32 sidecar.
68
+
69
+ Loads the spatial VAE from PixArt SDXL (diffusers) and the temporal
70
+ VAE from local safetensors. Falls back to spatial-only if temporal
71
+ module can't be loaded.
72
+ """
73
+ try:
74
+ from diffusers import AutoencoderKL
75
+ except ImportError:
76
+ raise RuntimeError("need diffusers for Open-Sora VAE spatial component")
77
+
78
+ # Load spatial VAE
79
+ spatial_vae = AutoencoderKL.from_pretrained(
80
+ "PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
81
+ subfolder="vae",
82
+ torch_dtype=torch.float32,
83
+ ).to(device)
84
+ spatial_vae.eval()
85
+
86
+ # Try to load temporal VAE weights
87
+ temporal_state = {}
88
+ safetensors_path = os.path.join(_LOCAL_VAE_DIR, "model.safetensors")
89
+ if os.path.isfile(safetensors_path):
90
+ with safe_open(safetensors_path, framework="pt") as f:
91
+ for k in f.keys():
92
+ if k.startswith("temporal_vae."):
93
+ temporal_state[k] = f.get_tensor(k)
94
+ if k.startswith("scale"):
95
+ temporal_state["scale"] = f.get_tensor(k)
96
+ if k.startswith("shift"):
97
+ temporal_state["shift"] = f.get_tensor(k)
98
+
99
+ _freeze_sidecar(spatial_vae, quantize, False)
100
+ return OpenSoraVAEWrapper(spatial_vae, temporal_state)
101
+
102
+
103
+ class OpenSoraVAEWrapper(nn.Module):
104
+ def __init__(self, spatial_vae, temporal_state=None):
105
+ super().__init__()
106
+ self.spatial = spatial_vae
107
+ self.latent_channels = 4
108
+ self.scale_factor_spatial = 8
109
+ self.scale_factor_temporal = 4
110
+ self.temporal_state = temporal_state
111
+ self.temporal_loaded = temporal_state is not None and len(temporal_state) > 0
112
+
113
+ @torch.no_grad()
114
+ def encode(self, video_tensor):
115
+ """Encode video tensor: [B,3,T,H,W] β†’ [B,4,T/4,H/8,W/8]."""
116
+ B, C, T, H, W = video_tensor.shape
117
+ # Process frame-by-frame through spatial VAE
118
+ latents = []
119
+ for t in range(T):
120
+ frame = video_tensor[:, :, t, :, :]
121
+ latent = self.spatial.encode(frame).latent_dist.sample()
122
+ latents.append(latent)
123
+ latent = torch.stack(latents, dim=2)
124
+ # Scale
125
+ latent = latent * 0.18215
126
+ # Temporal downsample (simple: take every 4th)
127
+ if latent.shape[2] >= 4:
128
+ latent = latent[:, :, ::4, :, :]
129
+ return latent
130
+
131
+ @torch.no_grad()
132
+ def decode(self, latents, num_frames=None):
133
+ """Decode latents: [B,4,T/4,H/8,W/8] β†’ [B,3,T,H,W]."""
134
+ B, C, T, H, W = latents.shape
135
+ # Temporal upsample (repeat each latent 4Γ—)
136
+ latents = latents.repeat_interleave(4, dim=2)
137
+ # Unscale
138
+ latents = latents / 0.18215
139
+ # Decode frame-by-frame
140
+ frames = []
141
+ for t in range(latents.shape[2]):
142
+ frame = latents[:, :, t, :, :]
143
+ decoded = self.spatial.decode(frame).sample
144
+ frames.append(decoded)
145
+ return torch.stack(frames, dim=2)
arbitor/encoders/opensora_vae_modules/autoencoder_2d.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from Flux
2
+ #
3
+ # Copyright 2024 Black Forest Labs
4
+
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+ # This source code is licensed under the license found in the
18
+ # LICENSE file in the root directory of this source tree.
19
+
20
+ from dataclasses import dataclass
21
+
22
+ import torch
23
+ from einops import rearrange
24
+ from torch import Tensor, nn
25
+ from torch.nn.functional import silu as swish
26
+
27
+ from opensora.registry import MODELS
28
+ from opensora.utils.ckpt import load_checkpoint
29
+
30
+ from .utils import DiagonalGaussianDistribution
31
+
32
+
33
+ @dataclass
34
+ class AutoEncoderConfig:
35
+ from_pretrained: str | None
36
+ cache_dir: str | None
37
+ resolution: int
38
+ in_channels: int
39
+ ch: int
40
+ out_ch: int
41
+ ch_mult: list[int]
42
+ num_res_blocks: int
43
+ z_channels: int
44
+ scale_factor: float
45
+ shift_factor: float
46
+ sample: bool = True
47
+
48
+
49
+ class AttnBlock(nn.Module):
50
+ def __init__(self, in_channels: int):
51
+ super().__init__()
52
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
53
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
54
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
55
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
56
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
57
+
58
+ def attention(self, h_: Tensor) -> Tensor:
59
+ h_ = self.norm(h_)
60
+ q = self.q(h_)
61
+ k = self.k(h_)
62
+ v = self.v(h_)
63
+ b, c, h, w = q.shape
64
+ q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
65
+ k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
66
+ v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
67
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
68
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w)
69
+
70
+ def forward(self, x: Tensor) -> Tensor:
71
+ return x + self.proj_out(self.attention(x))
72
+
73
+
74
+ class ResnetBlock(nn.Module):
75
+ def __init__(self, in_channels: int, out_channels: int):
76
+ super().__init__()
77
+ self.in_channels = in_channels
78
+ out_channels = in_channels if out_channels is None else out_channels
79
+ self.out_channels = out_channels
80
+
81
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
82
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
83
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
84
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
85
+ if self.in_channels != self.out_channels:
86
+ self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
87
+
88
+ def forward(self, x):
89
+ h = x
90
+ h = self.norm1(h)
91
+ h = swish(h)
92
+ h = self.conv1(h)
93
+
94
+ h = self.norm2(h)
95
+ h = swish(h)
96
+ h = self.conv2(h)
97
+
98
+ if self.in_channels != self.out_channels:
99
+ x = self.nin_shortcut(x)
100
+
101
+ return x + h
102
+
103
+
104
+ class Downsample(nn.Module):
105
+ def __init__(self, in_channels: int):
106
+ super().__init__()
107
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
108
+
109
+ def forward(self, x: Tensor) -> Tensor:
110
+ pad = (0, 1, 0, 1)
111
+ x = nn.functional.pad(x, pad, mode="constant", value=0)
112
+ return self.conv(x)
113
+
114
+
115
+ class Upsample(nn.Module):
116
+ def __init__(self, in_channels: int):
117
+ super().__init__()
118
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
119
+
120
+ def forward(self, x: Tensor) -> Tensor:
121
+ x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
122
+ return self.conv(x)
123
+
124
+
125
+ class Encoder(nn.Module):
126
+ def __init__(self, config: AutoEncoderConfig):
127
+ super().__init__()
128
+ self.ch = config.ch
129
+ self.num_resolutions = len(config.ch_mult)
130
+ self.num_res_blocks = config.num_res_blocks
131
+ self.resolution = config.resolution
132
+ self.in_channels = config.in_channels
133
+
134
+ # downsampling
135
+ self.conv_in = nn.Conv2d(config.in_channels, self.ch, kernel_size=3, stride=1, padding=1)
136
+
137
+ curr_res = config.resolution
138
+ in_ch_mult = (1,) + tuple(config.ch_mult)
139
+ self.in_ch_mult = in_ch_mult
140
+ self.down = nn.ModuleList()
141
+ block_in = self.ch
142
+ for i_level in range(self.num_resolutions):
143
+ block = nn.ModuleList()
144
+ attn = nn.ModuleList()
145
+ block_in = config.ch * in_ch_mult[i_level]
146
+ block_out = config.ch * config.ch_mult[i_level]
147
+ for _ in range(self.num_res_blocks):
148
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
149
+ block_in = block_out
150
+ down = nn.Module()
151
+ down.block = block
152
+ down.attn = attn
153
+ if i_level != self.num_resolutions - 1:
154
+ down.downsample = Downsample(block_in)
155
+ curr_res = curr_res // 2
156
+ self.down.append(down)
157
+
158
+ # middle
159
+ self.mid = nn.Module()
160
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
161
+ self.mid.attn_1 = AttnBlock(block_in)
162
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
163
+
164
+ # end
165
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
166
+ self.conv_out = nn.Conv2d(block_in, 2 * config.z_channels, kernel_size=3, stride=1, padding=1)
167
+
168
+ def forward(self, x: Tensor) -> Tensor:
169
+ # downsampling
170
+ hs = [self.conv_in(x)]
171
+ for i_level in range(self.num_resolutions):
172
+ for i_block in range(self.num_res_blocks):
173
+ h = self.down[i_level].block[i_block](hs[-1])
174
+ if len(self.down[i_level].attn) > 0:
175
+ h = self.down[i_level].attn[i_block](h)
176
+ hs.append(h)
177
+ if i_level != self.num_resolutions - 1:
178
+ hs.append(self.down[i_level].downsample(hs[-1]))
179
+
180
+ # middle
181
+ h = hs[-1]
182
+ h = self.mid.block_1(h)
183
+ h = self.mid.attn_1(h)
184
+ h = self.mid.block_2(h)
185
+ # end
186
+ h = self.norm_out(h)
187
+ h = swish(h)
188
+ h = self.conv_out(h)
189
+ return h
190
+
191
+
192
+ class Decoder(nn.Module):
193
+ def __init__(self, config: AutoEncoderConfig):
194
+ super().__init__()
195
+ self.ch = config.ch
196
+ self.num_resolutions = len(config.ch_mult)
197
+ self.num_res_blocks = config.num_res_blocks
198
+ self.resolution = config.resolution
199
+ self.in_channels = config.in_channels
200
+ self.ffactor = 2 ** (self.num_resolutions - 1)
201
+
202
+ block_in = config.ch * config.ch_mult[self.num_resolutions - 1]
203
+ curr_res = config.resolution // 2 ** (self.num_resolutions - 1)
204
+ self.z_shape = (1, config.z_channels, curr_res, curr_res)
205
+
206
+ # z to block_in
207
+ self.conv_in = nn.Conv2d(config.z_channels, block_in, kernel_size=3, stride=1, padding=1)
208
+
209
+ # middle
210
+ self.mid = nn.Module()
211
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
212
+ self.mid.attn_1 = AttnBlock(block_in)
213
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
214
+
215
+ # upsampling
216
+ self.up = nn.ModuleList()
217
+ for i_level in reversed(range(self.num_resolutions)):
218
+ block = nn.ModuleList()
219
+ attn = nn.ModuleList()
220
+ block_out = config.ch * config.ch_mult[i_level]
221
+ for _ in range(self.num_res_blocks + 1):
222
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
223
+ block_in = block_out
224
+ up = nn.Module()
225
+ up.block = block
226
+ up.attn = attn
227
+ if i_level != 0:
228
+ up.upsample = Upsample(block_in)
229
+ curr_res = curr_res * 2
230
+ self.up.insert(0, up) # prepend to get consistent order
231
+
232
+ # end
233
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
234
+ self.conv_out = nn.Conv2d(block_in, config.out_ch, kernel_size=3, stride=1, padding=1)
235
+
236
+ def forward(self, z: Tensor) -> Tensor:
237
+ # z to block_in
238
+ h = self.conv_in(z)
239
+
240
+ # middle
241
+ h = self.mid.block_1(h)
242
+ h = self.mid.attn_1(h)
243
+ h = self.mid.block_2(h)
244
+
245
+ # upsampling
246
+ for i_level in reversed(range(self.num_resolutions)):
247
+ for i_block in range(self.num_res_blocks + 1):
248
+ h = self.up[i_level].block[i_block](h)
249
+ if len(self.up[i_level].attn) > 0:
250
+ h = self.up[i_level].attn[i_block](h)
251
+ if i_level != 0:
252
+ h = self.up[i_level].upsample(h)
253
+
254
+ # end
255
+ h = self.norm_out(h)
256
+ h = swish(h)
257
+ return self.conv_out(h)
258
+
259
+
260
+ class AutoEncoder(nn.Module):
261
+ def __init__(self, config: AutoEncoderConfig):
262
+ super().__init__()
263
+ self.encoder = Encoder(config)
264
+ self.decoder = Decoder(config)
265
+ self.scale_factor = config.scale_factor
266
+ self.shift_factor = config.shift_factor
267
+ self.sample = config.sample
268
+
269
+ def encode_(self, x: Tensor) -> tuple[Tensor, DiagonalGaussianDistribution]:
270
+ T = x.shape[2]
271
+ x = rearrange(x, "b c t h w -> (b t) c h w")
272
+ params = self.encoder(x)
273
+ params = rearrange(params, "(b t) c h w -> b c t h w", t=T)
274
+ posterior = DiagonalGaussianDistribution(params)
275
+ if self.sample:
276
+ z = posterior.sample()
277
+ else:
278
+ z = posterior.mode()
279
+ z = self.scale_factor * (z - self.shift_factor)
280
+ return z, posterior
281
+
282
+ def encode(self, x: Tensor) -> Tensor:
283
+ return self.encode_(x)[0]
284
+
285
+ def decode(self, z: Tensor) -> Tensor:
286
+ T = z.shape[2]
287
+ z = rearrange(z, "b c t h w -> (b t) c h w")
288
+ z = z / self.scale_factor + self.shift_factor
289
+ x = self.decoder(z)
290
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=T)
291
+ return x
292
+
293
+ def forward(self, x: Tensor) -> tuple[Tensor, DiagonalGaussianDistribution, Tensor]:
294
+ # encode
295
+ x.shape[2]
296
+ z, posterior = self.encode_(x)
297
+ # decode
298
+ x_rec = self.decode(z)
299
+
300
+ return x_rec, posterior, z
301
+
302
+ def get_last_layer(self):
303
+ return self.decoder.conv_out.weight
304
+
305
+
306
+ @MODELS.register_module("autoencoder_2d")
307
+ def AutoEncoderFlux(
308
+ from_pretrained: str,
309
+ cache_dir=None,
310
+ resolution=256,
311
+ in_channels=3,
312
+ ch=128,
313
+ out_ch=3,
314
+ ch_mult=[1, 2, 4, 4],
315
+ num_res_blocks=2,
316
+ z_channels=16,
317
+ scale_factor=0.3611,
318
+ shift_factor=0.1159,
319
+ device_map: str | torch.device = "cuda",
320
+ torch_dtype: torch.dtype = torch.bfloat16,
321
+ ) -> AutoEncoder:
322
+ config = AutoEncoderConfig(
323
+ from_pretrained=from_pretrained,
324
+ cache_dir=cache_dir,
325
+ resolution=resolution,
326
+ in_channels=in_channels,
327
+ ch=ch,
328
+ out_ch=out_ch,
329
+ ch_mult=ch_mult,
330
+ num_res_blocks=num_res_blocks,
331
+ z_channels=z_channels,
332
+ scale_factor=scale_factor,
333
+ shift_factor=shift_factor,
334
+ )
335
+ with torch.device(device_map):
336
+ model = AutoEncoder(config).to(torch_dtype)
337
+ if from_pretrained:
338
+ model = load_checkpoint(model, from_pretrained, cache_dir=cache_dir, device_map=device_map)
339
+ return model
arbitor/encoders/opensora_vae_modules/autoencoder_kl_causal_3d.py ADDED
@@ -0,0 +1,638 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from diffusers==0.29.2 and HunyuanVideo
2
+ #
3
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+ # Copyright 2024 HunyuanVideo
18
+ #
19
+ # This source code is licensed under the license found in the
20
+ # LICENSE file in the root directory of this source tree.
21
+
22
+
23
+ from dataclasses import dataclass
24
+ from typing import Dict, Optional, Tuple, Union
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
29
+
30
+ from opensora.registry import MODELS
31
+ from opensora.utils.ckpt import load_checkpoint
32
+
33
+ try:
34
+ # This diffusers is modified and packed in the mirror.
35
+ from diffusers.loaders import FromOriginalVAEMixin
36
+ except ImportError:
37
+ # Use this to be compatible with the original diffusers.
38
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin as FromOriginalVAEMixin
39
+
40
+ from diffusers.models.attention_processor import (
41
+ ADDED_KV_ATTENTION_PROCESSORS,
42
+ CROSS_ATTENTION_PROCESSORS,
43
+ Attention,
44
+ AttentionProcessor,
45
+ AttnAddedKVProcessor,
46
+ AttnProcessor,
47
+ )
48
+ from diffusers.models.modeling_utils import ModelMixin
49
+ from diffusers.utils.accelerate_utils import apply_forward_hook
50
+
51
+ from opensora.models.hunyuan_vae.vae import (
52
+ DecoderCausal3D,
53
+ DecoderOutput,
54
+ DiagonalGaussianDistribution,
55
+ EncoderCausal3D,
56
+ )
57
+
58
+
59
+ @dataclass
60
+ class AutoEncoder3DConfig:
61
+ from_pretrained: str | None
62
+ act_fn: str = "silu"
63
+ in_channels: int = 3
64
+ out_channels: int = 3
65
+ latent_channels: int = 16
66
+ layers_per_block: int = 2
67
+ norm_num_groups: int = 32
68
+ scale_factor: float = 0.476986
69
+ shift_factor: float = 0
70
+ time_compression_ratio: int = 4
71
+ spatial_compression_ratio: int = 8
72
+ mid_block_add_attention: bool = True
73
+ block_out_channels: tuple[int] = (128, 256, 512, 512)
74
+ sample_size: int = 256
75
+ sample_tsize: int = 64
76
+ use_slicing: bool = False
77
+ use_spatial_tiling: bool = False
78
+ use_temporal_tiling: bool = False
79
+ tile_overlap_factor: float = 0.25
80
+ dropout: float = 0.0
81
+ channel: bool = False
82
+
83
+
84
+ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
85
+ r"""
86
+ A VAE model with KL loss for encoding images/videos into latents and decoding latent representations into images/videos.
87
+
88
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
89
+ for all models (such as downloading or saving).
90
+ """
91
+
92
+ _supports_gradient_checkpointing = True
93
+
94
+ @register_to_config
95
+ def __init__(self, config: AutoEncoder3DConfig):
96
+ super().__init__()
97
+
98
+ self.scale_factor = config.scale_factor
99
+ self.shift_factor = config.shift_factor
100
+
101
+ self.time_compression_ratio = config.time_compression_ratio
102
+ self.spatial_compression_ratio = config.spatial_compression_ratio
103
+ self.z_channels = config.latent_channels
104
+
105
+ self.encoder = EncoderCausal3D(
106
+ in_channels=config.in_channels,
107
+ out_channels=config.latent_channels,
108
+ block_out_channels=config.block_out_channels,
109
+ layers_per_block=config.layers_per_block,
110
+ act_fn=config.act_fn,
111
+ norm_num_groups=config.norm_num_groups,
112
+ double_z=True,
113
+ time_compression_ratio=config.time_compression_ratio,
114
+ spatial_compression_ratio=config.spatial_compression_ratio,
115
+ mid_block_add_attention=config.mid_block_add_attention,
116
+ dropout=config.dropout,
117
+ )
118
+
119
+ self.decoder = DecoderCausal3D(
120
+ in_channels=config.latent_channels,
121
+ out_channels=config.out_channels,
122
+ block_out_channels=config.block_out_channels,
123
+ layers_per_block=config.layers_per_block,
124
+ norm_num_groups=config.norm_num_groups,
125
+ act_fn=config.act_fn,
126
+ time_compression_ratio=config.time_compression_ratio,
127
+ spatial_compression_ratio=config.spatial_compression_ratio,
128
+ mid_block_add_attention=config.mid_block_add_attention,
129
+ dropout=config.dropout,
130
+ )
131
+
132
+ self.quant_conv = nn.Conv3d(2 * config.latent_channels, 2 * config.latent_channels, kernel_size=1)
133
+ self.post_quant_conv = nn.Conv3d(config.latent_channels, config.latent_channels, kernel_size=1)
134
+
135
+ self.use_slicing = config.use_slicing
136
+ self.use_spatial_tiling = config.use_spatial_tiling
137
+ self.use_temporal_tiling = config.use_temporal_tiling
138
+
139
+ # only relevant if vae tiling is enabled
140
+ self.tile_sample_min_tsize = config.sample_tsize
141
+ self.tile_latent_min_tsize = config.sample_tsize // config.time_compression_ratio
142
+
143
+ self.tile_sample_min_size = config.sample_size
144
+ sample_size = config.sample_size[0] if isinstance(config.sample_size, (list, tuple)) else config.sample_size
145
+ self.tile_latent_min_size = int(sample_size / (2 ** (len(config.block_out_channels) - 1)))
146
+ self.tile_overlap_factor = config.tile_overlap_factor
147
+
148
+ def enable_temporal_tiling(self, use_tiling: bool = True):
149
+ self.use_temporal_tiling = use_tiling
150
+
151
+ def disable_temporal_tiling(self):
152
+ self.enable_temporal_tiling(False)
153
+
154
+ def enable_spatial_tiling(self, use_tiling: bool = True):
155
+ self.use_spatial_tiling = use_tiling
156
+
157
+ def disable_spatial_tiling(self):
158
+ self.enable_spatial_tiling(False)
159
+
160
+ def enable_tiling(self, use_tiling: bool = True):
161
+ r"""
162
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
163
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
164
+ processing larger videos.
165
+ """
166
+ self.enable_spatial_tiling(use_tiling)
167
+ self.enable_temporal_tiling(use_tiling)
168
+
169
+ def disable_tiling(self):
170
+ r"""
171
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
172
+ decoding in one step.
173
+ """
174
+ self.disable_spatial_tiling()
175
+ self.disable_temporal_tiling()
176
+
177
+ def enable_slicing(self):
178
+ r"""
179
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
180
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
181
+ """
182
+ self.use_slicing = True
183
+
184
+ def disable_slicing(self):
185
+ r"""
186
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
187
+ decoding in one step.
188
+ """
189
+ self.use_slicing = False
190
+
191
+ @property
192
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
193
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
194
+ r"""
195
+ Returns:
196
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
197
+ indexed by its weight name.
198
+ """
199
+ # set recursively
200
+ processors = {}
201
+
202
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
203
+ if hasattr(module, "get_processor"):
204
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
205
+
206
+ for sub_name, child in module.named_children():
207
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
208
+
209
+ return processors
210
+
211
+ for name, module in self.named_children():
212
+ fn_recursive_add_processors(name, module, processors)
213
+
214
+ return processors
215
+
216
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
217
+ def set_attn_processor(
218
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
219
+ ):
220
+ r"""
221
+ Sets the attention processor to use to compute attention.
222
+
223
+ Parameters:
224
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
225
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
226
+ for **all** `Attention` layers.
227
+
228
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
229
+ processor. This is strongly recommended when setting trainable attention processors.
230
+
231
+ """
232
+ count = len(self.attn_processors.keys())
233
+
234
+ if isinstance(processor, dict) and len(processor) != count:
235
+ raise ValueError(
236
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
237
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
238
+ )
239
+
240
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
241
+ if hasattr(module, "set_processor"):
242
+ if not isinstance(processor, dict):
243
+ module.set_processor(processor, _remove_lora=_remove_lora)
244
+ else:
245
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
246
+
247
+ for sub_name, child in module.named_children():
248
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
249
+
250
+ for name, module in self.named_children():
251
+ fn_recursive_attn_processor(name, module, processor)
252
+
253
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
254
+ def set_default_attn_processor(self):
255
+ """
256
+ Disables custom attention processors and sets the default attention implementation.
257
+ """
258
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
259
+ processor = AttnAddedKVProcessor()
260
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
261
+ processor = AttnProcessor()
262
+ else:
263
+ raise ValueError(
264
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
265
+ )
266
+
267
+ self.set_attn_processor(processor, _remove_lora=True)
268
+
269
+ @apply_forward_hook
270
+ def encode(
271
+ self,
272
+ x: torch.FloatTensor,
273
+ sample_posterior: bool = True,
274
+ return_posterior: bool = False,
275
+ generator: Optional[torch.Generator] = None,
276
+ ) -> Union[torch.FloatTensor, Tuple[DiagonalGaussianDistribution]]:
277
+ """
278
+ Encode a batch of images/videos into latents.
279
+
280
+ Args:
281
+ x (`torch.FloatTensor`): Input batch of images/videos.
282
+ return_dict (`bool`, *optional*, defaults to `True`):
283
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
284
+
285
+ Returns:
286
+ The latent representations of the encoded images/videos. If `return_dict` is True, a
287
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
288
+ """
289
+ assert len(x.shape) == 5, "The input tensor should have 5 dimensions."
290
+
291
+ if self.use_temporal_tiling and x.shape[2] > self.tile_sample_min_tsize:
292
+ posterior = self.temporal_tiled_encode(x)
293
+ elif self.use_spatial_tiling and (
294
+ x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size
295
+ ):
296
+ posterior = self.spatial_tiled_encode(x)
297
+ else:
298
+ if self.use_slicing and x.shape[0] > 1:
299
+ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
300
+ h = torch.cat(encoded_slices)
301
+ else:
302
+ h = self.encoder(x)
303
+ moments = self.quant_conv(h)
304
+ posterior = DiagonalGaussianDistribution(moments)
305
+
306
+ if sample_posterior:
307
+ z = posterior.sample(generator=generator)
308
+ else:
309
+ z = posterior.mode()
310
+
311
+ z = self.scale_factor * (z - self.shift_factor) # shift & scale
312
+
313
+ if return_posterior:
314
+ return z, posterior
315
+ else:
316
+ return z
317
+
318
+ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
319
+ assert len(z.shape) == 5, "The input tensor should have 5 dimensions."
320
+
321
+ if self.use_temporal_tiling and z.shape[2] > self.tile_latent_min_tsize:
322
+ return self.temporal_tiled_decode(z, return_dict=return_dict)
323
+
324
+ if self.use_spatial_tiling and (
325
+ z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size
326
+ ):
327
+ return self.spatial_tiled_decode(z, return_dict=return_dict)
328
+
329
+ z = self.post_quant_conv(z)
330
+ dec = self.decoder(z)
331
+
332
+ if not return_dict:
333
+ return (dec,)
334
+
335
+ return DecoderOutput(sample=dec)
336
+
337
+ @apply_forward_hook
338
+ def decode(self, z: torch.FloatTensor) -> torch.FloatTensor:
339
+ """
340
+ Decode a batch of images/videos.
341
+
342
+ Args:
343
+ z (`torch.FloatTensor`): Input batch of latent vectors.
344
+
345
+ Returns:
346
+ [`~models.vae.DecoderOutput`] or `tuple`:
347
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
348
+ returned.
349
+
350
+ """
351
+ z = z / self.scale_factor + self.shift_factor # scale & shift
352
+
353
+ if self.use_slicing and z.shape[0] > 1:
354
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
355
+ decoded = torch.cat(decoded_slices)
356
+ else:
357
+ decoded = self._decode(z).sample
358
+ return decoded
359
+
360
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
361
+ blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
362
+ for y in range(blend_extent):
363
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
364
+ y / blend_extent
365
+ )
366
+ return b
367
+
368
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
369
+ blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
370
+ for x in range(blend_extent):
371
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
372
+ x / blend_extent
373
+ )
374
+ return b
375
+
376
+ def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
377
+ blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
378
+ for x in range(blend_extent):
379
+ b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (
380
+ x / blend_extent
381
+ )
382
+ return b
383
+
384
+ def spatial_tiled_encode(self, x: torch.FloatTensor, return_moments: bool = False) -> DiagonalGaussianDistribution:
385
+ r"""Encode a batch of images/videos using a tiled encoder.
386
+
387
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
388
+ steps. This is useful to keep memory use constant regardless of image/videos size. The end result of tiled encoding is
389
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
390
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
391
+ output, but they should be much less noticeable.
392
+
393
+ Args:
394
+ x (`torch.FloatTensor`): Input batch of images/videos.
395
+ return_dict (`bool`, *optional*, defaults to `True`):
396
+ Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
397
+
398
+ Returns:
399
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
400
+ If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
401
+ `tuple` is returned.
402
+ """
403
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
404
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
405
+ row_limit = self.tile_latent_min_size - blend_extent
406
+
407
+ # Split video into tiles and encode them separately.
408
+ rows = []
409
+ for i in range(0, x.shape[-2], overlap_size):
410
+ row = []
411
+ for j in range(0, x.shape[-1], overlap_size):
412
+ tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
413
+ tile = self.encoder(tile)
414
+ tile = self.quant_conv(tile)
415
+ row.append(tile)
416
+ rows.append(row)
417
+ result_rows = []
418
+ for i, row in enumerate(rows):
419
+ result_row = []
420
+ for j, tile in enumerate(row):
421
+ # blend the above tile and the left tile
422
+ # to the current tile and add the current tile to the result row
423
+ if i > 0:
424
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
425
+ if j > 0:
426
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
427
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
428
+ result_rows.append(torch.cat(result_row, dim=-1))
429
+
430
+ moments = torch.cat(result_rows, dim=-2)
431
+ if return_moments:
432
+ return moments
433
+ posterior = DiagonalGaussianDistribution(moments)
434
+ return posterior
435
+
436
+ def spatial_tiled_decode(
437
+ self, z: torch.FloatTensor, return_dict: bool = True
438
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
439
+ r"""
440
+ Decode a batch of images/videos using a tiled decoder.
441
+
442
+ Args:
443
+ z (`torch.FloatTensor`): Input batch of latent vectors.
444
+ return_dict (`bool`, *optional*, defaults to `True`):
445
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
446
+
447
+ Returns:
448
+ [`~models.vae.DecoderOutput`] or `tuple`:
449
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
450
+ returned.
451
+ """
452
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
453
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
454
+ row_limit = self.tile_sample_min_size - blend_extent
455
+
456
+ # Split z into overlapping tiles and decode them separately.
457
+ # The tiles have an overlap to avoid seams between tiles.
458
+ rows = []
459
+ for i in range(0, z.shape[-2], overlap_size):
460
+ row = []
461
+ for j in range(0, z.shape[-1], overlap_size):
462
+ tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
463
+ tile = self.post_quant_conv(tile)
464
+ decoded = self.decoder(tile)
465
+ row.append(decoded)
466
+ rows.append(row)
467
+ result_rows = []
468
+ for i, row in enumerate(rows):
469
+ result_row = []
470
+ for j, tile in enumerate(row):
471
+ # blend the above tile and the left tile
472
+ # to the current tile and add the current tile to the result row
473
+ if i > 0:
474
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
475
+ if j > 0:
476
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
477
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
478
+ result_rows.append(torch.cat(result_row, dim=-1))
479
+
480
+ dec = torch.cat(result_rows, dim=-2)
481
+ if not return_dict:
482
+ return (dec,)
483
+
484
+ return DecoderOutput(sample=dec)
485
+
486
+ def temporal_tiled_encode(self, x: torch.FloatTensor) -> DiagonalGaussianDistribution:
487
+ B, C, T, H, W = x.shape
488
+ overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor))
489
+ blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor)
490
+ t_limit = self.tile_latent_min_tsize - blend_extent
491
+
492
+ # Split the video into tiles and encode them separately.
493
+ row = []
494
+ for i in range(0, T, overlap_size):
495
+ tile = x[:, :, i : i + self.tile_sample_min_tsize + 1, :, :]
496
+ if self.use_spatial_tiling and (
497
+ tile.shape[-1] > self.tile_sample_min_size or tile.shape[-2] > self.tile_sample_min_size
498
+ ):
499
+ tile = self.spatial_tiled_encode(tile, return_moments=True)
500
+ else:
501
+ tile = self.encoder(tile)
502
+ tile = self.quant_conv(tile)
503
+ if i > 0:
504
+ tile = tile[:, :, 1:, :, :]
505
+ row.append(tile)
506
+ result_row = []
507
+ for i, tile in enumerate(row):
508
+ if i > 0:
509
+ tile = self.blend_t(row[i - 1], tile, blend_extent)
510
+ result_row.append(tile[:, :, :t_limit, :, :])
511
+ else:
512
+ result_row.append(tile[:, :, : t_limit + 1, :, :])
513
+ moments = torch.cat(result_row, dim=2)
514
+ posterior = DiagonalGaussianDistribution(moments)
515
+ return posterior
516
+
517
+ def temporal_tiled_decode(
518
+ self, z: torch.FloatTensor, return_dict: bool = True
519
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
520
+ # Split z into overlapping tiles and decode them separately.
521
+
522
+ B, C, T, H, W = z.shape
523
+ overlap_size = int(self.tile_latent_min_tsize * (1 - self.tile_overlap_factor))
524
+ blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor)
525
+ t_limit = self.tile_sample_min_tsize - blend_extent
526
+
527
+ row = []
528
+ for i in range(0, T, overlap_size):
529
+ tile = z[:, :, i : i + self.tile_latent_min_tsize + 1, :, :]
530
+ if self.use_spatial_tiling and (
531
+ tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size
532
+ ):
533
+ decoded = self.spatial_tiled_decode(tile, return_dict=True).sample
534
+ else:
535
+ tile = self.post_quant_conv(tile)
536
+ decoded = self.decoder(tile)
537
+ if i > 0:
538
+ decoded = decoded[:, :, 1:, :, :]
539
+ row.append(decoded)
540
+ result_row = []
541
+ for i, tile in enumerate(row):
542
+ if i > 0:
543
+ tile = self.blend_t(row[i - 1], tile, blend_extent)
544
+ result_row.append(tile[:, :, :t_limit, :, :])
545
+ else:
546
+ result_row.append(tile[:, :, : t_limit + 1, :, :])
547
+
548
+ dec = torch.cat(result_row, dim=2)
549
+ if not return_dict:
550
+ return (dec,)
551
+
552
+ return DecoderOutput(sample=dec)
553
+
554
+ def forward(
555
+ self,
556
+ sample: torch.FloatTensor,
557
+ sample_posterior: bool = True,
558
+ generator: Optional[torch.Generator] = None,
559
+ ) -> Tuple[torch.FloatTensor, DiagonalGaussianDistribution, torch.FloatTensor]:
560
+ r"""
561
+ Args:
562
+ sample (`torch.FloatTensor`): Input sample.
563
+ sample_posterior (`bool`, *optional*, defaults to `False`):
564
+ Whether to sample from the posterior.
565
+ return_dict (`bool`, *optional*, defaults to `True`):
566
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
567
+ """
568
+ x = sample
569
+ z, posterior = self.encode(x, return_posterior=True, sample_posterior=sample_posterior, generator=generator)
570
+ dec = self.decode(z)
571
+
572
+ return (dec, posterior, z)
573
+
574
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
575
+ def fuse_qkv_projections(self):
576
+ """
577
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
578
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
579
+
580
+ <Tip warning={true}>
581
+
582
+ This API is πŸ§ͺ experimental.
583
+
584
+ </Tip>
585
+ """
586
+ self.original_attn_processors = None
587
+
588
+ for _, attn_processor in self.attn_processors.items():
589
+ if "Added" in str(attn_processor.__class__.__name__):
590
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
591
+
592
+ self.original_attn_processors = self.attn_processors
593
+
594
+ for module in self.modules():
595
+ if isinstance(module, Attention):
596
+ module.fuse_projections(fuse=True)
597
+
598
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
599
+ def unfuse_qkv_projections(self):
600
+ """Disables the fused QKV projection if enabled.
601
+
602
+ <Tip warning={true}>
603
+
604
+ This API is πŸ§ͺ experimental.
605
+
606
+ </Tip>
607
+
608
+ """
609
+ if self.original_attn_processors is not None:
610
+ self.set_attn_processor(self.original_attn_processors)
611
+
612
+ def get_last_layer(self):
613
+ return self.decoder.conv_out.conv.weight
614
+
615
+ def get_latent_size(self, input_size: list[int]) -> list[int]:
616
+ latent_size = []
617
+ # T
618
+ latent_size.append((input_size[0] - 1) // self.time_compression_ratio + 1)
619
+ # H, w
620
+ for i in range(1, 3):
621
+ latent_size.append((input_size[i] - 1) // self.spatial_compression_ratio + 1)
622
+ return latent_size
623
+
624
+
625
+ @MODELS.register_module("hunyuan_vae")
626
+ def CausalVAE3D_HUNYUAN(
627
+ from_pretrained: str = None,
628
+ device_map: str | torch.device = "cuda",
629
+ torch_dtype: torch.dtype = torch.bfloat16,
630
+ **kwargs,
631
+ ) -> AutoencoderKLCausal3D:
632
+ config = AutoEncoder3DConfig(from_pretrained=from_pretrained, **kwargs)
633
+ with torch.device(device_map):
634
+ model = AutoencoderKLCausal3D(config).to(torch_dtype)
635
+ if from_pretrained:
636
+ model = load_checkpoint(model, from_pretrained, device_map=device_map, strict=True)
637
+
638
+ return model
arbitor/encoders/opensora_vae_modules/registry.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+
3
+ import torch.nn as nn
4
+ from mmengine.registry import Registry
5
+
6
+
7
+ def build_module(module: dict | nn.Module, builder: Registry, **kwargs) -> nn.Module | None:
8
+ """Build module from config or return the module itself.
9
+
10
+ Args:
11
+ module (dict | nn.Module): The module to build.
12
+ builder (Registry): The registry to build module.
13
+ *args, **kwargs: Arguments passed to build function.
14
+
15
+ Returns:
16
+ (None | nn.Module): The created model.
17
+ """
18
+ if module is None:
19
+ return None
20
+ if isinstance(module, dict):
21
+ cfg = deepcopy(module)
22
+ for k, v in kwargs.items():
23
+ cfg[k] = v
24
+ return builder.build(cfg)
25
+ elif isinstance(module, nn.Module):
26
+ return module
27
+ elif module is None:
28
+ return None
29
+ else:
30
+ raise TypeError(f"Only support dict and nn.Module, but got {type(module)}.")
31
+
32
+
33
+ MODELS = Registry(
34
+ "model",
35
+ locations=["opensora.models"],
36
+ )
37
+
38
+ DATASETS = Registry(
39
+ "dataset",
40
+ locations=["opensora.datasets"],
41
+ )
arbitor/encoders/opensora_vae_modules/unet_causal_3d_blocks.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from diffusers==0.29.2 and HunyuanVideo
2
+ #
3
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # #
17
+ # Copyright 2024 HunyuanVideo
18
+ #
19
+ # This source code is licensed under the license found in the
20
+ # LICENSE file in the root directory of this source tree.
21
+
22
+ from typing import Optional, Tuple, Union
23
+
24
+ import numpy as np
25
+ import torch
26
+ import torch.nn.functional as F
27
+ from diffusers.models.activations import get_activation
28
+ from diffusers.models.attention_processor import Attention
29
+ from diffusers.utils import logging
30
+ from einops import rearrange
31
+ from torch import nn
32
+
33
+ from opensora.acceleration.checkpoint import auto_grad_checkpoint
34
+ from opensora.models.vae.utils import ChannelChunkConv3d, get_conv3d_n_chunks
35
+
36
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
37
+
38
+ INTERPOLATE_NUMEL_LIMIT = 2**31 - 1
39
+
40
+
41
+ def chunk_nearest_interpolate(
42
+ x: torch.Tensor,
43
+ scale_factor,
44
+ ):
45
+ limit = INTERPOLATE_NUMEL_LIMIT // np.prod(scale_factor)
46
+ n_chunks = get_conv3d_n_chunks(x.numel(), x.size(1), limit)
47
+ x_chunks = x.chunk(n_chunks, dim=1)
48
+ x_chunks = [F.interpolate(x_chunk, scale_factor=scale_factor, mode="nearest") for x_chunk in x_chunks]
49
+ return torch.cat(x_chunks, dim=1)
50
+
51
+
52
+ def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, device, batch_size: int = None):
53
+ seq_len = n_frame * n_hw
54
+ mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
55
+ for i in range(seq_len):
56
+ i_frame = i // n_hw
57
+ mask[i, : (i_frame + 1) * n_hw] = 0
58
+ if batch_size is not None:
59
+ mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
60
+ return mask
61
+
62
+
63
+ class CausalConv3d(nn.Module):
64
+ """
65
+ Implements a causal 3D convolution layer where each position only depends on previous timesteps and current spatial locations.
66
+ This maintains temporal causality in video generation tasks.
67
+ """
68
+
69
+ def __init__(
70
+ self,
71
+ chan_in,
72
+ chan_out,
73
+ kernel_size: Union[int, Tuple[int, int, int]],
74
+ stride: Union[int, Tuple[int, int, int]] = 1,
75
+ dilation: Union[int, Tuple[int, int, int]] = 1,
76
+ pad_mode="replicate",
77
+ **kwargs,
78
+ ):
79
+ super().__init__()
80
+
81
+ self.pad_mode = pad_mode
82
+ padding = (
83
+ kernel_size // 2,
84
+ kernel_size // 2,
85
+ kernel_size // 2,
86
+ kernel_size // 2,
87
+ kernel_size - 1,
88
+ 0,
89
+ ) # W, H, T
90
+ self.time_causal_padding = padding
91
+
92
+ self.conv = ChannelChunkConv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
93
+
94
+ def forward(self, x):
95
+ x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
96
+ return self.conv(x)
97
+
98
+ class UpsampleCausal3D(nn.Module):
99
+ """
100
+ A 3D upsampling layer with an optional convolution.
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ channels: int,
106
+ out_channels: Optional[int] = None,
107
+ kernel_size: int = 3,
108
+ bias=True,
109
+ upsample_factor=(2, 2, 2),
110
+ ):
111
+ super().__init__()
112
+ self.channels = channels
113
+ self.out_channels = out_channels or channels
114
+ self.upsample_factor = upsample_factor
115
+ self.conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, bias=bias)
116
+
117
+ def forward(
118
+ self,
119
+ input_tensor: torch.FloatTensor,
120
+ ) -> torch.FloatTensor:
121
+ assert input_tensor.shape[1] == self.channels
122
+
123
+ #######################
124
+ # handle hidden states
125
+ #######################
126
+ hidden_states = input_tensor
127
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
128
+ # dtype = hidden_states.dtype
129
+ # if dtype == torch.bfloat16:
130
+ # hidden_states = hidden_states.to(torch.float32)
131
+
132
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
133
+ if hidden_states.shape[0] >= 64:
134
+ hidden_states = hidden_states.contiguous()
135
+
136
+ # interpolate H & W only for the first frame; interpolate T & H & W for the rest
137
+ T = hidden_states.size(2)
138
+ first_h, other_h = hidden_states.split((1, T - 1), dim=2)
139
+ # process non-1st frames
140
+ if T > 1:
141
+ other_h = chunk_nearest_interpolate(other_h, scale_factor=self.upsample_factor)
142
+ # proess 1st fram
143
+ first_h = first_h.squeeze(2)
144
+ first_h = chunk_nearest_interpolate(first_h, scale_factor=self.upsample_factor[1:])
145
+ first_h = first_h.unsqueeze(2)
146
+ # concat together
147
+ if T > 1:
148
+ hidden_states = torch.cat((first_h, other_h), dim=2)
149
+ else:
150
+ hidden_states = first_h
151
+
152
+ # If the input is bfloat16, we cast back to bfloat16
153
+ # if dtype == torch.bfloat16:
154
+ # hidden_states = hidden_states.to(dtype)
155
+
156
+ hidden_states = self.conv(hidden_states)
157
+
158
+ return hidden_states
159
+
160
+ class DownsampleCausal3D(nn.Module):
161
+ """
162
+ A 3D downsampling layer with an optional convolution.
163
+ """
164
+
165
+ def __init__(
166
+ self,
167
+ channels: int,
168
+ kernel_size=3,
169
+ bias=True,
170
+ stride=2,
171
+ ):
172
+ super().__init__()
173
+ self.channels = channels
174
+ self.out_channels = channels
175
+ self.conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, bias=bias)
176
+
177
+ def forward(self, input_tensor: torch.FloatTensor) -> torch.FloatTensor:
178
+ assert input_tensor.shape[1] == self.channels
179
+ hidden_states = self.conv(input_tensor)
180
+
181
+ return hidden_states
182
+
183
+
184
+ class ResnetBlockCausal3D(nn.Module):
185
+ r"""
186
+ A Resnet block.
187
+ """
188
+
189
+ def __init__(
190
+ self,
191
+ *,
192
+ in_channels: int,
193
+ out_channels: Optional[int] = None,
194
+ dropout: float = 0.0,
195
+ groups: int = 32,
196
+ groups_out: Optional[int] = None,
197
+ pre_norm: bool = True,
198
+ eps: float = 1e-6,
199
+ non_linearity: str = "swish",
200
+ output_scale_factor: float = 1.0,
201
+ use_in_shortcut: Optional[bool] = None,
202
+ conv_shortcut_bias: bool = True,
203
+ conv_3d_out_channels: Optional[int] = None,
204
+ ):
205
+ super().__init__()
206
+ self.pre_norm = pre_norm
207
+ self.pre_norm = True
208
+ self.in_channels = in_channels
209
+ out_channels = in_channels if out_channels is None else out_channels
210
+ self.out_channels = out_channels
211
+ self.output_scale_factor = output_scale_factor
212
+
213
+ if groups_out is None:
214
+ groups_out = groups
215
+
216
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
217
+ self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1)
218
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
219
+
220
+ self.dropout = torch.nn.Dropout(dropout)
221
+ conv_3d_out_channels = conv_3d_out_channels or out_channels
222
+ self.conv2 = CausalConv3d(out_channels, conv_3d_out_channels, kernel_size=3, stride=1)
223
+
224
+ self.nonlinearity = get_activation(non_linearity)
225
+
226
+ self.upsample = self.downsample = None
227
+
228
+ self.use_in_shortcut = self.in_channels != conv_3d_out_channels if use_in_shortcut is None else use_in_shortcut
229
+
230
+ self.conv_shortcut = None
231
+ if self.use_in_shortcut:
232
+ self.conv_shortcut = CausalConv3d(
233
+ in_channels,
234
+ conv_3d_out_channels,
235
+ kernel_size=1,
236
+ stride=1,
237
+ bias=conv_shortcut_bias,
238
+ )
239
+
240
+ def forward(
241
+ self,
242
+ input_tensor: torch.FloatTensor,
243
+ ) -> torch.FloatTensor:
244
+ hidden_states = input_tensor
245
+
246
+ hidden_states = self.norm1(hidden_states)
247
+ hidden_states = self.nonlinearity(hidden_states)
248
+ hidden_states = self.conv1(hidden_states)
249
+ hidden_states = self.norm2(hidden_states)
250
+ hidden_states = self.nonlinearity(hidden_states)
251
+ hidden_states = self.dropout(hidden_states)
252
+ hidden_states = self.conv2(hidden_states)
253
+
254
+ if self.conv_shortcut is not None:
255
+ input_tensor = self.conv_shortcut(input_tensor)
256
+
257
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
258
+
259
+ return output_tensor
260
+
261
+
262
+ class UNetMidBlockCausal3D(nn.Module):
263
+ """
264
+ A 3D UNet mid-block [`UNetMidBlockCausal3D`] with multiple residual blocks and optional attention blocks.
265
+ """
266
+
267
+ def __init__(
268
+ self,
269
+ in_channels: int,
270
+ dropout: float = 0.0,
271
+ num_layers: int = 1,
272
+ resnet_eps: float = 1e-6,
273
+ resnet_act_fn: str = "swish",
274
+ resnet_groups: int = 32,
275
+ attn_groups: Optional[int] = None,
276
+ resnet_pre_norm: bool = True,
277
+ add_attention: bool = True,
278
+ attention_head_dim: int = 1,
279
+ output_scale_factor: float = 1.0,
280
+ ):
281
+ super().__init__()
282
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
283
+ self.add_attention = add_attention
284
+
285
+ if attn_groups is None:
286
+ attn_groups = resnet_groups
287
+
288
+ # there is always at least one resnet
289
+ resnets = [
290
+ ResnetBlockCausal3D(
291
+ in_channels=in_channels,
292
+ out_channels=in_channels,
293
+ eps=resnet_eps,
294
+ groups=resnet_groups,
295
+ dropout=dropout,
296
+ non_linearity=resnet_act_fn,
297
+ output_scale_factor=output_scale_factor,
298
+ pre_norm=resnet_pre_norm,
299
+ )
300
+ ]
301
+ attentions = []
302
+
303
+ if attention_head_dim is None:
304
+ logger.warn(
305
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
306
+ )
307
+ attention_head_dim = in_channels
308
+
309
+ for _ in range(num_layers):
310
+ if self.add_attention:
311
+ attentions.append(
312
+ Attention(
313
+ in_channels,
314
+ heads=in_channels // attention_head_dim,
315
+ dim_head=attention_head_dim,
316
+ rescale_output_factor=output_scale_factor,
317
+ eps=resnet_eps,
318
+ norm_num_groups=attn_groups,
319
+ spatial_norm_dim=None,
320
+ residual_connection=True,
321
+ bias=True,
322
+ upcast_softmax=True,
323
+ _from_deprecated_attn_block=True,
324
+ )
325
+ )
326
+ else:
327
+ attentions.append(None)
328
+
329
+ resnets.append(
330
+ ResnetBlockCausal3D(
331
+ in_channels=in_channels,
332
+ out_channels=in_channels,
333
+ eps=resnet_eps,
334
+ groups=resnet_groups,
335
+ dropout=dropout,
336
+ non_linearity=resnet_act_fn,
337
+ output_scale_factor=output_scale_factor,
338
+ pre_norm=resnet_pre_norm,
339
+ )
340
+ )
341
+
342
+ self.attentions = nn.ModuleList(attentions)
343
+ self.resnets = nn.ModuleList(resnets)
344
+
345
+ def forward(self, hidden_states: torch.FloatTensor, attention_mask: Optional[torch.Tensor]) -> torch.FloatTensor:
346
+ hidden_states = self.resnets[0](hidden_states)
347
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
348
+ if attn is not None:
349
+ B, C, T, H, W = hidden_states.shape
350
+ hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c")
351
+ hidden_states = attn(hidden_states, attention_mask=attention_mask)
352
+ hidden_states = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=T, h=H, w=W)
353
+ hidden_states = resnet(hidden_states)
354
+
355
+ return hidden_states
356
+
357
+
358
+ class DownEncoderBlockCausal3D(nn.Module):
359
+ def __init__(
360
+ self,
361
+ in_channels: int,
362
+ out_channels: int,
363
+ dropout: float = 0.0,
364
+ num_layers: int = 1,
365
+ resnet_eps: float = 1e-6,
366
+ resnet_act_fn: str = "swish",
367
+ resnet_groups: int = 32,
368
+ resnet_pre_norm: bool = True,
369
+ output_scale_factor: float = 1.0,
370
+ add_downsample: bool = True,
371
+ downsample_stride: int = 2,
372
+ ):
373
+ super().__init__()
374
+ resnets = []
375
+
376
+ for i in range(num_layers):
377
+ in_channels = in_channels if i == 0 else out_channels
378
+ resnets.append(
379
+ ResnetBlockCausal3D(
380
+ in_channels=in_channels,
381
+ out_channels=out_channels,
382
+ eps=resnet_eps,
383
+ groups=resnet_groups,
384
+ dropout=dropout,
385
+ non_linearity=resnet_act_fn,
386
+ output_scale_factor=output_scale_factor,
387
+ pre_norm=resnet_pre_norm,
388
+ )
389
+ )
390
+
391
+ self.resnets = nn.ModuleList(resnets)
392
+
393
+ if add_downsample:
394
+ self.downsamplers = nn.ModuleList(
395
+ [
396
+ DownsampleCausal3D(
397
+ out_channels,
398
+ stride=downsample_stride,
399
+ )
400
+ ]
401
+ )
402
+ else:
403
+ self.downsamplers = None
404
+
405
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
406
+ for resnet in self.resnets:
407
+ hidden_states = auto_grad_checkpoint(resnet, hidden_states)
408
+
409
+ if self.downsamplers is not None:
410
+ for downsampler in self.downsamplers:
411
+ hidden_states = auto_grad_checkpoint(downsampler, hidden_states)
412
+
413
+ return hidden_states
414
+
415
+
416
+ class UpDecoderBlockCausal3D(nn.Module):
417
+ def __init__(
418
+ self,
419
+ in_channels: int,
420
+ out_channels: int,
421
+ resolution_idx: Optional[int] = None,
422
+ dropout: float = 0.0,
423
+ num_layers: int = 1,
424
+ resnet_eps: float = 1e-6,
425
+ resnet_act_fn: str = "swish",
426
+ resnet_groups: int = 32,
427
+ resnet_pre_norm: bool = True,
428
+ output_scale_factor: float = 1.0,
429
+ add_upsample: bool = True,
430
+ upsample_scale_factor=(2, 2, 2),
431
+ ):
432
+ super().__init__()
433
+ resnets = []
434
+
435
+ for i in range(num_layers):
436
+ input_channels = in_channels if i == 0 else out_channels
437
+
438
+ resnets.append(
439
+ ResnetBlockCausal3D(
440
+ in_channels=input_channels,
441
+ out_channels=out_channels,
442
+ eps=resnet_eps,
443
+ groups=resnet_groups,
444
+ dropout=dropout,
445
+ non_linearity=resnet_act_fn,
446
+ output_scale_factor=output_scale_factor,
447
+ pre_norm=resnet_pre_norm,
448
+ )
449
+ )
450
+
451
+ self.resnets = nn.ModuleList(resnets)
452
+
453
+ if add_upsample:
454
+ self.upsamplers = nn.ModuleList(
455
+ [
456
+ UpsampleCausal3D(
457
+ out_channels,
458
+ out_channels=out_channels,
459
+ upsample_factor=upsample_scale_factor,
460
+ )
461
+ ]
462
+ )
463
+ else:
464
+ self.upsamplers = None
465
+
466
+ self.resolution_idx = resolution_idx
467
+
468
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
469
+ for resnet in self.resnets:
470
+ hidden_states = auto_grad_checkpoint(resnet, hidden_states)
471
+
472
+ if self.upsamplers is not None:
473
+ for upsampler in self.upsamplers:
474
+ hidden_states = auto_grad_checkpoint(upsampler, hidden_states)
475
+
476
+ return hidden_states
arbitor/encoders/opensora_vae_modules/vae.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from HunyuanVideo
2
+ #
3
+ # Copyright 2024 HunyuanVideo
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ from dataclasses import dataclass
9
+ from typing import Optional, Tuple
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ from diffusers.utils import BaseOutput
15
+ from diffusers.utils.torch_utils import randn_tensor
16
+
17
+ from opensora.acceleration.checkpoint import auto_grad_checkpoint, checkpoint
18
+ from opensora.models.hunyuan_vae.unet_causal_3d_blocks import (
19
+ CausalConv3d,
20
+ DownEncoderBlockCausal3D,
21
+ UNetMidBlockCausal3D,
22
+ UpDecoderBlockCausal3D,
23
+ prepare_causal_attention_mask,
24
+ )
25
+
26
+
27
+ @dataclass
28
+ class DecoderOutput(BaseOutput):
29
+ r"""
30
+ Output of decoding method.
31
+
32
+ Args:
33
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
34
+ The decoded output sample from the last layer of the model.
35
+ """
36
+
37
+ sample: torch.FloatTensor
38
+
39
+
40
+ class EncoderCausal3D(nn.Module):
41
+ r"""
42
+ The `EncoderCausal3D` layer of a variational autoencoder that encodes its input into a latent representation.
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ in_channels: int = 3,
48
+ out_channels: int = 3,
49
+ block_out_channels: Tuple[int, ...] = (64,),
50
+ layers_per_block: int = 2,
51
+ norm_num_groups: int = 32,
52
+ act_fn: str = "silu",
53
+ double_z: bool = True,
54
+ mid_block_add_attention=True,
55
+ time_compression_ratio: int = 4,
56
+ spatial_compression_ratio: int = 8,
57
+ dropout: float = 0.0,
58
+ ):
59
+ super().__init__()
60
+ self.layers_per_block = layers_per_block
61
+
62
+ self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
63
+ self.mid_block = None
64
+ self.down_blocks = nn.ModuleList([])
65
+
66
+ # down
67
+ output_channel = block_out_channels[0]
68
+ for i, _ in enumerate(block_out_channels):
69
+ input_channel = output_channel
70
+ output_channel = block_out_channels[i]
71
+ is_final_block = i == len(block_out_channels) - 1
72
+ num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
73
+ num_time_downsample_layers = int(np.log2(time_compression_ratio))
74
+
75
+ if time_compression_ratio == 4:
76
+ add_spatial_downsample = bool(i < num_spatial_downsample_layers)
77
+ add_time_downsample = bool(
78
+ i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block
79
+ )
80
+ elif time_compression_ratio == 8:
81
+ add_spatial_downsample = bool(i < num_spatial_downsample_layers)
82
+ add_time_downsample = bool(i < num_spatial_downsample_layers)
83
+ else:
84
+ raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
85
+
86
+ downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
87
+ downsample_stride_T = (2,) if add_time_downsample else (1,)
88
+ downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
89
+ down_block = DownEncoderBlockCausal3D(
90
+ num_layers=self.layers_per_block,
91
+ in_channels=input_channel,
92
+ out_channels=output_channel,
93
+ dropout=dropout,
94
+ add_downsample=bool(add_spatial_downsample or add_time_downsample),
95
+ downsample_stride=downsample_stride,
96
+ resnet_eps=1e-6,
97
+ resnet_act_fn=act_fn,
98
+ resnet_groups=norm_num_groups,
99
+ )
100
+
101
+ self.down_blocks.append(down_block)
102
+
103
+ # mid
104
+ self.mid_block = UNetMidBlockCausal3D(
105
+ in_channels=block_out_channels[-1],
106
+ resnet_eps=1e-6,
107
+ resnet_act_fn=act_fn,
108
+ output_scale_factor=1,
109
+ attention_head_dim=block_out_channels[-1],
110
+ resnet_groups=norm_num_groups,
111
+ add_attention=mid_block_add_attention,
112
+ )
113
+
114
+ # out
115
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
116
+ self.conv_act = nn.SiLU()
117
+
118
+ conv_out_channels = 2 * out_channels if double_z else out_channels
119
+ self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3)
120
+
121
+ def prepare_attention_mask(self, hidden_states: torch.Tensor) -> torch.Tensor:
122
+ B, C, T, H, W = hidden_states.shape
123
+ attention_mask = prepare_causal_attention_mask(
124
+ T, H * W, hidden_states.dtype, hidden_states.device, batch_size=B
125
+ )
126
+ return attention_mask
127
+
128
+ def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
129
+ r"""The forward method of the `EncoderCausal3D` class."""
130
+ assert len(sample.shape) == 5, "The input tensor should have 5 dimensions"
131
+
132
+ sample = self.conv_in(sample)
133
+
134
+ # down
135
+ for down_block in self.down_blocks:
136
+ sample = down_block(sample)
137
+
138
+ # middle
139
+ if self.mid_block.add_attention:
140
+ attention_mask = self.prepare_attention_mask(sample)
141
+ else:
142
+ attention_mask = None
143
+ sample = auto_grad_checkpoint(self.mid_block, sample, attention_mask)
144
+
145
+ # post-process
146
+ sample = self.conv_norm_out(sample)
147
+ sample = self.conv_act(sample)
148
+ sample = self.conv_out(sample)
149
+
150
+ return sample
151
+
152
+
153
+ class DecoderCausal3D(nn.Module):
154
+ r"""
155
+ The `DecoderCausal3D` layer of a variational autoencoder that decodes its latent representation into an output sample.
156
+ """
157
+
158
+ def __init__(
159
+ self,
160
+ in_channels: int = 3,
161
+ out_channels: int = 3,
162
+ block_out_channels: Tuple[int, ...] = (64,),
163
+ layers_per_block: int = 2,
164
+ norm_num_groups: int = 32,
165
+ act_fn: str = "silu",
166
+ mid_block_add_attention=True,
167
+ time_compression_ratio: int = 4,
168
+ spatial_compression_ratio: int = 8,
169
+ dropout: float = 0.0,
170
+ ):
171
+ super().__init__()
172
+ self.layers_per_block = layers_per_block
173
+
174
+ self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1)
175
+ self.mid_block = None
176
+ self.up_blocks = nn.ModuleList([])
177
+
178
+ # mid
179
+ self.mid_block = UNetMidBlockCausal3D(
180
+ in_channels=block_out_channels[-1],
181
+ resnet_eps=1e-6,
182
+ resnet_act_fn=act_fn,
183
+ output_scale_factor=1,
184
+ attention_head_dim=block_out_channels[-1],
185
+ resnet_groups=norm_num_groups,
186
+ add_attention=mid_block_add_attention,
187
+ )
188
+
189
+ # up
190
+ reversed_block_out_channels = list(reversed(block_out_channels))
191
+ output_channel = reversed_block_out_channels[0]
192
+ for i, _ in enumerate(block_out_channels):
193
+ prev_output_channel = output_channel
194
+ output_channel = reversed_block_out_channels[i]
195
+ is_final_block = i == len(block_out_channels) - 1
196
+ num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
197
+ num_time_upsample_layers = int(np.log2(time_compression_ratio))
198
+
199
+ if time_compression_ratio == 4:
200
+ add_spatial_upsample = bool(i < num_spatial_upsample_layers)
201
+ add_time_upsample = bool(
202
+ i >= len(block_out_channels) - 1 - num_time_upsample_layers and not is_final_block
203
+ )
204
+ elif time_compression_ratio == 8:
205
+ add_spatial_upsample = bool(i < num_spatial_upsample_layers)
206
+ add_time_upsample = bool(i < num_spatial_upsample_layers)
207
+ else:
208
+ raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
209
+
210
+ upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
211
+ upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
212
+ upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
213
+ up_block = UpDecoderBlockCausal3D(
214
+ num_layers=self.layers_per_block + 1,
215
+ in_channels=prev_output_channel,
216
+ out_channels=output_channel,
217
+ resolution_idx=None,
218
+ dropout=dropout,
219
+ add_upsample=bool(add_spatial_upsample or add_time_upsample),
220
+ upsample_scale_factor=upsample_scale_factor,
221
+ resnet_eps=1e-6,
222
+ resnet_act_fn=act_fn,
223
+ resnet_groups=norm_num_groups,
224
+ )
225
+
226
+ self.up_blocks.append(up_block)
227
+ prev_output_channel = output_channel
228
+
229
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
230
+ self.conv_act = nn.SiLU()
231
+ self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
232
+
233
+ def post_process(self, sample: torch.Tensor) -> torch.Tensor:
234
+ sample = self.conv_norm_out(sample)
235
+ sample = self.conv_act(sample)
236
+ return sample
237
+
238
+ def prepare_attention_mask(self, hidden_states: torch.Tensor) -> torch.Tensor:
239
+ B, C, T, H, W = hidden_states.shape
240
+ attention_mask = prepare_causal_attention_mask(
241
+ T, H * W, hidden_states.dtype, hidden_states.device, batch_size=B
242
+ )
243
+ return attention_mask
244
+
245
+ def forward(
246
+ self,
247
+ sample: torch.FloatTensor,
248
+ ) -> torch.FloatTensor:
249
+ r"""The forward method of the `DecoderCausal3D` class."""
250
+ assert len(sample.shape) == 5, "The input tensor should have 5 dimensions."
251
+
252
+ sample = self.conv_in(sample)
253
+
254
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
255
+
256
+ # middle
257
+ if self.mid_block.add_attention:
258
+ attention_mask = self.prepare_attention_mask(sample)
259
+ else:
260
+ attention_mask = None
261
+
262
+ sample = auto_grad_checkpoint(self.mid_block, sample, attention_mask)
263
+ sample = sample.to(upscale_dtype)
264
+
265
+ # up
266
+ for up_block in self.up_blocks:
267
+ sample = up_block(sample)
268
+
269
+ # post-process
270
+ if getattr(self, "grad_checkpointing", False):
271
+ sample = checkpoint(self.post_process, sample, use_reentrant=True)
272
+ else:
273
+ sample = self.post_process(sample)
274
+
275
+ sample = self.conv_out(sample)
276
+
277
+ return sample
278
+
279
+
280
+ class DiagonalGaussianDistribution(object):
281
+ def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
282
+ if parameters.ndim == 3:
283
+ dim = 2 # (B, L, C)
284
+ elif parameters.ndim == 5 or parameters.ndim == 4:
285
+ dim = 1 # (B, C, T, H ,W) / (B, C, H, W)
286
+ else:
287
+ raise NotImplementedError
288
+ self.parameters = parameters
289
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim)
290
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
291
+ self.deterministic = deterministic
292
+ self.std = torch.exp(0.5 * self.logvar)
293
+ self.var = torch.exp(self.logvar)
294
+ if self.deterministic:
295
+ self.var = self.std = torch.zeros_like(
296
+ self.mean, device=self.parameters.device, dtype=self.parameters.dtype
297
+ )
298
+
299
+ def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
300
+ # make sure sample is on the same device as the parameters and has same dtype
301
+ sample = randn_tensor(
302
+ self.mean.shape,
303
+ generator=generator,
304
+ device=self.parameters.device,
305
+ dtype=self.parameters.dtype,
306
+ )
307
+ x = self.mean + self.std * sample
308
+ return x
309
+
310
+ def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
311
+ if self.deterministic:
312
+ return torch.Tensor([0.0])
313
+ else:
314
+ reduce_dim = list(range(1, self.mean.ndim))
315
+ if other is None:
316
+ return 0.5 * torch.sum(
317
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
318
+ dim=reduce_dim,
319
+ )
320
+ else:
321
+ return 0.5 * torch.sum(
322
+ torch.pow(self.mean - other.mean, 2) / other.var
323
+ + self.var / other.var
324
+ - 1.0
325
+ - self.logvar
326
+ + other.logvar,
327
+ dim=reduce_dim,
328
+ )
329
+
330
+ def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
331
+ if self.deterministic:
332
+ return torch.Tensor([0.0])
333
+ logtwopi = np.log(2.0 * np.pi)
334
+ return 0.5 * torch.sum(
335
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
336
+ dim=dims,
337
+ )
338
+
339
+ def mode(self) -> torch.Tensor:
340
+ return self.mean
arbitor/encoders/pig_vae.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """pig-vae (WanVAE) sidecar module.
2
+
3
+ Loads from local safetensors, .pth, or diffusers AutoencoderKLWan.
4
+ Exposes encode() and decode() for the VideoHead training pipeline.
5
+
6
+ Latent shape: [B, 16, T/4, H/8, W/8] for input video of T frames at HxW.
7
+ """
8
+ import os, torch
9
+ import torch.nn as nn
10
+
11
+ _LOCAL_VAE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models", "pig-vae")
12
+ _VAE_CONFIG = {
13
+ "base_dim": 96, "z_dim": 16, "dim_mult": [1, 2, 4, 4],
14
+ "num_res_blocks": 2, "dropout": 0.0,
15
+ "temperal_downsample": [False, True, True],
16
+ "in_channels": 3, "out_channels": 3,
17
+ "scale_factor_temporal": 4, "scale_factor_spatial": 8,
18
+ }
19
+
20
+
21
+ def _freeze_sidecar(model, quantize_requested=None, quantized=False):
22
+ model._arb_quantize_requested = quantize_requested
23
+ model._arb_quantized_int8 = bool(quantized and quantize_requested == "int8")
24
+ model._arb_quantized = bool(quantized)
25
+ for p in model.parameters():
26
+ p.requires_grad = False
27
+ return model
28
+
29
+
30
+ def _has_quantized_modules(model):
31
+ markers = ("Q", "Quanto", "Quantized", "WeightQ")
32
+ return any(any(marker in type(module).__name__ for marker in markers) for module in model.modules())
33
+
34
+
35
+ def _quantize_int8_if_requested(model, quantize):
36
+ if quantize == 'int8':
37
+ from optimum.quanto import quantize as quanto_quantize, freeze, qint8
38
+ quanto_quantize(model, weights=qint8)
39
+ freeze(model)
40
+ return _freeze_sidecar(model, quantize_requested=quantize, quantized=_has_quantized_modules(model))
41
+ return _freeze_sidecar(model, quantize_requested=quantize, quantized=False)
42
+
43
+
44
+ def _wan_vae_cls():
45
+ try:
46
+ from diffusers import AutoencoderKLWan
47
+ except ModuleNotFoundError as exc:
48
+ raise RuntimeError(
49
+ "pig-vae requires the optional diffusers dependency. "
50
+ "Install the project with `pip install -e .[diffusers]` in a venv "
51
+ "before loading or verifying pig-vae int8 quantization."
52
+ ) from exc
53
+ return AutoencoderKLWan
54
+
55
+
56
+ def load_vae(device='cuda', quantize='int8'):
57
+ """Load pig-vae from local cache or diffusers. Optionally int8 quantize."""
58
+ safetensors_path = os.path.join(_LOCAL_VAE_DIR, "model.safetensors")
59
+ gguf_path = os.path.join(_LOCAL_VAE_DIR, "pig_wan_vae_fp32-f16.gguf")
60
+
61
+ if os.path.isfile(safetensors_path):
62
+ return _load_local(safetensors_path, device, quantize, is_safetensors=True)
63
+ if os.path.isfile(gguf_path):
64
+ return _load_gguf(gguf_path, device, quantize)
65
+ return _load_from_hf(device, quantize)
66
+
67
+
68
+ def _build_vae():
69
+ AutoencoderKLWan = _wan_vae_cls()
70
+ return AutoencoderKLWan(
71
+ **_VAE_CONFIG,
72
+ latents_mean=[-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653,
73
+ -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632,
74
+ -0.1922, -0.9497, 0.2503, -0.2921],
75
+ latents_std=[2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708,
76
+ 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579,
77
+ 1.6382, 1.1253, 2.8251, 1.916],
78
+ )
79
+
80
+
81
+ def _load_local(path, device, quantize, is_safetensors=False):
82
+ if is_safetensors:
83
+ AutoencoderKLWan = _wan_vae_cls()
84
+ model = AutoencoderKLWan.from_single_file(path)
85
+ else:
86
+ model = _build_vae()
87
+ ckpt = torch.load(path, map_location="cpu", weights_only=True)
88
+ missing, unexpected = model.load_state_dict(ckpt, strict=False)
89
+ if missing or unexpected:
90
+ raise RuntimeError(
91
+ "pig-vae local .pth checkpoint does not match AutoencoderKLWan "
92
+ f"(missing={len(missing)}, unexpected={len(unexpected)})."
93
+ )
94
+ model = model.to(device)
95
+ model.eval()
96
+ model = _quantize_int8_if_requested(model, quantize)
97
+ return VAEWrapper(model)
98
+
99
+
100
+ def _load_gguf(path, device, quantize):
101
+ import gguf
102
+ reader = gguf.GGUFReader(path)
103
+ state_dict = {t.name: torch.tensor(t.data) for t in reader.tensors}
104
+ model = _build_vae()
105
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
106
+ if missing or unexpected:
107
+ raise RuntimeError(
108
+ "pig-vae local GGUF checkpoint does not match AutoencoderKLWan "
109
+ f"(missing={len(missing)}, unexpected={len(unexpected)})."
110
+ )
111
+ model = model.to(device)
112
+ model.eval()
113
+ model = _quantize_int8_if_requested(model, quantize)
114
+ return VAEWrapper(model)
115
+
116
+
117
+ def _load_from_hf(device, quantize):
118
+ AutoencoderKLWan = _wan_vae_cls()
119
+ model = AutoencoderKLWan.from_pretrained(
120
+ "Wan-AI/Wan2.1-T2V-1.3B", subfolder="vae",
121
+ torch_dtype=torch.bfloat16,
122
+ )
123
+ model = model.to(device)
124
+ model.eval()
125
+ model = _quantize_int8_if_requested(model, quantize)
126
+ return VAEWrapper(model)
127
+
128
+
129
+ class VAEWrapper(nn.Module):
130
+ def __init__(self, vae):
131
+ super().__init__()
132
+ self.vae = vae
133
+ self.latent_channels = _VAE_CONFIG["z_dim"]
134
+ self.scale_factor = 0.476986
135
+
136
+ def encode(self, video_tensor):
137
+ with torch.no_grad():
138
+ dist = self.vae.encode(video_tensor)
139
+ latents = dist.latent_dist.sample() if hasattr(dist, 'latent_dist') else dist
140
+ latents = latents * self.scale_factor
141
+ return latents
142
+
143
+ def decode(self, latents):
144
+ with torch.no_grad():
145
+ latents = latents / self.scale_factor
146
+ video = self.vae.decode(latents)
147
+ video = video.sample if hasattr(video, 'sample') else video
148
+ return video
arbitor/encoders/vae2d.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """2D VAE encoder β€” wraps PixArt SDXL AutoencoderKL encoder half.
2
+
3
+ Encodes images or mel spectrograms to [B, 4, H/8, W/8] latents.
4
+ Same encoder used for images AND audio spectrograms (via MelSpectrogram3Band).
5
+ Frozen float32 sidecar (no gradients).
6
+ """
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+
12
+ def load_vae2d(device="cuda", quantize=None):
13
+ from diffusers import AutoencoderKL
14
+
15
+ vae = AutoencoderKL.from_pretrained(
16
+ "PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
17
+ subfolder="vae",
18
+ torch_dtype=torch.float32,
19
+ ).to(device)
20
+ vae.eval()
21
+ for p in vae.parameters():
22
+ p.requires_grad = False
23
+
24
+ return VAE2DEncoder(vae)
25
+
26
+
27
+ class VAE2DEncoder(nn.Module):
28
+ def __init__(self, vae):
29
+ super().__init__()
30
+ self.encoder = vae.encoder
31
+ self.quant_conv = vae.quant_conv
32
+ self.latent_channels = 4
33
+ self.input_scale = 0.18215
34
+
35
+ def forward(self, x):
36
+ H, W = x.shape[-2], x.shape[-1]
37
+ pad_h = (8 - H % 8) % 8
38
+ pad_w = (8 - W % 8) % 8
39
+ if pad_h > 0 or pad_w > 0:
40
+ x = F.pad(x, (0, pad_w, 0, pad_h))
41
+
42
+ h = self.encoder(x)
43
+ moments = self.quant_conv(h)
44
+ posterior = torch.distributions.Normal(
45
+ moments[:, :self.latent_channels],
46
+ torch.nn.functional.softplus(moments[:, self.latent_channels:])
47
+ )
48
+ latent = posterior.rsample()
49
+ latent = latent * self.input_scale
50
+
51
+ if pad_h > 0 or pad_w > 0:
52
+ out_h = H // 8 if H >= 8 else 1
53
+ out_w = W // 8 if W >= 8 else 1
54
+ latent = latent[:, :, :out_h, :out_w]
55
+
56
+ return latent
arbitor/kernel/flash_vq.py ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FlashVQ: Custom Vector Quantization with dual Triton GPU + PyTorch CPU path.
3
+
4
+ Replaces vector_quantize_pytorch entirely (D-100). FlashVQCodebook is a standalone
5
+ nn.Module implementing all VQ operations:
6
+ - Cosine similarity codebook lookup
7
+ - EMA codebook update
8
+ - Dead code reset
9
+ - Rotation trick (gradient through quantization)
10
+ - Commitment loss
11
+
12
+ Dispatch pattern (following tscale.py):
13
+ if x.is_cuda and _HAS_TRITON β†’ _TritonFlashVQFn.apply()
14
+ else β†’ self._cpu_forward()
15
+ """
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+
21
+ _HAS_TRITON = False
22
+ try:
23
+ import triton
24
+ import triton.language as tl
25
+ _HAS_TRITON = True
26
+ except ImportError:
27
+ pass
28
+
29
+
30
+ class _RotationTrickFn(torch.autograd.Function):
31
+ """
32
+ Rotation trick gradient through vector quantization.
33
+
34
+ Instead of straight-through estimator (STE), rotate the encoder output
35
+ gradient toward the quantized vector direction. This helps the encoder
36
+ learn to produce outputs that align with codebook entries.
37
+ """
38
+
39
+ @staticmethod
40
+ def forward(ctx, x, quantized):
41
+ ctx.save_for_backward(x.detach(), quantized.detach())
42
+ return quantized
43
+
44
+ @staticmethod
45
+ def backward(ctx, grad_output):
46
+ x, quantized = ctx.saved_tensors
47
+ # Normalize in fp32 for numerical stability
48
+ x_norm = F.normalize(x.float(), dim=-1)
49
+ q_norm = F.normalize(quantized.float(), dim=-1)
50
+ # Gradient deflection: subtract projection onto (x_norm - q_norm)
51
+ # This rotates the gradient toward the quantized direction
52
+ diff = x_norm - q_norm
53
+ proj = (grad_output.float() * x_norm).sum(dim=-1, keepdim=True)
54
+ grad_x = grad_output.float() - proj * diff
55
+ return grad_x.to(grad_output.dtype), None
56
+
57
+
58
+ class FlashVQCodebook(nn.Module):
59
+ """
60
+ Vector quantization codebook with dual GPU (Triton) / CPU (PyTorch) paths.
61
+
62
+ Interface matches vector_quantize_pytorch.VectorQuantize:
63
+ forward(x) β†’ (quantized, indices, commitment_loss)
64
+
65
+ All VQ operations are self-contained:
66
+ - Cosine similarity codebook lookup
67
+ - Straight-through estimator (STE) with optional rotation trick
68
+ - EMA codebook update (decay=0.99)
69
+ - Dead code reset (threshold_ema_dead_code=2)
70
+ - Commitment loss
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ codebook_size: int = 8192,
76
+ codebook_dim: int = 32,
77
+ decay: float = 0.99,
78
+ commitment_weight: float = 1.0,
79
+ threshold_ema_dead_code: int = 2,
80
+ kmeans_init: bool = True,
81
+ kmeans_iters: int = 10,
82
+ rotation_trick: bool = True,
83
+ ):
84
+ super().__init__()
85
+ self.codebook_size = codebook_size
86
+ self.codebook_dim = codebook_dim
87
+ self.decay = decay
88
+ self.commitment_weight = commitment_weight
89
+ self.threshold_ema_dead_code = threshold_ema_dead_code
90
+ self.kmeans_init = kmeans_init
91
+ self.kmeans_iters = kmeans_iters
92
+ self.rotation_trick = rotation_trick
93
+
94
+ # Codebook buffers
95
+ self.register_buffer('embed', torch.randn(codebook_size, codebook_dim) * 0.02)
96
+ self.register_buffer('cluster_size', torch.zeros(codebook_size))
97
+ self.register_buffer('embed_avg', torch.zeros(codebook_size, codebook_dim))
98
+
99
+ # Tile sizes for Triton kernel (set on first GPU forward)
100
+ self._triton_block_bt = 16
101
+ self._triton_tile_k = 1024
102
+
103
+ def _compute_tile_sizes(self):
104
+ """
105
+ Dynamic tile sizing per D-102.
106
+
107
+ Queries GPU device properties to determine SRAM budget, then computes
108
+ BLOCK_BT and TILE_K such that:
109
+ BLOCK_BT * codebook_dim * 2 + TILE_K * codebook_dim * 2 < SRAM * 0.9
110
+
111
+ For sm_89 (RTX 4060, 99KB SRAM per SM):
112
+ codebook_size=8192, codebook_dim=32 β†’ BLOCK_BT=16, TILE_K=1024 (65KB)
113
+ codebook_size=4096, codebook_dim=32 β†’ BLOCK_BT=16, TILE_K=512 (33KB)
114
+ """
115
+ if not torch.cuda.is_available():
116
+ return
117
+ try:
118
+ props = torch.cuda.get_device_properties(0)
119
+ sram_budget = 99 * 1024 # SM 8.9: 99KB per SM
120
+
121
+ # Conservative estimate: each element is 2 bytes (bf16) in SRAM
122
+ elem_bytes = 2
123
+
124
+ # Find largest TILE_K that fits with BLOCK_BT=16
125
+ bt = 16
126
+ for tk in [2048, 1024, 512, 256, 128]:
127
+ sram_usage = bt * self.codebook_dim * elem_bytes + tk * self.codebook_dim * elem_bytes
128
+ if sram_usage < sram_budget * 0.9:
129
+ self._triton_block_bt = bt
130
+ self._triton_tile_k = tk
131
+ return
132
+
133
+ # Fallback for very constrained SRAM or large codebook_dim
134
+ self._triton_block_bt = 8
135
+ self._triton_tile_k = 256
136
+ except Exception:
137
+ # Default values
138
+ self._triton_block_bt = 16
139
+ self._triton_tile_k = 1024
140
+
141
+ def forward(self, x: torch.Tensor):
142
+ """
143
+ Args:
144
+ x: Input tensor of shape [*, codebook_dim]
145
+ Returns:
146
+ quantized: Tensor of same shape as x
147
+ indices: Tensor of shape [*] with codebook indices
148
+ commitment_loss: Scalar tensor
149
+ """
150
+ orig_shape = x.shape
151
+ x_flat = x.reshape(-1, self.codebook_dim)
152
+
153
+ if x.is_cuda and _HAS_TRITON:
154
+ quantized, indices, commitment_loss = self._triton_forward(x_flat)
155
+ else:
156
+ quantized, indices, commitment_loss = self._cpu_forward(x_flat)
157
+
158
+ quantized = quantized.reshape(orig_shape)
159
+ indices = indices.reshape(orig_shape[:-1])
160
+ return quantized, indices, commitment_loss
161
+
162
+ def _triton_forward(self, x_flat: torch.Tensor):
163
+ """Triton GPU path β€” dispatched when CUDA + Triton available."""
164
+ # Use _TritonFlashVQFn for forward + backward via autograd
165
+ quantized, indices, commitment_loss = _TritonFlashVQFn.apply(
166
+ x_flat, self.embed, self.cluster_size, self.embed_avg,
167
+ self.codebook_size, self.codebook_dim,
168
+ self.commitment_weight, self.rotation_trick,
169
+ )
170
+
171
+ # EMA update and dead code reset (under torch.no_grad)
172
+ with torch.no_grad():
173
+ self._ema_update(x_flat, indices)
174
+ self._dead_code_reset(x_flat)
175
+
176
+ return quantized, indices, commitment_loss
177
+
178
+ def _cpu_forward(self, x_flat: torch.Tensor):
179
+ """
180
+ Pure PyTorch CPU path β€” implements all VQ operations.
181
+
182
+ Steps:
183
+ 1. Cosine similarity lookup β†’ nearest codebook entry indices
184
+ 2. Quantize via straight-through estimator (or rotation trick)
185
+ 3. Compute commitment loss
186
+ 4. EMA update codebook (under torch.no_grad)
187
+ 5. Dead code reset (under torch.no_grad)
188
+ """
189
+ # ── Step 1: Cosine similarity lookup ──
190
+ x_norm = F.normalize(x_flat.float(), dim=-1)
191
+ embed_norm = F.normalize(self.embed.float(), dim=-1)
192
+ sim = x_norm @ embed_norm.T # [N, codebook_size]
193
+ indices = sim.argmax(dim=-1) # [N]
194
+
195
+ # ── Step 2: Quantize with STE or rotation trick ──
196
+ with torch.no_grad():
197
+ quantized = self.embed[indices] # [N, D]
198
+
199
+ if self.rotation_trick:
200
+ quantized = _RotationTrickFn.apply(x_flat, quantized)
201
+ else:
202
+ # Straight-through estimator
203
+ quantized = x_flat + (quantized - x_flat).detach()
204
+
205
+ # ── Step 3: Commitment loss ──
206
+ commitment_loss = self.commitment_weight * F.mse_loss(
207
+ x_flat.float(), quantized.detach().float()
208
+ )
209
+
210
+ # ── Step 4: EMA update ──
211
+ with torch.no_grad():
212
+ self._ema_update(x_flat, indices)
213
+
214
+ # ── Step 5: Dead code reset ──
215
+ self._dead_code_reset(x_flat)
216
+
217
+ return quantized, indices, commitment_loss
218
+
219
+ def _ema_update(self, x_flat: torch.Tensor, indices: torch.Tensor):
220
+ """
221
+ Exponential moving average codebook update.
222
+
223
+ Args:
224
+ x_flat: [N, D] input vectors
225
+ indices: [N] codebook indices for each input vector
226
+ """
227
+ one_hot = F.one_hot(indices, num_classes=self.codebook_size).float() # [N, codebook_size]
228
+ n_assign = one_hot.sum(dim=0) # [codebook_size]
229
+
230
+ # EMA on cluster_size (how many inputs assigned to each code)
231
+ self.cluster_size.mul_(self.decay).add_(n_assign * (1 - self.decay))
232
+
233
+ # EMA on embed_avg: weighted sum of assigned inputs
234
+ # embed_avg[c] = decay * embed_avg[c] + (1 - decay) * sum(x assigned to c)
235
+ x_float = x_flat.float()
236
+ for c in range(self.codebook_size):
237
+ mask = indices == c
238
+ count = mask.sum().item()
239
+ if count > 0:
240
+ assigned_sum = x_float[mask].sum(dim=0)
241
+ self.embed_avg[c].mul_(self.decay).add_(assigned_sum * (1 - self.decay))
242
+
243
+ # Normalize: embed = embed_avg / cluster_size (with epsilon)
244
+ cluster_size_safe = self.cluster_size.clamp(min=1e-5)
245
+ self.embed.copy_(self.embed_avg / cluster_size_safe.unsqueeze(1))
246
+
247
+ def _dead_code_reset(self, x_flat: torch.Tensor):
248
+ """
249
+ Replace dead codebook entries (cluster_size < threshold) with
250
+ random vectors from the current input batch.
251
+ """
252
+ dead_mask = self.cluster_size < self.threshold_ema_dead_code
253
+ n_dead = dead_mask.sum().item()
254
+ if n_dead == 0:
255
+ return
256
+ dead_indices = torch.where(dead_mask)[0]
257
+ # Replace with random input vectors
258
+ rand_idx = torch.randint(0, x_flat.shape[0], (n_dead,), device=x_flat.device)
259
+ self.embed[dead_indices] = x_flat[rand_idx].detach()
260
+ self.cluster_size[dead_indices] = 0.0
261
+ self.embed_avg[dead_indices] = 0.0
262
+
263
+ @torch.no_grad()
264
+ def kmeans_init_codebook(self, x: torch.Tensor):
265
+ """Initialize codebook via k-means on first batch."""
266
+ x_flat = x.reshape(-1, self.codebook_dim).float()
267
+ centroids = x_flat[torch.randperm(x_flat.shape[0])[:self.codebook_size]].clone()
268
+ for _ in range(self.kmeans_iters):
269
+ dist = torch.cdist(x_flat, centroids)
270
+ assign = dist.argmin(dim=-1)
271
+ for i in range(self.codebook_size):
272
+ mask = assign == i
273
+ if mask.sum() > 0:
274
+ centroids[i] = x_flat[mask].mean(dim=0)
275
+ self.embed.copy_(centroids)
276
+
277
+ @torch.no_grad()
278
+ def get_codebook_utilization(self) -> float:
279
+ """Fraction of codebook entries with any usage."""
280
+ return (self.cluster_size > 0).float().mean().item()
281
+
282
+ @torch.no_grad()
283
+ def get_dead_code_count(self) -> int:
284
+ """Number of codebook entries below EMA dead threshold."""
285
+ return (self.cluster_size < self.threshold_ema_dead_code).sum().item()
286
+
287
+
288
+ # ─── Triton GPU Kernels ───
289
+ # Only defined when Triton is available
290
+
291
+ if _HAS_TRITON:
292
+
293
+ @triton.jit
294
+ def _triton_flash_vq_lookup_kernel(
295
+ x_ptr, codebook_ptr, indices_ptr,
296
+ stride_xb, stride_xd,
297
+ stride_cb, stride_cd,
298
+ N_CTX: tl.constexpr,
299
+ CODEBOOK_SIZE: tl.constexpr,
300
+ CODEBOOK_DIM: tl.constexpr,
301
+ BLOCK_BT: tl.constexpr,
302
+ TILE_K: tl.constexpr,
303
+ ):
304
+ """
305
+ Tiled cosine similarity + argmax lookup for VQ codebook.
306
+
307
+ Architecture:
308
+ pid = batch tile index
309
+ Load input tile [BLOCK_BT, CODEBOOK_DIM]
310
+ Normalize in fp32
311
+ Tile over codebook in TILE_K chunks:
312
+ Load codebook tile [TILE_K, CODEBOOK_DIM]
313
+ Normalize in fp32
314
+ Compute dot product via tl.dot β†’ [BLOCK_BT, TILE_K]
315
+ Update running argmax
316
+ Store best indices
317
+
318
+ SRAM: all arithmetic in fp32 with small tiles to fit 99KB budget.
319
+ """
320
+ pid = tl.program_id(0)
321
+ offs_bt = pid * BLOCK_BT + tl.arange(0, BLOCK_BT)
322
+ offs_d = tl.arange(0, CODEBOOK_DIM)
323
+
324
+ # ── Load input tile ──
325
+ x_ptrs = x_ptr + offs_bt[:, None] * stride_xb + offs_d[None, :] * stride_xd
326
+ x = tl.load(x_ptrs, mask=offs_bt[:, None] < N_CTX, other=0.0)
327
+
328
+ # ── Normalize input in fp32 (no keepdims in Triton tl.sum) ──
329
+ x_f32 = x.to(tl.float32)
330
+ x_sq = tl.sum(x_f32 * x_f32, axis=1) # [BLOCK_BT]
331
+ x_norm_f32 = x_f32 / tl.sqrt(x_sq[:, None] + 1e-8)
332
+
333
+ # ── Running argmax over tiled codebook ──
334
+ best_sim = tl.full([BLOCK_BT], -float('inf'), dtype=tl.float32)
335
+ best_idx = tl.zeros([BLOCK_BT], dtype=tl.int32)
336
+
337
+ for k_start in range(0, CODEBOOK_SIZE, TILE_K):
338
+ offs_k = k_start + tl.arange(0, TILE_K)
339
+ k_mask = offs_k < CODEBOOK_SIZE
340
+
341
+ # Load codebook tile into fp32 directly for normalization
342
+ cb_ptrs = (codebook_ptr
343
+ + offs_k[:, None] * stride_cb
344
+ + offs_d[None, :] * stride_cd)
345
+ cb = tl.load(cb_ptrs, mask=k_mask[:, None], other=0.0)
346
+
347
+ # Normalize codebook tile in fp32
348
+ cb_f32 = cb.to(tl.float32)
349
+ cb_sq = tl.sum(cb_f32 * cb_f32, axis=1) # [TILE_K]
350
+ cb_norm_f32 = cb_f32 / tl.sqrt(cb_sq[:, None] + 1e-8)
351
+
352
+ # Cosine similarity via tl.dot (tf32 on sm_89)
353
+ sim = tl.dot(x_norm_f32, tl.trans(cb_norm_f32)) # [BLOCK_BT, TILE_K]
354
+
355
+ # Running argmax within this tile
356
+ tile_max = tl.max(sim, axis=1)
357
+ tile_argmax = tl.argmax(sim, axis=1)
358
+ tile_idx = k_start + tile_argmax
359
+
360
+ # Merge with best across tiles using element-wise mask
361
+ update_mask = tile_max > best_sim
362
+ best_sim = tl.where(update_mask, tile_max, best_sim)
363
+ best_idx = tl.where(update_mask, tile_idx, best_idx)
364
+
365
+ # ── Store results ──
366
+ tl.store(indices_ptr + offs_bt, best_idx, mask=offs_bt < N_CTX)
367
+
368
+
369
+ @triton.jit
370
+ def _triton_flash_vq_quantize_kernel(
371
+ codebook_ptr, indices_ptr, quantized_ptr,
372
+ stride_cb, stride_cd,
373
+ stride_qb, stride_qd,
374
+ N_CTX: tl.constexpr,
375
+ CODEBOOK_DIM: tl.constexpr,
376
+ BLOCK_BT: tl.constexpr,
377
+ ):
378
+ """
379
+ Gather quantized vectors from codebook at given indices.
380
+ Kernel form of: quantized[i] = codebook[indices[i]]
381
+ """
382
+ pid = tl.program_id(0)
383
+ offs_bt = pid * BLOCK_BT + tl.arange(0, BLOCK_BT)
384
+ offs_d = tl.arange(0, CODEBOOK_DIM)
385
+
386
+ # Load indices for this batch tile
387
+ idx = tl.load(indices_ptr + offs_bt, mask=offs_bt < N_CTX, other=0)
388
+
389
+ # Gather: for each i in BLOCK_BT, load codebook[idx[i], :]
390
+ # Pointer arithmetic with broadcasting
391
+ gather_ptrs = (codebook_ptr
392
+ + idx[:, None] * stride_cb
393
+ + offs_d[None, :] * stride_cd)
394
+ quantized = tl.load(gather_ptrs,
395
+ mask=offs_bt[:, None] < N_CTX,
396
+ other=0.0)
397
+
398
+ # Store quantized output
399
+ out_ptrs = (quantized_ptr
400
+ + offs_bt[:, None] * stride_qb
401
+ + offs_d[None, :] * stride_qd)
402
+ tl.store(out_ptrs, quantized, mask=offs_bt[:, None] < N_CTX)
403
+
404
+
405
+ def _triton_lookup(x, embed, block_bt=None, tile_k=None):
406
+ """
407
+ Launch Triton VQ lookup kernel with SRAM-safe tile sizes.
408
+
409
+ Args:
410
+ x: [N, D] input tensor (cuda, contiguous)
411
+ embed: [codebook_size, D] codebook (cuda, contiguous)
412
+ block_bt: BLOCK_BT tile size (auto-computed if None)
413
+ tile_k: TILE_K tile size (auto-computed if None)
414
+ Returns:
415
+ indices: [N] int64 tensor of argmax indices
416
+ """
417
+ N, D = x.shape
418
+ codebook_size = embed.shape[0]
419
+ assert embed.shape[1] == D, f"Codebook dim {embed.shape[1]} != input dim {D}"
420
+
421
+ # SRAM-safe tile sizes: kernel uses tf32 (fp32 math), and Triton
422
+ # pipelines data through shared memory. Conservative sizing ensures
423
+ # fits within ~99KB (sm_89) even with default num_stages=3.
424
+ #
425
+ # fp32 codebook tile: TILE_K * D * 4 β†’ 128*32*4 = 16KB
426
+ # fp32 input tile: BLOCK_BT * D * 4 β†’ 8*32*4 = 1KB
427
+ # Accumulator: BLOCK_BT*TILE_K*4 β†’ 8*128*4 = 4KB
428
+ # Per stage: ~21KB. With 3 pipeline stages: ~63KB (fits in 99KB).
429
+ #
430
+ # Larger tiles oversubscribe SRAM (tested: TILE_K=1024 β†’ 321KB needed).
431
+ if block_bt is None or tile_k is None:
432
+ BLOCK_BT = 8
433
+ TILE_K = 128
434
+ else:
435
+ BLOCK_BT, TILE_K = block_bt, tile_k
436
+
437
+ grid = (triton.cdiv(N, BLOCK_BT),)
438
+
439
+ indices = torch.empty(N, dtype=torch.int32, device=x.device)
440
+
441
+ _triton_flash_vq_lookup_kernel[grid](
442
+ x, embed, indices,
443
+ x.stride(0), x.stride(1),
444
+ embed.stride(0), embed.stride(1),
445
+ N, codebook_size, D,
446
+ BLOCK_BT=BLOCK_BT, TILE_K=TILE_K,
447
+ )
448
+
449
+ return indices.long()
450
+
451
+
452
+ class _TritonFlashVQFn(torch.autograd.Function):
453
+ """
454
+ Custom autograd Function wrapping Triton VQ kernels.
455
+
456
+ Forward: Triton tiled cosine similarity + argmax lookup
457
+ Backward: Rotation trick gradient or straight-through estimator
458
+ """
459
+
460
+ @staticmethod
461
+ def forward(ctx, x_flat, embed, cluster_size, embed_avg,
462
+ codebook_size, codebook_dim,
463
+ commitment_weight, rotation_trick):
464
+ # Triton tiled lookup for indices
465
+ with torch.no_grad():
466
+ indices = _triton_lookup(x_flat.contiguous(), embed.contiguous())
467
+
468
+ quantized = embed[indices]
469
+ commitment_loss = commitment_weight * F.mse_loss(x_flat.float(), quantized.detach().float())
470
+
471
+ # Clone saved tensors to avoid version conflicts with in-place EMA updates
472
+ ctx.save_for_backward(
473
+ x_flat.detach().clone(),
474
+ quantized.detach().clone(),
475
+ embed.detach().clone(),
476
+ )
477
+ ctx.codebook_dim = codebook_dim
478
+ ctx.rotation_trick = rotation_trick
479
+
480
+ return quantized, indices, commitment_loss
481
+
482
+ @staticmethod
483
+ def backward(ctx, grad_quantized, grad_indices, grad_commitment):
484
+ x_flat, quantized, embed = ctx.saved_tensors
485
+
486
+ if ctx.rotation_trick:
487
+ # Rotation trick gradient
488
+ x_norm = F.normalize(x_flat.float(), dim=-1)
489
+ q_norm = F.normalize(quantized.float(), dim=-1)
490
+ diff = x_norm - q_norm
491
+ proj = (grad_quantized.float() * x_norm).sum(dim=-1, keepdim=True)
492
+ grad_x = grad_quantized.float() - proj * diff
493
+ else:
494
+ # Straight-through estimator
495
+ grad_x = grad_quantized.float()
496
+
497
+ return grad_x.to(grad_quantized.dtype), None, None, None, None, None, None, None
498
+
499
+
500
+ # When Triton is not available, define a fallback lookup
501
+ if not _HAS_TRITON:
502
+
503
+ def _triton_lookup(x, embed):
504
+ """Fallback: torch-based cosine similarity lookup (CPU or CUDA without Triton)."""
505
+ with torch.no_grad():
506
+ x_norm = F.normalize(x.float(), dim=-1)
507
+ embed_norm = F.normalize(embed.float(), dim=-1)
508
+ sim = x_norm @ embed_norm.T
509
+ indices = sim.argmax(dim=-1)
510
+ return indices
arbitor/kernel/ternary_audit.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Iterable
5
+
6
+ import torch
7
+
8
+
9
+ @dataclass
10
+ class TensorState:
11
+ name: str
12
+ shape: tuple[int, ...]
13
+ dtype: str
14
+ bytes: int
15
+ trainable: bool = False
16
+
17
+
18
+ @dataclass
19
+ class TernaryAudit:
20
+ logical_ternary_weights: int
21
+ ternary_packed_bytes: int
22
+ ternary_scale_bytes: int
23
+ ternary_scale_accum_bytes: int
24
+ ternary_accum_bytes: int
25
+ ternary_corr_accum_bytes: int
26
+ ternary_step_counter_bytes: int
27
+ trainable_float_params: list[TensorState]
28
+ frozen_float_params: list[TensorState]
29
+ float_buffers: list[TensorState]
30
+
31
+ @property
32
+ def ternary_training_bytes(self) -> int:
33
+ return (
34
+ self.ternary_packed_bytes
35
+ + self.ternary_scale_bytes
36
+ + self.ternary_scale_accum_bytes
37
+ + self.ternary_accum_bytes
38
+ + self.ternary_corr_accum_bytes
39
+ + self.ternary_step_counter_bytes
40
+ )
41
+
42
+ @property
43
+ def trainable_float_bytes(self) -> int:
44
+ return sum(item.bytes for item in self.trainable_float_params)
45
+
46
+ @property
47
+ def frozen_float_bytes(self) -> int:
48
+ return sum(item.bytes for item in self.frozen_float_params)
49
+
50
+ @property
51
+ def float_buffer_bytes(self) -> int:
52
+ return sum(item.bytes for item in self.float_buffers)
53
+
54
+
55
+ def _tensor_bytes(t: torch.Tensor) -> int:
56
+ return t.numel() * t.element_size()
57
+
58
+
59
+ def _tensor_state(name: str, t: torch.Tensor, trainable: bool = False) -> TensorState:
60
+ return TensorState(
61
+ name=name,
62
+ shape=tuple(t.shape),
63
+ dtype=str(t.dtype).replace("torch.", ""),
64
+ bytes=_tensor_bytes(t),
65
+ trainable=trainable,
66
+ )
67
+
68
+
69
+ def _mb(n_bytes: int) -> float:
70
+ return n_bytes / (1024 * 1024)
71
+
72
+
73
+ def audit_model(model: torch.nn.Module) -> TernaryAudit:
74
+ logical_ternary_weights = 0
75
+ ternary_packed_bytes = 0
76
+ ternary_scale_bytes = 0
77
+ ternary_scale_accum_bytes = 0
78
+ ternary_accum_bytes = 0
79
+ ternary_corr_accum_bytes = 0
80
+ ternary_step_counter_bytes = 0
81
+
82
+ for module in model.modules():
83
+ if hasattr(module, "T_packed") and hasattr(module, "_T_shape"):
84
+ shape = tuple(int(x) for x in module._T_shape.tolist())
85
+ n_weights = 1
86
+ for dim in shape:
87
+ n_weights *= dim
88
+ logical_ternary_weights += n_weights
89
+ ternary_packed_bytes += _tensor_bytes(module.T_packed)
90
+ if hasattr(module, "E"):
91
+ ternary_scale_bytes += _tensor_bytes(module.E)
92
+ if hasattr(module, "E_accum"):
93
+ ternary_scale_accum_bytes += _tensor_bytes(module.E_accum)
94
+ if hasattr(module, "T_accum"):
95
+ ternary_accum_bytes += _tensor_bytes(module.T_accum)
96
+ if hasattr(module, "corr_accum"):
97
+ ternary_corr_accum_bytes += _tensor_bytes(module.corr_accum)
98
+ if hasattr(module, "step_counter"):
99
+ ternary_step_counter_bytes += _tensor_bytes(module.step_counter)
100
+
101
+ trainable_float_params: list[TensorState] = []
102
+ frozen_float_params: list[TensorState] = []
103
+ for name, param in model.named_parameters():
104
+ if not param.dtype.is_floating_point:
105
+ continue
106
+ state = _tensor_state(name, param, trainable=param.requires_grad)
107
+ if param.requires_grad:
108
+ trainable_float_params.append(state)
109
+ else:
110
+ frozen_float_params.append(state)
111
+
112
+ float_buffers = [
113
+ _tensor_state(name, buf)
114
+ for name, buf in model.named_buffers()
115
+ if buf.dtype.is_floating_point
116
+ ]
117
+
118
+ return TernaryAudit(
119
+ logical_ternary_weights=logical_ternary_weights,
120
+ ternary_packed_bytes=ternary_packed_bytes,
121
+ ternary_scale_bytes=ternary_scale_bytes,
122
+ ternary_scale_accum_bytes=ternary_scale_accum_bytes,
123
+ ternary_accum_bytes=ternary_accum_bytes,
124
+ ternary_corr_accum_bytes=ternary_corr_accum_bytes,
125
+ ternary_step_counter_bytes=ternary_step_counter_bytes,
126
+ trainable_float_params=trainable_float_params,
127
+ frozen_float_params=frozen_float_params,
128
+ float_buffers=float_buffers,
129
+ )
130
+
131
+
132
+ def format_audit(audit: TernaryAudit, limit: int = 12) -> str:
133
+ lines = [
134
+ "Ternary state audit:",
135
+ f" logical ternary weights: {audit.logical_ternary_weights:,}",
136
+ (
137
+ " ternary training state: "
138
+ f"{_mb(audit.ternary_training_bytes):.2f} MB "
139
+ f"(T={_mb(audit.ternary_packed_bytes):.2f}, "
140
+ f"E={_mb(audit.ternary_scale_bytes):.2f}, "
141
+ f"E_accum={_mb(audit.ternary_scale_accum_bytes):.2f}, "
142
+ f"T_accum={_mb(audit.ternary_accum_bytes):.2f}, "
143
+ f"corr_accum={_mb(audit.ternary_corr_accum_bytes):.2f}, "
144
+ f"steps={_mb(audit.ternary_step_counter_bytes):.4f})"
145
+ ),
146
+ (
147
+ " trainable float params: "
148
+ f"{len(audit.trainable_float_params)} tensors, "
149
+ f"{_mb(audit.trainable_float_bytes):.2f} MB"
150
+ ),
151
+ (
152
+ " frozen float params: "
153
+ f"{len(audit.frozen_float_params)} tensors, "
154
+ f"{_mb(audit.frozen_float_bytes):.2f} MB"
155
+ ),
156
+ (
157
+ " float buffers: "
158
+ f"{len(audit.float_buffers)} tensors, "
159
+ f"{_mb(audit.float_buffer_bytes):.2f} MB"
160
+ ),
161
+ ]
162
+
163
+ if audit.trainable_float_params:
164
+ lines.append(" largest trainable float params:")
165
+ for item in sorted(audit.trainable_float_params, key=lambda x: x.bytes, reverse=True)[:limit]:
166
+ lines.append(f" {item.name}: {item.shape} {item.dtype} {_mb(item.bytes):.2f} MB")
167
+
168
+ if audit.float_buffers:
169
+ lines.append(" largest float buffers:")
170
+ for item in sorted(audit.float_buffers, key=lambda x: x.bytes, reverse=True)[:limit]:
171
+ lines.append(f" {item.name}: {item.shape} {item.dtype} {_mb(item.bytes):.2f} MB")
172
+
173
+ return "\n".join(lines)
174
+
175
+
176
+ def freeze_float_parameters(
177
+ model: torch.nn.Module,
178
+ allow_prefixes: Iterable[str] = (),
179
+ ) -> list[TensorState]:
180
+ allow = tuple(allow_prefixes)
181
+ frozen: list[TensorState] = []
182
+ for name, param in model.named_parameters():
183
+ if allow and name.startswith(allow):
184
+ continue
185
+ if param.dtype.is_floating_point and param.requires_grad:
186
+ frozen.append(_tensor_state(name, param, trainable=True))
187
+ param.requires_grad_(False)
188
+ return frozen
189
+
190
+
191
+ def trainable_parameters(model: torch.nn.Module) -> list[torch.nn.Parameter]:
192
+ return [p for p in model.parameters() if p.requires_grad]
arbitor/kernel/ternary_scale.py ADDED
@@ -0,0 +1,1811 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import threading
3
+ import warnings
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from enum import IntEnum
9
+ from math import ceil
10
+
11
+ from ..converters.convert_to_ternary8 import pack_ternary, unpack_ternary
12
+
13
+ _HAS_TILELANG = False
14
+ try:
15
+ import tilelang
16
+ import tilelang.language as T
17
+ _HAS_TILELANG = True
18
+ except ImportError:
19
+ pass
20
+
21
+ _HAS_TRITON = False
22
+ try:
23
+ import triton
24
+ import triton.language as tl
25
+ _HAS_TRITON = True
26
+ except ImportError:
27
+ pass
28
+
29
+
30
+ def _backend_preference() -> str:
31
+ backend = os.environ.get("ARB_TERNARY_BACKEND", "auto").strip().lower()
32
+ if backend not in {"auto", "tilelang", "triton", "torch"}:
33
+ warnings.warn(
34
+ f"Unknown ARB_TERNARY_BACKEND={backend!r}; falling back to auto.",
35
+ RuntimeWarning,
36
+ stacklevel=2,
37
+ )
38
+ return "auto"
39
+ return backend
40
+
41
+
42
+ def _rmsnorm_triton_max_dim() -> int:
43
+ raw = os.environ.get("ARB_RMSNORM_TRITON_MAX_DIM", "4096").strip()
44
+ try:
45
+ return max(0, int(raw))
46
+ except ValueError:
47
+ warnings.warn(
48
+ f"Invalid ARB_RMSNORM_TRITON_MAX_DIM={raw!r}; using 4096.",
49
+ RuntimeWarning,
50
+ stacklevel=2,
51
+ )
52
+ return 4096
53
+
54
+
55
+ def _bigint_corr_strength() -> float:
56
+ raw = os.environ.get("ARB_BIGINT_CORR_STRENGTH", "4.0").strip()
57
+ try:
58
+ return float(raw)
59
+ except ValueError:
60
+ warnings.warn(
61
+ f"Invalid ARB_BIGINT_CORR_STRENGTH={raw!r}; using 4.0.",
62
+ RuntimeWarning,
63
+ stacklevel=2,
64
+ )
65
+ return 4.0
66
+
67
+
68
+ class _ComponentContext:
69
+ _local = threading.local()
70
+
71
+ @classmethod
72
+ def get(cls):
73
+ val = getattr(cls._local, "current", None)
74
+ if val is None:
75
+ return None, 1.0
76
+ return val
77
+
78
+ @classmethod
79
+ def set(cls, name, weight=1.0):
80
+ if name is None:
81
+ cls._local.current = None
82
+ else:
83
+ cls._local.current = (name, weight)
84
+
85
+ @classmethod
86
+ def clear(cls):
87
+ cls._local.current = None
88
+
89
+
90
+ _COMPONENT_CONTEXT = _ComponentContext
91
+
92
+
93
+ def _tilelang_training_enabled() -> bool:
94
+ return os.environ.get("ARB_TILELANG_TRAINING", "0").strip().lower() in {"1", "true", "yes"}
95
+
96
+
97
+ if _HAS_TILELANG:
98
+
99
+ tilelang_jit = tilelang.jit(pass_configs={"tl.disable_warp_specialized": True})
100
+
101
+ def _ternary_fwd_kernel(
102
+ M: int, N: int, K: int, group_size: int = 12,
103
+ corr_strength: float = 4.0,
104
+ block_M: int = 64, block_N: int = 64, block_K: int = 32,
105
+ threads: int = 128, num_stages: int = 2,
106
+ ):
107
+ gpr = (K + group_size - 1) // group_size
108
+ cs = corr_strength
109
+
110
+ @T.prim_func
111
+ def kernel(
112
+ x: T.Tensor((M, K), "float16"),
113
+ T_packed: T.Tensor((N * K + 4) // 5, "uint8"),
114
+ E: T.Tensor((N * gpr), "int8"),
115
+ corr_accum: T.Tensor((N * gpr), "int64"),
116
+ step_counter: T.Tensor((1,), "int64"),
117
+ output: T.Tensor((M, N), "float32"),
118
+ ):
119
+ steps = T.cast(step_counter[0], "int32")
120
+ with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=threads) as (bx, by):
121
+ x_shared = T.alloc_shared((block_M, block_K), dtype="float16")
122
+ dq_shared = T.alloc_shared((block_N, block_K), dtype="float16")
123
+ acc = T.alloc_fragment((block_M, block_N), dtype="float32")
124
+ T.use_swizzle(10)
125
+ T.clear(acc)
126
+ for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
127
+ T.copy(x[bx * block_M, k * block_K], x_shared)
128
+ for i, j in T.Parallel(block_N, block_K):
129
+ i_glob = by * block_N + i
130
+ j_glob = k * block_K + j
131
+ if i_glob < N and j_glob < K:
132
+ lin_idx = i_glob * K + j_glob
133
+ pack_idx = lin_idx // 5
134
+ trit_pos = lin_idx % 5
135
+ packed_val = T.cast(T_packed[pack_idx], "int32")
136
+ trit = T.if_then_else(
137
+ trit_pos == 0, packed_val % 3,
138
+ T.if_then_else(trit_pos == 1, (packed_val // 3) % 3,
139
+ T.if_then_else(trit_pos == 2, (packed_val // 9) % 3,
140
+ T.if_then_else(trit_pos == 3, (packed_val // 27) % 3,
141
+ (packed_val // 81) % 3))))
142
+ sign_val = T.cast(trit, "int32") - 1
143
+ exp_idx = i_glob * gpr + j_glob // group_size
144
+ exp_val = T.cast(E[exp_idx], "int32")
145
+ ca = T.cast(corr_accum[exp_idx], "int32")
146
+ den = T.max(steps * group_size, 1)
147
+ mc = T.cast(ca, "float32") / T.cast(den, "float32")
148
+ e_adj = T.cast(exp_val, "float32") + mc * cs
149
+ ecl = T.min(T.max(e_adj, -14.0), 15.0)
150
+ dq_shared[i, j] = T.cast(T.exp2(ecl) * T.cast(sign_val, "float32"), "float16")
151
+ T.gemm(x_shared, dq_shared, acc, transpose_B=True)
152
+ T.copy(acc, output[bx * block_M, by * block_N])
153
+ return tilelang_jit(kernel)
154
+
155
+ def _ternary_grad_x_kernel(
156
+ M: int, N: int, K: int, group_size: int = 12,
157
+ corr_strength: float = 4.0,
158
+ block_M: int = 64, block_N: int = 64, block_K: int = 32,
159
+ threads: int = 128, num_stages: int = 2,
160
+ ):
161
+ gpr = (K + group_size - 1) // group_size
162
+ cs = corr_strength
163
+
164
+ @T.prim_func
165
+ def kernel(
166
+ grad_y: T.Tensor((M, N), "float16"),
167
+ T_packed: T.Tensor((N * K + 4) // 5, "uint8"),
168
+ E: T.Tensor((N * gpr), "int8"),
169
+ corr_accum: T.Tensor((N * gpr), "int64"),
170
+ step_counter: T.Tensor((1,), "int64"),
171
+ output: T.Tensor((M, K), "float32"),
172
+ ):
173
+ steps = T.cast(step_counter[0], "int32")
174
+ with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(K, block_K), threads=threads) as (bx, by):
175
+ gy_shared = T.alloc_shared((block_M, block_N), dtype="float16")
176
+ dq_shared = T.alloc_shared((block_N, block_K), dtype="float16")
177
+ acc = T.alloc_fragment((block_M, block_K), dtype="float32")
178
+ T.use_swizzle(10)
179
+ T.clear(acc)
180
+ for n in T.Pipelined(T.ceildiv(N, block_N), num_stages=num_stages):
181
+ T.copy(grad_y[bx * block_M, n * block_N], gy_shared)
182
+ for i, j in T.Parallel(block_N, block_K):
183
+ i_glob = n * block_N + i
184
+ j_glob = by * block_K + j
185
+ if i_glob < N and j_glob < K:
186
+ lin_idx = i_glob * K + j_glob
187
+ pack_idx = lin_idx // 5
188
+ trit_pos = lin_idx % 5
189
+ packed_val = T.cast(T_packed[pack_idx], "int32")
190
+ trit = T.if_then_else(
191
+ trit_pos == 0, packed_val % 3,
192
+ T.if_then_else(trit_pos == 1, (packed_val // 3) % 3,
193
+ T.if_then_else(trit_pos == 2, (packed_val // 9) % 3,
194
+ T.if_then_else(trit_pos == 3, (packed_val // 27) % 3,
195
+ (packed_val // 81) % 3))))
196
+ sign_val = T.cast(trit, "int32") - 1
197
+ exp_idx = i_glob * gpr + j_glob // group_size
198
+ exp_val = T.cast(E[exp_idx], "int32")
199
+ ca = T.cast(corr_accum[exp_idx], "int32")
200
+ den = T.max(steps * group_size, 1)
201
+ mc = T.cast(ca, "float32") / T.cast(den, "float32")
202
+ e_adj = T.cast(exp_val, "float32") + mc * cs
203
+ ecl = T.min(T.max(e_adj, -14.0), 15.0)
204
+ dq_shared[i, j] = T.cast(T.exp2(ecl) * T.cast(sign_val, "float32"), "float16")
205
+ T.gemm(gy_shared, dq_shared, acc)
206
+ T.copy(acc, output[bx * block_M, by * block_K])
207
+ return tilelang_jit(kernel)
208
+
209
+ _KERNEL_CACHE_FWD = {}
210
+ _KERNEL_CACHE_GX = {}
211
+
212
+ def _get_kernel(M, N, K, group_size, mode, corr_strength=4.0):
213
+ cs = corr_strength
214
+ if mode == "fwd":
215
+ cache = _KERNEL_CACHE_FWD
216
+ key = (M, N, K, group_size, cs)
217
+ if key not in cache:
218
+ cache[key] = _ternary_fwd_kernel(M, N, K, group_size, corr_strength=cs)
219
+ return cache[key]
220
+ elif mode == "grad_x":
221
+ cache = _KERNEL_CACHE_GX
222
+ key = (M, N, K, group_size)
223
+ if key not in cache:
224
+ cache[key] = _ternary_grad_x_kernel(M, N, K, group_size)
225
+ return cache[key]
226
+ raise ValueError(f"Unknown TileLang kernel mode: {mode}")
227
+
228
+
229
+ def _get_grad_kernels(M, N, K, group_size):
230
+ return _get_kernel(M, N, K, group_size, "grad_x")
231
+
232
+
233
+ class _TernaryLinearFn(torch.autograd.Function):
234
+ @staticmethod
235
+ def forward(ctx, x, module, fwd_kernel):
236
+ ctx.module = module
237
+ T_packed = module.T_packed
238
+ E = module.E
239
+ shape = tuple(module._T_shape.tolist())
240
+ N, K = shape
241
+ x_2d = x.reshape(-1, K).contiguous()
242
+ ctx.group_size = module.group_size
243
+ ctx.shape = shape
244
+ ctx.x_shape = x.shape
245
+ comp_name, _ = _COMPONENT_CONTEXT.get()
246
+ ctx.comp_name = comp_name
247
+ ctx.x_dtype = x.dtype
248
+ has_corr = hasattr(module, "corr_accum") and hasattr(module, "step_counter")
249
+ ctx.save_for_backward(x_2d, T_packed, E)
250
+ ctx.has_corr = has_corr
251
+ ctx.step_snapshot = int(module.step_counter.item()) if has_corr else 0
252
+ with torch.no_grad():
253
+ M = x_2d.shape[0]
254
+ output = torch.empty(M, N, device=x.device, dtype=torch.float32)
255
+ if has_corr:
256
+ fwd_kernel(x_2d.half(), T_packed, E,
257
+ module.corr_accum.contiguous(),
258
+ module.step_counter.contiguous(), output)
259
+ else:
260
+ fwd_kernel(x_2d.half(), T_packed, E,
261
+ torch.zeros(N * ((K + module.group_size - 1) // module.group_size),
262
+ dtype=torch.int64, device=x.device),
263
+ torch.zeros(1, dtype=torch.int64, device=x.device), output)
264
+ return output.reshape(*x.shape[:-1], N)
265
+
266
+ @staticmethod
267
+ def backward(ctx, grad_output):
268
+ x_2d, T_packed, E = ctx.saved_tensors
269
+ group_size = ctx.group_size
270
+ N, K = ctx.shape
271
+ M = x_2d.shape[0]
272
+ grad_2d = grad_output.reshape(-1, N).contiguous()
273
+ if ctx.has_corr:
274
+ corr_accum = ctx.module.corr_accum.contiguous()
275
+ step_counter = torch.tensor([ctx.step_snapshot], dtype=torch.int64, device=x_2d.device)
276
+ else:
277
+ corr_accum = torch.zeros(N * ((K + group_size - 1) // group_size),
278
+ dtype=torch.int64, device=x_2d.device)
279
+ step_counter = torch.zeros(1, dtype=torch.int64, device=x_2d.device)
280
+ grad_x_kernel = _get_grad_kernels(M, N, K, group_size)
281
+ with torch.no_grad():
282
+ grad_x = torch.empty(M, K, device=x_2d.device, dtype=torch.float32)
283
+ grad_x_kernel(grad_2d.half(), T_packed, E, corr_accum, step_counter, grad_x)
284
+ comp_name = ctx.comp_name
285
+ if _HAS_TRITON and ctx.has_corr and getattr(ctx.module, "_stream_backward_updates", True):
286
+ bwd_name, bwd_weight = _COMPONENT_CONTEXT.get()
287
+ if bwd_name is None:
288
+ bwd_weight = 1.0
289
+ base_step = int(getattr(ctx.module, "_backward_t_accum_step", 1))
290
+ corr_step = max(1, int(round(abs(float(bwd_weight)) * base_step)))
291
+ if bwd_weight < 0:
292
+ corr_step = -corr_step
293
+ _triton_accumulate_corr_direct(
294
+ T_packed, grad_2d, x_2d, ctx.module.corr_accum,
295
+ N, K, group_size, corr_step=corr_step,
296
+ )
297
+ ctx.module.step_counter.add_(abs(corr_step))
298
+ ctx.module._streamed_bigint_backward = True
299
+ elif _HAS_TRITON:
300
+ grad_sign = _triton_ternary_grad_sign(grad_2d, x_2d, N, K)
301
+ if comp_name is not None:
302
+ setattr(ctx.module, f"_hook_grad_T_sign_{comp_name}", grad_sign.detach())
303
+ else:
304
+ ctx.module._hook_grad_T_sign = grad_sign.detach()
305
+ elif comp_name is not None:
306
+ setattr(ctx.module, f"_hook_grad_2d_{comp_name}", grad_2d.detach())
307
+ setattr(ctx.module, f"_hook_x_2d_{comp_name}", x_2d.detach())
308
+ else:
309
+ ctx.module._hook_grad_2d = grad_2d.detach()
310
+ ctx.module._hook_x_2d = x_2d.detach()
311
+ grad_x_reshaped = grad_x.reshape(*ctx.x_shape).to(dtype=ctx.x_dtype)
312
+ return grad_x_reshaped, None, None
313
+
314
+
315
+ if _HAS_TRITON:
316
+
317
+ @triton.jit
318
+ def _triton_ternary_fwd_kernel(
319
+ x_ptr, packed_ptr, e_ptr, corr_ptr, step_ptr, out_ptr,
320
+ M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
321
+ GPR: tl.constexpr, GROUP_SIZE: tl.constexpr,
322
+ CORR_STRENGTH: tl.constexpr,
323
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
324
+ ):
325
+ pid_m = tl.program_id(0)
326
+ pid_n = tl.program_id(1)
327
+
328
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
329
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
330
+ offs_k = tl.arange(0, BLOCK_K)
331
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
332
+
333
+ for k0 in range(0, K, BLOCK_K):
334
+ k = k0 + offs_k
335
+ x = tl.load(
336
+ x_ptr + offs_m[:, None] * K + k[None, :],
337
+ mask=(offs_m[:, None] < M) & (k[None, :] < K),
338
+ other=0.0,
339
+ )
340
+
341
+ lin = offs_n[:, None] * K + k[None, :]
342
+ pack_idx = lin // 5
343
+ trit_pos = lin - pack_idx * 5
344
+ packed = tl.load(
345
+ packed_ptr + pack_idx,
346
+ mask=(offs_n[:, None] < N) & (k[None, :] < K),
347
+ other=0,
348
+ ).to(tl.int32)
349
+ divisor = tl.where(
350
+ trit_pos == 0, 1,
351
+ tl.where(trit_pos == 1, 3,
352
+ tl.where(trit_pos == 2, 9,
353
+ tl.where(trit_pos == 3, 27, 81))),
354
+ )
355
+ trit = (packed // divisor) % 3
356
+ sign = trit.to(tl.int32) - 1
357
+
358
+ e_idx = offs_n[:, None] * GPR + k[None, :] // GROUP_SIZE
359
+ e_val = tl.load(
360
+ e_ptr + e_idx,
361
+ mask=(offs_n[:, None] < N) & (k[None, :] < K),
362
+ other=0,
363
+ ).to(tl.float32)
364
+ corr_val = tl.load(
365
+ corr_ptr + e_idx,
366
+ mask=(offs_n[:, None] < N) & (k[None, :] < K),
367
+ other=0,
368
+ ).to(tl.float32)
369
+ step_val = tl.load(step_ptr).to(tl.float32)
370
+ denom = tl.maximum(step_val * GROUP_SIZE, 1.0)
371
+ e_adj = e_val + (corr_val / denom) * CORR_STRENGTH
372
+ w = sign.to(tl.float32) * tl.exp2(e_adj)
373
+ w = tl.where((offs_n[:, None] < N) & (k[None, :] < K), w, 0.0)
374
+ acc += tl.dot(x, tl.trans(w))
375
+
376
+ tl.store(
377
+ out_ptr + offs_m[:, None] * N + offs_n[None, :],
378
+ acc,
379
+ mask=(offs_m[:, None] < M) & (offs_n[None, :] < N),
380
+ )
381
+
382
+
383
+ @triton.jit
384
+ def _triton_ternary_grad_x_kernel(
385
+ grad_ptr, packed_ptr, e_ptr, corr_ptr, step_ptr, out_ptr,
386
+ M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
387
+ GPR: tl.constexpr, GROUP_SIZE: tl.constexpr,
388
+ CORR_STRENGTH: tl.constexpr,
389
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
390
+ ):
391
+ pid_m = tl.program_id(0)
392
+ pid_k = tl.program_id(1)
393
+
394
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
395
+ offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
396
+ offs_n = tl.arange(0, BLOCK_N)
397
+ acc = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32)
398
+
399
+ for n0 in range(0, N, BLOCK_N):
400
+ n = n0 + offs_n
401
+ grad = tl.load(
402
+ grad_ptr + offs_m[:, None] * N + n[None, :],
403
+ mask=(offs_m[:, None] < M) & (n[None, :] < N),
404
+ other=0.0,
405
+ )
406
+
407
+ lin = n[:, None] * K + offs_k[None, :]
408
+ pack_idx = lin // 5
409
+ trit_pos = lin - pack_idx * 5
410
+ packed = tl.load(
411
+ packed_ptr + pack_idx,
412
+ mask=(n[:, None] < N) & (offs_k[None, :] < K),
413
+ other=0,
414
+ ).to(tl.int32)
415
+ divisor = tl.where(
416
+ trit_pos == 0, 1,
417
+ tl.where(trit_pos == 1, 3,
418
+ tl.where(trit_pos == 2, 9,
419
+ tl.where(trit_pos == 3, 27, 81))),
420
+ )
421
+ trit = (packed // divisor) % 3
422
+ sign = trit.to(tl.int32) - 1
423
+
424
+ e_idx = n[:, None] * GPR + offs_k[None, :] // GROUP_SIZE
425
+ e_val = tl.load(
426
+ e_ptr + e_idx,
427
+ mask=(n[:, None] < N) & (offs_k[None, :] < K),
428
+ other=0,
429
+ ).to(tl.float32)
430
+ corr_val = tl.load(
431
+ corr_ptr + e_idx,
432
+ mask=(n[:, None] < N) & (offs_k[None, :] < K),
433
+ other=0,
434
+ ).to(tl.float32)
435
+ step_val = tl.load(step_ptr).to(tl.float32)
436
+ denom = tl.maximum(step_val * GROUP_SIZE, 1.0)
437
+ e_adj = e_val + (corr_val / denom) * CORR_STRENGTH
438
+ w = sign.to(tl.float32) * tl.exp2(e_adj)
439
+ w = tl.where((n[:, None] < N) & (offs_k[None, :] < K), w, 0.0)
440
+ acc += tl.dot(grad, w)
441
+
442
+ tl.store(
443
+ out_ptr + offs_m[:, None] * K + offs_k[None, :],
444
+ acc,
445
+ mask=(offs_m[:, None] < M) & (offs_k[None, :] < K),
446
+ )
447
+
448
+
449
+ @triton.jit
450
+ def _triton_ternary_grad_sign_kernel(
451
+ grad_ptr, x_ptr, sign_ptr,
452
+ M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
453
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
454
+ ):
455
+ pid_n = tl.program_id(0)
456
+ pid_k = tl.program_id(1)
457
+
458
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
459
+ offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
460
+ offs_m = tl.arange(0, BLOCK_M)
461
+ acc = tl.zeros((BLOCK_N, BLOCK_K), dtype=tl.float32)
462
+
463
+ for m0 in range(0, M, BLOCK_M):
464
+ m = m0 + offs_m
465
+ grad = tl.load(
466
+ grad_ptr + m[:, None] * N + offs_n[None, :],
467
+ mask=(m[:, None] < M) & (offs_n[None, :] < N),
468
+ other=0.0,
469
+ )
470
+ x = tl.load(
471
+ x_ptr + m[:, None] * K + offs_k[None, :],
472
+ mask=(m[:, None] < M) & (offs_k[None, :] < K),
473
+ other=0.0,
474
+ )
475
+ acc += tl.dot(tl.trans(grad), x, input_precision="ieee")
476
+
477
+ sign = tl.where(acc > 0.0, 1, tl.where(acc < 0.0, -1, 0))
478
+ tl.store(
479
+ sign_ptr + offs_n[:, None] * K + offs_k[None, :],
480
+ sign.to(tl.int8),
481
+ mask=(offs_n[:, None] < N) & (offs_k[None, :] < K),
482
+ )
483
+
484
+
485
+ @triton.jit
486
+ def _triton_update_e_kernel(
487
+ packed_ptr, grad_sign_ptr, e_ptr, e_accum_ptr,
488
+ N: tl.constexpr, K: tl.constexpr,
489
+ GROUP_SIZE: tl.constexpr, GPR: tl.constexpr,
490
+ E_ACCUM_THRESHOLD: tl.constexpr,
491
+ BLOCK_N: tl.constexpr, BLOCK_G: tl.constexpr, BLOCK_K: tl.constexpr,
492
+ ):
493
+ pid_n = tl.program_id(0)
494
+ pid_g = tl.program_id(1)
495
+
496
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
497
+ offs_g = pid_g * BLOCK_G + tl.arange(0, BLOCK_G)
498
+ offs_r = tl.arange(0, BLOCK_K)
499
+ k = offs_g[:, None] * GROUP_SIZE + offs_r[None, :]
500
+ valid_group = offs_g < GPR
501
+
502
+ lin = offs_n[:, None, None] * K + k[None, :, :]
503
+ pack_idx = lin // 5
504
+ trit_pos = lin - pack_idx * 5
505
+ packed = tl.load(
506
+ packed_ptr + pack_idx,
507
+ mask=(offs_n[:, None, None] < N) & valid_group[None, :, None] & (offs_r[None, None, :] < GROUP_SIZE) & (k[None, :, :] < K),
508
+ other=0,
509
+ ).to(tl.int32)
510
+ divisor = tl.where(
511
+ trit_pos == 0, 1,
512
+ tl.where(trit_pos == 1, 3,
513
+ tl.where(trit_pos == 2, 9,
514
+ tl.where(trit_pos == 3, 27, 81))),
515
+ )
516
+ trit = (packed // divisor) % 3
517
+ ternary = trit.to(tl.int32) - 1
518
+
519
+ grad_sign = tl.load(
520
+ grad_sign_ptr + offs_n[:, None, None] * K + k[None, :, :],
521
+ mask=(offs_n[:, None, None] < N) & valid_group[None, :, None] & (offs_r[None, None, :] < GROUP_SIZE) & (k[None, :, :] < K),
522
+ other=0,
523
+ ).to(tl.int32)
524
+ contrib = grad_sign * ternary
525
+ score = tl.sum(contrib, axis=2)
526
+ delta = tl.where(score > 0, -1, tl.where(score < 0, 1, 0))
527
+
528
+ e_idx = offs_n[:, None] * GPR + offs_g[None, :]
529
+ old_accum = tl.load(
530
+ e_accum_ptr + e_idx,
531
+ mask=(offs_n[:, None] < N) & valid_group[None, :],
532
+ other=0,
533
+ ).to(tl.int32)
534
+ new_accum = tl.minimum(127, tl.maximum(-128, old_accum + delta))
535
+ step_up = new_accum >= E_ACCUM_THRESHOLD
536
+ step_down = new_accum <= -E_ACCUM_THRESHOLD
537
+ e_step = tl.where(step_up, 1, tl.where(step_down, -1, 0))
538
+ stored_accum = new_accum - e_step * E_ACCUM_THRESHOLD
539
+
540
+ old_e = tl.load(
541
+ e_ptr + e_idx,
542
+ mask=(offs_n[:, None] < N) & valid_group[None, :],
543
+ other=0,
544
+ ).to(tl.int32)
545
+ new_e = tl.minimum(127, tl.maximum(-128, old_e + e_step))
546
+ tl.store(
547
+ e_ptr + e_idx,
548
+ new_e.to(tl.int8),
549
+ mask=(offs_n[:, None] < N) & valid_group[None, :],
550
+ )
551
+ tl.store(
552
+ e_accum_ptr + e_idx,
553
+ stored_accum.to(tl.int8),
554
+ mask=(offs_n[:, None] < N) & valid_group[None, :],
555
+ )
556
+
557
+
558
+ @triton.jit
559
+ def _triton_ternary_step_kernel(
560
+ packed_ptr, grad_sign_ptr, accum_ptr, per_group_threshold_ptr,
561
+ TOTAL: tl.constexpr, ACCUM_THRESHOLD: tl.constexpr,
562
+ T_ACCUM_STEP: tl.constexpr,
563
+ K: tl.constexpr, GPR: tl.constexpr, GROUP_SIZE: tl.constexpr,
564
+ HAS_PER_GROUP_THRESHOLD: tl.constexpr,
565
+ BLOCK_T: tl.constexpr,
566
+ ):
567
+ pack_idx = tl.program_id(0)
568
+ offs_t = tl.arange(0, BLOCK_T)
569
+ valid_trit = offs_t < 5
570
+ lin = pack_idx * 5 + offs_t
571
+ valid = valid_trit & (lin < TOTAL)
572
+
573
+ old_packed = tl.load(packed_ptr + pack_idx).to(tl.int32)
574
+ divisor = tl.where(
575
+ offs_t == 0, 1,
576
+ tl.where(offs_t == 1, 3,
577
+ tl.where(offs_t == 2, 9,
578
+ tl.where(offs_t == 3, 27, 81))),
579
+ )
580
+ old_code = (old_packed // divisor) % 3
581
+ old_sign = old_code.to(tl.int32) - 1
582
+
583
+ grad_sign = tl.load(grad_sign_ptr + lin, mask=valid, other=0).to(tl.int32)
584
+ old_accum = tl.load(accum_ptr + lin, mask=valid, other=0).to(tl.int32)
585
+ new_accum = tl.minimum(127, tl.maximum(-128, old_accum - grad_sign * T_ACCUM_STEP))
586
+
587
+ if HAS_PER_GROUP_THRESHOLD:
588
+ n = lin // K
589
+ k = lin - n * K
590
+ g_idx = n * GPR + k // GROUP_SIZE
591
+ threshold = tl.load(per_group_threshold_ptr + g_idx, mask=valid, other=ACCUM_THRESHOLD).to(tl.int32)
592
+ else:
593
+ threshold = ACCUM_THRESHOLD
594
+
595
+ flip_up = new_accum > threshold
596
+ flip_down = new_accum < -threshold
597
+ did_flip = valid & (flip_up | flip_down)
598
+ new_sign = tl.where(flip_up, 1, tl.where(flip_down, -1, old_sign))
599
+ stored_accum = tl.where(did_flip, 0, new_accum)
600
+ tl.store(accum_ptr + lin, stored_accum.to(tl.int8), mask=valid)
601
+
602
+ new_code = tl.where(valid, new_sign + 1, 0)
603
+ packed_val = tl.sum(new_code * divisor, axis=0)
604
+ tl.store(packed_ptr + pack_idx, packed_val.to(tl.uint8))
605
+
606
+
607
+ @triton.jit
608
+ def _triton_update_e_direct_kernel(
609
+ packed_ptr, grad_ptr, x_ptr, e_ptr, e_accum_ptr,
610
+ M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
611
+ GROUP_SIZE: tl.constexpr, GPR: tl.constexpr,
612
+ E_ACCUM_THRESHOLD: tl.constexpr,
613
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
614
+ ):
615
+ pid_n = tl.program_id(0)
616
+ pid_g = tl.program_id(1)
617
+
618
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
619
+ offs_r = tl.arange(0, BLOCK_K)
620
+ k = pid_g * GROUP_SIZE + offs_r
621
+ offs_m = tl.arange(0, BLOCK_M)
622
+ acc = tl.zeros((BLOCK_N, BLOCK_K), dtype=tl.float32)
623
+
624
+ for m0 in range(0, M, BLOCK_M):
625
+ m = m0 + offs_m
626
+ grad = tl.load(
627
+ grad_ptr + m[:, None] * N + offs_n[None, :],
628
+ mask=(m[:, None] < M) & (offs_n[None, :] < N),
629
+ other=0.0,
630
+ )
631
+ x = tl.load(
632
+ x_ptr + m[:, None] * K + k[None, :],
633
+ mask=(m[:, None] < M) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K),
634
+ other=0.0,
635
+ )
636
+ acc += tl.dot(tl.trans(grad), x, input_precision="ieee")
637
+
638
+ grad_sign = tl.where(acc > 0.0, 1, tl.where(acc < 0.0, -1, 0)).to(tl.int32)
639
+ lin = offs_n[:, None] * K + k[None, :]
640
+ pack_idx = lin // 5
641
+ trit_pos = lin - pack_idx * 5
642
+ packed = tl.load(
643
+ packed_ptr + pack_idx,
644
+ mask=(offs_n[:, None] < N) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K),
645
+ other=0,
646
+ ).to(tl.int32)
647
+ divisor = tl.where(
648
+ trit_pos == 0, 1,
649
+ tl.where(trit_pos == 1, 3,
650
+ tl.where(trit_pos == 2, 9,
651
+ tl.where(trit_pos == 3, 27, 81))),
652
+ )
653
+ trit = (packed // divisor) % 3
654
+ ternary = trit.to(tl.int32) - 1
655
+ contrib = tl.where(
656
+ (offs_n[:, None] < N) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K),
657
+ grad_sign * ternary,
658
+ 0,
659
+ )
660
+ score = tl.sum(contrib, axis=1)
661
+ delta = tl.where(score > 0, -1, tl.where(score < 0, 1, 0))
662
+
663
+ e_idx = offs_n * GPR + pid_g
664
+ old_accum = tl.load(e_accum_ptr + e_idx, mask=offs_n < N, other=0).to(tl.int32)
665
+ new_accum = tl.minimum(127, tl.maximum(-128, old_accum + delta))
666
+ step_up = new_accum >= E_ACCUM_THRESHOLD
667
+ step_down = new_accum <= -E_ACCUM_THRESHOLD
668
+ e_step = tl.where(step_up, 1, tl.where(step_down, -1, 0))
669
+ stored_accum = new_accum - e_step * E_ACCUM_THRESHOLD
670
+
671
+ old_e = tl.load(e_ptr + e_idx, mask=offs_n < N, other=0).to(tl.int32)
672
+ new_e = tl.minimum(127, tl.maximum(-128, old_e + e_step))
673
+ tl.store(e_ptr + e_idx, new_e.to(tl.int8), mask=offs_n < N)
674
+ tl.store(e_accum_ptr + e_idx, stored_accum.to(tl.int8), mask=offs_n < N)
675
+
676
+
677
+ @triton.jit
678
+ def _triton_ternary_step_direct_kernel(
679
+ packed_ptr, grad_ptr, x_ptr, accum_ptr, per_group_threshold_ptr,
680
+ M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
681
+ TOTAL: tl.constexpr, ACCUM_THRESHOLD: tl.constexpr,
682
+ T_ACCUM_STEP: tl.constexpr,
683
+ GPR: tl.constexpr, GROUP_SIZE: tl.constexpr,
684
+ HAS_PER_GROUP_THRESHOLD: tl.constexpr,
685
+ BLOCK_M: tl.constexpr, BLOCK_T: tl.constexpr,
686
+ ):
687
+ pack_idx = tl.program_id(0)
688
+ offs_t = tl.arange(0, BLOCK_T)
689
+ lin = pack_idx * 5 + offs_t
690
+ valid_trit = offs_t < 5
691
+ valid = valid_trit & (lin < TOTAL)
692
+ n = lin // K
693
+ k = lin - n * K
694
+
695
+ offs_m = tl.arange(0, BLOCK_M)
696
+ acc = tl.zeros((BLOCK_T,), dtype=tl.float32)
697
+ for m0 in range(0, M, BLOCK_M):
698
+ m = m0 + offs_m
699
+ grad = tl.load(
700
+ grad_ptr + m[:, None] * N + n[None, :],
701
+ mask=(m[:, None] < M) & valid[None, :],
702
+ other=0.0,
703
+ )
704
+ x = tl.load(
705
+ x_ptr + m[:, None] * K + k[None, :],
706
+ mask=(m[:, None] < M) & valid[None, :],
707
+ other=0.0,
708
+ )
709
+ acc += tl.sum(grad * x, axis=0)
710
+
711
+ grad_sign = tl.where(acc > 0.0, 1, tl.where(acc < 0.0, -1, 0)).to(tl.int32)
712
+
713
+ old_packed = tl.load(packed_ptr + pack_idx).to(tl.int32)
714
+ divisor = tl.where(
715
+ offs_t == 0, 1,
716
+ tl.where(offs_t == 1, 3,
717
+ tl.where(offs_t == 2, 9,
718
+ tl.where(offs_t == 3, 27, 81))),
719
+ )
720
+ old_code = (old_packed // divisor) % 3
721
+ old_sign = old_code.to(tl.int32) - 1
722
+
723
+ old_accum = tl.load(accum_ptr + lin, mask=valid, other=0).to(tl.int32)
724
+ new_accum = tl.minimum(127, tl.maximum(-128, old_accum - grad_sign * T_ACCUM_STEP))
725
+
726
+ if HAS_PER_GROUP_THRESHOLD:
727
+ g_idx = n * GPR + k // GROUP_SIZE
728
+ threshold = tl.load(per_group_threshold_ptr + g_idx, mask=valid, other=ACCUM_THRESHOLD).to(tl.int32)
729
+ else:
730
+ threshold = ACCUM_THRESHOLD
731
+
732
+ flip_up = new_accum > threshold
733
+ flip_down = new_accum < -threshold
734
+ did_flip = valid & (flip_up | flip_down)
735
+ new_sign = tl.where(flip_up, 1, tl.where(flip_down, -1, old_sign))
736
+ stored_accum = tl.where(did_flip, 0, new_accum)
737
+ tl.store(accum_ptr + lin, stored_accum.to(tl.int8), mask=valid)
738
+
739
+ new_code = tl.where(valid, new_sign + 1, 0)
740
+ packed_val = tl.sum(new_code * divisor, axis=0)
741
+ tl.store(packed_ptr + pack_idx, packed_val.to(tl.uint8))
742
+
743
+
744
+ @triton.jit
745
+ def _triton_accumulate_t_direct_kernel(
746
+ grad_ptr, x_ptr, accum_ptr,
747
+ M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
748
+ TOTAL: tl.constexpr, T_ACCUM_STEP: tl.constexpr,
749
+ BLOCK_M: tl.constexpr, BLOCK_T: tl.constexpr,
750
+ ):
751
+ pack_idx = tl.program_id(0)
752
+ offs_t = tl.arange(0, BLOCK_T)
753
+ lin = pack_idx * 5 + offs_t
754
+ valid_trit = offs_t < 5
755
+ valid = valid_trit & (lin < TOTAL)
756
+ n = lin // K
757
+ k = lin - n * K
758
+
759
+ offs_m = tl.arange(0, BLOCK_M)
760
+ acc = tl.zeros((BLOCK_T,), dtype=tl.float32)
761
+ for m0 in range(0, M, BLOCK_M):
762
+ m = m0 + offs_m
763
+ grad = tl.load(
764
+ grad_ptr + m[:, None] * N + n[None, :],
765
+ mask=(m[:, None] < M) & valid[None, :],
766
+ other=0.0,
767
+ )
768
+ x = tl.load(
769
+ x_ptr + m[:, None] * K + k[None, :],
770
+ mask=(m[:, None] < M) & valid[None, :],
771
+ other=0.0,
772
+ )
773
+ acc += tl.sum(grad * x, axis=0)
774
+
775
+ grad_sign = tl.where(acc > 0.0, 1, tl.where(acc < 0.0, -1, 0)).to(tl.int32)
776
+ old_accum = tl.load(accum_ptr + lin, mask=valid, other=0).to(tl.int32)
777
+ new_accum = tl.minimum(127, tl.maximum(-128, old_accum - grad_sign * T_ACCUM_STEP))
778
+ tl.store(accum_ptr + lin, new_accum.to(tl.int8), mask=valid)
779
+
780
+
781
+ @triton.jit
782
+ def _triton_accumulate_e_direct_kernel(
783
+ packed_ptr, grad_ptr, x_ptr, e_accum_ptr,
784
+ M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
785
+ GROUP_SIZE: tl.constexpr, GPR: tl.constexpr,
786
+ E_ACCUM_STEP: tl.constexpr,
787
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
788
+ ):
789
+ pid_n = tl.program_id(0)
790
+ pid_g = tl.program_id(1)
791
+
792
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
793
+ offs_r = tl.arange(0, BLOCK_K)
794
+ k = pid_g * GROUP_SIZE + offs_r
795
+ offs_m = tl.arange(0, BLOCK_M)
796
+ acc = tl.zeros((BLOCK_N, BLOCK_K), dtype=tl.float32)
797
+
798
+ for m0 in range(0, M, BLOCK_M):
799
+ m = m0 + offs_m
800
+ grad = tl.load(
801
+ grad_ptr + m[:, None] * N + offs_n[None, :],
802
+ mask=(m[:, None] < M) & (offs_n[None, :] < N),
803
+ other=0.0,
804
+ )
805
+ x = tl.load(
806
+ x_ptr + m[:, None] * K + k[None, :],
807
+ mask=(m[:, None] < M) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K),
808
+ other=0.0,
809
+ )
810
+ acc += tl.dot(tl.trans(grad), x, input_precision="ieee")
811
+
812
+ grad_sign = tl.where(acc > 0.0, 1, tl.where(acc < 0.0, -1, 0)).to(tl.int32)
813
+ lin = offs_n[:, None] * K + k[None, :]
814
+ pack_idx = lin // 5
815
+ trit_pos = lin - pack_idx * 5
816
+ packed = tl.load(
817
+ packed_ptr + pack_idx,
818
+ mask=(offs_n[:, None] < N) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K),
819
+ other=0,
820
+ ).to(tl.int32)
821
+ divisor = tl.where(
822
+ trit_pos == 0, 1,
823
+ tl.where(trit_pos == 1, 3,
824
+ tl.where(trit_pos == 2, 9,
825
+ tl.where(trit_pos == 3, 27, 81))),
826
+ )
827
+ trit = (packed // divisor) % 3
828
+ ternary = trit.to(tl.int32) - 1
829
+ contrib = tl.where(
830
+ (offs_n[:, None] < N) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K),
831
+ grad_sign * ternary,
832
+ 0,
833
+ )
834
+ score = tl.sum(contrib, axis=1)
835
+ delta = tl.where(score > 0, -1, tl.where(score < 0, 1, 0))
836
+
837
+ e_idx = offs_n * GPR + pid_g
838
+ old_accum = tl.load(e_accum_ptr + e_idx, mask=offs_n < N, other=0).to(tl.int32)
839
+ new_accum = tl.minimum(127, tl.maximum(-128, old_accum + delta * E_ACCUM_STEP))
840
+ tl.store(e_accum_ptr + e_idx, new_accum.to(tl.int8), mask=offs_n < N)
841
+
842
+
843
+ @triton.jit
844
+ def _triton_accumulate_corr_direct_kernel(
845
+ packed_ptr, grad_ptr, x_ptr, corr_ptr,
846
+ M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
847
+ GROUP_SIZE: tl.constexpr, GPR: tl.constexpr,
848
+ CORR_STEP: tl.constexpr,
849
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
850
+ ):
851
+ pid_n = tl.program_id(0)
852
+ pid_g = tl.program_id(1)
853
+
854
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
855
+ offs_r = tl.arange(0, BLOCK_K)
856
+ k = pid_g * GROUP_SIZE + offs_r
857
+ offs_m = tl.arange(0, BLOCK_M)
858
+ acc = tl.zeros((BLOCK_N, BLOCK_K), dtype=tl.float32)
859
+
860
+ for m0 in range(0, M, BLOCK_M):
861
+ m = m0 + offs_m
862
+ grad = tl.load(
863
+ grad_ptr + m[:, None] * N + offs_n[None, :],
864
+ mask=(m[:, None] < M) & (offs_n[None, :] < N),
865
+ other=0.0,
866
+ )
867
+ x = tl.load(
868
+ x_ptr + m[:, None] * K + k[None, :],
869
+ mask=(m[:, None] < M) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K),
870
+ other=0.0,
871
+ )
872
+ acc += tl.dot(tl.trans(grad), x, input_precision="ieee")
873
+
874
+ grad_sign = tl.where(acc > 0.0, 1, tl.where(acc < 0.0, -1, 0)).to(tl.int32)
875
+ lin = offs_n[:, None] * K + k[None, :]
876
+ pack_idx = lin // 5
877
+ trit_pos = lin - pack_idx * 5
878
+ packed = tl.load(
879
+ packed_ptr + pack_idx,
880
+ mask=(offs_n[:, None] < N) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K),
881
+ other=0,
882
+ ).to(tl.int32)
883
+ divisor = tl.where(
884
+ trit_pos == 0, 1,
885
+ tl.where(trit_pos == 1, 3,
886
+ tl.where(trit_pos == 2, 9,
887
+ tl.where(trit_pos == 3, 27, 81))),
888
+ )
889
+ trit = (packed // divisor) % 3
890
+ ternary = trit.to(tl.int32) - 1
891
+ contrib = tl.where(
892
+ (offs_n[:, None] < N) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K),
893
+ grad_sign * ternary,
894
+ 0,
895
+ )
896
+ score = tl.sum(contrib, axis=1)
897
+
898
+ corr_idx = offs_n * GPR + pid_g
899
+ old_corr = tl.load(corr_ptr + corr_idx, mask=offs_n < N, other=0).to(tl.int64)
900
+ new_corr = old_corr - score.to(tl.int64) * CORR_STEP
901
+ tl.store(corr_ptr + corr_idx, new_corr, mask=offs_n < N)
902
+
903
+
904
+ @triton.jit
905
+ def _triton_apply_accumulated_flips_kernel(
906
+ packed_ptr, accum_ptr, per_group_threshold_ptr,
907
+ TOTAL: tl.constexpr, ACCUM_THRESHOLD: tl.constexpr,
908
+ K: tl.constexpr, GPR: tl.constexpr, GROUP_SIZE: tl.constexpr,
909
+ HAS_PER_GROUP_THRESHOLD: tl.constexpr,
910
+ BLOCK_T: tl.constexpr,
911
+ ):
912
+ pack_idx = tl.program_id(0)
913
+ offs_t = tl.arange(0, BLOCK_T)
914
+ valid_trit = offs_t < 5
915
+ lin = pack_idx * 5 + offs_t
916
+ valid = valid_trit & (lin < TOTAL)
917
+
918
+ old_packed = tl.load(packed_ptr + pack_idx).to(tl.int32)
919
+ divisor = tl.where(
920
+ offs_t == 0, 1,
921
+ tl.where(offs_t == 1, 3,
922
+ tl.where(offs_t == 2, 9,
923
+ tl.where(offs_t == 3, 27, 81))),
924
+ )
925
+ old_code = (old_packed // divisor) % 3
926
+ old_sign = old_code.to(tl.int32) - 1
927
+
928
+ old_accum = tl.load(accum_ptr + lin, mask=valid, other=0).to(tl.int32)
929
+ if HAS_PER_GROUP_THRESHOLD:
930
+ n = lin // K
931
+ k = lin - n * K
932
+ g_idx = n * GPR + k // GROUP_SIZE
933
+ threshold = tl.load(per_group_threshold_ptr + g_idx, mask=valid, other=ACCUM_THRESHOLD).to(tl.int32)
934
+ else:
935
+ threshold = ACCUM_THRESHOLD
936
+
937
+ flip_up = old_accum > threshold
938
+ flip_down = old_accum < -threshold
939
+ did_flip = valid & (flip_up | flip_down)
940
+ new_sign = tl.where(flip_up, 1, tl.where(flip_down, -1, old_sign))
941
+ stored_accum = tl.where(did_flip, 0, old_accum)
942
+ tl.store(accum_ptr + lin, stored_accum.to(tl.int8), mask=valid)
943
+
944
+ new_code = tl.where(valid, new_sign + 1, 0)
945
+ packed_val = tl.sum(new_code * divisor, axis=0)
946
+ tl.store(packed_ptr + pack_idx, packed_val.to(tl.uint8))
947
+
948
+
949
+ def _triton_ternary_forward(x_2d, packed, e, corr_accum, step_counter, n_out, k_in, group_size):
950
+ block_m, block_n, block_k = 16, 16, 32
951
+ out = torch.empty((x_2d.shape[0], n_out), device=x_2d.device, dtype=torch.float32)
952
+ grid = (triton.cdiv(x_2d.shape[0], block_m), triton.cdiv(n_out, block_n))
953
+ _triton_ternary_fwd_kernel[grid](
954
+ x_2d, packed, e, corr_accum, step_counter, out,
955
+ x_2d.shape[0], n_out, k_in, ceil(k_in / group_size), group_size,
956
+ _bigint_corr_strength(),
957
+ BLOCK_M=block_m, BLOCK_N=block_n, BLOCK_K=block_k,
958
+ )
959
+ return out
960
+
961
+
962
+ def _triton_ternary_grad_x(grad_2d, packed, e, corr_accum, step_counter, m_rows, n_out, k_in, group_size):
963
+ block_m, block_n, block_k = 16, 16, 32
964
+ out = torch.empty((m_rows, k_in), device=grad_2d.device, dtype=torch.float32)
965
+ grid = (triton.cdiv(m_rows, block_m), triton.cdiv(k_in, block_k))
966
+ _triton_ternary_grad_x_kernel[grid](
967
+ grad_2d, packed, e, corr_accum, step_counter, out,
968
+ m_rows, n_out, k_in, ceil(k_in / group_size), group_size,
969
+ _bigint_corr_strength(),
970
+ BLOCK_M=block_m, BLOCK_N=block_n, BLOCK_K=block_k,
971
+ )
972
+ return out
973
+
974
+
975
+ def _triton_ternary_grad_sign(grad_2d, x_2d, n_out, k_in):
976
+ block_m, block_n, block_k = 32, 16, 32
977
+ out = torch.empty((n_out, k_in), device=grad_2d.device, dtype=torch.int8)
978
+ grid = (triton.cdiv(n_out, block_n), triton.cdiv(k_in, block_k))
979
+ _triton_ternary_grad_sign_kernel[grid](
980
+ grad_2d, x_2d, out,
981
+ x_2d.shape[0], n_out, k_in,
982
+ BLOCK_M=block_m, BLOCK_N=block_n, BLOCK_K=block_k,
983
+ )
984
+ return out
985
+
986
+
987
+ def _triton_update_e(packed, grad_sign, e, e_accum, n_out, k_in, group_size, e_accum_threshold=4):
988
+ block_n, block_g = 8, 4
989
+ gpr = ceil(k_in / group_size)
990
+ block_k = 1 << (group_size - 1).bit_length()
991
+ grid = (triton.cdiv(n_out, block_n), triton.cdiv(gpr, block_g))
992
+ _triton_update_e_kernel[grid](
993
+ packed, grad_sign, e, e_accum,
994
+ n_out, k_in, group_size, gpr, int(e_accum_threshold),
995
+ BLOCK_N=block_n, BLOCK_G=block_g, BLOCK_K=block_k,
996
+ )
997
+
998
+
999
+ def _triton_update_e_direct(packed, grad_2d, x_2d, e, e_accum, n_out, k_in, group_size, e_accum_threshold=4):
1000
+ block_m, block_n = 32, 8
1001
+ block_k = 1 << (group_size - 1).bit_length()
1002
+ gpr = ceil(k_in / group_size)
1003
+ grid = (triton.cdiv(n_out, block_n), gpr)
1004
+ _triton_update_e_direct_kernel[grid](
1005
+ packed, grad_2d, x_2d, e, e_accum,
1006
+ x_2d.shape[0], n_out, k_in, group_size, gpr, int(e_accum_threshold),
1007
+ BLOCK_M=block_m, BLOCK_N=block_n, BLOCK_K=block_k,
1008
+ )
1009
+
1010
+
1011
+ def _triton_ternary_step(packed, grad_sign, accum, total, accum_threshold, t_accum_step=1,
1012
+ per_group_threshold=None, n_out=0, k_in=0, group_size=0):
1013
+ block_t = 8
1014
+ grid = (triton.cdiv(total, 5),)
1015
+ has_pgt = per_group_threshold is not None
1016
+ dummy = torch.empty(1, device=accum.device, dtype=torch.int8)
1017
+ gpr = (k_in + group_size - 1) // group_size if has_pgt else 0
1018
+ _triton_ternary_step_kernel[grid](
1019
+ packed, grad_sign, accum,
1020
+ per_group_threshold if has_pgt else dummy,
1021
+ total, accum_threshold, int(t_accum_step),
1022
+ k_in if has_pgt else 0, gpr, group_size if has_pgt else 0,
1023
+ has_pgt,
1024
+ BLOCK_T=block_t,
1025
+ )
1026
+
1027
+
1028
+ def _triton_ternary_step_direct(packed, grad_2d, x_2d, accum, n_out, k_in, total, accum_threshold, t_accum_step=1,
1029
+ per_group_threshold=None, group_size=0):
1030
+ block_m, block_t = 32, 8
1031
+ grid = (triton.cdiv(total, 5),)
1032
+ has_pgt = per_group_threshold is not None
1033
+ dummy = torch.empty(1, device=accum.device, dtype=torch.int8)
1034
+ gpr = (k_in + group_size - 1) // group_size if has_pgt else 0
1035
+ _triton_ternary_step_direct_kernel[grid](
1036
+ packed, grad_2d, x_2d, accum,
1037
+ per_group_threshold if has_pgt else dummy,
1038
+ x_2d.shape[0], n_out, k_in,
1039
+ total, accum_threshold, int(t_accum_step),
1040
+ gpr, group_size if has_pgt else 0,
1041
+ has_pgt,
1042
+ BLOCK_M=block_m, BLOCK_T=block_t,
1043
+ )
1044
+
1045
+
1046
+ def _triton_accumulate_direct(packed, grad_2d, x_2d, t_accum, e_accum,
1047
+ n_out, k_in, group_size,
1048
+ t_accum_step=1, e_accum_step=1,
1049
+ update_scales=True):
1050
+ block_m, block_t = 32, 8
1051
+ total = n_out * k_in
1052
+ grid = (triton.cdiv(total, 5),)
1053
+ _triton_accumulate_t_direct_kernel[grid](
1054
+ grad_2d, x_2d, t_accum,
1055
+ grad_2d.shape[0], n_out, k_in, total, int(t_accum_step),
1056
+ BLOCK_M=block_m, BLOCK_T=block_t,
1057
+ )
1058
+ if update_scales and e_accum is not None:
1059
+ block_n = 8
1060
+ block_k = 1 << (group_size - 1).bit_length()
1061
+ gpr = ceil(k_in / group_size)
1062
+ grid_e = (triton.cdiv(n_out, block_n), gpr)
1063
+ _triton_accumulate_e_direct_kernel[grid_e](
1064
+ packed, grad_2d, x_2d, e_accum,
1065
+ grad_2d.shape[0], n_out, k_in, group_size, gpr, int(e_accum_step),
1066
+ BLOCK_M=block_m, BLOCK_N=block_n, BLOCK_K=block_k,
1067
+ )
1068
+
1069
+
1070
+ def _triton_accumulate_corr_direct(packed, grad_2d, x_2d, corr_accum,
1071
+ n_out, k_in, group_size, corr_step=1):
1072
+ block_m, block_n = 32, 8
1073
+ block_k = 1 << (group_size - 1).bit_length()
1074
+ gpr = ceil(k_in / group_size)
1075
+ grid = (triton.cdiv(n_out, block_n), gpr)
1076
+ _triton_accumulate_corr_direct_kernel[grid](
1077
+ packed, grad_2d, x_2d, corr_accum,
1078
+ grad_2d.shape[0], n_out, k_in, group_size, gpr, int(corr_step),
1079
+ BLOCK_M=block_m, BLOCK_N=block_n, BLOCK_K=block_k,
1080
+ )
1081
+
1082
+
1083
+ def _triton_apply_accumulated_flips(packed, accum, total, accum_threshold,
1084
+ per_group_threshold=None,
1085
+ k_in=0, group_size=0):
1086
+ block_t = 8
1087
+ grid = (triton.cdiv(total, 5),)
1088
+ has_pgt = per_group_threshold is not None
1089
+ dummy = torch.empty(1, device=accum.device, dtype=torch.int8)
1090
+ gpr = (k_in + group_size - 1) // group_size if has_pgt else 0
1091
+ _triton_apply_accumulated_flips_kernel[grid](
1092
+ packed, accum,
1093
+ per_group_threshold if has_pgt else dummy,
1094
+ total, accum_threshold,
1095
+ k_in if has_pgt else 0, gpr, group_size if has_pgt else 0,
1096
+ has_pgt,
1097
+ BLOCK_T=block_t,
1098
+ )
1099
+
1100
+
1101
+ @triton.jit
1102
+ def _triton_ternary_embed_fwd_kernel(
1103
+ idx_ptr, packed_ptr, e_ptr, out_ptr,
1104
+ NUM_IDX: tl.constexpr, DIM: tl.constexpr,
1105
+ VOCAB: tl.constexpr, GPR: tl.constexpr, GROUP_SIZE: tl.constexpr,
1106
+ BLOCK_B: tl.constexpr, BLOCK_D: tl.constexpr,
1107
+ ):
1108
+ pid = tl.program_id(0)
1109
+ offs_b = pid * BLOCK_B + tl.arange(0, BLOCK_B)
1110
+ offs_d = tl.arange(0, BLOCK_D)
1111
+
1112
+ idx = tl.load(idx_ptr + offs_b, mask=offs_b < NUM_IDX, other=0).to(tl.int32)
1113
+
1114
+ lin = idx[:, None] * DIM + offs_d[None, :]
1115
+ pack_idx = lin // 5
1116
+ trit_pos = lin - pack_idx * 5
1117
+ packed = tl.load(packed_ptr + pack_idx, mask=(offs_b[:, None] < NUM_IDX) & (offs_d[None, :] < DIM), other=0).to(tl.int32)
1118
+ divisor = tl.where(
1119
+ trit_pos == 0, 1,
1120
+ tl.where(trit_pos == 1, 3,
1121
+ tl.where(trit_pos == 2, 9,
1122
+ tl.where(trit_pos == 3, 27, 81))),
1123
+ )
1124
+ trit = (packed // divisor) % 3
1125
+ sign = trit.to(tl.int32) - 1
1126
+
1127
+ e_idx = idx[:, None] * GPR + offs_d[None, :] // GROUP_SIZE
1128
+ e_val = tl.load(e_ptr + e_idx, mask=(offs_b[:, None] < NUM_IDX) & (offs_d[None, :] < DIM), other=0).to(tl.float32)
1129
+ w = sign.to(tl.float32) * tl.exp2(e_val)
1130
+ w = tl.where((offs_b[:, None] < NUM_IDX) & (offs_d[None, :] < DIM), w, 0.0)
1131
+
1132
+ tl.store(
1133
+ out_ptr + offs_b[:, None] * DIM + offs_d[None, :],
1134
+ w,
1135
+ mask=(offs_b[:, None] < NUM_IDX) & (offs_d[None, :] < DIM),
1136
+ )
1137
+
1138
+
1139
+ @triton.jit
1140
+ def _triton_ternary_embed_bwd_accum_kernel(
1141
+ idx_ptr, grad_ptr, accum_ptr,
1142
+ NUM_IDX: tl.constexpr, DIM: tl.constexpr,
1143
+ BLOCK_B: tl.constexpr, BLOCK_D: tl.constexpr,
1144
+ ):
1145
+ pid = tl.program_id(0)
1146
+ offs_b = pid * BLOCK_B + tl.arange(0, BLOCK_B)
1147
+ offs_d = tl.arange(0, BLOCK_D)
1148
+ valid = (offs_b[:, None] < NUM_IDX) & (offs_d[None, :] < DIM)
1149
+ idx = tl.load(idx_ptr + offs_b, mask=offs_b < NUM_IDX, other=0).to(tl.int32)
1150
+ g = tl.load(grad_ptr + offs_b[:, None] * DIM + offs_d[None, :], mask=valid, other=0.0)
1151
+ dst = idx[:, None] * DIM + offs_d[None, :]
1152
+ tl.atomic_add(accum_ptr + dst, g, mask=valid)
1153
+
1154
+
1155
+ @triton.jit
1156
+ def _triton_ternary_embed_bwd_sign_kernel(
1157
+ accum_ptr, sign_ptr,
1158
+ VOCAB: tl.constexpr, DIM: tl.constexpr,
1159
+ BLOCK_V: tl.constexpr, BLOCK_D: tl.constexpr,
1160
+ ):
1161
+ pid_v = tl.program_id(0)
1162
+ offs_v = pid_v * BLOCK_V + tl.arange(0, BLOCK_V)
1163
+ offs_d = tl.arange(0, BLOCK_D)
1164
+ valid = (offs_v[:, None] < VOCAB) & (offs_d[None, :] < DIM)
1165
+ acc = tl.load(accum_ptr + offs_v[:, None] * DIM + offs_d[None, :], mask=valid, other=0.0)
1166
+ sign_val = tl.where(acc > 0.0, 1, tl.where(acc < 0.0, -1, 0)).to(tl.int8)
1167
+ tl.store(sign_ptr + offs_v[:, None] * DIM + offs_d[None, :], sign_val, mask=valid)
1168
+
1169
+
1170
+ def _triton_ternary_embed_grad_sign(indices, grad_output, vocab, dim):
1171
+ flat_idx = indices.reshape(-1).contiguous().to(torch.int32)
1172
+ grad_2d = grad_output.reshape(-1, dim).contiguous()
1173
+ num_idx = flat_idx.shape[0]
1174
+ accum = torch.zeros(vocab, dim, device=grad_output.device, dtype=torch.float32)
1175
+ block_b = 64
1176
+ grid = (triton.cdiv(num_idx, block_b),)
1177
+ _triton_ternary_embed_bwd_accum_kernel[grid](
1178
+ flat_idx, grad_2d, accum,
1179
+ num_idx, dim,
1180
+ BLOCK_B=block_b, BLOCK_D=triton.next_power_of_2(dim),
1181
+ )
1182
+ sign_out = torch.empty(vocab, dim, device=grad_output.device, dtype=torch.int8)
1183
+ block_v = 32
1184
+ grid2 = (triton.cdiv(vocab, block_v),)
1185
+ _triton_ternary_embed_bwd_sign_kernel[grid2](
1186
+ accum, sign_out,
1187
+ vocab, dim,
1188
+ BLOCK_V=block_v, BLOCK_D=triton.next_power_of_2(dim),
1189
+ )
1190
+ return sign_out
1191
+
1192
+
1193
+ def _triton_ternary_embed(indices, packed, e, vocab, dim, group_size):
1194
+ flat_idx = indices.reshape(-1).contiguous().to(torch.int32)
1195
+ num_idx = flat_idx.shape[0]
1196
+ out = torch.empty((num_idx, dim), device=indices.device, dtype=torch.float32)
1197
+ block_b, block_d = 32, triton.next_power_of_2(dim)
1198
+ gpr = ceil(dim / group_size)
1199
+ grid = (triton.cdiv(num_idx, block_b),)
1200
+ _triton_ternary_embed_fwd_kernel[grid](
1201
+ flat_idx, packed, e, out,
1202
+ num_idx, dim, vocab, gpr, group_size,
1203
+ BLOCK_B=block_b, BLOCK_D=block_d,
1204
+ )
1205
+ return out.reshape(*indices.shape, dim)
1206
+
1207
+
1208
+ class _TritonTernaryEmbedFn(torch.autograd.Function):
1209
+ @staticmethod
1210
+ def forward(ctx, indices, _dummy, module):
1211
+ shape = tuple(module._T_shape.tolist())
1212
+ vocab, dim = shape
1213
+ packed = module.T_packed.contiguous()
1214
+ e = module.E.contiguous()
1215
+ ctx.save_for_backward(indices, packed, e)
1216
+ ctx.module = module
1217
+ ctx.shape = shape
1218
+ ctx.group_size = module.group_size
1219
+ comp_name, _ = _COMPONENT_CONTEXT.get()
1220
+ ctx.comp_name = comp_name
1221
+ return _triton_ternary_embed(indices, packed, e, vocab, dim, module.group_size)
1222
+
1223
+ @staticmethod
1224
+ def backward(ctx, grad_output):
1225
+ indices, packed, e = ctx.saved_tensors
1226
+ vocab, dim = ctx.shape
1227
+ grad_2d = grad_output.reshape(-1, dim).contiguous()
1228
+ comp_name = ctx.comp_name
1229
+ has_corr = hasattr(ctx.module, "corr_accum") and hasattr(ctx.module, "_accumulate_corr_from_grad_sign")
1230
+ if getattr(ctx.module, "_stream_backward_updates", True) and has_corr:
1231
+ # BigInt streaming: accumulate correlation directly
1232
+ grad_sign = _triton_ternary_embed_grad_sign(indices, grad_2d, vocab, dim)
1233
+ T = unpack_ternary(packed, tuple(ctx.module._T_shape.tolist()), int(ctx.module._T_pad.item())).to(device=grad_sign.device)
1234
+ signed = grad_sign.to(torch.int16) * T.to(torch.int16)
1235
+ ctx.module._accumulate_corr_from_grad_sign(grad_sign)
1236
+ ctx.module._streamed_bigint_backward = True
1237
+ elif comp_name is not None:
1238
+ setattr(ctx.module, f"_hook_grad_T_sign_{comp_name}", _triton_ternary_embed_grad_sign(indices, grad_2d, vocab, dim))
1239
+ T = unpack_ternary(packed, tuple(ctx.module._T_shape.tolist()), int(ctx.module._T_pad.item()))
1240
+ setattr(ctx.module, f"_hook_T_{comp_name}", T.to(device=grad_2d.device))
1241
+ else:
1242
+ ctx.module._hook_grad_T_sign = _triton_ternary_embed_grad_sign(indices, grad_2d, vocab, dim)
1243
+ T = unpack_ternary(packed, tuple(ctx.module._T_shape.tolist()), int(ctx.module._T_pad.item()))
1244
+ ctx.module._hook_T = T.to(device=grad_2d.device)
1245
+ return None, None, None
1246
+
1247
+
1248
+ class _TritonTernaryLinearFn(torch.autograd.Function):
1249
+ @staticmethod
1250
+ def forward(ctx, x, module):
1251
+ shape = tuple(module._T_shape.tolist())
1252
+ n_out, k_in = shape
1253
+ x_2d = x.reshape(-1, k_in).contiguous()
1254
+ packed = module.T_packed.contiguous()
1255
+ e = module.E.contiguous()
1256
+ ctx.save_for_backward(x_2d, packed, e)
1257
+ ctx.step_snapshot = int(module.step_counter.item())
1258
+ ctx.x_shape = x.shape
1259
+ ctx.shape = shape
1260
+ ctx.group_size = module.group_size
1261
+ ctx.module = module
1262
+ comp_name, _ = _COMPONENT_CONTEXT.get()
1263
+ ctx.comp_name = comp_name
1264
+ corr = module.corr_accum.contiguous()
1265
+ step = module.step_counter.contiguous()
1266
+ out = _triton_ternary_forward(x_2d, packed, e, corr, step, n_out, k_in, module.group_size)
1267
+ return out.reshape(*x.shape[:-1], n_out)
1268
+
1269
+ @staticmethod
1270
+ def backward(ctx, grad_output):
1271
+ x_2d, packed, e = ctx.saved_tensors
1272
+ n_out, k_in = ctx.shape
1273
+ grad_2d = grad_output.reshape(-1, n_out).contiguous()
1274
+ corr = ctx.module.corr_accum.contiguous()
1275
+ step = torch.tensor([ctx.step_snapshot], device=e.device, dtype=torch.int64)
1276
+ grad_x = _triton_ternary_grad_x(
1277
+ grad_2d, packed, e, corr, step, x_2d.shape[0], n_out, k_in, ctx.group_size
1278
+ )
1279
+ with torch.no_grad():
1280
+ if getattr(ctx.module, "_stream_backward_updates", True):
1281
+ _, bwd_weight = _COMPONENT_CONTEXT.get()
1282
+ corr_step = max(1, int(round(abs(float(bwd_weight)))))
1283
+ if bwd_weight < 0:
1284
+ corr_step = -corr_step
1285
+ _triton_accumulate_corr_direct(
1286
+ packed, grad_2d, x_2d, ctx.module.corr_accum,
1287
+ n_out, k_in, ctx.group_size, corr_step=corr_step,
1288
+ )
1289
+ ctx.module.step_counter.add_(abs(corr_step))
1290
+ ctx.module._streamed_bigint_backward = True
1291
+ else:
1292
+ grad_sign = _triton_ternary_grad_sign(grad_2d, x_2d, n_out, k_in)
1293
+ comp_name = ctx.comp_name
1294
+ if comp_name is not None:
1295
+ setattr(ctx.module, f"_hook_grad_T_sign_{comp_name}", grad_sign.detach())
1296
+ else:
1297
+ ctx.module._hook_grad_T_sign = grad_sign.detach()
1298
+ return grad_x.reshape(*ctx.x_shape), None
1299
+
1300
+
1301
+ class _BigIntTernaryLinearFn(torch.autograd.Function):
1302
+ @staticmethod
1303
+ def forward(ctx, x, module):
1304
+ shape = tuple(module._T_shape.tolist())
1305
+ n_out, k_in = shape
1306
+ x_2d = x.reshape(-1, k_in).contiguous()
1307
+ ctx.module = module
1308
+ ctx.x_shape = x.shape
1309
+ ctx.shape = shape
1310
+ ctx.x_dtype = x.dtype
1311
+ ctx.save_for_backward(x_2d)
1312
+ with torch.no_grad():
1313
+ w_eff = module.dequantize().to(device=x.device, dtype=torch.float32)
1314
+ out = F.linear(x_2d.float(), w_eff, module.bias.float() if module.bias is not None else None)
1315
+ return out.reshape(*x.shape[:-1], n_out)
1316
+
1317
+ @staticmethod
1318
+ def backward(ctx, grad_output):
1319
+ (x_2d,) = ctx.saved_tensors
1320
+ module = ctx.module
1321
+ n_out, k_in = ctx.shape
1322
+ grad_2d = grad_output.reshape(-1, n_out).contiguous()
1323
+ with torch.no_grad():
1324
+ w_eff = module.dequantize().to(device=grad_2d.device, dtype=torch.float32)
1325
+ grad_x = grad_2d.float() @ w_eff
1326
+ grad_sign = (grad_2d.float().transpose(0, 1) @ x_2d.float()).sign().to(torch.int8)
1327
+ module._accumulate_corr_from_grad_sign(grad_sign)
1328
+ module._streamed_bigint_backward = True
1329
+ return grad_x.reshape(*ctx.x_shape).to(dtype=ctx.x_dtype), None
1330
+
1331
+
1332
+ """
1333
+ Log-Space Group Scale Representation
1334
+
1335
+ Convention (matching agents' Option B recommendation):
1336
+ S = 2^E where S = scale, E = int8 log-space exponent
1337
+ W_eff = T * 2^E
1338
+
1339
+ Key log-space properties exploited:
1340
+ Multiplication β†’ addition: S1 * S2 = 2^(E1 + E2)
1341
+ Division β†’ subtraction: S1 / S2 = 2^(E1 - E2)
1342
+ Dequant β†’ integer shift: 2^E * T = T << E (for E >= 0)
1343
+
1344
+ No IEEE floats in persistent state. E is stored as int8.
1345
+ Ephemeral float only exists in autograd's computation graph.
1346
+ """
1347
+
1348
+ class TScaleType(IntEnum):
1349
+ T4 = 4
1350
+ T6 = 6
1351
+ T8 = 8
1352
+ T16 = 16
1353
+ T32 = 32
1354
+ T64 = 64
1355
+ T96 = 96
1356
+
1357
+ GROUP_SIZES = {
1358
+ TScaleType.T4: 4,
1359
+ TScaleType.T6: 6,
1360
+ TScaleType.T8: 8,
1361
+ TScaleType.T16: 16,
1362
+ TScaleType.T32: 32,
1363
+ TScaleType.T64: 64,
1364
+ TScaleType.T96: 96,
1365
+ }
1366
+ TILE_SIZE = 384
1367
+
1368
+ def _n_groups(shape, group_size):
1369
+ out_dim, in_dim = shape
1370
+ return out_dim * ceil(in_dim / group_size)
1371
+
1372
+ def _expand_E(E, shape, group_size):
1373
+ out_dim, in_dim = shape
1374
+ gpr = ceil(in_dim / group_size)
1375
+ E_2d = E.view(out_dim, gpr)
1376
+ E_exp = E_2d.repeat_interleave(group_size, dim=1)
1377
+ if E_exp.shape[1] > in_dim:
1378
+ E_exp = E_exp[:, :in_dim]
1379
+ return E_exp
1380
+
1381
+ def _ternarize(x, threshold=0.05):
1382
+ return x.sign() * (x.abs() > threshold).to(x.dtype)
1383
+
1384
+
1385
+ def _scaled_init_threshold(threshold: float, init_std: float) -> float:
1386
+ if init_std <= 0:
1387
+ return threshold
1388
+ return min(float(threshold), 0.5 * float(init_std))
1389
+
1390
+ class TernaryScaleTensor(nn.Module):
1391
+ def __init__(
1392
+ self,
1393
+ in_dim: int,
1394
+ out_dim: int,
1395
+ threshold: float = 0.05,
1396
+ weight_init_std: float | None = None,
1397
+ tscale_type: TScaleType = TScaleType.T32,
1398
+ bias: bool = False,
1399
+ ):
1400
+ super().__init__()
1401
+ self.in_dim = in_dim
1402
+ self.out_dim = out_dim
1403
+ init_std = min(0.1, in_dim ** -0.5) if weight_init_std is None else float(weight_init_std)
1404
+ init_threshold = _scaled_init_threshold(threshold, init_std)
1405
+ self.threshold = init_threshold
1406
+ self.tscale_type = tscale_type
1407
+ self.group_size = GROUP_SIZES[tscale_type]
1408
+ shape = (out_dim, in_dim)
1409
+ n_grp = _n_groups(shape, self.group_size)
1410
+
1411
+ w_init = torch.randn(out_dim, in_dim) * init_std
1412
+ T_init = _ternarize(w_init, init_threshold)
1413
+ packed_T, T_shape, T_pad = pack_ternary(T_init)
1414
+
1415
+ self.register_buffer("T_packed", packed_T)
1416
+ self.register_buffer("_T_shape", torch.tensor([out_dim, in_dim], dtype=torch.long))
1417
+ self.register_buffer("_T_pad", torch.tensor(T_pad, dtype=torch.long))
1418
+
1419
+ gpr = ceil(in_dim / self.group_size)
1420
+ total_in = gpr * self.group_size
1421
+ padded = torch.zeros(out_dim, total_in)
1422
+ abs_w = w_init.abs()
1423
+ padded[:, :in_dim] = abs_w
1424
+ grouped = padded.view(out_dim, gpr, self.group_size)
1425
+ grp_means = grouped.mean(dim=2)
1426
+ E_vals = torch.where(grp_means > 0, grp_means, torch.ones_like(grp_means))
1427
+ E_int = E_vals.log2().clamp(-128, 127).to(torch.int8)
1428
+ self.register_buffer("E", E_int.flatten())
1429
+ self.register_buffer("corr_accum", torch.zeros_like(self.E, dtype=torch.int64))
1430
+ self.register_buffer("step_counter", torch.zeros(1, dtype=torch.int64))
1431
+
1432
+ if bias:
1433
+ self.register_buffer("bias", torch.zeros(out_dim, dtype=torch.int32))
1434
+ else:
1435
+ self.bias = None
1436
+
1437
+ def _get_T(self):
1438
+ return unpack_ternary(self.T_packed, tuple(self._T_shape.tolist()), int(self._T_pad.item()))
1439
+
1440
+ def _get_S(self):
1441
+ gpr = ceil(self.in_dim / self.group_size)
1442
+ e_adj = self.E.float()
1443
+ if hasattr(self, "corr_accum") and hasattr(self, "step_counter"):
1444
+ step = int(self.step_counter.item())
1445
+ if step > 0:
1446
+ denom = max(step * self.group_size, 1)
1447
+ e_adj = e_adj + (self.corr_accum.float() / denom) * _bigint_corr_strength()
1448
+ E_exp = _expand_E(e_adj, (self.out_dim, self.in_dim), self.group_size)
1449
+ return torch.exp2(E_exp)
1450
+
1451
+ def _ensure_group_lr(self):
1452
+ if not hasattr(self, "group_lr"):
1453
+ self.register_buffer("group_lr", torch.ones_like(self.E, dtype=torch.int8))
1454
+ elif self.group_lr.shape != self.E.shape or self.group_lr.device != self.E.device:
1455
+ self.group_lr = torch.ones_like(self.E, dtype=torch.int8)
1456
+ return self.group_lr
1457
+
1458
+ def precompile_kernels(self, M: int):
1459
+ pass
1460
+
1461
+ def forward(self, x):
1462
+ backend = _backend_preference()
1463
+ if backend == "tilelang" and _HAS_TILELANG:
1464
+ if torch.is_grad_enabled() and not _tilelang_training_enabled():
1465
+ raise RuntimeError(
1466
+ "ARB_TERNARY_BACKEND='tilelang' is inference-only by default. "
1467
+ "BigInt ternary training should use ARB_TERNARY_BACKEND='triton'. "
1468
+ "Set ARB_TILELANG_TRAINING=1 only for experimental TileLang training."
1469
+ )
1470
+ x_for_grad = x
1471
+ if torch.is_grad_enabled() and not x.requires_grad:
1472
+ x_for_grad = x.detach().requires_grad_(True)
1473
+ N, K = tuple(self._T_shape.tolist())
1474
+ x_2d = x_for_grad.reshape(-1, K)
1475
+ M = x_2d.shape[0]
1476
+ try:
1477
+ fwd_kernel = _get_kernel(M, N, K, self.group_size, "fwd")
1478
+ y = _TernaryLinearFn.apply(x_for_grad, self, fwd_kernel)
1479
+ if self.bias is not None:
1480
+ y = y + self.bias.float()
1481
+ return y
1482
+ except Exception as e:
1483
+ warnings.warn(f"TileLang forward failed for {self._T_shape.tolist()}: {e}")
1484
+ if _HAS_TRITON:
1485
+ backend = "triton"
1486
+ else:
1487
+ backend = "torch"
1488
+ if x.is_cuda and _HAS_TRITON and backend in {"auto", "triton"}:
1489
+ x_for_grad = x
1490
+ if torch.is_grad_enabled() and not x.requires_grad:
1491
+ x_for_grad = x.detach().requires_grad_(True)
1492
+ y = _TritonTernaryLinearFn.apply(x_for_grad, self)
1493
+ if self.bias is not None:
1494
+ y = y + self.bias.float()
1495
+ return y
1496
+ if backend == "triton":
1497
+ raise RuntimeError("ARB_TERNARY_BACKEND='triton' requested, but Triton is unavailable for this input.")
1498
+ x_for_grad = x
1499
+ if torch.is_grad_enabled() and not x.requires_grad:
1500
+ x_for_grad = x.detach().requires_grad_(True)
1501
+ return _BigIntTernaryLinearFn.apply(x_for_grad, self)
1502
+
1503
+ @torch.no_grad()
1504
+ def _accumulate_corr_from_grad_sign(self, grad_sign, corr_step=1):
1505
+ shape = tuple(self._T_shape.tolist())
1506
+ out_dim, in_dim = shape
1507
+ if tuple(grad_sign.shape) != shape:
1508
+ return
1509
+ T = self._get_T().to(device=grad_sign.device, dtype=torch.int16)
1510
+ signed = grad_sign.to(torch.int16) * T
1511
+ gpr = ceil(in_dim / self.group_size)
1512
+ total_in = gpr * self.group_size
1513
+ if total_in > in_dim:
1514
+ signed = F.pad(signed, (0, total_in - in_dim))
1515
+ score = signed.view(out_dim, gpr, self.group_size).sum(dim=2, dtype=torch.int16)
1516
+ self.corr_accum -= score.flatten().to(device=self.corr_accum.device, dtype=torch.int64) * int(corr_step)
1517
+ self.step_counter += abs(int(corr_step))
1518
+
1519
+ def ternary_step(self, lr=1, accum_threshold=None):
1520
+ self._had_flip = False
1521
+ if hasattr(self, "_hook_grad_T_sign"):
1522
+ self._accumulate_corr_from_grad_sign(self._hook_grad_T_sign)
1523
+ del self._hook_grad_T_sign
1524
+
1525
+ def update_E(self, lr=1, loss_signal=None):
1526
+ has_dense_grad = hasattr(self, "_hook_grad_T_sign")
1527
+ has_direct_grad = hasattr(self, "_hook_grad_2d") and hasattr(self, "_hook_x_2d")
1528
+ if not has_dense_grad and not has_direct_grad:
1529
+ return
1530
+ if has_dense_grad:
1531
+ self._accumulate_corr_from_grad_sign(self._hook_grad_T_sign)
1532
+ del self._hook_grad_T_sign
1533
+ else:
1534
+ grad = self._hook_grad_2d.to(device=self.E.device, dtype=torch.float32)
1535
+ x = self._hook_x_2d.to(device=self.E.device, dtype=torch.float32)
1536
+ grad_sign = (grad.transpose(0, 1) @ x).sign().to(torch.int8)
1537
+ self._accumulate_corr_from_grad_sign(grad_sign)
1538
+ del self._hook_grad_2d
1539
+ del self._hook_x_2d
1540
+ if hasattr(self, "_hook_T"):
1541
+ del self._hook_T
1542
+
1543
+ @property
1544
+ def effective_bpw(self) -> float:
1545
+ group_size = self.group_size
1546
+ total = self._T_shape[0].item() * self._T_shape[1].item()
1547
+ n_grp = _n_groups(tuple(self._T_shape.tolist()), group_size)
1548
+ sign_bits = total * (8 / 5)
1549
+ scale_bits = n_grp * 8.0
1550
+ corr_bits = n_grp * 64.0
1551
+ bias_bits = self.bias.numel() * 32.0 if self.bias is not None else 0.0
1552
+ return (sign_bits + scale_bits + corr_bits + bias_bits) / total
1553
+
1554
+ def dequantize(self) -> torch.Tensor:
1555
+ T = self._get_T().float()
1556
+ S = self._get_S()
1557
+ return S * T
1558
+
1559
+ def tscale_to(self, tscale_type: TScaleType):
1560
+ self.tscale_type = tscale_type
1561
+ old_group_size = self.group_size
1562
+ self.group_size = GROUP_SIZES[tscale_type]
1563
+ shape = tuple(self._T_shape.tolist())
1564
+ out_dim, in_dim = shape
1565
+ new_gpr = ceil(in_dim / self.group_size)
1566
+ new_n_grp = out_dim * new_gpr
1567
+ if self.E.shape[0] != new_n_grp:
1568
+ T = self._get_T().float()
1569
+ total_in = new_gpr * self.group_size
1570
+ padded = torch.zeros(out_dim, total_in, device=self.T_packed.device)
1571
+ abs_w = T.abs()
1572
+ padded[:, :in_dim] = abs_w
1573
+ grouped = padded.view(out_dim, new_gpr, self.group_size)
1574
+ grp_means = grouped.mean(dim=2)
1575
+ E_new = torch.where(grp_means > 0, grp_means, torch.ones_like(grp_means))
1576
+ E_int = E_new.log2().clamp(-128, 127).to(torch.int8)
1577
+ self.E = E_int.flatten()
1578
+ self.corr_accum = torch.zeros_like(self.E, dtype=torch.int64)
1579
+ self.step_counter = torch.zeros(1, dtype=torch.int64, device=self.E.device)
1580
+ return self
1581
+
1582
+ tscale_cast = tscale_to
1583
+
1584
+ def extra_repr(self) -> str:
1585
+ return (
1586
+ f"in_dim={self.in_dim}, out_dim={self.out_dim}, "
1587
+ f"tscale_type={self.tscale_type.name}, group_size={self.group_size}, "
1588
+ f"effective_bpw={self.effective_bpw:.2f}"
1589
+ )
1590
+
1591
+ if _HAS_TRITON:
1592
+
1593
+ @triton.jit
1594
+ def _triton_rmsnorm_fwd_kernel(
1595
+ x_ptr, packed_ptr, e_ptr, out_ptr,
1596
+ BATCH: tl.constexpr, DIM: tl.constexpr,
1597
+ GPR: tl.constexpr, GROUP_SIZE: tl.constexpr,
1598
+ BLOCK_B: tl.constexpr, BLOCK_D: tl.constexpr,
1599
+ ):
1600
+ pid_b = tl.program_id(0)
1601
+ offs_b = pid_b * BLOCK_B + tl.arange(0, BLOCK_B)
1602
+ offs_d = tl.arange(0, BLOCK_D)
1603
+
1604
+ x = tl.load(
1605
+ x_ptr + offs_b[:, None] * DIM + offs_d[None, :],
1606
+ mask=(offs_b[:, None] < BATCH) & (offs_d[None, :] < DIM),
1607
+ other=0.0,
1608
+ )
1609
+ sq = x * x
1610
+ msq = tl.sum(sq, axis=1, keep_dims=True) / DIM
1611
+ rms = tl.sqrt(msq + 1e-5)
1612
+ x_norm = x / rms
1613
+
1614
+ pack_idx = offs_d // 5
1615
+ trit_pos = offs_d - pack_idx * 5
1616
+ packed = tl.load(packed_ptr + pack_idx, mask=offs_d < DIM, other=0).to(tl.int32)
1617
+ divisor = tl.where(
1618
+ trit_pos == 0, 1,
1619
+ tl.where(trit_pos == 1, 3,
1620
+ tl.where(trit_pos == 2, 9,
1621
+ tl.where(trit_pos == 3, 27, 81))),
1622
+ )
1623
+ trit = (packed // divisor) % 3
1624
+ sign = trit.to(tl.int32) - 1
1625
+
1626
+ e_idx = offs_d // GROUP_SIZE
1627
+ e_val = tl.load(e_ptr + e_idx, mask=offs_d < DIM, other=0).to(tl.float32)
1628
+ w = sign.to(tl.float32) * tl.exp2(e_val)
1629
+ w = tl.where(offs_d < DIM, w, 0.0)
1630
+
1631
+ out = x_norm * w[None, :]
1632
+ tl.store(
1633
+ out_ptr + offs_b[:, None] * DIM + offs_d[None, :],
1634
+ out,
1635
+ mask=(offs_b[:, None] < BATCH) & (offs_d[None, :] < DIM),
1636
+ )
1637
+
1638
+
1639
+ @triton.jit
1640
+ def _triton_rmsnorm_bwd_kernel(
1641
+ grad_out_ptr, x_ptr, packed_ptr, e_ptr,
1642
+ grad_x_ptr,
1643
+ BATCH: tl.constexpr, DIM: tl.constexpr,
1644
+ GPR: tl.constexpr, GROUP_SIZE: tl.constexpr,
1645
+ BLOCK_B: tl.constexpr, BLOCK_D: tl.constexpr,
1646
+ ):
1647
+ pid_b = tl.program_id(0)
1648
+ offs_b = pid_b * BLOCK_B + tl.arange(0, BLOCK_B)
1649
+ offs_d = tl.arange(0, BLOCK_D)
1650
+
1651
+ x = tl.load(
1652
+ x_ptr + offs_b[:, None] * DIM + offs_d[None, :],
1653
+ mask=(offs_b[:, None] < BATCH) & (offs_d[None, :] < DIM),
1654
+ other=0.0,
1655
+ )
1656
+ sq = x * x
1657
+ msq = tl.sum(sq, axis=1, keep_dims=True) / DIM
1658
+ rms = tl.sqrt(msq + 1e-5)
1659
+ x_norm = x / rms
1660
+
1661
+ pack_idx = offs_d // 5
1662
+ trit_pos = offs_d - pack_idx * 5
1663
+ packed = tl.load(packed_ptr + pack_idx, mask=offs_d < DIM, other=0).to(tl.int32)
1664
+ divisor = tl.where(
1665
+ trit_pos == 0, 1,
1666
+ tl.where(trit_pos == 1, 3,
1667
+ tl.where(trit_pos == 2, 9,
1668
+ tl.where(trit_pos == 3, 27, 81))),
1669
+ )
1670
+ trit = (packed // divisor) % 3
1671
+ sign = trit.to(tl.int32) - 1
1672
+
1673
+ e_idx = offs_d // GROUP_SIZE
1674
+ e_val = tl.load(e_ptr + e_idx, mask=offs_d < DIM, other=0).to(tl.float32)
1675
+ w = sign.to(tl.float32) * tl.exp2(e_val)
1676
+ w = tl.where(offs_d < DIM, w, 0.0)
1677
+
1678
+ dy = tl.load(
1679
+ grad_out_ptr + offs_b[:, None] * DIM + offs_d[None, :],
1680
+ mask=(offs_b[:, None] < BATCH) & (offs_d[None, :] < DIM),
1681
+ other=0.0,
1682
+ )
1683
+ dyw = dy * w[None, :]
1684
+
1685
+ c1 = tl.sum(x_norm * dyw, axis=1, keep_dims=True) / DIM
1686
+ dx = (dyw - x_norm * c1) / rms
1687
+
1688
+ tl.store(
1689
+ grad_x_ptr + offs_b[:, None] * DIM + offs_d[None, :],
1690
+ dx,
1691
+ mask=(offs_b[:, None] < BATCH) & (offs_d[None, :] < DIM),
1692
+ )
1693
+
1694
+
1695
+ class _TritonRMSNormFn(torch.autograd.Function):
1696
+ @staticmethod
1697
+ def forward(ctx, x, module, packed, e, dim, group_size):
1698
+ ctx.module = module
1699
+ x_2d = x.reshape(-1, dim).contiguous()
1700
+ batch = x_2d.shape[0]
1701
+ out = torch.empty_like(x_2d)
1702
+ block_b = 16
1703
+ grid = (triton.cdiv(batch, block_b),)
1704
+ _triton_rmsnorm_fwd_kernel[grid](
1705
+ x_2d, packed, e, out,
1706
+ batch, dim, ceil(dim / group_size), group_size,
1707
+ BLOCK_B=block_b, BLOCK_D=triton.next_power_of_2(dim),
1708
+ )
1709
+ ctx.save_for_backward(x_2d, packed, e)
1710
+ ctx.dim = dim
1711
+ ctx.group_size = group_size
1712
+ comp_name, _ = _COMPONENT_CONTEXT.get()
1713
+ ctx.comp_name = comp_name
1714
+ return out.reshape(*x.shape)
1715
+
1716
+ @staticmethod
1717
+ def backward(ctx, grad_output):
1718
+ x_2d, packed, e = ctx.saved_tensors
1719
+ dim = ctx.dim
1720
+ group_size = ctx.group_size
1721
+ grad_2d = grad_output.reshape(-1, dim).contiguous()
1722
+ batch = grad_2d.shape[0]
1723
+ grad_x = torch.empty_like(x_2d)
1724
+ block_b = 16
1725
+ grid = (triton.cdiv(batch, block_b),)
1726
+ _triton_rmsnorm_bwd_kernel[grid](
1727
+ grad_2d, x_2d, packed, e, grad_x,
1728
+ batch, dim, ceil(dim / group_size), group_size,
1729
+ BLOCK_B=block_b, BLOCK_D=triton.next_power_of_2(dim),
1730
+ )
1731
+ return grad_x.reshape(*grad_output.shape), None, None, None, None, None
1732
+
1733
+
1734
+ class TernaryRMSNorm(nn.Module):
1735
+ def __init__(self, dim, eps=1e-5, threshold=0.05, tscale_type=TScaleType.T64):
1736
+ super().__init__()
1737
+ self.dim = dim
1738
+ self.eps = eps
1739
+ self.threshold = threshold
1740
+ self.tscale_type = tscale_type
1741
+ self.group_size = GROUP_SIZES[tscale_type]
1742
+ shape = (1, dim)
1743
+ n_grp = _n_groups(shape, self.group_size)
1744
+
1745
+ w_init = torch.ones(1, dim)
1746
+ T_init = _ternarize(w_init, threshold)
1747
+ packed_T, T_shape, T_pad = pack_ternary(T_init)
1748
+
1749
+ self.register_buffer("T_packed", packed_T)
1750
+ self.register_buffer("_T_shape", torch.tensor([1, dim], dtype=torch.long))
1751
+ self.register_buffer("_T_pad", torch.tensor(T_pad, dtype=torch.long))
1752
+
1753
+ gpr = ceil(dim / self.group_size)
1754
+ total_in = gpr * self.group_size
1755
+ padded = torch.zeros(1, total_in)
1756
+ abs_w = w_init.abs()
1757
+ padded[:, :dim] = abs_w
1758
+ grouped = padded.view(1, gpr, self.group_size)
1759
+ grp_means = grouped.mean(dim=2)
1760
+ E_vals = torch.where(grp_means > 0, grp_means, torch.ones_like(grp_means))
1761
+ self.register_buffer("E", E_vals.flatten().log2().clamp(-128, 127).to(torch.int8))
1762
+ self.register_buffer("E_accum", torch.zeros_like(self.E, dtype=torch.int8))
1763
+ self.register_buffer("group_lr", torch.ones_like(self.E, dtype=torch.int8))
1764
+
1765
+ self.register_buffer("T_accum", torch.zeros(1, dim, dtype=torch.int8))
1766
+
1767
+ def _ensure_E_accum(self):
1768
+ if not hasattr(self, "E_accum"):
1769
+ self.register_buffer("E_accum", torch.zeros_like(self.E, dtype=torch.int8))
1770
+ elif self.E_accum.shape != self.E.shape or self.E_accum.device != self.E.device:
1771
+ self.E_accum = torch.zeros_like(self.E, dtype=torch.int8)
1772
+ return self.E_accum
1773
+
1774
+ def _ensure_group_lr(self):
1775
+ if not hasattr(self, "group_lr"):
1776
+ self.register_buffer("group_lr", torch.ones_like(self.E, dtype=torch.int8))
1777
+ elif self.group_lr.shape != self.E.shape or self.group_lr.device != self.E.device:
1778
+ self.group_lr = torch.ones_like(self.E, dtype=torch.int8)
1779
+ return self.group_lr
1780
+
1781
+ def _get_T(self):
1782
+ return unpack_ternary(self.T_packed, tuple(self._T_shape.tolist()), int(self._T_pad.item())).squeeze(0)
1783
+
1784
+ def forward(self, x):
1785
+ if x.is_cuda and _HAS_TRITON and self.dim <= _rmsnorm_triton_max_dim():
1786
+ return _TritonRMSNormFn.apply(
1787
+ x, self, self.T_packed.contiguous(), self.E.contiguous(),
1788
+ self.dim, self.group_size,
1789
+ )
1790
+
1791
+ inv_rms = torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
1792
+ if x.is_cuda:
1793
+ # TernaryRMSNorm is initialized as an identity scale and does not
1794
+ # train E/T. Avoid unpacking a full large-dim weight or launching
1795
+ # the high-register Triton backward kernel on 8GB GPUs.
1796
+ return x * inv_rms
1797
+
1798
+ T = self._get_T()
1799
+ E_exp = _expand_E(self.E, tuple(self._T_shape.tolist()), self.group_size).squeeze(0)
1800
+ S = torch.exp2(E_exp.float())
1801
+ weight = S * T.float()
1802
+ return weight * (x * inv_rms)
1803
+
1804
+ def ternary_step(self, lr=1, accum_threshold=3):
1805
+ pass
1806
+
1807
+ def update_E(self, lr=1, loss_signal=None):
1808
+ pass
1809
+
1810
+ def extra_repr(self):
1811
+ return f"dim={self.dim}, tscale_type={self.tscale_type.name}"
arbitor/kernel/triton_video.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Triton kernels for video denoising (used by VideoHead)."""
2
+ import torch
3
+ import torch.nn as nn
4
+ from math import ceil as _ceil
5
+
6
+ from .ternary_scale import _HAS_TRITON
7
+
8
+ if _HAS_TRITON:
9
+ import triton
10
+ import triton.language as tl
11
+
12
+ @triton.jit
13
+ def _triton_video_denoise_fwd_kernel(
14
+ latent, pred_noise, out,
15
+ TOTAL: tl.constexpr, ALPHA: tl.constexpr, BLOCK: tl.constexpr,
16
+ ):
17
+ offsets = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
18
+ mask = offsets < TOTAL
19
+ l = tl.load(latent + offsets, mask=mask, other=0.0)
20
+ p = tl.load(pred_noise + offsets, mask=mask, other=0.0)
21
+ beta = 1.0 - ALPHA
22
+ inv_sqrt = 1.0 / tl.sqrt(ALPHA + 0.00000001)
23
+ tl.store(out + offsets, (l - beta * p) * inv_sqrt, mask=mask)
24
+
25
+ @triton.jit
26
+ def _triton_video_denoise_bwd_kernel(
27
+ grad_out, grad_latent, grad_pred,
28
+ TOTAL: tl.constexpr, ALPHA: tl.constexpr, BLOCK: tl.constexpr,
29
+ ):
30
+ offsets = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
31
+ mask = offsets < TOTAL
32
+ g = tl.load(grad_out + offsets, mask=mask, other=0.0)
33
+ beta = 1.0 - ALPHA
34
+ inv_sqrt = 1.0 / tl.sqrt(ALPHA + 0.00000001)
35
+ tl.store(grad_latent + offsets, g * inv_sqrt, mask=mask)
36
+ tl.store(grad_pred + offsets, -beta * g * inv_sqrt, mask=mask)
37
+
38
+
39
+ class _TritonVideoDenoiseFn(torch.autograd.Function):
40
+ @staticmethod
41
+ def forward(ctx, latent, pred_noise, alpha):
42
+ latent_c = latent.contiguous()
43
+ pred_c = pred_noise.contiguous()
44
+ out = torch.empty_like(latent_c)
45
+ total = latent_c.numel()
46
+ block = 256
47
+ grid = (_ceil_div(total, block),)
48
+ alpha_f = float(alpha)
49
+ _triton_video_denoise_fwd_kernel[grid](
50
+ latent_c, pred_c, out,
51
+ total, alpha_f, BLOCK=block,
52
+ )
53
+ ctx.alpha = alpha_f
54
+ ctx.shape = latent.shape
55
+ return out.reshape_as(latent)
56
+
57
+ @staticmethod
58
+ def backward(ctx, grad_out):
59
+ grad_c = grad_out.contiguous()
60
+ grad_latent = torch.empty_like(grad_c)
61
+ grad_pred = torch.empty_like(grad_c)
62
+ total = grad_c.numel()
63
+ block = 256
64
+ grid = (_ceil_div(total, block),)
65
+ _triton_video_denoise_bwd_kernel[grid](
66
+ grad_c, grad_latent, grad_pred,
67
+ total, ctx.alpha, BLOCK=block,
68
+ )
69
+ return grad_latent.reshape(ctx.shape), grad_pred.reshape(ctx.shape), None
70
+
71
+
72
+ def video_denoise_step(latent, pred_noise, alpha):
73
+ if _HAS_TRITON and latent.is_cuda and pred_noise.is_cuda and _TritonVideoDenoiseFn is not None:
74
+ return _TritonVideoDenoiseFn.apply(latent, pred_noise, alpha)
75
+ return (latent - (1 - alpha) * pred_noise) / (alpha ** 0.5 + 1e-8)
arbitor/main.py ADDED
@@ -0,0 +1,585 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ARB β€” Any Relational Bit. Core model assembly."""
2
+ import warnings
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from math import ceil as _ceil
7
+
8
+ _ceil_div = lambda a, b: _ceil(a / b) if b > 0 else 0
9
+
10
+ from .config import VOCAB, HIDDEN_DIM, SPECIAL_VOCAB, CTX, THRESHOLD, CODEBOOK_DIM, CODEBOOK_SIZE, KV_LEDGER_SIZE, KQ_CACHE_SIZE, MEMGRAM_STRUCT_PRIMES, MEMGRAM_CONV_PRIMES, MEMGRAM_EMBED_DIM, MEMGRAM_KEY_DIM, KGVQ_CODEBOOK_SIZE, KGVQ_CODEBOOK_DIM, K_MAX_COMPOSITES, MG_TOP_K
11
+ from .kernel.ternary_scale import TScaleType, TernaryScaleTensor, TernaryRMSNorm, _HAS_TRITON
12
+ try:
13
+ from .kernel.ternary_scale import _triton_apply_accumulated_flips
14
+ except ImportError:
15
+ _triton_apply_accumulated_flips = None
16
+ from .converters.convert_to_ternary8 import pack_ternary
17
+ try:
18
+ from .kernel.ternary_scale import _TritonTernaryEmbedFn
19
+ except ImportError:
20
+ _TritonTernaryEmbedFn = None
21
+ from .sequencers import ByteEmbedding, MultimodalSequencer
22
+ from .vq import SharedVQ
23
+ from .components import (
24
+ ByteHead, OutputRouter,
25
+ MemGram, LossComponents, LossWeights,
26
+ CompositeProposalHead, MoEGraph,
27
+ )
28
+ from .decoders import VideoHead, TalkerHead
29
+ from .components import _BOUNDARY_TOKEN_MAP as _BOUNDARY_MAP
30
+ from .attention import KVLedger, KQCache, ContextAttentionScheduler
31
+ from .kernel.flash_vq import FlashVQCodebook
32
+ def _extract_boundary_from_input(x):
33
+ if x.dim() != 2:
34
+ return None
35
+ first_token = x[0, 0].item()
36
+ if first_token in _BOUNDARY_MAP:
37
+ return first_token
38
+ for tok in x[0].tolist():
39
+ if tok in _BOUNDARY_MAP:
40
+ return tok
41
+ return None
42
+
43
+
44
+ class ARBModel(nn.Module):
45
+ def __init__(self, tscale_type=TScaleType.T32, threshold=THRESHOLD,
46
+ max_graph_hops=4, max_moe_iters=4, halt_threshold=0.99,
47
+ enable_image=False, enable_audio=False, enable_vq=True, enable_graph=True,
48
+ enable_memory_modules=False, enable_moe=True,
49
+ shared_vq_size=None, kgvq_codebook_size=None,
50
+ enable_attention=True, enable_output_router=True,
51
+ enable_video_output=True, enable_talker_output=True):
52
+ super().__init__()
53
+ self.image_enabled = enable_image
54
+ self.audio_enabled = enable_audio
55
+ self.embedding = ByteEmbedding(tscale_type=tscale_type)
56
+ self.multimodal_sequencer = MultimodalSequencer(
57
+ tscale_type=tscale_type,
58
+ enable_text=True, enable_image=enable_image, enable_audio=enable_audio,
59
+ )
60
+ self.text_sequencer = self.multimodal_sequencer.text
61
+ self.image_sequencer = self.multimodal_sequencer.image
62
+ self.audio_sequencer = self.multimodal_sequencer.audio
63
+ self.vq_enabled = enable_vq
64
+ self.bridge = SharedVQ(
65
+ codebook_size=shared_vq_size,
66
+ tscale_type=tscale_type, enable_image=enable_image, enable_audio=enable_audio,
67
+ ) if enable_vq else None
68
+ self.vq_to_trigram = TernaryScaleTensor(CODEBOOK_DIM, HIDDEN_DIM, tscale_type=tscale_type) if enable_vq else None
69
+ self.vq_to_trigram_norm = TernaryRMSNorm(HIDDEN_DIM, tscale_type=tscale_type) if enable_vq else None
70
+ self.graph_enabled = enable_graph and enable_vq
71
+ graph_vocab_size = self.bridge.total_codebook_size if self.graph_enabled else None
72
+ self.threshold = threshold
73
+ self.moegraph = MoEGraph(
74
+ trigram_dim=HIDDEN_DIM, codebook_size=graph_vocab_size or CODEBOOK_SIZE,
75
+ max_iters=max_moe_iters, halt_threshold=halt_threshold,
76
+ top_k=MG_TOP_K,
77
+ ) if self.graph_enabled else None
78
+ self.byte_head = ByteHead(tscale_type=tscale_type)
79
+ # Composite motif generation (Phase 17)
80
+ self.composite_head = CompositeProposalHead(
81
+ dim=HIDDEN_DIM, codebook_dim=KGVQ_CODEBOOK_DIM,
82
+ k_max=K_MAX_COMPOSITES, codebook_size=kgvq_codebook_size or KGVQ_CODEBOOK_SIZE,
83
+ tscale_type=tscale_type,
84
+ ) if self.graph_enabled else None
85
+ self.output_router = OutputRouter(tscale_type=tscale_type, depth=3) if enable_output_router else None
86
+ self.video_head = VideoHead(tscale_type=tscale_type) if enable_video_output else None
87
+ self.talker_head = TalkerHead(tscale_type=tscale_type) if enable_talker_output else None
88
+ self.memgram = MemGram(
89
+ struct_primes=MEMGRAM_STRUCT_PRIMES,
90
+ conv_primes=MEMGRAM_CONV_PRIMES,
91
+ embed_dim=MEMGRAM_EMBED_DIM, key_dim=MEMGRAM_KEY_DIM, hidden_dim=HIDDEN_DIM,
92
+ ) if enable_memory_modules else None
93
+ self.memgram_enabled = self.memgram is not None
94
+
95
+ # KV Ledger + Attention (Phase 16 β€” replaces LSTM)
96
+ self.kv_ledger = KVLedger(max_size=KV_LEDGER_SIZE) if enable_attention else None
97
+ self.kq_cache = KQCache(max_size=KQ_CACHE_SIZE) if enable_attention else None
98
+ self.attention = ContextAttentionScheduler(dim=HIDDEN_DIM) if enable_attention else None
99
+ self.attention_enabled = bool(enable_attention)
100
+
101
+ def forward(self, x, targets=None, commitment_warmup_weight=1.0,
102
+ act_warmup_mode=False, ponder_lambda=0.01, images=None,
103
+ audio=None, timestep=0, loss_weights=None, output_mode=None):
104
+ has_image = images is not None
105
+ has_audio = audio is not None
106
+ if has_image and (not self.image_enabled or self.image_sequencer is None):
107
+ raise ValueError("images provided but model has enable_image=False")
108
+ if has_audio and (not self.audio_enabled or self.audio_sequencer is None):
109
+ raise ValueError("audio provided but model has enable_audio=False")
110
+
111
+ embedded = self.embedding(x)
112
+ seq_inputs = {'text': embedded}
113
+ if has_image:
114
+ seq_inputs['image'] = images
115
+ if has_audio:
116
+ seq_inputs['audio'] = audio
117
+ seq_outputs = self.multimodal_sequencer(seq_inputs)
118
+ relational = seq_outputs['text']
119
+
120
+ indices_dict = {}
121
+ if self.vq_enabled:
122
+ bridge_inputs = {'text': relational}
123
+ if 'image' in seq_outputs:
124
+ bridge_inputs['image'] = seq_outputs['image']
125
+ if 'audio' in seq_outputs:
126
+ bridge_inputs['audio'] = seq_outputs['audio']
127
+
128
+ combined, vq_losses, indices_dict = self.bridge(bridge_inputs, timestep=timestep)
129
+ if combined is None:
130
+ combined = relational
131
+ elif combined.shape[-1] == CODEBOOK_DIM:
132
+ combined = self.vq_to_trigram_norm(self.vq_to_trigram(combined))
133
+ vq_loss = vq_losses.get('text_vq', torch.zeros((), device=x.device))
134
+ if 'image_vq' in vq_losses:
135
+ vq_loss = vq_loss + vq_losses['image_vq']
136
+ if 'audio_vq' in vq_losses:
137
+ vq_loss = vq_loss + vq_losses['audio_vq']
138
+ else:
139
+ combined = relational
140
+ vq_loss = torch.zeros((), device=x.device)
141
+
142
+ active_mods = ['text']
143
+ if has_image:
144
+ active_mods.append('image')
145
+ if has_audio:
146
+ active_mods.append('audio')
147
+ active_count = len(active_mods)
148
+
149
+ # MemGram injection (after VQ, before Graph β€” D92)
150
+ memgram_decay_reg = torch.tensor(0.0, device=x.device)
151
+
152
+ if self.memgram_enabled and self.memgram is not None and self.vq_enabled:
153
+ vq_indices = indices_dict.get('text', torch.zeros(combined.shape[0], combined.shape[1], dtype=torch.long, device=x.device))
154
+ combined = self.memgram(
155
+ vq_indices=vq_indices,
156
+ hidden_state=combined,
157
+ )
158
+
159
+ all_indices = None
160
+ composite_ids = None
161
+ composite_vq_loss = None
162
+ processed = combined
163
+ moegraph_ponder_loss = torch.tensor(0.0, device=x.device)
164
+
165
+ if self.graph_enabled and self.moegraph is not None and self.vq_enabled and vq_loss is not None:
166
+ self.moegraph._codebook_table = self.bridge.vq.table
167
+ self.moegraph._codebook_embed = None
168
+
169
+ all_indices = indices_dict.get('text', combined.new_zeros(combined.shape[0], combined.shape[1], dtype=torch.long))
170
+ if has_image and 'image' in indices_dict:
171
+ all_indices = torch.cat([all_indices, indices_dict['image']], dim=1)
172
+ if has_audio and 'audio' in indices_dict:
173
+ all_indices = torch.cat([all_indices, indices_dict['audio']], dim=1)
174
+
175
+ # MemGram retrieval for MoEGraph injection
176
+ memgram_cb = None
177
+ if self.memgram_enabled and self.memgram is not None and self.vq_enabled:
178
+ vq_idx = indices_dict.get('text', combined.new_zeros(combined.shape[0], combined.shape[1], dtype=torch.long))
179
+ memgram_cb = self.memgram.retrieve_cb(vq_idx)
180
+
181
+ # Attention output for KV conditioning
182
+ attn_out = None
183
+ if self.attention_enabled and self.attention is not None and self.kv_ledger is not None:
184
+ attn_out = self.attention(combined, self.kv_ledger, kq_cache=self.kq_cache)
185
+
186
+ # MoEGraph forward (unified ACT loop)
187
+ processed, moegraph_ponder_loss = self.moegraph(
188
+ combined, all_indices,
189
+ attention_output=attn_out,
190
+ memgram_cb_output=memgram_cb,
191
+ threshold=self.threshold,
192
+ )
193
+
194
+ # Composite motif generation (Phase 17)
195
+ if self.composite_head is not None:
196
+ composite_ids, composite_vq_loss, _ = self.composite_head(processed.mean(dim=1))
197
+
198
+ # Update bounded int-only KG co-occurrence state.
199
+ self.moegraph.update_kg_edges(all_indices)
200
+
201
+ # OutputRouter: route to appropriate head
202
+ if targets is not None or output_mode == "text":
203
+ logits = self.byte_head(processed)
204
+ elif output_mode == "video":
205
+ if self.video_head is None:
206
+ raise ValueError("output_mode='video' requested but video output is disabled")
207
+ logits = self.video_head(processed)
208
+ elif output_mode in {"audio", "talker"}:
209
+ if self.talker_head is None:
210
+ raise ValueError("audio/talker output requested but talker output is disabled")
211
+ logits = self.talker_head(processed)
212
+ elif self.training and self.output_router is not None:
213
+ route = self.output_router(processed, training=True)
214
+ route_weights, route_logits = route
215
+ logits = self.byte_head(processed)
216
+ elif self.output_router is not None:
217
+ route = self.output_router(processed, training=False)
218
+ if isinstance(route, torch.Tensor) and route.numel() > 0:
219
+ use_video = (route == 2).any() and self.video_head is not None
220
+ use_talk = (route == 3).any() and self.talker_head is not None
221
+ logits = self.video_head(processed) if use_video else \
222
+ self.talker_head(processed) if use_talk else \
223
+ self.byte_head(processed)
224
+ else:
225
+ logits = self.byte_head(processed)
226
+ else:
227
+ logits = self.byte_head(processed)
228
+
229
+ T_text = relational.shape[1]
230
+ if logits.dim() == 3 and logits.shape[-1] == VOCAB:
231
+ logits = logits[:, :T_text, :]
232
+ with torch.no_grad():
233
+ self._append_predictions_to_kv(logits.argmax(dim=-1), composite_ids=composite_ids)
234
+ losses = None
235
+ if targets is not None:
236
+ next_byte_logits = logits[:, :-1, :].contiguous()
237
+ lm_loss = F.cross_entropy(
238
+ next_byte_logits.view(-1, VOCAB),
239
+ targets.contiguous().view(-1),
240
+ ignore_index=SPECIAL_VOCAB["PAD"]
241
+ )
242
+ vq_component = commitment_warmup_weight * vq_loss if self.vq_enabled else None
243
+ losses = LossComponents(
244
+ lm=lm_loss,
245
+ vq_commitment=vq_component,
246
+ graph_l1=None,
247
+ moegraph_ponder=moegraph_ponder_loss,
248
+ memgram_decay_reg=memgram_decay_reg if self.memgram_enabled else None,
249
+ composite_vq=composite_vq_loss if self.composite_head is not None and composite_ids is not None else None,
250
+ weights=loss_weights if loss_weights is not None else LossWeights(),
251
+ )
252
+
253
+ return logits, losses, all_indices, None
254
+
255
+ @torch.no_grad()
256
+ def _append_predictions_to_kv(self, pred_ids, composite_ids=None):
257
+ if self.kv_ledger is None or self.kq_cache is None:
258
+ return
259
+ for b in range(pred_ids.shape[0]):
260
+ for t in range(pred_ids.shape[1]):
261
+ token_id = int(pred_ids[b, t])
262
+ self.kv_ledger.append(token_id)
263
+ self.kq_cache.append(token_id)
264
+ if composite_ids is None:
265
+ continue
266
+ composite_offset = self.bridge.total_codebook_size if self.vq_enabled and self.bridge is not None else 0
267
+ for k in range(composite_ids.shape[1]):
268
+ cid = int(composite_ids[b, k])
269
+ if cid >= 0:
270
+ self.kv_ledger.append(composite_offset + cid)
271
+
272
+ def _ternary_update_memory(self, accum_threshold=8, update_scales=True,
273
+ loss_components=None, loss_signal=None):
274
+ signal = loss_components.total if loss_components is not None else loss_signal
275
+ t_step = self._ternary_t_step(signal)
276
+ if signal is not None and not torch.isfinite(signal.detach()).all():
277
+ warnings.warn("Non-finite loss detected β€” skipping ternary state update",
278
+ RuntimeWarning, stacklevel=2)
279
+ self._clear_ternary_hooks()
280
+ self.zero_grad(set_to_none=True)
281
+ return
282
+
283
+ if loss_components is not None:
284
+ self._componentwise_ternary_backward(loss_components, t_step, update_scales, accum_threshold)
285
+ else:
286
+ self._apply_regular_ternary_hooks(accum_threshold, update_scales, t_step, loss_signal)
287
+ self._clear_ternary_hooks()
288
+ self._clear_backward_update_flags()
289
+
290
+ def prepare_ternary_backward(self, loss_signal=None, update_scales=True):
291
+ """Configure streaming CUDA ternary updates before `loss.backward()`.
292
+
293
+ BigInt-scaled dense linear backward accumulates directly into int64
294
+ `corr_accum`, while legacy sparse tables still use int8 `T_accum`.
295
+ Calling this before backward lets the streaming path use the same
296
+ loss-scaled step that `_ternary_update_memory()` will finalize.
297
+ """
298
+ t_step = self._ternary_t_step(loss_signal)
299
+ for module in self.modules():
300
+ if hasattr(module, "T_accum") or hasattr(module, "corr_accum"):
301
+ module._backward_t_accum_step = t_step
302
+ module._backward_update_scales = bool(update_scales)
303
+ module._stream_backward_updates = True
304
+
305
+ def _clear_backward_update_flags(self):
306
+ for module in self.modules():
307
+ for attr in (
308
+ "_backward_t_accum_step",
309
+ "_backward_update_scales",
310
+ "_stream_backward_updates",
311
+ "_streamed_ternary_backward",
312
+ "_streamed_bigint_backward",
313
+ ):
314
+ if hasattr(module, attr):
315
+ delattr(module, attr)
316
+
317
+ @staticmethod
318
+ def _ternary_t_step(loss_signal):
319
+ return 1
320
+
321
+ def _clear_ternary_hooks(self):
322
+ base_names = [
323
+ "_hook_grad_T_sign", "_hook_grad_2d", "_hook_x_2d", "_hook_T",
324
+ "_hook_sparse_indices", "_hook_sparse_grad_sign", "_hook_sparse_T",
325
+ ]
326
+ for module in self.modules():
327
+ if hasattr(module, "_T_accum_fp"):
328
+ delattr(module, "_T_accum_fp")
329
+ for hook_name in base_names:
330
+ if hasattr(module, hook_name):
331
+ delattr(module, hook_name)
332
+ for hook_name in list(vars(module).keys()):
333
+ if hook_name.startswith((
334
+ "_hook_grad_T_sign_", "_hook_grad_2d_", "_hook_x_2d_", "_hook_T_",
335
+ "_hook_sparse_indices_", "_hook_sparse_grad_sign_", "_hook_sparse_T_",
336
+ )):
337
+ delattr(module, hook_name)
338
+
339
+ def _componentwise_ternary_backward(self, loss_components, t_step, update_scales, accum_threshold):
340
+ from arbitor.kernel.ternary_scale import _COMPONENT_CONTEXT
341
+
342
+ self.prepare_ternary_backward(loss_components.total, update_scales=update_scales)
343
+ active = [(n, t, w) for n, t, w in loss_components.active_fields
344
+ if t is not None and t.dim() == 0 and t.requires_grad and float(w) != 0.0]
345
+ for idx, (name, comp_tensor, weight) in enumerate(active):
346
+ retain = idx < len(active) - 1
347
+ _COMPONENT_CONTEXT.set(name, weight)
348
+ try:
349
+ comp_tensor.backward(retain_graph=retain)
350
+ finally:
351
+ _COMPONENT_CONTEXT.clear()
352
+ self._consume_component_hooks(name, weight, t_step, update_scales, accum_threshold)
353
+
354
+ with torch.no_grad():
355
+ for module in self.modules():
356
+ if self._is_large_sparse_embedding(module):
357
+ continue
358
+ if update_scales:
359
+ self._step_E_from_accum(module)
360
+ self._apply_accumulated_flips(module, accum_threshold=accum_threshold)
361
+
362
+ def _consume_component_hooks(self, name, weight, t_step, update_scales, accum_threshold):
363
+ for module in self.modules():
364
+ sparse_idx_key = f"_hook_sparse_indices_{name}"
365
+ sparse_grad_key = f"_hook_sparse_grad_sign_{name}"
366
+ sparse_t_key = f"_hook_sparse_T_{name}"
367
+ if hasattr(module, sparse_idx_key) and hasattr(module, sparse_grad_key):
368
+ setattr(module, "_hook_sparse_indices", getattr(module, sparse_idx_key))
369
+ setattr(module, "_hook_sparse_grad_sign", getattr(module, sparse_grad_key))
370
+ if hasattr(module, sparse_t_key):
371
+ setattr(module, "_hook_sparse_T", getattr(module, sparse_t_key))
372
+ if update_scales and hasattr(module, "update_E"):
373
+ module._e_accum_threshold = 8
374
+ module.update_E()
375
+ if hasattr(module, "T_accum"):
376
+ module._t_accum_step = max(1, int(round(abs(float(weight)) * t_step)))
377
+ if hasattr(module, "ternary_step"):
378
+ module.ternary_step(accum_threshold=accum_threshold)
379
+ for key in (sparse_idx_key, sparse_grad_key, sparse_t_key):
380
+ if hasattr(module, key):
381
+ delattr(module, key)
382
+ continue
383
+
384
+ dense_key = f"_hook_grad_T_sign_{name}"
385
+ dense_t_key = f"_hook_T_{name}"
386
+ if hasattr(module, dense_key):
387
+ grad_sign = getattr(module, dense_key)
388
+ hook_t = getattr(module, dense_t_key, None)
389
+ self._accumulate_component_grad_continuous(
390
+ module, grad_sign, weight, t_step,
391
+ )
392
+ delattr(module, dense_key)
393
+ if hasattr(module, dense_t_key):
394
+ delattr(module, dense_t_key)
395
+
396
+ grad_key = f"_hook_grad_2d_{name}"
397
+ x_key = f"_hook_x_2d_{name}"
398
+ if not hasattr(module, grad_key) or not hasattr(module, x_key):
399
+ continue
400
+ comp_grad = getattr(module, grad_key)
401
+ comp_x = getattr(module, x_key)
402
+ if torch.isfinite(comp_grad).all() and torch.isfinite(comp_x).all():
403
+ raw_grad = torch.clamp(comp_grad.transpose(0, 1) @ comp_x, -10.0, 10.0)
404
+ self._accumulate_component_grad_continuous(
405
+ module, raw_grad, weight, t_step,
406
+ )
407
+ delattr(module, grad_key)
408
+ delattr(module, x_key)
409
+
410
+ def _accumulate_component_grad_continuous(self, module, raw_grad, weight, t_step):
411
+ """Component loss accumulation without persistent float optimizer state."""
412
+ if not hasattr(module, "_T_shape"):
413
+ return
414
+ shape = tuple(int(x) for x in module._T_shape.tolist())
415
+ if tuple(raw_grad.shape) != shape:
416
+ return
417
+ with torch.no_grad():
418
+ step = max(1, int(round(abs(float(weight)) * t_step)))
419
+ if float(weight) < 0:
420
+ step = -step
421
+ if hasattr(module, "corr_accum") and hasattr(module, "_accumulate_corr_from_grad_sign"):
422
+ signed = raw_grad.sign().to(device=module.corr_accum.device, dtype=torch.int8)
423
+ module._accumulate_corr_from_grad_sign(signed, corr_step=step)
424
+ return
425
+ if not hasattr(module, "T_accum") or tuple(module.T_accum.shape) != shape:
426
+ return
427
+ if hasattr(module, "_T_accum_fp"):
428
+ delattr(module, "_T_accum_fp")
429
+ signed = raw_grad.sign().to(device=module.T_accum.device, dtype=torch.int8)
430
+ module.T_accum.copy_(
431
+ torch.clamp(
432
+ module.T_accum.to(torch.int16) - signed.to(torch.int16) * step,
433
+ -127,
434
+ 127,
435
+ ).to(torch.int8)
436
+ )
437
+
438
+ def _apply_regular_ternary_hooks(self, accum_threshold, update_scales, t_step, loss_signal):
439
+ for module in self.modules():
440
+ is_bigint = hasattr(module, "corr_accum") and hasattr(module, "_accumulate_corr_from_grad_sign")
441
+ is_legacy = hasattr(module, "T_accum") or hasattr(module, "E_accum")
442
+ if is_bigint or is_legacy:
443
+ self._prepare_per_group_threshold(module)
444
+ streamed = bool(getattr(module, "_streamed_ternary_backward", False))
445
+ has_hook = (
446
+ hasattr(module, "_hook_grad_T_sign")
447
+ or (hasattr(module, "_hook_grad_2d") and hasattr(module, "_hook_x_2d"))
448
+ or (hasattr(module, "_hook_sparse_indices") and hasattr(module, "_hook_sparse_grad_sign"))
449
+ )
450
+ bigint_streamed = bool(getattr(module, "_streamed_bigint_backward", False))
451
+ if (streamed or bigint_streamed) and not has_hook:
452
+ if streamed and update_scales:
453
+ self._step_E_from_accum(module)
454
+ if streamed:
455
+ had_flip = self._apply_accumulated_flips(module, accum_threshold=accum_threshold)
456
+ self._record_flip_health(module, had_flip)
457
+ if hasattr(module, "per_group_threshold"):
458
+ del module.per_group_threshold
459
+ continue
460
+ if has_hook:
461
+ if hasattr(module, "_hook_grad_T_sign") and hasattr(module, "_accumulate_corr_from_grad_sign"):
462
+ module._accumulate_corr_from_grad_sign(module._hook_grad_T_sign)
463
+ del module._hook_grad_T_sign
464
+ if hasattr(module, "ternary_step"):
465
+ module.ternary_step(accum_threshold=accum_threshold)
466
+ if hasattr(module, "per_group_threshold"):
467
+ del module.per_group_threshold
468
+
469
+ def _prepare_per_group_threshold(self, module):
470
+ if self._is_large_sparse_embedding(module):
471
+ module.per_group_threshold = None
472
+ return
473
+ if hasattr(module, "corr_accum") and not hasattr(module, "T_accum"):
474
+ module.per_group_threshold = None
475
+ return
476
+ if not hasattr(module, "E") or not hasattr(module, "_T_shape"):
477
+ module.per_group_threshold = None
478
+ return
479
+ shape = tuple(int(x) for x in module._T_shape.tolist())
480
+ out_dim, in_dim = shape
481
+ gpr = _ceil_div(in_dim, module.group_size)
482
+ E_view = module.E.view(out_dim, gpr).float()
483
+ threshold_g = 8.0 + 0.25 * torch.min(E_view.abs(), torch.tensor(32.0, device=E_view.device))
484
+ module.per_group_threshold = torch.clamp(threshold_g, max=16.0).to(torch.int8).reshape(-1)
485
+
486
+ @staticmethod
487
+ def _is_large_sparse_embedding(module):
488
+ return (
489
+ hasattr(module, "num_embeddings")
490
+ and hasattr(module, "sparse_threshold")
491
+ and module.num_embeddings >= module.sparse_threshold
492
+ )
493
+
494
+ @staticmethod
495
+ def _step_E_from_accum(module):
496
+ if hasattr(module, "corr_accum"):
497
+ return # BigInt modules don't use E_accum threshold flips
498
+ if not hasattr(module, "E") or not hasattr(module, "E_accum"):
499
+ return
500
+ threshold = int(getattr(module, "_e_accum_threshold", 8))
501
+ accum = module.E_accum.to(torch.int16)
502
+ step = torch.where(
503
+ accum >= threshold,
504
+ torch.ones_like(accum, dtype=torch.int16),
505
+ torch.where(accum <= -threshold, torch.full_like(accum, -1, dtype=torch.int16), torch.zeros_like(accum, dtype=torch.int16)),
506
+ )
507
+ if step.any():
508
+ module.E = torch.clamp(module.E.to(torch.int16) + step, -128, 127).to(torch.int8)
509
+ module.E_accum = (accum - step * threshold).to(torch.int8)
510
+
511
+ @staticmethod
512
+ def _apply_accumulated_flips(module, accum_threshold=3):
513
+ """Packed-byte carry: when T_accum crosses Β±1, move trit by Β±1 via Β±3^pos."""
514
+ if not hasattr(module, "T_accum") or not hasattr(module, "T_packed") or not hasattr(module, "_T_shape"):
515
+ return False
516
+ shape = tuple(int(x) for x in module._T_shape.tolist())
517
+ if tuple(module.T_accum.shape) != shape:
518
+ return False
519
+ carry_up = module.T_accum > 1
520
+ carry_down = module.T_accum < -1
521
+ if not carry_up.any() and not carry_down.any():
522
+ return False
523
+ dev = module.T_packed.device
524
+ out_dim, in_dim = shape
525
+ pows = torch.tensor([1, 3, 9, 27, 81], device=dev, dtype=torch.int16)
526
+ pk = module.T_packed.to(torch.int16).clone()
527
+ for p in range(5):
528
+ if p >= in_dim:
529
+ continue
530
+ cols = torch.arange(p, in_dim, 5, device=dev)
531
+ if cols.numel() == 0:
532
+ continue
533
+ is_up = carry_up[:, cols]
534
+ is_dn = carry_down[:, cols]
535
+ if not is_up.any() and not is_dn.any():
536
+ continue
537
+ rows_2d = torch.arange(out_dim, device=dev)[:, None]
538
+ lin_idx = rows_2d * in_dim + cols[None, :]
539
+ byte_idx = lin_idx // 5
540
+ pv = pk[byte_idx]
541
+ p_up = (pv + pows[p]).clamp(0, 242)
542
+ p_dn = (pv - pows[p]).clamp(0, 242)
543
+ pk[byte_idx] = torch.where(is_up, p_up, torch.where(is_dn, p_dn, pv))
544
+ module.T_packed = pk.to(torch.uint8)
545
+ # Reset T_accum to 0 on carry so W = T_accum Γ— T doesn't jump
546
+ mask = carry_up | carry_down
547
+ module.T_accum[mask] = torch.zeros_like(module.T_accum[mask])
548
+ return True
549
+
550
+ @staticmethod
551
+ def _record_flip_health(module, had_flip):
552
+ if not hasattr(module, "T_accum"):
553
+ return
554
+ steps_since = getattr(module, "_steps_since_flip", 0)
555
+ module._steps_since_flip = 0 if had_flip else steps_since + 1
556
+ module._had_flip = False
557
+
558
+ def generate(self, idx, max_new_token, temperature=1.0, images=None, audio=None,
559
+ conversation_id=None, top_k=None, min_new_tokens=0, return_metadata=False):
560
+ if self.kv_ledger is not None and self.kv_ledger.size == 0:
561
+ with torch.no_grad():
562
+ for token_id in idx.reshape(-1).tolist():
563
+ self.kv_ledger.append(int(token_id))
564
+ self.kq_cache.append(int(token_id))
565
+ for i in range(max_new_token):
566
+ idx_cond = idx[:, -CTX:]
567
+ logits, _, _, _ = self(idx_cond, images=images, audio=audio, timestep=i, output_mode="text")
568
+ last_logits = logits[:, -1, :] / temperature
569
+ # top-k filtering
570
+ if top_k is not None and top_k > 0:
571
+ v, _ = torch.topk(last_logits, min(top_k, last_logits.size(-1)))
572
+ kth = v[:, -1].unsqueeze(-1).expand_as(last_logits)
573
+ last_logits = last_logits.where(last_logits >= kth, float('-inf'))
574
+ probs = F.softmax(last_logits, dim=-1)
575
+ idx_next = torch.multinomial(probs, num_samples=1)
576
+ idx = torch.cat([idx, idx_next], dim=1)
577
+ # Enforce min_new_tokens (only relevant if caller truncates after generation)
578
+ generated = idx.shape[1] - (min_new_tokens if return_metadata else 0)
579
+ if return_metadata:
580
+ return {
581
+ "tokens": idx,
582
+ "n_generated": generated,
583
+ "temperature": temperature,
584
+ }
585
+ return idx
arbitor/optim/__init__.py ADDED
File without changes
arbitor/optim/sign_sgd.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.optim import Optimizer
3
+
4
+
5
+ class SignSGD(Optimizer):
6
+ def __init__(self, params, lr=1e-2, weight_decay=0.0):
7
+ defaults = dict(lr=lr, weight_decay=weight_decay)
8
+ super().__init__(params, defaults)
9
+
10
+ @torch.no_grad()
11
+ def step(self, closure=None):
12
+ loss = None
13
+ if closure is not None:
14
+ with torch.enable_grad():
15
+ loss = closure()
16
+
17
+ for group in self.param_groups:
18
+ lr = group["lr"]
19
+ wd = group["weight_decay"]
20
+
21
+ for p in group["params"]:
22
+ if p.grad is None:
23
+ continue
24
+
25
+ grad = p.grad
26
+ if grad.is_sparse:
27
+ grad = grad.to_dense()
28
+
29
+ update = grad.sign()
30
+
31
+ if wd > 0:
32
+ update = update + wd * p.sign()
33
+
34
+ p.add_(-lr * update)
35
+
36
+ return loss
37
+
38
+ @torch.no_grad()
39
+ def get_memory_mb(self, params=None) -> float:
40
+ if params is None:
41
+ params = []
42
+ for group in self.param_groups:
43
+ params.extend(group["params"])
44
+ total_bytes = sum(p.numel() * p.element_size() for p in params)
45
+ return total_bytes / (1024 * 1024)
arbitor/profiling.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Profiling utilities: torch.profiler wrapper and analysis tools.
3
+
4
+ Following D-103: profile first, optimize only hot paths.
5
+ Uses torch.profiler to identify training loop bottlenecks.
6
+ """
7
+
8
+ import sys
9
+ import os
10
+ import json
11
+ import math
12
+ import torch
13
+
14
+ sys.path.insert(0, os.path.dirname(__file__))
15
+
16
+ from .main import ARBModel
17
+ from .config import VOCAB, CTX
18
+
19
+
20
+ def profile_training(model, train_data, device, n_steps=20, warmup_steps=5,
21
+ top_k=10, batch_size=64, ctx=CTX):
22
+ """
23
+ Profile N training steps using torch.profiler.
24
+
25
+ Runs profiling with CUDA + CPU activity tracing, warmup steps (no profiling),
26
+ then profiled steps. Returns list of top-K hot path tuples and saves JSON.
27
+
28
+ Args:
29
+ model: ARBModel instance
30
+ train_data: 1D byte tensor of training data
31
+ device: 'cuda' or 'cpu'
32
+ n_steps: Number of profiled training steps
33
+ warmup_steps: Steps before profiling begins (no tracing)
34
+ top_k: Number of top operations to return
35
+ batch_size: Batch size for each training step
36
+ ctx: Context window length
37
+
38
+ Returns:
39
+ List of dicts with keys: op_name, cuda_time_us, cpu_time_us, calls
40
+ """
41
+ model.train()
42
+ prof = None
43
+
44
+ if device == "cuda":
45
+ prof = torch.profiler.profile(
46
+ activities=[
47
+ torch.profiler.ProfilerActivity.CPU,
48
+ torch.profiler.ProfilerActivity.CUDA,
49
+ ],
50
+ record_shapes=True,
51
+ with_stack=True,
52
+ with_flops=True,
53
+ )
54
+ else:
55
+ prof = torch.profiler.profile(
56
+ activities=[torch.profiler.ProfilerActivity.CPU],
57
+ record_shapes=True,
58
+ with_stack=False,
59
+ )
60
+
61
+ # Warmup steps (no profiling)
62
+ for _ in range(warmup_steps):
63
+ ix = torch.randint(0, len(train_data) - ctx - 1, (batch_size,))
64
+ x = torch.stack([train_data[j: j + ctx] for j in ix])
65
+ targets = x[:, 3:]
66
+ x = x.to(device)
67
+ targets = targets.to(device)
68
+ with torch.no_grad():
69
+ model(x, targets=targets)
70
+
71
+ # Profiled steps
72
+ prof.start()
73
+ for _ in range(n_steps):
74
+ ix = torch.randint(0, len(train_data) - ctx - 1, (batch_size,))
75
+ x = torch.stack([train_data[j: j + ctx] for j in ix])
76
+ targets = x[:, 3:]
77
+ x = x.to(device)
78
+ targets = targets.to(device)
79
+ with torch.no_grad():
80
+ model(x, targets=targets)
81
+ if device == "cuda":
82
+ torch.cuda.synchronize()
83
+ prof.stop()
84
+
85
+ # Process profiler output
86
+ if device == "cuda":
87
+ key_avg = prof.key_averages()
88
+ table = key_avg.table(sort_by="cuda_time_total", row_limit=top_k)
89
+ else:
90
+ key_avg = prof.key_averages()
91
+ table = key_avg.table(sort_by="cpu_time_total", row_limit=top_k)
92
+
93
+ # Extract top-K entries
94
+ events = key_avg.events() if hasattr(key_avg, 'events') else key_avg[:top_k]
95
+ top_results = []
96
+ for evt in events[:top_k]:
97
+ # device_time replaces deprecated cuda_time in recent PyTorch
98
+ cuda_t = (evt.device_time if hasattr(evt, 'device_time') and evt.device_time is not None
99
+ else evt.cuda_time if hasattr(evt, 'cuda_time') else 0)
100
+ entry = {
101
+ "op_name": evt.key if hasattr(evt, 'key') else str(evt),
102
+ "cuda_time_us": cuda_t,
103
+ "cpu_time_us": evt.cpu_time if hasattr(evt, 'cpu_time') else 0,
104
+ "calls": evt.count if hasattr(evt, 'count') else 1,
105
+ }
106
+ top_results.append(entry)
107
+
108
+ # Print summary
109
+ print("\n=== Profiling Results (Top-{} Hot Paths) ===".format(top_k))
110
+ print(table)
111
+ print("============================================\n")
112
+
113
+ # Save profiler output as JSON
114
+ prof.export_chrome_trace("/tmp/profiler_trace.json")
115
+
116
+ return top_results
117
+
118
+
119
+ def analyze_profiler_output(prof_path):
120
+ """
121
+ Load saved profiler JSON output and extract key insights.
122
+
123
+ Args:
124
+ prof_path: Path to saved profiler JSON file
125
+
126
+ Returns:
127
+ List of dicts with op_name, cuda_time_us, cpu_time_us, calls
128
+ """
129
+ with open(prof_path, "r") as f:
130
+ data = json.load(f)
131
+
132
+ # Profiler JSON can be a dict with 'traceEvents' or a flat list
133
+ if isinstance(data, dict) and "traceEvents" in data:
134
+ events = data["traceEvents"]
135
+ elif isinstance(data, list):
136
+ events = data
137
+ else:
138
+ events = []
139
+
140
+ # Aggregate events by name
141
+ op_stats = {}
142
+ for evt in events:
143
+ if isinstance(evt, dict):
144
+ name = evt.get("name", "unknown")
145
+ dur = evt.get("dur", 0) # microseconds
146
+ cat = evt.get("cat", "")
147
+ if name not in op_stats:
148
+ op_stats[name] = {"cuda_time_us": 0, "cpu_time_us": 0, "calls": 0}
149
+ if "gpu" in cat.lower():
150
+ op_stats[name]["cuda_time_us"] += dur
151
+ elif "cpu" in cat.lower() or cat == "":
152
+ op_stats[name]["cpu_time_us"] += dur
153
+ op_stats[name]["calls"] += 1
154
+
155
+ # Sort by CUDA time descending
156
+ sorted_ops = sorted(
157
+ op_stats.items(),
158
+ key=lambda x: x[1]["cuda_time_us"],
159
+ reverse=True,
160
+ )
161
+
162
+ results = []
163
+ for name, stats in sorted_ops:
164
+ results.append({
165
+ "op_name": name,
166
+ "cuda_time_us": stats["cuda_time_us"],
167
+ "cpu_time_us": stats["cpu_time_us"],
168
+ "calls": stats["calls"],
169
+ })
170
+
171
+ # Print formatted summary
172
+ print("\n=== Profiler Analysis ===")
173
+ print(f"{'Operation':<40} {'CUDA Time (us)':>15} {'CPU Time (us)':>15} {'Calls':>8}")
174
+ print("-" * 80)
175
+ for r in results[:20]:
176
+ print(f"{r['op_name']:<40} {r['cuda_time_us']:>15.0f} {r['cpu_time_us']:>15.0f} {r['calls']:>8}")
177
+
178
+ # Identify dominating patterns
179
+ total_cuda = sum(r["cuda_time_us"] for r in results)
180
+ if total_cuda > 0:
181
+ print("\n=== Hot Path Analysis ===")
182
+ for r in results[:5]:
183
+ pct = (r["cuda_time_us"] / total_cuda) * 100 if total_cuda > 0 else 0
184
+ label = ""
185
+ if "vq" in r["op_name"].lower() or "flash_vq" in r["op_name"].lower():
186
+ label = " β†’ VQ candidate for Triton kernel"
187
+ elif "moe" in r["op_name"].lower() or "scatter" in r["op_name"].lower():
188
+ label = " β†’ MoE dispatch candidate"
189
+ elif "embed" in r["op_name"].lower() or "gather" in r["op_name"].lower():
190
+ label = " β†’ Embedding gather (existing Triton kernel)"
191
+ elif "mm" in r["op_name"].lower() or "linear" in r["op_name"].lower():
192
+ label = " β†’ General matmul (torch.compile candidate)"
193
+ print(f" {r['op_name']:<40} {pct:>5.1f}%{label}")
194
+
195
+ print("============================================\n")
196
+ return results
arbitor/sequencers.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Sequencer modules β€” input processing for all modalities."""
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+ from .kernel.ternary_scale import TernaryScaleTensor, TScaleType, TernaryRMSNorm, GROUP_SIZES, _HAS_TRITON, _HAS_TILELANG
7
+ if _HAS_TRITON:
8
+ import triton
9
+ import triton.language as tl
10
+ else:
11
+ triton = None
12
+ tl = None
13
+ try:
14
+ from .kernel.ternary_scale import _TritonTernaryEmbedFn
15
+ except ImportError:
16
+ _TritonTernaryEmbedFn = None
17
+ from .converters.convert_to_ternary8 import pack_ternary, unpack_ternary
18
+ from math import ceil as _ceil
19
+
20
+ _ceil_div = lambda a, b: _ceil(a / b) if b > 0 else 0
21
+ from .config import VOCAB, EMBEDDING_DIM, HIDDEN_DIM, AUDIO_SR, AUDIO_FRAME_RATE
22
+
23
+
24
+ class ByteEmbedding(nn.Module):
25
+ """Byte-level embedding via packed ternary + BigInt correlation.
26
+
27
+ All training state is integer. T_accum/E_accum replaced by
28
+ corr_accum (int64 per group, never clips or resets).
29
+
30
+ S = 2^(E + K Γ— mean_corr) where mean_corr = corr_accum / (step Γ— gs)
31
+ """
32
+ def __init__(self, tscale_type=TScaleType.T32):
33
+ super().__init__()
34
+ self.tscale_type = tscale_type
35
+ self.threshold = 0.05
36
+ self.group_size = GROUP_SIZES.get(tscale_type, GROUP_SIZES[TScaleType.T64])
37
+ shape = (VOCAB, EMBEDDING_DIM)
38
+
39
+ init_std = 0.02
40
+ init_threshold = min(self.threshold, 0.5 * init_std)
41
+ self.threshold = init_threshold
42
+ w_init = torch.randn(VOCAB, EMBEDDING_DIM) * init_std
43
+ T_init = w_init.sign() * (w_init.abs() > init_threshold).to(w_init.dtype)
44
+ packed_T, T_shape, T_pad = pack_ternary(T_init)
45
+
46
+ self.register_buffer("T_packed", packed_T)
47
+ self.register_buffer("_T_shape", torch.tensor([VOCAB, EMBEDDING_DIM], dtype=torch.long))
48
+ self.register_buffer("_T_pad", torch.tensor(T_pad, dtype=torch.long))
49
+
50
+ out_dim, in_dim = shape
51
+ gpr = _ceil_div(in_dim, self.group_size)
52
+ total_in = gpr * self.group_size
53
+ padded = torch.zeros(out_dim, total_in)
54
+ abs_w = w_init.abs()
55
+ padded[:, :in_dim] = abs_w
56
+ grouped = padded.view(out_dim, gpr, self.group_size)
57
+ grp_means = grouped.mean(dim=2)
58
+ E_vals = torch.where(grp_means > 0, grp_means, torch.ones_like(grp_means))
59
+ self.register_buffer("E", E_vals.flatten().log2().clamp(-128, 127).to(torch.int8))
60
+
61
+ # BigInt correlation accumulator (replaces T_accum + E_accum)
62
+ n_grp = out_dim * gpr
63
+ self.register_buffer("corr_accum", torch.zeros(n_grp, dtype=torch.int64))
64
+ self.register_buffer("step_counter", torch.zeros(1, dtype=torch.int64))
65
+
66
+ self.norm = TernaryRMSNorm(EMBEDDING_DIM, tscale_type=tscale_type)
67
+
68
+ def _get_T(self):
69
+ return unpack_ternary(self.T_packed, tuple(self._T_shape.tolist()), int(self._T_pad.item()))
70
+
71
+ def _get_S(self):
72
+ gpr = _ceil_div(EMBEDDING_DIM, self.group_size)
73
+ e_adj = self.E.float()
74
+ step = int(self.step_counter.item())
75
+ if step > 0:
76
+ from .kernel.ternary_scale import _bigint_corr_strength
77
+ denom = max(step * self.group_size, 1)
78
+ e_adj = e_adj + (self.corr_accum.float() / denom) * _bigint_corr_strength()
79
+ E_exp = e_adj.view(VOCAB, gpr).repeat_interleave(self.group_size, dim=1)
80
+ if E_exp.shape[1] > EMBEDDING_DIM:
81
+ E_exp = E_exp[:, :EMBEDDING_DIM]
82
+ return torch.exp2(E_exp)
83
+
84
+ @torch.no_grad()
85
+ def _accumulate_corr_from_grad_sign(self, grad_sign, corr_step=1):
86
+ if grad_sign is None:
87
+ return
88
+ shape = tuple(self._T_shape.tolist())
89
+ out_dim, in_dim = shape
90
+ if tuple(grad_sign.shape) != shape:
91
+ return
92
+ gs = self.group_size
93
+ T = self._get_T().to(device=grad_sign.device, dtype=torch.int16)
94
+ signed = grad_sign.to(torch.int16) * T
95
+ gpr = _ceil_div(in_dim, gs)
96
+ total_in = gpr * gs
97
+ if total_in > in_dim:
98
+ signed = F.pad(signed, (0, total_in - in_dim))
99
+ score = signed.view(out_dim, gpr, gs).sum(dim=2, dtype=torch.int16)
100
+ self.corr_accum -= score.flatten().to(dtype=torch.int64) * int(corr_step)
101
+ self.step_counter += abs(int(corr_step))
102
+
103
+ def forward(self, x):
104
+ if x.is_cuda and _HAS_TRITON and _TritonTernaryEmbedFn is not None:
105
+ _dummy = torch.zeros(1, device=x.device, requires_grad=True)
106
+ emb = _TritonTernaryEmbedFn.apply(x, _dummy, self)
107
+ return self.norm(emb)
108
+ T = self._get_T()
109
+ S = self._get_S()
110
+ w_eff = S * T.float()
111
+ w_eff_grad = w_eff.detach().requires_grad_(True)
112
+
113
+ def capture_w_grad(grad_w):
114
+ self._hook_grad_T_sign = grad_w.sign().to(torch.int8)
115
+
116
+ w_eff_grad.register_hook(capture_w_grad)
117
+ out = self.norm(F.embedding(x, w_eff_grad))
118
+ return out
119
+
120
+ def ternary_step(self, accum_threshold=3):
121
+ if hasattr(self, "_hook_grad_T_sign"):
122
+ if hasattr(self, "_accumulate_corr_from_grad_sign"):
123
+ self._accumulate_corr_from_grad_sign(self._hook_grad_T_sign)
124
+ del self._hook_grad_T_sign
125
+
126
+ def update_E(self, loss_signal=None):
127
+ pass # E is fixed; S adjusted via corr_accum
128
+
129
+
130
+ class Sequencer(nn.Module):
131
+ def __init__(self, modality, window_size, tscale_type=TScaleType.T32):
132
+ super().__init__()
133
+ self.modality = modality
134
+ self.window_size = window_size
135
+ self.tscale_type = tscale_type
136
+
137
+ def forward(self, x):
138
+ raise NotImplementedError
139
+
140
+
141
+ class TextSequencer(Sequencer):
142
+ def __init__(self, tscale_type=TScaleType.T32):
143
+ super().__init__(modality='text', window_size=3, tscale_type=tscale_type)
144
+ self.projection = TernaryScaleTensor(EMBEDDING_DIM * self.window_size, HIDDEN_DIM, tscale_type=tscale_type)
145
+ self.norm = TernaryRMSNorm(HIDDEN_DIM, tscale_type=tscale_type)
146
+
147
+ def forward(self, x):
148
+ trigrams = x.unfold(dimension=1, size=self.window_size, step=1)
149
+ trigrams = rearrange(trigrams, 'b t d w -> b t (d w)')
150
+ relational = self.projection(trigrams)
151
+ return self.norm(relational)
152
+ class VAE2DSequencer(Sequencer):
153
+ def __init__(self, tscale_type=TScaleType.T32, quantize=None, device="cpu"):
154
+ super().__init__(modality='image', window_size=1, tscale_type=tscale_type)
155
+ from .encoders.vae2d import load_vae2d as _load_vae2d
156
+ self.vae = _load_vae2d(device=device, quantize=quantize)
157
+ self.vae_device = torch.device(device)
158
+ self.project = TernaryScaleTensor(4, HIDDEN_DIM, tscale_type=tscale_type)
159
+ self.norm = TernaryRMSNorm(HIDDEN_DIM, tscale_type=tscale_type)
160
+
161
+ def forward(self, x):
162
+ if x.device != self.vae_device:
163
+ x = x.to(self.vae_device)
164
+ latent = self.vae(x)
165
+ tokens = rearrange(latent, 'b c h w -> b (h w) c')
166
+ out = self.project(tokens)
167
+ return self.norm(out)
168
+
169
+
170
+ class VAEAudioSequencer(Sequencer):
171
+ def __init__(self, tscale_type=TScaleType.T32, quantize=None, device="cpu"):
172
+ super().__init__(modality='audio', window_size=1, tscale_type=tscale_type)
173
+ from .encoders.vae2d import load_vae2d as _load_vae2d
174
+ from .encoders.mel_frontend import MelSpectrogram3Band as _Mel3Band
175
+ self.vae = _load_vae2d(device=device, quantize=quantize)
176
+ self.vae_device = torch.device(device)
177
+ self.mel = _Mel3Band(sample_rate=AUDIO_SR)
178
+ self.project = TernaryScaleTensor(4, HIDDEN_DIM, tscale_type=tscale_type)
179
+ self.norm = TernaryRMSNorm(HIDDEN_DIM, tscale_type=tscale_type)
180
+
181
+ def forward(self, waveform):
182
+ if waveform.dim() == 1:
183
+ waveform = waveform.unsqueeze(0)
184
+ elif waveform.dim() == 3:
185
+ if waveform.shape[1] == 1:
186
+ waveform = waveform.squeeze(1)
187
+ else:
188
+ waveform = waveform.mean(dim=1)
189
+ spec = self.mel(waveform)
190
+ if spec.device != self.vae_device:
191
+ spec = spec.to(self.vae_device)
192
+ latent = self.vae(spec)
193
+ tokens = rearrange(latent, 'b c h w -> b (h w) c')
194
+ out = self.project(tokens)
195
+ return self.norm(out)
196
+
197
+
198
+ class MultimodalSequencer(nn.Module):
199
+ def __init__(self, tscale_type=TScaleType.T32, enable_text=True, enable_image=True, enable_audio=True):
200
+ super().__init__()
201
+ self.text = TextSequencer(tscale_type=tscale_type) if enable_text else None
202
+ self.image = VAE2DSequencer(tscale_type=tscale_type) if enable_image else None
203
+ self.audio = VAEAudioSequencer(tscale_type=tscale_type) if enable_audio else None
204
+ self.enabled_modalities = []
205
+ if enable_text:
206
+ self.enabled_modalities.append('text')
207
+ if enable_image:
208
+ self.enabled_modalities.append('image')
209
+ if enable_audio:
210
+ self.enabled_modalities.append('audio')
211
+
212
+ def forward(self, modality_inputs):
213
+ outputs = {}
214
+ for mod in self.enabled_modalities:
215
+ seq = getattr(self, mod)
216
+ if mod in modality_inputs and modality_inputs[mod] is not None and seq is not None:
217
+ outputs[mod] = seq(modality_inputs[mod])
218
+ return outputs
arbitor/vq.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """VQ modules β€” vector quantization adapters."""
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from .kernel.ternary_scale import TernaryScaleTensor, TScaleType, TernaryRMSNorm
7
+ from .components import TernaryVQCodebook
8
+ from .config import EMBEDDING_DIM, HIDDEN_DIM, CODEBOOK_DIM, SHARED_VQ_SIZE, TIMESTAMP_MAX_PERIOD
9
+
10
+
11
+ class SharedVQ(nn.Module):
12
+ """Single shared VQ codebook for all modalities (10M entries).
13
+
14
+ Each modality projects to the shared CODEBOOK_DIM=64 space, then
15
+ quantizes independently through the shared codebook. Text uses
16
+ CODEBOOK_DIM directly.
17
+
18
+ IDs are globally unique: all modalities share the same range [0, 10M).
19
+ """
20
+ def __init__(self, codebook_size=SHARED_VQ_SIZE, codebook_dim=CODEBOOK_DIM,
21
+ tscale_type=TScaleType.T32, enable_image=True, enable_audio=True):
22
+ super().__init__()
23
+ codebook_size = SHARED_VQ_SIZE if codebook_size is None else codebook_size
24
+ self.codebook_size = codebook_size
25
+ self.codebook_dim = codebook_dim
26
+
27
+ # Per-modality input projections (their_dim β†’ CODEBOOK_DIM)
28
+ self.text_proj = TernaryScaleTensor(HIDDEN_DIM, codebook_dim, tscale_type=tscale_type)
29
+ if enable_image:
30
+ self.image_proj = TernaryScaleTensor(HIDDEN_DIM, codebook_dim, tscale_type=tscale_type)
31
+ if enable_audio:
32
+ self.audio_proj = TernaryScaleTensor(HIDDEN_DIM, codebook_dim, tscale_type=tscale_type)
33
+
34
+ # Shared VQ codebook
35
+ self.vq = TernaryVQCodebook(
36
+ codebook_size=codebook_size,
37
+ codebook_dim=codebook_dim,
38
+ commitment_weight=1.0,
39
+ tscale_type=tscale_type,
40
+ )
41
+ self.modalities = ['text']
42
+ if enable_image:
43
+ self.modalities.append('image')
44
+ if enable_audio:
45
+ self.modalities.append('audio')
46
+
47
+ @staticmethod
48
+ def _sinusoidal_timestamp(seq_len, dim, max_period=TIMESTAMP_MAX_PERIOD, device=None):
49
+ freqs = torch.exp(-torch.arange(0, dim, 2, device=device).float() * (math.log(max_period) / dim))
50
+ t = torch.arange(seq_len, device=device).float().unsqueeze(1)
51
+ pe = torch.zeros(seq_len, dim, device=device)
52
+ pe[:, 0::2] = torch.sin(t * freqs)
53
+ pe[:, 1::2] = torch.cos(t * freqs)
54
+ return pe
55
+
56
+ def forward(self, modality_inputs, timestep=0):
57
+ outputs = []
58
+ vq_losses = {}
59
+ indices_dict = {}
60
+ for mod in self.modalities:
61
+ if mod not in modality_inputs or modality_inputs[mod] is None:
62
+ continue
63
+ x = modality_inputs[mod]
64
+ proj = getattr(self, f'{mod}_proj')
65
+ x_proj = proj(x)
66
+ quantized, idx, loss = self.vq(x_proj)
67
+ outputs.append(quantized)
68
+ vq_losses[f'{mod}_vq'] = loss
69
+ indices_dict[mod] = idx
70
+
71
+ combined = torch.cat(outputs, dim=1) if outputs else modality_inputs.get('text', None)
72
+ if combined is not None and timestep > 0:
73
+ ts_enc = self._sinusoidal_timestamp(combined.shape[1], combined.shape[2], device=combined.device)
74
+ combined = combined + ts_enc.unsqueeze(0)
75
+ return combined, vq_losses, indices_dict
76
+
77
+ @property
78
+ def total_codebook_size(self):
79
+ return self.codebook_size
80
+
81
+ @torch.no_grad()
82
+ def get_codebook_utilization(self):
83
+ cluster_size = self.vq.cluster_size
84
+ return (cluster_size > 0).float().mean().item()
85
+
86
+ @torch.no_grad()
87
+ def get_dead_code_count(self):
88
+ cluster_size = self.vq.cluster_size
89
+ return (cluster_size < self.vq.threshold_ema_dead_code).sum().item()
docs/ARB-RENAME-NOTE.md ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: ARB System Rename
3
+ date: 2026-05-18
4
+ context: System renamed from MORPH to ARB (Any Relational Bit)
5
+ ---
6
+
7
+ # ARB System β€” Rename Note
8
+
9
+ ## New Name
10
+
11
+ The system has been renamed from **MORPH** to **ARB** (Any Relational Bit).
12
+
13
+ - **ARB** = Any Relational Bit β€” the core ternary architecture
14
+ - **ARBS** = ARB System β€” the full software system
15
+ - **ARBitor** = The Python package name (`arbitor/`)
16
+
17
+ ## Package Structure
18
+
19
+ All core system files now live under `arbitor/`:
20
+
21
+ ```
22
+ models/Trigram/
23
+ β”œβ”€β”€ arbitor/ # Core ARB system package
24
+ β”‚ β”œβ”€β”€ __init__.py # Public API exports
25
+ β”‚ β”œβ”€β”€ trigram.py # Core model (ARBModel replaces MORPHTernaryModel)
26
+ β”‚ β”œβ”€β”€ tscale.py # Ternary scale tensors
27
+ β”‚ β”œβ”€β”€ convert_to_ternary.py # 5-trit packing
28
+ β”‚ β”œβ”€β”€ convert_to_ternary*.py # Legacy converters
29
+ β”‚ β”œβ”€β”€ flash_vq.py # FlashVQ codebook
30
+ β”‚ β”œβ”€β”€ ternary_audit.py # Model state auditor
31
+ β”‚ β”œβ”€β”€ profiling.py # Profiling utilities
32
+ β”‚ β”œβ”€β”€ train.py # Training pipeline
33
+ β”‚ β”œβ”€β”€ optim/
34
+ β”‚ β”‚ └── sign_sgd.py # SignSGD optimizer
35
+ β”‚ └── encoders/ # Float sidecar encoders
36
+ β”‚ β”œβ”€β”€ __init__.py
37
+ β”‚ β”œβ”€β”€ audio_codec.py
38
+ β”‚ β”œβ”€β”€ audio_vq_encoder.py
39
+ β”‚ └── video_vae.py
40
+ β”œβ”€β”€ testing/ # Tests (import from arbitor)
41
+ β”œβ”€β”€ .planning/ # Planning docs (P0-P10 complete)
42
+ β”œβ”€β”€ TRUE-TERNARY-REFACTOR*.md # Architecture refactor notes
43
+ β”œβ”€β”€ BENCHMARK.md # Benchmark docs
44
+ └── benchmark_true_ternary.py # Benchmark scripts
45
+ ```
46
+
47
+ ## Import Changes
48
+
49
+ | Before | After |
50
+ |--------|-------|
51
+ | `from trigram import ARBModel` | `from arbitor.trigram import ARBModel` |
52
+ | `from tscale import TernaryScaleTensor` | `from arbitor.tscale import TernaryScaleTensor` |
53
+ | `from optim.sign_sgd import SignSGD` | `from arbitor.optim.sign_sgd import SignSGD` |
54
+ | `from encoders.video_vae import load_vae` | `from arbitor.encoders.video_vae import load_vae` |
55
+ | `from arbitor import ARBModel` | Shorthand via `arbitor/__init__.py` |
56
+ | `import trigram` | `from arbitor import trigram` |
57
+
58
+ ## Class Rename
59
+
60
+ | Before | After |
61
+ |--------|-------|
62
+ | `MORPHTernaryModel` | `ARBModel` |
docs/arbs-tts/README.md ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ARBS Ternary Training System (TTS)
2
+
3
+ ## E1TM Format β€” Exponent-1 Ternary Mantissa
4
+
5
+ E1TM encodes each weight group as **one int8 exponent shared across N ternary mantissas**.
6
+
7
+ ```
8
+ W_eff[i] = S Γ— T[i] where T[i] ∈ {-1, 0, +1}, S = 2^{E + Ξ”}
9
+
10
+ E = int8 logβ‚‚ scale (persistent, per group)
11
+ Ξ” = 4 Γ— corr_accum / (step Γ— gs) (from BigInt accumulator)
12
+ S = 2^{E+Ξ”} (float32, ephemeral β€” created per forward, discarded)
13
+ ```
14
+
15
+ ### Format variants
16
+
17
+ | Name | TScaleType | T per E | gs | E bpw | T bpw | Total bpw (inf) | Precision |
18
+ |---|---|---|---|---|---|---|---|
19
+ | E1TM4 | T4 | 4 | 4 | 2.000 | 1.58 | 3.58 | Highest |
20
+ | E1TM6 | T6 | 6 | 6 | 1.333 | 1.58 | 2.91 | |
21
+ | E1TM8 | T8 | 8 | 8 | 1.000 | 1.58 | 2.58 | |
22
+ | E1TM16 | T16 | 16 | 16 | 0.500 | 1.58 | 2.08 | |
23
+ | **E1TM32** | **T32** | **32** | **32** | **0.250** | **1.58** | **1.85** | **Default** |
24
+ | E1TM64 | T64 | 64 | 64 | 0.125 | 1.58 | 1.71 | |
25
+ | E1TM96 | T96 | 96 | 96 | 0.083 | 1.58 | 1.67 | Most packed |
26
+
27
+ Higher T number = more T per E = less storage = coarser per-weight magnitude.
28
+
29
+ ### Group sizes
30
+
31
+ The TScaleType name is the group size:
32
+
33
+ ```python
34
+ TScaleType.T4 β†’ gs = 4 β†’ E shared across 4 ternary mantissas
35
+ TScaleType.T32 β†’ gs = 32 β†’ E shared across 32 ternary mantissas
36
+ TScaleType.T96 β†’ gs = 96 β†’ E shared across 96 ternary mantissas
37
+ ```
38
+
39
+ ### Persistent training state (all integer)
40
+
41
+ | Buffer | Type | Size/weight | Role |
42
+ |---|---|---|---|
43
+ | T_packed | uint8 | 1.58 bpw | Base-3 packed ternary {-1,0,+1}, 5 trits/byte |
44
+ | E | int8 | 8/N bpw | Logβ‚‚ scale, one per N-weight group |
45
+ | corr_accum | int64 | 64/N bpw | BigInt accumulator for gradient sign votes |
46
+ | step_counter | int64 | 0 bpw | Total steps processed |
47
+
48
+ **No float32/16 anywhere in persistent state.** Float32 ephemeral `W_eff` is created per-forward and discarded after backward.
49
+
50
+ ### Why ternary over binary or int4
51
+
52
+ | Format | Values/weight | Packing efficiency | Null state |
53
+ |---|---|---|---|
54
+ | Binary | 2 | 1 bit/bw (100%) | No |
55
+ | Ternary | 3 | 1.58 bpw (logβ‚‚3 β‰ˆ 95%) | **Yes** (T=0 = null) |
56
+ | Int4 | 16 | 4 bpw (100%) | No |
57
+
58
+ Ternary's null state (T=0) provides structural sparsity β€” β‰ˆ38% of weights are zero, skipping matmul tiles. No other low-bit format has this property at equivalent bpw.
59
+
60
+ ### The BigInt difference
61
+
62
+ Unlike conventional quantization where E is static after conversion, ARBS TTS trains **through** E via a BigInt correlation accumulator:
63
+
64
+ ```
65
+ corr_accum[g] -= Ξ£ (grad_sign Γ— T) # int64, never clips or resets
66
+ Ξ” = 4 Γ— corr_accum / (step Γ— gs) # continuous adjustment from integer division
67
+ S = 2^{E + Ξ”} # effective scale (ephemeral float32)
68
+ ```
69
+
70
+ The division `corr_accum / (step Γ— gs)` is the **Big Number Calculator** operation β€” it converts the accumulated integer evidence into a continuous ratio with arbitrary precision. No threshold flips, no discrete steps, no information loss.
71
+
72
+ ### Training vs inference
73
+
74
+ | Phase | T_packed | E | corr_accum | step | S |
75
+ |---|---|---|---|---|---|
76
+ | Training | Read-only | Read-only | **Accumulates** | **Increments** | Computed from corr/step |
77
+ | Inference (Option A) | Frozen | Frozen | Frozen | Frozen | Burned into checkpoint |
78
+ | Inference (Option B) | Frozen | **Fused** | Discarded | Discarded | Static 2^{E_fused} |
79
+
80
+ **Option A** (export): keep corr_accum + step for continuous S.
81
+ **Option B** (fuse): `E_fused = round(E + 4 Γ— corr_accum / (step Γ— gs))` β€” discards corr_accum, drops to 2.6 bpw.
82
+
83
+ ### Relationship to IEEE float
84
+
85
+ ```
86
+ IEEE FP32: 1 sign + 8 exponent + 23 mantissa β†’ per value
87
+ E1TM32: 1 exponent (int8) + 32 ternary signs β†’ per group of 32
88
+ ```
89
+
90
+ In IEEE, the exponent and mantissa belong to the same value. In E1TM, the exponent is **shared** β€” the mantissa is split into N independent ternary signs. The corr_accum provides sub-exponent precision beyond the int8 E, making the effective scale continuous rather than constrained to the 256 discrete `2^E` values.
docs/benchmarks/BENCHMARK.md ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TrueTernary Benchmark
2
+
3
+ Results from `benchmark_true_ternary.py` β€” comparing pure ternary training against standard methods on MORPHTernaryModel.
4
+
5
+ ## Quick Start
6
+
7
+ ```bash
8
+ cd models/Trigram
9
+
10
+ # TrueTernary (strict, 0 float params, 14M ternary weights)
11
+ python benchmark_true_ternary.py --configs TrueTernary --steps 200 --batch 4 --ctx 33
12
+
13
+ # Adam baseline (full model, 102M float params)
14
+ python benchmark_true_ternary.py --configs Adam_FP32 --steps 200 --batch 4 --ctx 33
15
+
16
+ # Compare both
17
+ python benchmark_true_ternary.py --configs Adam_FP32,TrueTernary --steps 200 --batch 4 --ctx 33 --reuse-base
18
+
19
+ # Training script (strict ternary)
20
+ python train.py --max_steps 1000 --batch_size 4 --ctx 33 --strict_ternary
21
+ ```
22
+
23
+ ## Head-to-Head: TrueTernary vs Adam (200 steps, B=4, C=33)
24
+
25
+ | Metric | Adam_FP32 | TrueTernary |
26
+ |--------|-----------|-------------|
27
+ | **Trainable params** | 102,629,376 (float32) | **0** (pure ternary) |
28
+ | **Model weights** | 473.6 MB | **0.0 MB** |
29
+ | **Optimizer state** | 391.5 MB | **0.0 MB** |
30
+ | **Training state** | 473.6 + 391.5 = **865 MB** | **18.3 MB** (buffers only) |
31
+ | **Peak VRAM** | ~2,548 MB | **~232 MB** (includes CUDA context) |
32
+ | **Step time** | ~200 ms | **~131 ms** |
33
+ | **Final loss** | ~12.3 | **5.75** ↓ |
34
+ | **Min loss** | β€” | **4.49** |
35
+ | **Converges?** | Yes (to high loss) | **Yes (near optimal: ln(288)β‰ˆ5.66)** |
36
+
37
+ ### Key Takeaways
38
+
39
+ - **VRAM**: TrueTernary uses **~40Γ— less** persistent state (18 MB vs 865 MB)
40
+ - **Speed**: 1.5Γ— faster per step (131 ms vs 200 ms) β€” pure add/sub/skip, no float GEMM
41
+ - **Convergence**: TrueTernary reaches **5.75** (near theoretical minimum ln(288) β‰ˆ 5.66) β€” Adam stalls at **12.3**
42
+ - **No float params**: TrueTernary has 0 trainable float params, 0 float buffers
43
+
44
+ ## TrueTernary Training Dynamics (200 steps)
45
+
46
+ The loss curve follows a characteristic 3-phase pattern:
47
+
48
+ ```
49
+ Phase 1 (steps 0-15): Mass T flips from random init, loss spikes to ~90
50
+ Phase 2 (steps 15-80): Recovery and convergence, loss drops from ~15 to ~6
51
+ Phase 3 (steps 80-200): Stable convergence, loss hovers at 5.7-6.0
52
+ ```
53
+
54
+ **Convergence evidence:**
55
+
56
+ | Segment | Mean Loss | Min Loss | Trend |
57
+ |---------|-----------|----------|-------|
58
+ | Steps 0-50 | 13.4 | 4.49 | High variance (T flips) |
59
+ | Steps 50-100 | 8.7 | 6.03 | Monotonic decline |
60
+ | Steps 100-150 | 6.4 | 5.69 | Approaching optimum |
61
+ | Steps 150-200 | **5.82** | **5.64** | **Converged** |
62
+
63
+ The minimum loss of **4.49** is well below the uniform-distribution baseline (ln(288) β‰ˆ 5.66), indicating the model captures meaningful byte-level patterns.
64
+
65
+ ## Training State Breakdown (14M ternary weights)
66
+
67
+ | Component | Storage | Size | Role |
68
+ |-----------|---------|------|------|
69
+ | T_packed | 5-trit/byte uint8 | 2.67 MB | Packed {-1, 0, +1} weights |
70
+ | E | int8 per group | 1.12 MB | Logβ‚‚ scale exponent |
71
+ | E_accum | int8 per group | 1.12 MB | Residual E accumulator |
72
+ | T_accum | int8 per weight | 13.36 MB | Gradient sign accumulator |
73
+ | **Total** | | **18.27 MB** | |
74
+
75
+ All int8 or packed ternary β€” no IEEE float anywhere in weight state.
76
+
77
+ ## Scale Projection to 3B Parameters
78
+
79
+ | Component | 14M | 3B (projected) |
80
+ |-----------|-----|----------------|
81
+ | T_packed | 2.67 MB | **~572 MB** |
82
+ | E | 1.12 MB | **~240 MB** |
83
+ | E_accum | 1.12 MB | **~240 MB** |
84
+ | T_accum | 13.36 MB | **~2.86 GB** |
85
+ | **Total training** | **18.27 MB** | **~3.9 GB** |
86
+ | Inference (T+E only) | ~3.8 MB | **~812 MB** |
87
+
88
+ At 3B: **~3.9 GB** training VRAM fits on a single RTX 4060 (8 GB). Compare to BF16 Adam: **~18 GB** (requires server GPU).
89
+
90
+ ## Architecture Components
91
+
92
+ All internal trainable components are now ternary or integer buffers (REFACTOR6+):
93
+
94
+ - `TernaryScaleTensor` β€” packed ternary linear layers
95
+ - `TernaryEmbeddingTable` β€” packed ternary embedding lookup
96
+ - `TernaryLSTMCell` β€” LSTM with ternary projections
97
+ - `TernaryVQCodebook` β€” VQ with ternary embedding table
98
+ - Graph: Triton-backed edge aggregation + gather-add kernels (REFACTOR8)
99
+ - MoE: Triton-backed dense combine kernel (REFACTOR8)
100
+
101
+ The only remaining float parameters are imported frozen encoders (ViT, Whisper).
102
+
103
+ ## Running the Benchmark
104
+
105
+ ```bash
106
+ # Default: 200 steps, batch=4, ctx=33
107
+ python benchmark_true_ternary.py
108
+
109
+ # Custom config
110
+ python benchmark_true_ternary.py \
111
+ --configs TrueTernary \
112
+ --steps 500 \
113
+ --batch 8 \
114
+ --ctx 66 \
115
+ --update-backend gpu \
116
+ --scale-update-interval 1
117
+
118
+ # Compare with Adam
119
+ python benchmark_true_ternary.py \
120
+ --configs Adam_FP32,TrueTernary \
121
+ --steps 200 \
122
+ --batch 4 \
123
+ --ctx 33 \
124
+ --reuse-base
125
+
126
+ # Change T_accum threshold (higher = less frequent flips)
127
+ python benchmark_true_ternary.py \
128
+ --accum-threshold 5
129
+
130
+ # Full training pipeline
131
+ python train.py \
132
+ --max_steps 5000 \
133
+ --batch_size 8 \
134
+ --ctx 66 \
135
+ --strict_ternary \
136
+ --scale_update_interval 1 \
137
+ --run_name my_ternary_run
138
+ ```
139
+
140
+ ## Benchmark CLI Arguments
141
+
142
+ | Argument | Default | Description |
143
+ |----------|---------|-------------|
144
+ | `--configs` | `TrueTernary` | Comma-separated: `Adam_FP32`, `SignSGD_Old`, `TrueTernary` |
145
+ | `--steps` | 200 | Training steps |
146
+ | `--batch` | 4 | Batch size |
147
+ | `--ctx` | 33 | Context length |
148
+ | `--update-backend` | `gpu` | `gpu`, `gpu-signcache`, `dense-fallback`, `none` |
149
+ | `--scale-update-interval` | 1 | E update frequency (0 = disable) |
150
+ | `--accum-threshold` | 3 | T_accum flip threshold |
151
+ | `--print-every` | 50 | Logging frequency |