Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. Β See raw diff
- REVIEW.md +224 -0
- arbitor.egg-info/PKG-INFO +18 -0
- arbitor.egg-info/SOURCES.txt +104 -0
- arbitor.egg-info/dependency_links.txt +1 -0
- arbitor.egg-info/requires.txt +16 -0
- arbitor.egg-info/top_level.txt +6 -0
- arbitor/__init__.py +35 -0
- arbitor/attention/__init__.py +15 -0
- arbitor/attention/context_attention.py +109 -0
- arbitor/attention/frame_buffer.py +78 -0
- arbitor/attention/kq_cache.py +30 -0
- arbitor/attention/kv_ledger.py +57 -0
- arbitor/attention/mla.py +176 -0
- arbitor/attention/ring_buffer.py +49 -0
- arbitor/components.py +1218 -0
- arbitor/config.py +125 -0
- arbitor/converters/convert_to_ternary2.py +81 -0
- arbitor/converters/convert_to_ternary54.py +120 -0
- arbitor/converters/convert_to_ternary64.py +111 -0
- arbitor/converters/convert_to_ternary8.py +101 -0
- arbitor/decoders.py +231 -0
- arbitor/encoders/__init__.py +11 -0
- arbitor/encoders/audio.py +83 -0
- arbitor/encoders/mel_frontend.py +70 -0
- arbitor/encoders/models/__init__.py +86 -0
- arbitor/encoders/models/download.py +132 -0
- arbitor/encoders/models/opensora-vae/config.json +35 -0
- arbitor/encoders/models/opensora-vae/model.safetensors +3 -0
- arbitor/encoders/models/pig-vae/model.safetensors +3 -0
- arbitor/encoders/opensora_vae.py +145 -0
- arbitor/encoders/opensora_vae_modules/autoencoder_2d.py +339 -0
- arbitor/encoders/opensora_vae_modules/autoencoder_kl_causal_3d.py +638 -0
- arbitor/encoders/opensora_vae_modules/registry.py +41 -0
- arbitor/encoders/opensora_vae_modules/unet_causal_3d_blocks.py +476 -0
- arbitor/encoders/opensora_vae_modules/vae.py +340 -0
- arbitor/encoders/pig_vae.py +148 -0
- arbitor/encoders/vae2d.py +56 -0
- arbitor/kernel/flash_vq.py +510 -0
- arbitor/kernel/ternary_audit.py +192 -0
- arbitor/kernel/ternary_scale.py +1811 -0
- arbitor/kernel/triton_video.py +75 -0
- arbitor/main.py +585 -0
- arbitor/optim/__init__.py +0 -0
- arbitor/optim/sign_sgd.py +45 -0
- arbitor/profiling.py +196 -0
- arbitor/sequencers.py +218 -0
- arbitor/vq.py +89 -0
- docs/ARB-RENAME-NOTE.md +62 -0
- docs/arbs-tts/README.md +90 -0
- docs/benchmarks/BENCHMARK.md +151 -0
REVIEW.md
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ARBS Code Audit: Dead Imports, Dead Code, and Triton Kernel Analysis
|
| 2 |
+
|
| 3 |
+
**Reviewed:** 2026-05-20T00:00:00Z
|
| 4 |
+
**Depth:** standard
|
| 5 |
+
**Files Reviewed:** 10
|
| 6 |
+
|
| 7 |
+
## Summary
|
| 8 |
+
|
| 9 |
+
The ARBS codebase has **3 BLOCKER bugs** that will cause runtime crashes, **8 unused class/function definitions** (dead code), **7 dead Triton kernels in components.py** that should be moved to `arbitor/kernel/`, and **21+ unused imports** across files. Two missing function definitions (`_graph_gather_add`, `_moe_dense_combine`) exist in dead code paths but would crash if those paths were ever activated.
|
| 10 |
+
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
## BLOCKER Issues
|
| 14 |
+
|
| 15 |
+
### CR-01: `_TernaryLinearFn.forward` references undefined `x_2d` (NameError at runtime)
|
| 16 |
+
|
| 17 |
+
**File:** `arbitor/kernel/ternary_scale.py:206-208`
|
| 18 |
+
**Issue:** The TileLang `_TernaryLinearFn.forward()` method references `x_2d` on lines 206-208, but `x_2d` is never defined in the method's scope. This will cause a `NameError` at runtime if the TileLang code path is taken in `TernaryScaleTensor.forward` (line 1069). The Triton variant `_TritonTernaryLinearFn` (line 878) correctly defines `x_2d = x.reshape(-1, k_in).contiguous()` before use, so this was likely an omission when the TileLang function was written.
|
| 19 |
+
|
| 20 |
+
```python
|
| 21 |
+
# Line 206 β NameError: name 'x_2d' is not defined
|
| 22 |
+
M = x_2d.shape[0]
|
| 23 |
+
output = torch.empty(M, N, device=x.device, dtype=torch.float32)
|
| 24 |
+
fwd_kernel(x_2d.half(), T_packed, E, output)
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
**Fix:** Add `x_2d = x.reshape(-1, K).contiguous()` before line 206:
|
| 28 |
+
```python
|
| 29 |
+
with torch.no_grad():
|
| 30 |
+
N, K = shape
|
| 31 |
+
x_2d = x.reshape(-1, K).contiguous() # missing definition
|
| 32 |
+
M = x_2d.shape[0]
|
| 33 |
+
output = torch.empty(M, N, device=x.device, dtype=torch.float32)
|
| 34 |
+
fwd_kernel(x_2d.half(), T_packed, E, output)
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
---
|
| 38 |
+
|
| 39 |
+
### CR-02: `_check_tilelang_finite` called but never defined (NameError at runtime)
|
| 40 |
+
|
| 41 |
+
**File:** `arbitor/kernel/ternary_scale.py:1072`
|
| 42 |
+
**Issue:** `_check_tilelang_finite()` is called in `TernaryScaleTensor.forward()` but is never defined anywhere in the codebase. This will cause a `NameError` at runtime when the TileLang path is active and the kernel produces valid output (the check is specifically gated by `_HAS_TILELANG` being True).
|
| 43 |
+
|
| 44 |
+
**Fix:** Either define the function (if the check is intentional) or remove the call:
|
| 45 |
+
```python
|
| 46 |
+
# Replace line 1072 with a direct finiteness check or remove
|
| 47 |
+
if not torch.isfinite(y).all():
|
| 48 |
+
raise FloatingPointError("TileLang ternary kernel produced non-finite activations")
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
---
|
| 52 |
+
|
| 53 |
+
### CR-03: `self.modality_gate` used but never assigned (AttributeError at runtime)
|
| 54 |
+
|
| 55 |
+
**File:** `arbitor/main.py:129-130`
|
| 56 |
+
**Issue:** `ARBModel.forward()` references `self.modality_gate` but it is never assigned in `ARBModel.__init__()`. While `ModalityGate` is imported at line 19, it is never instantiated and stored as `self.modality_gate`. This will cause an `AttributeError` on any forward pass where `self.modality_gate is not None` is evaluated.
|
| 57 |
+
|
| 58 |
+
The code at lines 129-132:
|
| 59 |
+
```python
|
| 60 |
+
if self.modality_gate is not None:
|
| 61 |
+
gate_weights, active_count, hops = self.modality_gate(active_mods)
|
| 62 |
+
else:
|
| 63 |
+
gate_weights, active_count, hops = {}, len(active_mods), 1
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
**Fix:** Add `self.modality_gate = ModalityGate()` in `ARBModel.__init__()` (or assign `self.modality_gate = None` if the gate should be optional):
|
| 67 |
+
```python
|
| 68 |
+
# In ARBModel.__init__, after line 78:
|
| 69 |
+
self.modality_gate = ModalityGate()
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
---
|
| 73 |
+
|
| 74 |
+
## WARNING: Undefined Functions in Dead Code
|
| 75 |
+
|
| 76 |
+
### WR-01: `_graph_gather_add` called but never defined
|
| 77 |
+
|
| 78 |
+
**File:** `arbitor/components.py:739`
|
| 79 |
+
**Issue:** `TernaryGraph.forward()` calls `_graph_gather_add(vq_output, node_features, vq_indices)` but this function is never defined anywhere in the codebase. `TernaryGraph` is dead code (never imported or used), so this does not crash currently, but it blocks any future use of `TernaryGraph`.
|
| 80 |
+
|
| 81 |
+
**Fix:** Define `_graph_gather_add` or remove the dead class.
|
| 82 |
+
|
| 83 |
+
---
|
| 84 |
+
|
| 85 |
+
### WR-02: `_moe_dense_combine` called but never defined
|
| 86 |
+
|
| 87 |
+
**File:** `arbitor/components.py:941`
|
| 88 |
+
**Issue:** `SharedProjectionMoE.forward()` calls `_moe_dense_combine(torch.stack(...), topk_idx, topk_weights)` but this function is never defined. `SharedProjectionMoE` is dead code, but the missing function is a latent bug.
|
| 89 |
+
|
| 90 |
+
**Fix:** Define `_moe_dense_combine` or remove the dead class.
|
| 91 |
+
|
| 92 |
+
---
|
| 93 |
+
|
| 94 |
+
## WARNING: Unused Class/Function Definitions (Dead Code)
|
| 95 |
+
|
| 96 |
+
### WR-03: `TernaryLSTMCell` class β defined but never used
|
| 97 |
+
|
| 98 |
+
**File:** `arbitor/components.py:189-207`
|
| 99 |
+
**Issue:** `TernaryLSTMCell` is defined and re-exported from `__init__.py` (line 23) but is never instantiated anywhere in the codebase. The model uses `MoEGraph` with attention (MLA) instead of LSTM-based processing.
|
| 100 |
+
|
| 101 |
+
---
|
| 102 |
+
|
| 103 |
+
### WR-04: `TernaryGraph` class β defined but never used
|
| 104 |
+
|
| 105 |
+
**File:** `arbitor/components.py:665-802`
|
| 106 |
+
**Issue:** `TernaryGraph` is defined in `components.py` but never imported or instantiated. It was replaced by `MoEGraph` (line 1342). The only reference is in a comment (line 1348).
|
| 107 |
+
**Also:** `TernaryGraph` references the undefined function `_graph_gather_add` (see WR-01), so it cannot function even if someone tried to use it.
|
| 108 |
+
|
| 109 |
+
---
|
| 110 |
+
|
| 111 |
+
### WR-05: `SharedProjectionMoE` class β defined but never used
|
| 112 |
+
|
| 113 |
+
**File:** `arbitor/components.py:806-999`
|
| 114 |
+
**Issue:** `SharedProjectionMoE` is defined in `components.py` but never imported or instantiated. It was replaced by `MoEGraph._run_expert()` (line 1429). The only reference is in a comment (line 1348).
|
| 115 |
+
**Also:** References the undefined function `_moe_dense_combine` (see WR-02).
|
| 116 |
+
|
| 117 |
+
---
|
| 118 |
+
|
| 119 |
+
### WR-06: 7 dead Triton kernel functions in `components.py`
|
| 120 |
+
|
| 121 |
+
**File:** `arbitor/components.py:266-386`
|
| 122 |
+
**Issue:** These Triton kernel functions are defined inside the `if _HAS_TRITON:` block but are only referenced by their forward/backward wrapper functions which are themselves part of dead code (`TernaryGraph` and `SharedProjectionMoE`):
|
| 123 |
+
|
| 124 |
+
| Line | Function | Used By |
|
| 125 |
+
|------|----------|---------|
|
| 126 |
+
| 268 | `_triton_graph_aggregate_fwd_kernel` | dead (TernaryGraph) |
|
| 127 |
+
| 292 | `_triton_graph_aggregate_bwd_kernel` | dead (TernaryGraph) |
|
| 128 |
+
| 316 | `_triton_graph_gather_add_fwd_kernel` | dead (TernaryGraph) |
|
| 129 |
+
| 329 | `_triton_graph_gather_add_bwd_kernel` | dead (TernaryGraph) |
|
| 130 |
+
| 342 | `_triton_moe_dense_combine_fwd_kernel` | dead (SharedProjectionMoE) |
|
| 131 |
+
| 359 | `_triton_moe_dense_combine_bwd_expert_kernel` | dead (SharedProjectionMoE) |
|
| 132 |
+
| 374 | `_triton_moe_dense_combine_bwd_weight_kernel` | dead (SharedProjectionMoE) |
|
| 133 |
+
|
| 134 |
+
The live Triton kernels (`_triton_video_denoise_fwd_kernel` line 389, `_triton_video_denoise_bwd_kernel` line 402) are still in `components.py` and should also be moved to `arbitor/kernel/`.
|
| 135 |
+
|
| 136 |
+
---
|
| 137 |
+
|
| 138 |
+
### WR-07: `_triton_flash_vq_quantize_kernel` β dead Triton kernel
|
| 139 |
+
|
| 140 |
+
**File:** `arbitor/kernel/flash_vq.py:370-402`
|
| 141 |
+
**Issue:** This Triton kernel is defined but never called. The `_TritonFlashVQFn.forward()` method uses PyTorch's `embed[indices]` for the gather operation (line 468) instead of this kernel.
|
| 142 |
+
|
| 143 |
+
---
|
| 144 |
+
|
| 145 |
+
### WR-08: `TILE_SIZE = 384` β unused constant
|
| 146 |
+
|
| 147 |
+
**File:** `arbitor/kernel/ternary_scale.py:949`
|
| 148 |
+
**Issue:** `TILE_SIZE` is defined as a module-level constant but never referenced anywhere in the codebase.
|
| 149 |
+
|
| 150 |
+
---
|
| 151 |
+
|
| 152 |
+
## WARNING: Unused Imports
|
| 153 |
+
|
| 154 |
+
### WR-09: Unused imports in `arbitor/main.py` (line 10)
|
| 155 |
+
|
| 156 |
+
| Symbol | Used In File? |
|
| 157 |
+
|--------|--------------|
|
| 158 |
+
| `EMBEDDING_DIM` | No β not referenced in body |
|
| 159 |
+
| `FFN_HIDDEN` | No β not referenced in body |
|
| 160 |
+
| `CODEBOOK_DIM` | No β not referenced in body |
|
| 161 |
+
| `ATTENTION_STRIDE` | No β not referenced in body |
|
| 162 |
+
| `MG_N_EXPERTS` | No β MoEGraph uses default, not passed |
|
| 163 |
+
| `MG_CORE_RANK` | No β MoEGraph uses default |
|
| 164 |
+
| `MG_SHARED_INTER` | No β MoEGraph uses default |
|
| 165 |
+
| `MG_ACT_ITERS` | No β MoEGraph uses default |
|
| 166 |
+
|
| 167 |
+
---
|
| 168 |
+
|
| 169 |
+
### WR-10: Unused imports in `arbitor/components.py` (line 21)
|
| 170 |
+
|
| 171 |
+
| Symbol | Used In Live Code? | Note |
|
| 172 |
+
|--------|-------------------|------|
|
| 173 |
+
| `FFN_HIDDEN` | No | Not referenced in file body |
|
| 174 |
+
| `CTX` | No | Not referenced in file body |
|
| 175 |
+
| `THRESHOLD` | No | Only used in dead `TernaryGraph`. Live `MoEGraph` hardcodes `threshold=0.05` |
|
| 176 |
+
| `KG_EMA_ALPHA` | No | Only used in dead `TernaryGraph`. Live `MoEGraph` hardcodes `0.99` |
|
| 177 |
+
| `KG_REQUANT_EVERY` | No | Only used in dead `TernaryGraph`. Live `MoEGraph` hardcodes `50` |
|
| 178 |
+
| `KG_TERNARY_THRESHOLD` | No | Only used in dead `TernaryGraph`. Live `MoEGraph` hardcodes `0.3` |
|
| 179 |
+
|
| 180 |
+
---
|
| 181 |
+
|
| 182 |
+
### WR-11: Unused imports in `arbitor/profiling.py` (line 17)
|
| 183 |
+
|
| 184 |
+
| Symbol | Used In File? |
|
| 185 |
+
|--------|--------------|
|
| 186 |
+
| `VOCAB` | No β not referenced in body |
|
| 187 |
+
| `math` (line 11) | No β not referenced in body |
|
| 188 |
+
|
| 189 |
+
---
|
| 190 |
+
|
| 191 |
+
## INFO: Triton Kernel Code in `components.py` Should Be Moved to `arbitor/kernel/`
|
| 192 |
+
|
| 193 |
+
### IN-01: Live Triton kernels reside in `components.py` instead of `arbitor/kernel/`
|
| 194 |
+
|
| 195 |
+
**File:** `arbitor/components.py:389-445`
|
| 196 |
+
**Issue:** The codebase convention places Triton kernels in `arbitor/kernel/` (e.g., `ternary_scale.py`, `flash_vq.py`, `ternary_audit.py`). Two live Triton kernels remain in `components.py`:
|
| 197 |
+
|
| 198 |
+
- `_triton_video_denoise_fwd_kernel` (line 389)
|
| 199 |
+
- `_triton_video_denoise_bwd_kernel` (line 402)
|
| 200 |
+
- `_TritonVideoDenoiseFn` (line 415)
|
| 201 |
+
- `_video_denoise_step` (line 448)
|
| 202 |
+
|
| 203 |
+
These should be extracted into `arbitor/kernel/video_denoise.py` and imported from there, following the pattern established by `ternary_scale.py` and `flash_vq.py`.
|
| 204 |
+
|
| 205 |
+
---
|
| 206 |
+
|
| 207 |
+
## INFO: Additional Dead Code
|
| 208 |
+
|
| 209 |
+
### IN-02: Hardcoded MoEGraph config values bypass config constants
|
| 210 |
+
|
| 211 |
+
**File:** `arbitor/components.py:1381-1383`
|
| 212 |
+
**Issue:** `MoEGraph` uses hardcoded values (`50`, `0.3`, `0.99`) instead of the imported config constants (`KG_REQUANT_EVERY`, `KG_TERNARY_THRESHOLD`, `KG_EMA_ALPHA`). The values happen to match the config, but any future config changes will silently be ignored.
|
| 213 |
+
|
| 214 |
+
---
|
| 215 |
+
|
| 216 |
+
### IN-03: `AUDIO_VOCAB` not used meaningfully in `config.py`
|
| 217 |
+
|
| 218 |
+
**File:** `arbitor/config.py:2`
|
| 219 |
+
**Issue:** `AUDIO_VOCAB=288` is imported and used in `TalkerHead` and `TinyNeuralCodec`, but the `SPECIAL_VOCAB` map (line 65) defines tokens up to 287. `AUDIO_VOCAB` = `VOCAB` = 288, meaning the audio head has the same vocabulary as the text head. This may be intentional for the current prototype but is worth flagging given `AUDIO_VOCAB` vs `VOCAB` are separate constants.
|
| 220 |
+
|
| 221 |
+
---
|
| 222 |
+
|
| 223 |
+
_Reviewed: 2026-05-20T00:00:00Z_
|
| 224 |
+
_Reviewer: gsd-code-reviewer (deep analysis)_
|
arbitor.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.4
|
| 2 |
+
Name: arbitor
|
| 3 |
+
Version: 0.2.0
|
| 4 |
+
Summary: ARB (Any Relational Bit) β ternary-weighted neural network system
|
| 5 |
+
License: MIT
|
| 6 |
+
Requires-Python: >=3.12
|
| 7 |
+
Requires-Dist: torch>=2.5
|
| 8 |
+
Requires-Dist: einops
|
| 9 |
+
Requires-Dist: tqdm
|
| 10 |
+
Provides-Extra: dev
|
| 11 |
+
Requires-Dist: pytest; extra == "dev"
|
| 12 |
+
Provides-Extra: cuda
|
| 13 |
+
Requires-Dist: torch>=2.5; extra == "cuda"
|
| 14 |
+
Requires-Dist: triton>=3.0; extra == "cuda"
|
| 15 |
+
Provides-Extra: triton
|
| 16 |
+
Requires-Dist: triton>=3.0; extra == "triton"
|
| 17 |
+
Provides-Extra: tilelang
|
| 18 |
+
Requires-Dist: tilelang; extra == "tilelang"
|
arbitor.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pyproject.toml
|
| 2 |
+
arbitor/__init__.py
|
| 3 |
+
arbitor/components.py
|
| 4 |
+
arbitor/config.py
|
| 5 |
+
arbitor/decoders.py
|
| 6 |
+
arbitor/main.py
|
| 7 |
+
arbitor/profiling.py
|
| 8 |
+
arbitor/sequencers.py
|
| 9 |
+
arbitor/vq.py
|
| 10 |
+
arbitor.egg-info/PKG-INFO
|
| 11 |
+
arbitor.egg-info/SOURCES.txt
|
| 12 |
+
arbitor.egg-info/dependency_links.txt
|
| 13 |
+
arbitor.egg-info/requires.txt
|
| 14 |
+
arbitor.egg-info/top_level.txt
|
| 15 |
+
arbitor/attention/__init__.py
|
| 16 |
+
arbitor/attention/context_attention.py
|
| 17 |
+
arbitor/attention/frame_buffer.py
|
| 18 |
+
arbitor/attention/kq_cache.py
|
| 19 |
+
arbitor/attention/kv_ledger.py
|
| 20 |
+
arbitor/attention/mla.py
|
| 21 |
+
arbitor/attention/ring_buffer.py
|
| 22 |
+
arbitor/converters/convert_to_ternary2.py
|
| 23 |
+
arbitor/converters/convert_to_ternary54.py
|
| 24 |
+
arbitor/converters/convert_to_ternary64.py
|
| 25 |
+
arbitor/converters/convert_to_ternary8.py
|
| 26 |
+
arbitor/encoders/__init__.py
|
| 27 |
+
arbitor/encoders/audio.py
|
| 28 |
+
arbitor/encoders/mel_frontend.py
|
| 29 |
+
arbitor/encoders/opensora_vae.py
|
| 30 |
+
arbitor/encoders/pig_vae.py
|
| 31 |
+
arbitor/encoders/vae2d.py
|
| 32 |
+
arbitor/encoders/models/__init__.py
|
| 33 |
+
arbitor/encoders/models/download.py
|
| 34 |
+
arbitor/encoders/opensora_vae_modules/autoencoder_2d.py
|
| 35 |
+
arbitor/encoders/opensora_vae_modules/autoencoder_kl_causal_3d.py
|
| 36 |
+
arbitor/encoders/opensora_vae_modules/registry.py
|
| 37 |
+
arbitor/encoders/opensora_vae_modules/unet_causal_3d_blocks.py
|
| 38 |
+
arbitor/encoders/opensora_vae_modules/vae.py
|
| 39 |
+
arbitor/kernel/flash_vq.py
|
| 40 |
+
arbitor/kernel/ternary_audit.py
|
| 41 |
+
arbitor/kernel/ternary_scale.py
|
| 42 |
+
arbitor/kernel/triton_video.py
|
| 43 |
+
arbitor/optim/__init__.py
|
| 44 |
+
arbitor/optim/sign_sgd.py
|
| 45 |
+
testing/bigcalc.py
|
| 46 |
+
testing/scaled_optum.py
|
| 47 |
+
testing/sign_gsd.py
|
| 48 |
+
testing/test_200_step_smoke.py
|
| 49 |
+
testing/test_bigint_ternary.py
|
| 50 |
+
testing/test_gradient_capture.py
|
| 51 |
+
testing/test_polarity_validation.py
|
| 52 |
+
testing/test_tilelang_training.py
|
| 53 |
+
testing/test_tscale.py
|
| 54 |
+
testing/tscale_mini.py
|
| 55 |
+
testing/attention/__init__.py
|
| 56 |
+
testing/attention/test_kq_cache.py
|
| 57 |
+
testing/attention/test_kv_cache.py
|
| 58 |
+
testing/attention/test_lstm_removal.py
|
| 59 |
+
testing/attention/test_lstm_removal_clean.py
|
| 60 |
+
testing/attention/test_mla.py
|
| 61 |
+
testing/attention/test_ring_buffer.py
|
| 62 |
+
testing/benchmarks/benchmark.py
|
| 63 |
+
testing/benchmarks/benchmark_phase2.py
|
| 64 |
+
testing/benchmarks/benchmark_true_ternary.py
|
| 65 |
+
testing/eval/eval_checkpoints.py
|
| 66 |
+
testing/eval/eval_generation.py
|
| 67 |
+
testing/eval/eval_metrics.py
|
| 68 |
+
testing/eval/test_eval.py
|
| 69 |
+
testing/kg/test_composite_head.py
|
| 70 |
+
testing/kg/test_kg_edges.py
|
| 71 |
+
testing/kg/test_kv_integration.py
|
| 72 |
+
testing/model/audio-comprehension.py
|
| 73 |
+
testing/model/health.py
|
| 74 |
+
testing/model/image-comprehension.py
|
| 75 |
+
testing/model/test-stp.py
|
| 76 |
+
testing/model/test_arb.py
|
| 77 |
+
testing/model/test_flash.py
|
| 78 |
+
testing/model/test_tscale.py
|
| 79 |
+
testing/model/text-comprehension.py
|
| 80 |
+
testing/model/video-comprehension.py
|
| 81 |
+
testing/vae/test_opensora_vae.py
|
| 82 |
+
tests/test_cross_modal.py
|
| 83 |
+
tests/test_lti.py
|
| 84 |
+
tests/test_moegraph_topk.py
|
| 85 |
+
tests/test_vae2d.py
|
| 86 |
+
tests/test_vae2d_sequencer.py
|
| 87 |
+
training/audio.py
|
| 88 |
+
training/diffusion.py
|
| 89 |
+
training/pretrain.py
|
| 90 |
+
training/text.py
|
| 91 |
+
training/vision.py
|
| 92 |
+
training/data/__init__.py
|
| 93 |
+
training/data/prepare_cc12m.py
|
| 94 |
+
training/data/prepare_fineweb.py
|
| 95 |
+
training/data/prepare_librispeech.py
|
| 96 |
+
training/data/prepare_starcoder.py
|
| 97 |
+
training/data/prepare_webvid.py
|
| 98 |
+
training/data/tokenize_from_hf.py
|
| 99 |
+
training/finetuning/__init__.py
|
| 100 |
+
training/finetuning/audio.py
|
| 101 |
+
training/finetuning/diffusion.py
|
| 102 |
+
training/finetuning/lora.py
|
| 103 |
+
training/finetuning/text.py
|
| 104 |
+
training/finetuning/vision.py
|
arbitor.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
arbitor.egg-info/requires.txt
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.5
|
| 2 |
+
einops
|
| 3 |
+
tqdm
|
| 4 |
+
|
| 5 |
+
[cuda]
|
| 6 |
+
torch>=2.5
|
| 7 |
+
triton>=3.0
|
| 8 |
+
|
| 9 |
+
[dev]
|
| 10 |
+
pytest
|
| 11 |
+
|
| 12 |
+
[tilelang]
|
| 13 |
+
tilelang
|
| 14 |
+
|
| 15 |
+
[triton]
|
| 16 |
+
triton>=3.0
|
arbitor.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
arbitor
|
| 2 |
+
docs
|
| 3 |
+
models
|
| 4 |
+
testing
|
| 5 |
+
tests
|
| 6 |
+
training
|
arbitor/__init__.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ARBitor β Any Relational Bit System.
|
| 2 |
+
|
| 3 |
+
Core package for the ARB ternary-weighted neural network.
|
| 4 |
+
Quick import: from arbitor import ARBModel, VOCAB
|
| 5 |
+
"""
|
| 6 |
+
from .config import VOCAB, AUDIO_VOCAB, AUDIO_SR, AUDIO_FRAME_RATE, \
|
| 7 |
+
EMBEDDING_DIM, HIDDEN_DIM, CTX, SPECIAL_VOCAB, \
|
| 8 |
+
CODEBOOK_DIM, SHARED_VQ_SIZE, \
|
| 9 |
+
MG_N_EXPERTS, MG_CORE_RANK, MG_SHARED_INTER, MG_ACT_ITERS, \
|
| 10 |
+
MEMGRAM_STRUCT_PRIMES, MEMGRAM_CONV_PRIMES, MEMGRAM_EMBED_DIM, MEMGRAM_KEY_DIM
|
| 11 |
+
|
| 12 |
+
from .kernel.ternary_scale import (
|
| 13 |
+
TernaryScaleTensor, TernaryRMSNorm, TScaleType, GROUP_SIZES,
|
| 14 |
+
_HAS_TRITON, _HAS_TILELANG,
|
| 15 |
+
)
|
| 16 |
+
from .kernel.flash_vq import FlashVQCodebook
|
| 17 |
+
from .kernel.ternary_audit import audit_model, format_audit, freeze_float_parameters, trainable_parameters
|
| 18 |
+
from .converters.convert_to_ternary8 import pack_ternary, unpack_ternary
|
| 19 |
+
|
| 20 |
+
from .sequencers import ByteEmbedding, Sequencer, TextSequencer, VAE2DSequencer, VAEAudioSequencer, MultimodalSequencer
|
| 21 |
+
from .vq import SharedVQ
|
| 22 |
+
from .components import (
|
| 23 |
+
TernaryEmbeddingTable, TernaryVQCodebook,
|
| 24 |
+
GNNLoRAAdapter, HaltingUnit,
|
| 25 |
+
MemGram, MoEGraph,
|
| 26 |
+
ByteHead, OutputRouter,
|
| 27 |
+
LossComponents, LossWeights, StickyZoneSTE,
|
| 28 |
+
KGVQCodebook, CompositeProposalHead,
|
| 29 |
+
_BOUNDARY_TOKEN_MAP,
|
| 30 |
+
)
|
| 31 |
+
from .decoders import VideoHead, TalkerHead, MRFBlock, TinyNeuralCodec
|
| 32 |
+
from .main import ARBModel, _extract_boundary_from_input
|
| 33 |
+
|
| 34 |
+
# Re-export encoders
|
| 35 |
+
from .encoders import TinyNeuralCodec as Codec, AudioVQEncoder, load_vae, VAEWrapper
|
arbitor/attention/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ARB Attention β KV Ledger, MLA, Sliding Window Attention."""
|
| 2 |
+
from .ring_buffer import GPURingBuffer
|
| 3 |
+
from .kv_ledger import KVLedger
|
| 4 |
+
from .kq_cache import KQCache
|
| 5 |
+
from .mla import (MultiHeadLatentAttention, apply_rotary_emb,
|
| 6 |
+
precompute_freqs_cis)
|
| 7 |
+
from .context_attention import ContextAttentionScheduler
|
| 8 |
+
from .frame_buffer import TemporalFrameBuffer
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"GPURingBuffer", "KVLedger", "KQCache",
|
| 12 |
+
"MultiHeadLatentAttention", "apply_rotary_emb",
|
| 13 |
+
"precompute_freqs_cis", "ContextAttentionScheduler",
|
| 14 |
+
"TemporalFrameBuffer",
|
| 15 |
+
]
|
arbitor/attention/context_attention.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Context Attention Scheduler β sliding window + full context orchestration.
|
| 2 |
+
|
| 3 |
+
Schedules 4 sliding window (d=64, CSA-compressed to d=16) and 4 full context
|
| 4 |
+
(d=32, HCA-compressed to d=8) MLA attention passes. Combines both via gating.
|
| 5 |
+
|
| 6 |
+
Pipeline: GNN output β ContextAttentionScheduler β MoE input
|
| 7 |
+
"""
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
from ..config import HIDDEN_DIM, MLA_HCA_STRIDE
|
| 11 |
+
from ..kernel.ternary_scale import TernaryScaleTensor, TernaryRMSNorm, TScaleType
|
| 12 |
+
from .mla import (MultiHeadLatentAttention, precompute_freqs_cis,
|
| 13 |
+
MLA_N_LAYERS, MLA_N_HEADS, MLA_SLIDE_DIM, MLA_FULL_DIM,
|
| 14 |
+
MLA_QK_NOPE_HEAD_DIM, MLA_QK_ROPE_HEAD_DIM,
|
| 15 |
+
MLA_V_HEAD_DIM, MLA_ROPE_THETA,
|
| 16 |
+
MLA_CSA_DIM, MLA_HCA_DIM, MLA_HCA_STRIDE)
|
| 17 |
+
|
| 18 |
+
SLIDING_WINDOW_SIZE = 32768
|
| 19 |
+
KV_LEDGER_SIZE = 262144
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ContextAttentionScheduler(nn.Module):
|
| 23 |
+
def __init__(self, dim=HIDDEN_DIM):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.dim = dim
|
| 26 |
+
|
| 27 |
+
# Slide layers with CSA compression (d=64 β d=16) β half of total layers
|
| 28 |
+
n_layers_per_pass = max(1, MLA_N_LAYERS // 2)
|
| 29 |
+
self.slide_layers = nn.ModuleList([
|
| 30 |
+
MultiHeadLatentAttention(
|
| 31 |
+
dim=dim, n_heads=MLA_N_HEADS, kv_lora_rank=MLA_SLIDE_DIM,
|
| 32 |
+
qk_nope_head_dim=MLA_QK_NOPE_HEAD_DIM,
|
| 33 |
+
qk_rope_head_dim=MLA_QK_ROPE_HEAD_DIM,
|
| 34 |
+
v_head_dim=MLA_V_HEAD_DIM,
|
| 35 |
+
csa_dim=MLA_CSA_DIM, hca_dim=None,
|
| 36 |
+
) for _ in range(n_layers_per_pass)
|
| 37 |
+
])
|
| 38 |
+
# CSA: embed motif IDs β kv_lora_rank, then compress β csa_dim
|
| 39 |
+
self.slide_embed = TernaryScaleTensor(1, MLA_SLIDE_DIM, tscale_type=TScaleType.T32)
|
| 40 |
+
self.slide_compress = TernaryScaleTensor(MLA_SLIDE_DIM, MLA_CSA_DIM, tscale_type=TScaleType.T32)
|
| 41 |
+
|
| 42 |
+
# Full context layers with HCA compression (d=32 β d=8) β half of total layers
|
| 43 |
+
self.full_layers = nn.ModuleList([
|
| 44 |
+
MultiHeadLatentAttention(
|
| 45 |
+
dim=dim, n_heads=MLA_N_HEADS, kv_lora_rank=MLA_FULL_DIM,
|
| 46 |
+
qk_nope_head_dim=MLA_QK_NOPE_HEAD_DIM,
|
| 47 |
+
qk_rope_head_dim=MLA_QK_ROPE_HEAD_DIM,
|
| 48 |
+
v_head_dim=MLA_V_HEAD_DIM,
|
| 49 |
+
csa_dim=None, hca_dim=MLA_HCA_DIM,
|
| 50 |
+
) for _ in range(n_layers_per_pass)
|
| 51 |
+
])
|
| 52 |
+
# HCA: embed motif IDs β kv_lora_rank, then compress β hca_dim
|
| 53 |
+
self.full_embed = TernaryScaleTensor(1, MLA_FULL_DIM, tscale_type=TScaleType.T32)
|
| 54 |
+
self.full_compress = TernaryScaleTensor(MLA_FULL_DIM, MLA_HCA_DIM, tscale_type=TScaleType.T32)
|
| 55 |
+
|
| 56 |
+
self.gate = TernaryScaleTensor(dim, 1, tscale_type=TScaleType.T32)
|
| 57 |
+
|
| 58 |
+
self._freqs_cis = None
|
| 59 |
+
self._max_freq_len = 0
|
| 60 |
+
|
| 61 |
+
def _ensure_freqs(self, seq_len, device):
|
| 62 |
+
needed = max(seq_len, SLIDING_WINDOW_SIZE, KV_LEDGER_SIZE)
|
| 63 |
+
if self._freqs_cis is None or needed > self._max_freq_len:
|
| 64 |
+
self._max_freq_len = needed
|
| 65 |
+
self._freqs_cis = precompute_freqs_cis(
|
| 66 |
+
MLA_QK_ROPE_HEAD_DIM, needed, theta=MLA_ROPE_THETA
|
| 67 |
+
).to(device)
|
| 68 |
+
return self._freqs_cis
|
| 69 |
+
|
| 70 |
+
def forward(self, x, kv_ledger, full_ledger=None, kq_cache=None):
|
| 71 |
+
bsz, seqlen, _ = x.shape
|
| 72 |
+
device = x.device
|
| 73 |
+
freqs_cis = self._ensure_freqs(seqlen, device)
|
| 74 |
+
|
| 75 |
+
full_ledger = full_ledger or kv_ledger
|
| 76 |
+
|
| 77 |
+
window_size = min(SLIDING_WINDOW_SIZE, kv_ledger.size) if kv_ledger.size > 0 else 0
|
| 78 |
+
|
| 79 |
+
out_slide = x
|
| 80 |
+
if window_size > 0:
|
| 81 |
+
start = max(0, kv_ledger.size - SLIDING_WINDOW_SIZE)
|
| 82 |
+
end = kv_ledger.size
|
| 83 |
+
slide_ids = kv_ledger.get_range(start, end).float().unsqueeze(-1)
|
| 84 |
+
# Embed to kv_lora_rank, then CSA compress to csa_dim
|
| 85 |
+
slide_latent = self.slide_embed(slide_ids)
|
| 86 |
+
csa_cache = self.slide_compress(slide_latent)
|
| 87 |
+
pe_cache = torch.zeros(csa_cache.shape[0], MLA_QK_ROPE_HEAD_DIM, device=device)
|
| 88 |
+
|
| 89 |
+
for layer in self.slide_layers:
|
| 90 |
+
out_slide = layer(out_slide, slide_latent, pe_cache,
|
| 91 |
+
start_pos=0, freqs_cis=freqs_cis, mask=None,
|
| 92 |
+
csa_cache=csa_cache)
|
| 93 |
+
|
| 94 |
+
out_full = x
|
| 95 |
+
if full_ledger.size > 0:
|
| 96 |
+
full = full_ledger.get_sparse(stride=MLA_HCA_STRIDE)
|
| 97 |
+
full_ids = full.float().unsqueeze(-1)
|
| 98 |
+
full_latent = self.full_embed(full_ids)
|
| 99 |
+
hca_cache = self.full_compress(full_latent)
|
| 100 |
+
pe_cache = torch.zeros(hca_cache.shape[0], MLA_QK_ROPE_HEAD_DIM, device=device)
|
| 101 |
+
|
| 102 |
+
for layer in self.full_layers:
|
| 103 |
+
out_full = layer(out_full, full_latent, pe_cache,
|
| 104 |
+
start_pos=0, freqs_cis=freqs_cis, mask=None,
|
| 105 |
+
hca_cache=hca_cache, hca_pe_cache=pe_cache)
|
| 106 |
+
|
| 107 |
+
gate = torch.sigmoid(self.gate(x.mean(dim=1, keepdim=True)))
|
| 108 |
+
out = gate * out_slide + (1 - gate) * out_full
|
| 109 |
+
return out
|
arbitor/attention/frame_buffer.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""TemporalFrameBuffer β ring buffer for video latents with HCA compression.
|
| 2 |
+
|
| 3 |
+
Stores the last N video latents (local) and maintains a compressed long-range
|
| 4 |
+
cache via TernaryScaleTensor projection. Used for conditioning video generation
|
| 5 |
+
on previous time steps.
|
| 6 |
+
|
| 7 |
+
Latent shape: [B, C, H', W'] where C=OPEN_SORA_LATENT_CHANNELS=4,
|
| 8 |
+
H'=VIDEO_HEIGHT=32, W'=VIDEO_WIDTH=32. Each "latent" is one 4-frame chunk.
|
| 9 |
+
"""
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from ..kernel.ternary_scale import TernaryScaleTensor, TScaleType
|
| 13 |
+
from .ring_buffer import GPURingBuffer
|
| 14 |
+
from ..config import FRAME_BUFFER_LOCAL_SIZE, FRAME_BUFFER_CACHE_STRIDE, \
|
| 15 |
+
OPEN_SORA_LATENT_CHANNELS, VIDEO_HEIGHT, VIDEO_WIDTH
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class TemporalFrameBuffer(nn.Module):
|
| 19 |
+
def __init__(self, local_size=FRAME_BUFFER_LOCAL_SIZE,
|
| 20 |
+
cache_stride=FRAME_BUFFER_CACHE_STRIDE,
|
| 21 |
+
latent_channels=OPEN_SORA_LATENT_CHANNELS,
|
| 22 |
+
height=VIDEO_HEIGHT, width=VIDEO_WIDTH,
|
| 23 |
+
tscale_type=TScaleType.T32):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.latent_channels = latent_channels
|
| 26 |
+
self.spatial_dim = height * width
|
| 27 |
+
self.latent_flat_dim = latent_channels * self.spatial_dim
|
| 28 |
+
|
| 29 |
+
self.local = GPURingBuffer(
|
| 30 |
+
max_size=local_size,
|
| 31 |
+
dtype=torch.float32,
|
| 32 |
+
dim=self.latent_flat_dim,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
self.compress_proj = TernaryScaleTensor(
|
| 36 |
+
self.latent_flat_dim,
|
| 37 |
+
self.latent_flat_dim // 4,
|
| 38 |
+
tscale_type=tscale_type,
|
| 39 |
+
)
|
| 40 |
+
self.compressed_cache = []
|
| 41 |
+
self.cache_stride = cache_stride
|
| 42 |
+
self._frames_since_compress = 0
|
| 43 |
+
|
| 44 |
+
def append(self, latent):
|
| 45 |
+
B = latent.shape[0]
|
| 46 |
+
flat = latent.reshape(B, -1)
|
| 47 |
+
self.local.append(flat)
|
| 48 |
+
|
| 49 |
+
self._frames_since_compress += 1
|
| 50 |
+
if self._frames_since_compress >= self.cache_stride:
|
| 51 |
+
compressed = self.compress_proj(flat)
|
| 52 |
+
self.compressed_cache.append(compressed.detach())
|
| 53 |
+
self._frames_since_compress = 0
|
| 54 |
+
|
| 55 |
+
def get_local(self, n=None):
|
| 56 |
+
n = n or self.local.max_size
|
| 57 |
+
result = self.local.get_last_n(n)
|
| 58 |
+
if result.dim() == 0 or result.shape[0] == 0:
|
| 59 |
+
return torch.zeros(0, 1, self.latent_flat_dim)
|
| 60 |
+
if result.dim() == 1:
|
| 61 |
+
result = result.unsqueeze(0)
|
| 62 |
+
return result
|
| 63 |
+
|
| 64 |
+
def get_compressed_cache(self):
|
| 65 |
+
if not self.compressed_cache:
|
| 66 |
+
return torch.zeros(0, 1, self.latent_flat_dim // 4)
|
| 67 |
+
return torch.stack(self.compressed_cache, dim=0)
|
| 68 |
+
|
| 69 |
+
def get_conditioning(self, n_local=None):
|
| 70 |
+
return {
|
| 71 |
+
"local": self.get_local(n_local),
|
| 72 |
+
"compressed": self.get_compressed_cache(),
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
def reset(self):
|
| 76 |
+
self.local.reset()
|
| 77 |
+
self.compressed_cache = []
|
| 78 |
+
self._frames_since_compress = 0
|
arbitor/attention/kq_cache.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""KQ Cache β small ring buffer of last 8K motif IDs for O(1) peek.
|
| 2 |
+
|
| 3 |
+
Per D-64: Small ring buffer holding last 8K motif IDs. No compression - just raw IDs.
|
| 4 |
+
O(1) peek for fast motif lookup without MemGram query.
|
| 5 |
+
|
| 6 |
+
Per D-65: Updated after each ByteHead output append to ledger.
|
| 7 |
+
"""
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
from ..config import KQ_CACHE_SIZE
|
| 11 |
+
from .ring_buffer import GPURingBuffer
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class KQCache(nn.Module):
|
| 15 |
+
def __init__(self, max_size=KQ_CACHE_SIZE):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.ring = GPURingBuffer(max_size=max_size, dtype=torch.int32, dim=1)
|
| 18 |
+
|
| 19 |
+
def append(self, motif_id: int):
|
| 20 |
+
self.ring.append(torch.tensor(motif_id, dtype=torch.int32, device=self.ring.buffer.device))
|
| 21 |
+
|
| 22 |
+
def peek(self, n=1):
|
| 23 |
+
return self.ring.get_last_n(n)
|
| 24 |
+
|
| 25 |
+
@property
|
| 26 |
+
def size(self):
|
| 27 |
+
return self.ring.size
|
| 28 |
+
|
| 29 |
+
def reset(self):
|
| 30 |
+
self.ring.reset()
|
arbitor/attention/kv_ledger.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""KV Ledger β append-only ring buffer of motif IDs (int32), max 256K entries.
|
| 2 |
+
|
| 3 |
+
Per D-57: Append-only ring buffer of motif IDs (int32), max 256K entries.
|
| 4 |
+
When full, oldest entries are overwritten. Stored as flat tensor on GPU.
|
| 5 |
+
|
| 6 |
+
Per D-59: The ledger stores only what the model outputs (motif IDs),
|
| 7 |
+
not input prompts. Prompts go through VQ -> GNN -> Motif pipeline first.
|
| 8 |
+
|
| 9 |
+
KV is consumed by the ContextAttentionScheduler. Its output is injected into
|
| 10 |
+
MoEGraph, which then conditions the router and output heads through the shared
|
| 11 |
+
processed relational state.
|
| 12 |
+
"""
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
from ..config import KV_LEDGER_SIZE, SLIDING_WINDOW_SIZE
|
| 16 |
+
from .ring_buffer import GPURingBuffer
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class KVLedger(nn.Module):
|
| 20 |
+
def __init__(self, max_size=KV_LEDGER_SIZE):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.ring = GPURingBuffer(max_size=max_size, dtype=torch.int32, dim=1)
|
| 23 |
+
|
| 24 |
+
def append(self, motif_id: int):
|
| 25 |
+
self.ring.append(torch.tensor(motif_id, dtype=torch.int32, device=self.ring.buffer.device))
|
| 26 |
+
|
| 27 |
+
def get_sliding_window(self, n=SLIDING_WINDOW_SIZE):
|
| 28 |
+
return self.ring.get_last_n(n)
|
| 29 |
+
|
| 30 |
+
def get_range(self, start, end):
|
| 31 |
+
n = end - start
|
| 32 |
+
if n <= 0 or start >= self.ring.size:
|
| 33 |
+
return torch.zeros(0, dtype=torch.int32, device=self.ring.buffer.device)
|
| 34 |
+
if start + n <= self.ring.max_size:
|
| 35 |
+
return self.ring.buffer[start:start + n].squeeze(-1)
|
| 36 |
+
first = self.ring.buffer[start:].squeeze(-1)
|
| 37 |
+
second = self.ring.buffer[:n - (self.ring.max_size - start)].squeeze(-1)
|
| 38 |
+
return torch.cat([first, second])
|
| 39 |
+
|
| 40 |
+
def get_sparse(self, stride=8):
|
| 41 |
+
size = self.ring.size
|
| 42 |
+
if size == 0:
|
| 43 |
+
return torch.zeros(0, dtype=torch.int32, device=self.ring.buffer.device)
|
| 44 |
+
all_vals = self.ring.get_all()
|
| 45 |
+
indices = torch.arange(0, size, stride, device=self.ring.buffer.device, dtype=torch.long)
|
| 46 |
+
indices = indices[indices < len(all_vals)]
|
| 47 |
+
return all_vals[indices]
|
| 48 |
+
|
| 49 |
+
@property
|
| 50 |
+
def size(self):
|
| 51 |
+
return self.ring.size
|
| 52 |
+
|
| 53 |
+
def __len__(self):
|
| 54 |
+
return self.ring.size
|
| 55 |
+
|
| 56 |
+
def reset(self):
|
| 57 |
+
self.ring.reset()
|
arbitor/attention/mla.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Multi-Head Latent Attention with CSA + HCA compression (DeepSeek V4 style).
|
| 2 |
+
|
| 3 |
+
Ternary-weighted. KV cache stores compressed latent at multiple levels:
|
| 4 |
+
- Base: MLA latent (d=kv_lora_rank, typically 64/32)
|
| 5 |
+
- CSA: Secondary compression (d_csa, e.g. 16) β 4x compression on cache
|
| 6 |
+
- HCA: Heavily compressed (d_hca, e.g. 8) β 8x compression, wider stride
|
| 7 |
+
|
| 8 |
+
Scores = q_nope_absorbed @ decompress(kv_cache) + q_pe @ pe_cache
|
| 9 |
+
"""
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from ..config import HIDDEN_DIM, MLA_CSA_DIM, MLA_HCA_DIM, MLA_HCA_STRIDE, MLA_N_LAYERS
|
| 14 |
+
from ..kernel.ternary_scale import TernaryScaleTensor, TernaryRMSNorm, TScaleType
|
| 15 |
+
|
| 16 |
+
MLA_N_HEADS = 32
|
| 17 |
+
MLA_QK_NOPE_HEAD_DIM = 96
|
| 18 |
+
MLA_QK_ROPE_HEAD_DIM = 32
|
| 19 |
+
MLA_V_HEAD_DIM = 96
|
| 20 |
+
MLA_ROPE_THETA = 10000.0
|
| 21 |
+
MLA_SLIDE_DIM = 64
|
| 22 |
+
MLA_FULL_DIM = 32
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def apply_rotary_emb(x, freqs_cis):
|
| 26 |
+
x_complex = torch.view_as_complex(
|
| 27 |
+
x.float().reshape(*x.shape[:-1], -1, 2)
|
| 28 |
+
)
|
| 29 |
+
freqs = freqs_cis.unsqueeze(1).unsqueeze(0)
|
| 30 |
+
return torch.view_as_real(x_complex * freqs).flatten(-2).to(x.dtype)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def precompute_freqs_cis(dim, end, theta=MLA_ROPE_THETA):
|
| 34 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
| 35 |
+
t = torch.arange(end, device=freqs.device)
|
| 36 |
+
freqs = torch.outer(t, freqs)
|
| 37 |
+
return torch.polar(torch.ones_like(freqs), freqs)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class MultiHeadLatentAttention(nn.Module):
|
| 41 |
+
def __init__(self, dim=HIDDEN_DIM, n_heads=MLA_N_HEADS, kv_lora_rank=MLA_SLIDE_DIM,
|
| 42 |
+
qk_nope_head_dim=MLA_QK_NOPE_HEAD_DIM, qk_rope_head_dim=MLA_QK_ROPE_HEAD_DIM,
|
| 43 |
+
v_head_dim=MLA_V_HEAD_DIM, max_seq_len=65536,
|
| 44 |
+
csa_dim=MLA_CSA_DIM, hca_dim=MLA_HCA_DIM,
|
| 45 |
+
tscale_type=TScaleType.T32):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.dim = dim
|
| 48 |
+
self.n_heads = n_heads
|
| 49 |
+
self.kv_lora_rank = kv_lora_rank
|
| 50 |
+
self.qk_nope_head_dim = qk_nope_head_dim
|
| 51 |
+
self.qk_rope_head_dim = qk_rope_head_dim
|
| 52 |
+
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
| 53 |
+
self.v_head_dim = v_head_dim
|
| 54 |
+
self.softmax_scale = self.qk_head_dim ** -0.5
|
| 55 |
+
self.max_seq_len = max_seq_len
|
| 56 |
+
self.csa_dim = csa_dim
|
| 57 |
+
self.hca_dim = hca_dim
|
| 58 |
+
|
| 59 |
+
self.wq_norm = TernaryRMSNorm(dim, tscale_type=tscale_type)
|
| 60 |
+
self.wq = TernaryScaleTensor(dim, n_heads * self.qk_head_dim, tscale_type=tscale_type)
|
| 61 |
+
|
| 62 |
+
combined_out = n_heads * (qk_nope_head_dim + v_head_dim)
|
| 63 |
+
self.wkv_b = TernaryScaleTensor(kv_lora_rank, combined_out, tscale_type=tscale_type)
|
| 64 |
+
self.wo = TernaryScaleTensor(n_heads * v_head_dim, dim, tscale_type=tscale_type)
|
| 65 |
+
|
| 66 |
+
# CSA: secondary compression (kv_lora_rank -> csa_dim)
|
| 67 |
+
if csa_dim and csa_dim < kv_lora_rank:
|
| 68 |
+
self.csa_compress = TernaryScaleTensor(kv_lora_rank, csa_dim, tscale_type=tscale_type)
|
| 69 |
+
self.csa_decompress = TernaryScaleTensor(csa_dim, kv_lora_rank, tscale_type=tscale_type)
|
| 70 |
+
else:
|
| 71 |
+
self.csa_compress = None
|
| 72 |
+
self.csa_decompress = None
|
| 73 |
+
|
| 74 |
+
# HCA: heavily compressed (kv_lora_rank -> hca_dim)
|
| 75 |
+
if hca_dim and hca_dim < (csa_dim or kv_lora_rank):
|
| 76 |
+
self.hca_compress = TernaryScaleTensor(kv_lora_rank, hca_dim, tscale_type=tscale_type)
|
| 77 |
+
self.hca_decompress = TernaryScaleTensor(hca_dim, kv_lora_rank, tscale_type=tscale_type)
|
| 78 |
+
else:
|
| 79 |
+
self.hca_compress = None
|
| 80 |
+
self.hca_decompress = None
|
| 81 |
+
|
| 82 |
+
def _compress(self, kv_cache, compress_proj):
|
| 83 |
+
"""Compress kv_cache from kv_lora_rank to smaller dim."""
|
| 84 |
+
return compress_proj(kv_cache)
|
| 85 |
+
|
| 86 |
+
def _decompress(self, cache, decompress_proj):
|
| 87 |
+
"""Decompress cache back to kv_lora_rank."""
|
| 88 |
+
return decompress_proj(cache)
|
| 89 |
+
|
| 90 |
+
def _compute_scores(self, q_nope_absorbed, q_pe, kv_flat, pe_flat,
|
| 91 |
+
start_pos, seqlen, mask):
|
| 92 |
+
"""Shared score computation for base, CSA, and HCA attention."""
|
| 93 |
+
n_keys = min(kv_flat.shape[0], pe_flat.shape[0])
|
| 94 |
+
kv_flat = kv_flat[:n_keys]
|
| 95 |
+
pe_flat = pe_flat[:n_keys]
|
| 96 |
+
if n_keys == 0:
|
| 97 |
+
return q_pe.new_zeros(q_pe.shape[0], seqlen, q_pe.shape[2], 0)
|
| 98 |
+
scores = (
|
| 99 |
+
torch.einsum("bshc,btc->bsht",
|
| 100 |
+
q_nope_absorbed, kv_flat.unsqueeze(0))
|
| 101 |
+
+ torch.einsum("bshr,btr->bsht",
|
| 102 |
+
q_pe, pe_flat.unsqueeze(0))
|
| 103 |
+
) * self.softmax_scale
|
| 104 |
+
|
| 105 |
+
if mask is not None:
|
| 106 |
+
scores = scores + mask.unsqueeze(0).unsqueeze(0)
|
| 107 |
+
if mask is None and seqlen > 1:
|
| 108 |
+
causal = torch.triu(
|
| 109 |
+
torch.full((seqlen, n_keys), float('-inf'), device=q_pe.device),
|
| 110 |
+
diagonal=1 + start_pos
|
| 111 |
+
)
|
| 112 |
+
scores = scores + causal.unsqueeze(0).unsqueeze(2)
|
| 113 |
+
return scores
|
| 114 |
+
|
| 115 |
+
def forward(self, x, kv_cache, pe_cache, start_pos=0, freqs_cis=None, mask=None,
|
| 116 |
+
csa_cache=None, hca_cache=None, hca_pe_cache=None):
|
| 117 |
+
bsz, seqlen, _ = x.size()
|
| 118 |
+
|
| 119 |
+
q = self.wq(self.wq_norm(x))
|
| 120 |
+
q = q.view(bsz, seqlen, self.n_heads, self.qk_head_dim)
|
| 121 |
+
q_nope, q_pe = torch.split(
|
| 122 |
+
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
| 123 |
+
|
| 124 |
+
if freqs_cis is not None:
|
| 125 |
+
q_pe = apply_rotary_emb(q_pe, freqs_cis[start_pos:start_pos + seqlen])
|
| 126 |
+
|
| 127 |
+
wkv_b = self.wkv_b._get_T() * self.wkv_b._get_S()
|
| 128 |
+
wkv_b = wkv_b.view(self.n_heads, -1, self.kv_lora_rank)
|
| 129 |
+
|
| 130 |
+
q_nope_absorbed = torch.einsum(
|
| 131 |
+
"bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
|
| 132 |
+
|
| 133 |
+
n_cache = min(kv_cache.shape[0], pe_cache.shape[0])
|
| 134 |
+
kv_flat = kv_cache[:n_cache]
|
| 135 |
+
pe_flat = pe_cache[:n_cache]
|
| 136 |
+
|
| 137 |
+
# Decompress CSA cache if provided (replaces base kv_cache)
|
| 138 |
+
if csa_cache is not None and self.csa_decompress is not None:
|
| 139 |
+
n_csa = min(csa_cache.shape[0], pe_flat.shape[0])
|
| 140 |
+
kv_flat = self._decompress(csa_cache[:n_csa], self.csa_decompress)
|
| 141 |
+
pe_flat = pe_flat[:n_csa]
|
| 142 |
+
|
| 143 |
+
# Base attention (exact, CSA-compressed if applicable)
|
| 144 |
+
scores = self._compute_scores(
|
| 145 |
+
q_nope_absorbed, q_pe, kv_flat, pe_flat,
|
| 146 |
+
start_pos, seqlen, mask,
|
| 147 |
+
)
|
| 148 |
+
scores = scores.softmax(dim=-1, dtype=torch.float32)
|
| 149 |
+
|
| 150 |
+
attn_out = torch.einsum(
|
| 151 |
+
"bsht,btc->bshc", scores, kv_flat.unsqueeze(0))
|
| 152 |
+
|
| 153 |
+
# HCA long-range attention (heavily compressed, strided)
|
| 154 |
+
hca_out = None
|
| 155 |
+
if hca_cache is not None and self.hca_decompress is not None:
|
| 156 |
+
hca_kv = self._decompress(hca_cache, self.hca_decompress)
|
| 157 |
+
if hca_pe_cache is None:
|
| 158 |
+
hca_pe = pe_cache[::MLA_HCA_STRIDE]
|
| 159 |
+
else:
|
| 160 |
+
hca_pe = hca_pe_cache
|
| 161 |
+
n_hca = min(hca_kv.shape[0], hca_pe.shape[0])
|
| 162 |
+
hca_kv = hca_kv[:n_hca]
|
| 163 |
+
hca_pe = hca_pe[:n_hca]
|
| 164 |
+
hca_scores = self._compute_scores(
|
| 165 |
+
q_nope_absorbed, q_pe, hca_kv, hca_pe,
|
| 166 |
+
start_pos, seqlen, mask=None,
|
| 167 |
+
)
|
| 168 |
+
hca_scores = hca_scores.softmax(dim=-1, dtype=torch.float32)
|
| 169 |
+
hca_out = torch.einsum(
|
| 170 |
+
"bsht,btc->bshc", hca_scores, hca_kv.unsqueeze(0))
|
| 171 |
+
attn_out = attn_out + hca_out
|
| 172 |
+
|
| 173 |
+
attn_unproj = torch.einsum(
|
| 174 |
+
"bshc,hdc->bshd", attn_out, wkv_b[:, -self.v_head_dim:])
|
| 175 |
+
|
| 176 |
+
return self.wo(attn_unproj.flatten(2))
|
arbitor/attention/ring_buffer.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""GPURingBuffer β generic GPU ring buffer utility.
|
| 2 |
+
|
| 3 |
+
O(1) append via circular pointer, chronological get_last_n with wrap handling.
|
| 4 |
+
All storage via register_buffer for device movement and state_dict serialization.
|
| 5 |
+
"""
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class GPURingBuffer(nn.Module):
|
| 11 |
+
def __init__(self, max_size: int, dtype: torch.dtype = torch.int32, dim: int = 1):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.max_size = max_size
|
| 14 |
+
self.ptr = 0
|
| 15 |
+
self.size = 0
|
| 16 |
+
buffer_shape = (max_size, dim if dim > 1 else 1)
|
| 17 |
+
self.register_buffer("buffer", torch.zeros(buffer_shape, dtype=dtype))
|
| 18 |
+
|
| 19 |
+
def append(self, x):
|
| 20 |
+
if not isinstance(x, torch.Tensor):
|
| 21 |
+
x = torch.tensor(x, dtype=self.buffer.dtype, device=self.buffer.device)
|
| 22 |
+
if self.buffer.dim() == 2 and x.dim() == 0:
|
| 23 |
+
x = x.view(1)
|
| 24 |
+
self.buffer[self.ptr] = x
|
| 25 |
+
self.ptr = (self.ptr + 1) % self.max_size
|
| 26 |
+
self.size = min(self.size + 1, self.max_size)
|
| 27 |
+
|
| 28 |
+
def get_last_n(self, n: int):
|
| 29 |
+
n = min(n, self.size)
|
| 30 |
+
if n == 0:
|
| 31 |
+
return torch.zeros(0, *self.buffer.shape[1:], dtype=self.buffer.dtype, device=self.buffer.device)
|
| 32 |
+
start = (self.ptr - n) % self.max_size
|
| 33 |
+
if start + n <= self.max_size:
|
| 34 |
+
result = self.buffer[start:start + n]
|
| 35 |
+
else:
|
| 36 |
+
first = self.buffer[start:]
|
| 37 |
+
second = self.buffer[:n - (self.max_size - start)]
|
| 38 |
+
result = torch.cat([first, second], dim=0)
|
| 39 |
+
if result.dim() > 1 and result.shape[1] == 1:
|
| 40 |
+
result = result.squeeze(-1)
|
| 41 |
+
return result
|
| 42 |
+
|
| 43 |
+
def get_all(self):
|
| 44 |
+
return self.get_last_n(self.size)
|
| 45 |
+
|
| 46 |
+
def reset(self):
|
| 47 |
+
self.buffer.zero_()
|
| 48 |
+
self.ptr = 0
|
| 49 |
+
self.size = 0
|
arbitor/components.py
ADDED
|
@@ -0,0 +1,1218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Components β core neural network modules for the ARB system."""
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
from .kernel.ternary_scale import TernaryScaleTensor, TScaleType, TernaryRMSNorm, GROUP_SIZES, _COMPONENT_CONTEXT, _HAS_TRITON
|
| 7 |
+
try:
|
| 8 |
+
from .kernel.ternary_scale import _TritonTernaryEmbedFn
|
| 9 |
+
except ImportError:
|
| 10 |
+
_TritonTernaryEmbedFn = None
|
| 11 |
+
from .converters.convert_to_ternary8 import pack_ternary, unpack_ternary
|
| 12 |
+
from dataclasses import dataclass, field, fields
|
| 13 |
+
from math import ceil as _ceil, log2 as _log2
|
| 14 |
+
from transformers import AutoModel, AutoFeatureExtractor
|
| 15 |
+
from .config import VOCAB, EMBEDDING_DIM, HIDDEN_DIM, AUDIO_VOCAB, AUDIO_SR, AUDIO_FRAME_RATE, SPECIAL_VOCAB, CODEBOOK_DIM, CODEBOOK_SIZE, FFN_HIDDEN, CTX, THRESHOLD, KG_EMA_ALPHA, KG_REQUANT_EVERY, KG_TERNARY_THRESHOLD, KGVQ_CODEBOOK_SIZE, KGVQ_CODEBOOK_DIM, KGVQ_DECAY, KGVQ_COMMITMENT_WEIGHT, KGVQ_DEAD_CODE_THRESHOLD, K_MAX_COMPOSITES, MG_N_EXPERTS, MG_CORE_RANK, MG_SHARED_INTER, MG_ACT_ITERS, MG_WORKSPACE_DIM, BYTEHEAD_ACT_MAX_ITERS, BYTEHEAD_ACT_HALT_CONSECUTIVE
|
| 16 |
+
|
| 17 |
+
_ceil_div = lambda a, b: _ceil(a / b) if b > 0 else 0
|
| 18 |
+
|
| 19 |
+
from .sequencers import ByteEmbedding
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class LossWeights:
|
| 24 |
+
lm: float = 1.0
|
| 25 |
+
vq_commitment: float = 1.0
|
| 26 |
+
moe_aux: float = 1.0
|
| 27 |
+
graph_l1: float = 0.001
|
| 28 |
+
graph_ponder: float = 1.0
|
| 29 |
+
moe_ponder: float = 1.0
|
| 30 |
+
moegraph_ponder: float = 1.0
|
| 31 |
+
memgram_decay_reg: float = 0.01
|
| 32 |
+
composite_vq: float = 1.0
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class LossComponents:
|
| 37 |
+
lm: torch.Tensor = None
|
| 38 |
+
vq_commitment: torch.Tensor = None
|
| 39 |
+
moe_aux: torch.Tensor = None
|
| 40 |
+
graph_l1: torch.Tensor = None
|
| 41 |
+
graph_ponder: torch.Tensor = None
|
| 42 |
+
moe_ponder: torch.Tensor = None
|
| 43 |
+
moegraph_ponder: torch.Tensor = None
|
| 44 |
+
memgram_decay_reg: torch.Tensor = None
|
| 45 |
+
composite_vq: torch.Tensor = None
|
| 46 |
+
weights: LossWeights = field(default_factory=LossWeights)
|
| 47 |
+
|
| 48 |
+
@property
|
| 49 |
+
def total(self) -> torch.Tensor:
|
| 50 |
+
w = self.weights
|
| 51 |
+
loss = None
|
| 52 |
+
|
| 53 |
+
def add_component(current, weight, component):
|
| 54 |
+
if component is None:
|
| 55 |
+
return current
|
| 56 |
+
weighted = weight * component
|
| 57 |
+
return weighted if current is None else current + weighted
|
| 58 |
+
|
| 59 |
+
loss = add_component(loss, w.lm, self.lm)
|
| 60 |
+
loss = add_component(loss, w.vq_commitment, self.vq_commitment)
|
| 61 |
+
loss = add_component(loss, w.moe_aux, self.moe_aux)
|
| 62 |
+
loss = add_component(loss, w.graph_l1, self.graph_l1)
|
| 63 |
+
loss = add_component(loss, w.graph_ponder, self.graph_ponder)
|
| 64 |
+
loss = add_component(loss, w.moe_ponder, self.moe_ponder)
|
| 65 |
+
loss = add_component(loss, w.moegraph_ponder, self.moegraph_ponder)
|
| 66 |
+
loss = add_component(loss, w.memgram_decay_reg, self.memgram_decay_reg)
|
| 67 |
+
loss = add_component(loss, w.composite_vq, self.composite_vq)
|
| 68 |
+
if loss is None:
|
| 69 |
+
raise ValueError("LossComponents.total requested with no active loss tensors")
|
| 70 |
+
return loss
|
| 71 |
+
|
| 72 |
+
@property
|
| 73 |
+
def active_fields(self) -> list[tuple[str, torch.Tensor, float]]:
|
| 74 |
+
result = []
|
| 75 |
+
for field in fields(self):
|
| 76 |
+
name = field.name
|
| 77 |
+
if name == 'weights':
|
| 78 |
+
continue
|
| 79 |
+
tensor = getattr(self, name)
|
| 80 |
+
if tensor is not None:
|
| 81 |
+
weight = getattr(self.weights, name)
|
| 82 |
+
result.append((name, tensor, weight))
|
| 83 |
+
return result
|
| 84 |
+
|
| 85 |
+
def log(self, writer, step, prefix="loss"):
|
| 86 |
+
writer.add_scalar(f"{prefix}/total", self.total.item(), step)
|
| 87 |
+
if self.lm is not None:
|
| 88 |
+
writer.add_scalar(f"{prefix}/lm", self.lm.item(), step)
|
| 89 |
+
if self.vq_commitment is not None:
|
| 90 |
+
writer.add_scalar(f"{prefix}/vq_commitment", self.vq_commitment.item(), step)
|
| 91 |
+
if self.moe_aux is not None:
|
| 92 |
+
writer.add_scalar(f"{prefix}/moe_aux", self.moe_aux.item(), step)
|
| 93 |
+
if self.graph_l1 is not None:
|
| 94 |
+
writer.add_scalar(f"{prefix}/graph_l1", self.graph_l1.item(), step)
|
| 95 |
+
if self.graph_ponder is not None:
|
| 96 |
+
writer.add_scalar(f"{prefix}/graph_ponder", self.graph_ponder.item(), step)
|
| 97 |
+
if self.moe_ponder is not None:
|
| 98 |
+
writer.add_scalar(f"{prefix}/moe_ponder", self.moe_ponder.item(), step)
|
| 99 |
+
if self.moegraph_ponder is not None:
|
| 100 |
+
writer.add_scalar(f"{prefix}/moegraph_ponder", self.moegraph_ponder.item(), step)
|
| 101 |
+
if self.memgram_decay_reg is not None:
|
| 102 |
+
writer.add_scalar(f"{prefix}/memgram_decay_reg", self.memgram_decay_reg.item(), step)
|
| 103 |
+
if self.composite_vq is not None:
|
| 104 |
+
writer.add_scalar(f"{prefix}/composite_vq", self.composite_vq.item(), step)
|
| 105 |
+
|
| 106 |
+
def backward(self, retain_graph=False):
|
| 107 |
+
self.total.backward(retain_graph=retain_graph)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class StickyZoneSTE(torch.autograd.Function):
|
| 111 |
+
@staticmethod
|
| 112 |
+
def forward(ctx, w, threshold):
|
| 113 |
+
ctx.save_for_backward(w, torch.tensor(threshold))
|
| 114 |
+
return w.sign() * (w.abs() > threshold).to(w.dtype)
|
| 115 |
+
|
| 116 |
+
@staticmethod
|
| 117 |
+
def backward(ctx, grad_output):
|
| 118 |
+
w, threshold_tensor = ctx.saved_tensors
|
| 119 |
+
threshold = threshold_tensor.item()
|
| 120 |
+
ratio = torch.clamp(w.abs() / threshold, 0.0, 1.0)
|
| 121 |
+
return grad_output * ratio, None
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class TernaryEmbeddingTable(nn.Module):
|
| 125 |
+
def __init__(self, num_embeddings, embedding_dim, tscale_type=TScaleType.T32,
|
| 126 |
+
init_std=0.02, threshold=0.05, normalize=False):
|
| 127 |
+
super().__init__()
|
| 128 |
+
self.num_embeddings = num_embeddings
|
| 129 |
+
self.embedding_dim = embedding_dim
|
| 130 |
+
self.tscale_type = tscale_type
|
| 131 |
+
init_threshold = min(float(threshold), 0.5 * float(init_std)) if init_std > 0 else threshold
|
| 132 |
+
self.threshold = init_threshold
|
| 133 |
+
self.normalize = normalize
|
| 134 |
+
self.group_size = GROUP_SIZES.get(tscale_type, GROUP_SIZES[TScaleType.T64])
|
| 135 |
+
self.sparse_threshold = 65_536
|
| 136 |
+
|
| 137 |
+
if num_embeddings >= self.sparse_threshold:
|
| 138 |
+
n_trits = num_embeddings * embedding_dim
|
| 139 |
+
n_packed = _ceil_div(n_trits, 5)
|
| 140 |
+
packed_T = torch.randint(0, 243, (n_packed,), dtype=torch.uint8)
|
| 141 |
+
T_pad = n_packed * 5 - n_trits
|
| 142 |
+
gpr = _ceil_div(embedding_dim, self.group_size)
|
| 143 |
+
init_exp = int(round(_log2(max(init_std, 1e-8))))
|
| 144 |
+
self.register_buffer("T_packed", packed_T)
|
| 145 |
+
self.register_buffer("_T_shape", torch.tensor([num_embeddings, embedding_dim], dtype=torch.long))
|
| 146 |
+
self.register_buffer("_T_pad", torch.tensor(T_pad, dtype=torch.long))
|
| 147 |
+
self.register_buffer(
|
| 148 |
+
"E",
|
| 149 |
+
torch.full((num_embeddings * gpr,), init_exp, dtype=torch.int8),
|
| 150 |
+
)
|
| 151 |
+
self.register_buffer("E_accum", torch.zeros_like(self.E, dtype=torch.int8))
|
| 152 |
+
self.register_buffer("T_accum", torch.zeros(num_embeddings, embedding_dim, dtype=torch.int8))
|
| 153 |
+
self._ema_alpha: float = 0.1
|
| 154 |
+
self._loss_temp_scale: float = 1.0
|
| 155 |
+
return
|
| 156 |
+
|
| 157 |
+
w_init = torch.randn(num_embeddings, embedding_dim) * init_std
|
| 158 |
+
T_init = w_init.sign() * (w_init.abs() > init_threshold).to(w_init.dtype)
|
| 159 |
+
packed_T, _, T_pad = pack_ternary(T_init)
|
| 160 |
+
self.register_buffer("T_packed", packed_T)
|
| 161 |
+
self.register_buffer("_T_shape", torch.tensor([num_embeddings, embedding_dim], dtype=torch.long))
|
| 162 |
+
self.register_buffer("_T_pad", torch.tensor(T_pad, dtype=torch.long))
|
| 163 |
+
|
| 164 |
+
gpr = _ceil_div(embedding_dim, self.group_size)
|
| 165 |
+
total_in = gpr * self.group_size
|
| 166 |
+
padded = torch.zeros(num_embeddings, total_in)
|
| 167 |
+
padded[:, :embedding_dim] = w_init.abs()
|
| 168 |
+
grouped = padded.view(num_embeddings, gpr, self.group_size)
|
| 169 |
+
E_vals = torch.where(grouped.mean(dim=2) > 0, grouped.mean(dim=2), torch.ones(num_embeddings, gpr))
|
| 170 |
+
self.register_buffer("E", E_vals.flatten().log2().clamp(-128, 127).to(torch.int8))
|
| 171 |
+
self.register_buffer("E_accum", torch.zeros_like(self.E, dtype=torch.int8))
|
| 172 |
+
self.register_buffer("T_accum", torch.zeros(num_embeddings, embedding_dim, dtype=torch.int8))
|
| 173 |
+
self._ema_alpha: float = 0.1
|
| 174 |
+
self._loss_temp_scale: float = 1.0
|
| 175 |
+
|
| 176 |
+
def _get_T(self):
|
| 177 |
+
return unpack_ternary(self.T_packed, tuple(self._T_shape.tolist()), int(self._T_pad.item()))
|
| 178 |
+
|
| 179 |
+
def _get_T_rows(self, indices):
|
| 180 |
+
indices = indices.reshape(-1).to(device=self.T_packed.device, dtype=torch.long)
|
| 181 |
+
dim = self.embedding_dim
|
| 182 |
+
cols = torch.arange(dim, device=indices.device, dtype=torch.long)
|
| 183 |
+
lin = indices[:, None] * dim + cols[None, :]
|
| 184 |
+
pack_idx = lin // 5
|
| 185 |
+
trit_pos = lin - pack_idx * 5
|
| 186 |
+
packed = self.T_packed[pack_idx].to(torch.long)
|
| 187 |
+
divisors = torch.tensor([1, 3, 9, 27, 81], device=indices.device, dtype=torch.long)
|
| 188 |
+
code = (packed // divisors[trit_pos]) % 3
|
| 189 |
+
return (code.to(torch.int8) - 1)
|
| 190 |
+
|
| 191 |
+
def _expand_E_rows(self, indices):
|
| 192 |
+
indices = indices.reshape(-1).to(device=self.E.device, dtype=torch.long)
|
| 193 |
+
gpr = _ceil_div(self.embedding_dim, self.group_size)
|
| 194 |
+
E_rows = self.E.view(self.num_embeddings, gpr)[indices]
|
| 195 |
+
E_exp = E_rows.repeat_interleave(self.group_size, dim=1)
|
| 196 |
+
return E_exp[:, :self.embedding_dim]
|
| 197 |
+
|
| 198 |
+
@torch.no_grad()
|
| 199 |
+
def _set_T_rows(self, row_indices, rows):
|
| 200 |
+
row_indices = row_indices.reshape(-1).to(device=self.T_packed.device, dtype=torch.long)
|
| 201 |
+
rows = rows.to(device=self.T_packed.device, dtype=torch.int8).reshape(row_indices.numel(), self.embedding_dim)
|
| 202 |
+
divisors = [1, 3, 9, 27, 81]
|
| 203 |
+
for row_pos, row_idx in enumerate(row_indices.tolist()):
|
| 204 |
+
row = rows[row_pos]
|
| 205 |
+
for col in range(self.embedding_dim):
|
| 206 |
+
lin = row_idx * self.embedding_dim + col
|
| 207 |
+
pack_idx = lin // 5
|
| 208 |
+
trit_pos = lin - pack_idx * 5
|
| 209 |
+
divisor = divisors[trit_pos]
|
| 210 |
+
old = int(self.T_packed[pack_idx].item())
|
| 211 |
+
old_code = (old // divisor) % 3
|
| 212 |
+
new_code = int(row[col].item()) + 1
|
| 213 |
+
if old_code != new_code:
|
| 214 |
+
self.T_packed[pack_idx] = old - old_code * divisor + new_code * divisor
|
| 215 |
+
|
| 216 |
+
def _expand_E(self):
|
| 217 |
+
out_dim, in_dim = tuple(self._T_shape.tolist())
|
| 218 |
+
gpr = _ceil_div(in_dim, self.group_size)
|
| 219 |
+
E_2d = self.E.view(out_dim, gpr)
|
| 220 |
+
E_exp = E_2d.repeat_interleave(self.group_size, dim=1)
|
| 221 |
+
return E_exp[:, :in_dim]
|
| 222 |
+
|
| 223 |
+
def _ensure_E_accum(self):
|
| 224 |
+
if not hasattr(self, "E_accum"):
|
| 225 |
+
self.register_buffer("E_accum", torch.zeros_like(self.E, dtype=torch.int8))
|
| 226 |
+
elif self.E_accum.shape != self.E.shape or self.E_accum.device != self.E.device:
|
| 227 |
+
self.E_accum = torch.zeros_like(self.E, dtype=torch.int8)
|
| 228 |
+
return self.E_accum
|
| 229 |
+
|
| 230 |
+
def forward(self, indices):
|
| 231 |
+
use_sparse = self.num_embeddings >= self.sparse_threshold
|
| 232 |
+
if use_sparse:
|
| 233 |
+
idx_flat = indices.reshape(-1).to(device=self.T_packed.device, dtype=torch.long)
|
| 234 |
+
T_rows = self._get_T_rows(idx_flat)
|
| 235 |
+
E_exp = self._expand_E_rows(idx_flat)
|
| 236 |
+
w_eff = torch.exp2(E_exp.float()) * T_rows.float()
|
| 237 |
+
w_eff_grad = w_eff.detach().requires_grad_(torch.is_grad_enabled())
|
| 238 |
+
if torch.is_grad_enabled():
|
| 239 |
+
comp_name, _ = _COMPONENT_CONTEXT.get()
|
| 240 |
+
def capture_sparse_grad(grad):
|
| 241 |
+
suffix = f"_{comp_name}" if comp_name is not None else ""
|
| 242 |
+
setattr(self, f"_hook_sparse_indices{suffix}", idx_flat.detach())
|
| 243 |
+
setattr(self, f"_hook_sparse_grad_sign{suffix}", grad.reshape(-1, self.embedding_dim).sign().to(torch.int8).detach())
|
| 244 |
+
setattr(self, f"_hook_sparse_T{suffix}", T_rows.detach())
|
| 245 |
+
w_eff_grad.register_hook(capture_sparse_grad)
|
| 246 |
+
out = w_eff_grad.reshape(*indices.shape, self.embedding_dim)
|
| 247 |
+
return F.normalize(out, dim=-1) if self.normalize else out
|
| 248 |
+
if indices.is_cuda and _HAS_TRITON and _TritonTernaryEmbedFn is not None:
|
| 249 |
+
dummy = torch.zeros(1, device=indices.device, requires_grad=True)
|
| 250 |
+
out = _TritonTernaryEmbedFn.apply(indices, dummy, self)
|
| 251 |
+
else:
|
| 252 |
+
T = self._get_T()
|
| 253 |
+
w_eff = torch.exp2(self._expand_E().float()) * T.float()
|
| 254 |
+
w_eff_grad = w_eff.detach().requires_grad_(True)
|
| 255 |
+
self._hook_T = T
|
| 256 |
+
def capture_w_grad(grad_w):
|
| 257 |
+
self._hook_grad_T_sign = grad_w.sign().to(torch.int8)
|
| 258 |
+
w_eff_grad.register_hook(capture_w_grad)
|
| 259 |
+
out = F.embedding(indices, w_eff_grad)
|
| 260 |
+
return F.normalize(out, dim=-1) if self.normalize else out
|
| 261 |
+
|
| 262 |
+
def ternary_step(self, accum_threshold=3):
|
| 263 |
+
if hasattr(self, "_hook_sparse_indices") and hasattr(self, "_hook_sparse_grad_sign"):
|
| 264 |
+
return self._sparse_ternary_step(accum_threshold=accum_threshold)
|
| 265 |
+
if hasattr(self, "_hook_grad_T_sign"):
|
| 266 |
+
if hasattr(self, "_accumulate_corr_from_grad_sign"):
|
| 267 |
+
self._accumulate_corr_from_grad_sign(self._hook_grad_T_sign)
|
| 268 |
+
del self._hook_grad_T_sign
|
| 269 |
+
|
| 270 |
+
def update_E(self, loss_signal=None):
|
| 271 |
+
if hasattr(self, "_hook_sparse_indices") and hasattr(self, "_hook_sparse_grad_sign"):
|
| 272 |
+
return self._sparse_update_E(loss_signal=loss_signal)
|
| 273 |
+
|
| 274 |
+
@torch.no_grad()
|
| 275 |
+
def _sparse_ternary_step(self, accum_threshold=3):
|
| 276 |
+
indices = self._hook_sparse_indices.to(device=self.T_accum.device, dtype=torch.long)
|
| 277 |
+
grad_sign = self._hook_sparse_grad_sign.to(device=self.T_accum.device, dtype=torch.int16)
|
| 278 |
+
if indices.numel() == 0:
|
| 279 |
+
return
|
| 280 |
+
unique, inverse = torch.unique(indices, return_inverse=True)
|
| 281 |
+
grad_sum = torch.zeros(unique.numel(), self.embedding_dim, device=self.T_accum.device, dtype=torch.int16)
|
| 282 |
+
grad_sum.index_add_(0, inverse, grad_sign)
|
| 283 |
+
grad_step = grad_sum.sign().to(torch.int16) * int(getattr(self, "_t_accum_step", 1))
|
| 284 |
+
current = self.T_accum[unique].to(torch.int16)
|
| 285 |
+
updated = torch.clamp(current - grad_step, -128, 127).to(torch.int8)
|
| 286 |
+
|
| 287 |
+
pgt = getattr(self, "per_group_threshold", None)
|
| 288 |
+
if pgt is not None:
|
| 289 |
+
gpr = _ceil_div(self.embedding_dim, self.group_size)
|
| 290 |
+
threshold = pgt.view(self.num_embeddings, gpr)[unique]
|
| 291 |
+
threshold = threshold.unsqueeze(-1).expand(unique.numel(), gpr, self.group_size)
|
| 292 |
+
threshold = threshold.reshape(unique.numel(), gpr * self.group_size)[:, :self.embedding_dim]
|
| 293 |
+
threshold = threshold.to(updated.device)
|
| 294 |
+
flip_up = updated > threshold
|
| 295 |
+
flip_down = updated < -threshold
|
| 296 |
+
else:
|
| 297 |
+
flip_up = updated > accum_threshold
|
| 298 |
+
flip_down = updated < -accum_threshold
|
| 299 |
+
self._had_flip = bool((flip_up | flip_down).any().item())
|
| 300 |
+
if self._had_flip:
|
| 301 |
+
rows = self._get_T_rows(unique).to(updated.device)
|
| 302 |
+
rows = torch.where(flip_up, torch.ones_like(rows), torch.where(flip_down, -torch.ones_like(rows), rows))
|
| 303 |
+
self._set_T_rows(unique, rows)
|
| 304 |
+
updated = torch.where(flip_up | flip_down, torch.zeros_like(updated), updated)
|
| 305 |
+
self.T_accum[unique] = updated
|
| 306 |
+
del self._hook_sparse_indices
|
| 307 |
+
del self._hook_sparse_grad_sign
|
| 308 |
+
if hasattr(self, "_hook_sparse_T"):
|
| 309 |
+
del self._hook_sparse_T
|
| 310 |
+
|
| 311 |
+
@torch.no_grad()
|
| 312 |
+
def _sparse_update_E(self, loss_signal=None):
|
| 313 |
+
indices = self._hook_sparse_indices.to(device=self.E.device, dtype=torch.long)
|
| 314 |
+
grad_sign = self._hook_sparse_grad_sign.to(device=self.E.device, dtype=torch.int16)
|
| 315 |
+
T_rows = self._hook_sparse_T if hasattr(self, "_hook_sparse_T") else self._get_T_rows(indices)
|
| 316 |
+
T_rows = T_rows.to(device=self.E.device, dtype=torch.int16)
|
| 317 |
+
if indices.numel() == 0:
|
| 318 |
+
return
|
| 319 |
+
unique, inverse = torch.unique(indices, return_inverse=True)
|
| 320 |
+
gpr = _ceil_div(self.embedding_dim, self.group_size)
|
| 321 |
+
total_in = gpr * self.group_size
|
| 322 |
+
signed = grad_sign * T_rows
|
| 323 |
+
grouped = F.pad(signed, (0, total_in - self.embedding_dim)).view(indices.numel(), gpr, self.group_size)
|
| 324 |
+
score = grouped.sum(dim=2)
|
| 325 |
+
delta = torch.where(
|
| 326 |
+
score > 0,
|
| 327 |
+
torch.full_like(score, -1, dtype=torch.int16),
|
| 328 |
+
torch.where(score < 0, torch.ones_like(score, dtype=torch.int16), torch.zeros_like(score, dtype=torch.int16)),
|
| 329 |
+
)
|
| 330 |
+
delta_sum = torch.zeros(unique.numel(), gpr, device=self.E.device, dtype=torch.int16)
|
| 331 |
+
delta_sum.index_add_(0, inverse, delta)
|
| 332 |
+
delta_sign = delta_sum.sign()
|
| 333 |
+
e_idx = unique[:, None] * gpr + torch.arange(gpr, device=self.E.device, dtype=torch.long)[None, :]
|
| 334 |
+
accum = torch.clamp(self.E_accum[e_idx].to(torch.int16) + delta_sign, -128, 127)
|
| 335 |
+
threshold = int(getattr(self, "_e_accum_threshold", 4))
|
| 336 |
+
step = torch.where(
|
| 337 |
+
accum >= threshold,
|
| 338 |
+
torch.ones_like(accum, dtype=torch.int16),
|
| 339 |
+
torch.where(accum <= -threshold, torch.full_like(accum, -1, dtype=torch.int16), torch.zeros_like(accum, dtype=torch.int16)),
|
| 340 |
+
)
|
| 341 |
+
self.E[e_idx] = torch.clamp(self.E[e_idx].to(torch.int16) + step, -128, 127).to(torch.int8)
|
| 342 |
+
self.E_accum[e_idx] = (accum - step * threshold).to(torch.int8)
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
class TernaryVQCodebook(nn.Module):
|
| 347 |
+
def __init__(self, codebook_size, codebook_dim, commitment_weight=1.0,
|
| 348 |
+
tscale_type=TScaleType.T32, exact_lookup_max=16384,
|
| 349 |
+
candidate_count=256):
|
| 350 |
+
super().__init__()
|
| 351 |
+
self.codebook_size = codebook_size
|
| 352 |
+
self.codebook_dim = codebook_dim
|
| 353 |
+
self.commitment_weight = commitment_weight
|
| 354 |
+
self.exact_lookup_max = exact_lookup_max
|
| 355 |
+
self.candidate_count = candidate_count
|
| 356 |
+
self.threshold_ema_dead_code = 2
|
| 357 |
+
self.table = TernaryEmbeddingTable(codebook_size, codebook_dim, tscale_type=tscale_type, normalize=True)
|
| 358 |
+
self.register_buffer("cluster_size", torch.zeros(codebook_size, dtype=torch.int16))
|
| 359 |
+
|
| 360 |
+
@property
|
| 361 |
+
def embed(self):
|
| 362 |
+
idx = torch.arange(self.codebook_size, device=self.table.T_packed.device)
|
| 363 |
+
return self.table(idx)
|
| 364 |
+
|
| 365 |
+
def _candidate_ids(self, flat):
|
| 366 |
+
c = min(self.candidate_count, self.codebook_size)
|
| 367 |
+
take = min(flat.shape[1], 16)
|
| 368 |
+
primes = torch.tensor(
|
| 369 |
+
[1009, 9176, 6361, 5333, 4447, 3469, 2531, 1613,
|
| 370 |
+
811, 421, 211, 109, 59, 31, 17, 7],
|
| 371 |
+
device=flat.device, dtype=torch.float32,
|
| 372 |
+
)[:take]
|
| 373 |
+
signed = torch.sign(flat[:, :take].float())
|
| 374 |
+
base = torch.abs(torch.round((signed * primes).sum(dim=1) * 104729)).to(torch.long)
|
| 375 |
+
offsets = torch.arange(c, device=flat.device, dtype=torch.long)
|
| 376 |
+
stride = 2_654_435_761
|
| 377 |
+
return (base[:, None] + offsets[None, :] * stride) % self.codebook_size
|
| 378 |
+
|
| 379 |
+
def _lookup(self, flat):
|
| 380 |
+
if self.codebook_size <= self.exact_lookup_max:
|
| 381 |
+
x_norm = F.normalize(flat.float(), dim=-1)
|
| 382 |
+
codebook = self.embed.to(device=flat.device)
|
| 383 |
+
sim = x_norm @ codebook.T
|
| 384 |
+
indices = sim.argmax(dim=-1)
|
| 385 |
+
quantized = codebook[indices]
|
| 386 |
+
return quantized, indices
|
| 387 |
+
|
| 388 |
+
candidate_ids = self._candidate_ids(flat)
|
| 389 |
+
x_norm = F.normalize(flat.float(), dim=-1)
|
| 390 |
+
n, c, d = flat.shape[0], candidate_ids.shape[1], flat.shape[1]
|
| 391 |
+
chunk = 64
|
| 392 |
+
quantized = torch.empty_like(flat)
|
| 393 |
+
indices = torch.empty(n, dtype=torch.long, device=flat.device)
|
| 394 |
+
for start in range(0, n, chunk):
|
| 395 |
+
end = min(start + chunk, n)
|
| 396 |
+
chunk_ids = candidate_ids[start:end]
|
| 397 |
+
chunk_vecs = self.table(chunk_ids).float()
|
| 398 |
+
chunk_norm = F.normalize(chunk_vecs, dim=-1)
|
| 399 |
+
chunk_sim = (chunk_norm * x_norm[start:end].unsqueeze(1)).sum(dim=-1)
|
| 400 |
+
chunk_best = chunk_sim.argmax(dim=-1)
|
| 401 |
+
indices[start:end] = candidate_ids[start:end].gather(1, chunk_best.unsqueeze(1)).squeeze(1)
|
| 402 |
+
quantized[start:end] = chunk_vecs[torch.arange(end - start, device=flat.device), chunk_best]
|
| 403 |
+
return quantized, indices
|
| 404 |
+
|
| 405 |
+
def forward(self, x):
|
| 406 |
+
orig_shape = x.shape
|
| 407 |
+
flat = x.reshape(-1, self.codebook_dim)
|
| 408 |
+
quantized, indices = self._lookup(flat)
|
| 409 |
+
commitment = self.commitment_weight * (
|
| 410 |
+
F.mse_loss(flat.float(), quantized.detach().float())
|
| 411 |
+
+ 0.25 * F.mse_loss(quantized.float(), flat.detach().float())
|
| 412 |
+
)
|
| 413 |
+
quantized = flat + (quantized - flat).detach()
|
| 414 |
+
with torch.no_grad():
|
| 415 |
+
unique, counts = torch.unique(indices, return_counts=True)
|
| 416 |
+
current = self.cluster_size[unique].to(torch.int32)
|
| 417 |
+
updated = torch.clamp(current + counts.to(device=current.device, dtype=torch.int32), 0, 32767).to(torch.int16)
|
| 418 |
+
self.cluster_size[unique] = updated
|
| 419 |
+
return quantized.reshape(orig_shape), indices.reshape(orig_shape[:-1]), commitment
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
class GNNLoRAAdapter(nn.Module):
|
| 423 |
+
def __init__(self, dim, rank=32, max_hops=4):
|
| 424 |
+
super().__init__()
|
| 425 |
+
self.max_hops = max_hops
|
| 426 |
+
self.down = TernaryScaleTensor(dim, rank, tscale_type=TScaleType.T32)
|
| 427 |
+
self.up = TernaryScaleTensor(rank, dim, tscale_type=TScaleType.T32)
|
| 428 |
+
self.scale = TernaryEmbeddingTable(max_hops, rank, tscale_type=TScaleType.T32)
|
| 429 |
+
|
| 430 |
+
def forward(self, x, hop_t):
|
| 431 |
+
t_idx = min(hop_t, self.max_hops - 1)
|
| 432 |
+
s = self.scale(torch.tensor(t_idx, device=x.device))
|
| 433 |
+
return self.up(self.down(x) * s)
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
class HaltingUnit(nn.Module):
|
| 437 |
+
def __init__(self, dim, tscale_type=TScaleType.T32):
|
| 438 |
+
super().__init__()
|
| 439 |
+
self.proj = TernaryScaleTensor(dim, 1, tscale_type=tscale_type)
|
| 440 |
+
self.norm = TernaryRMSNorm(dim, tscale_type=tscale_type)
|
| 441 |
+
|
| 442 |
+
def forward(self, x):
|
| 443 |
+
return torch.sigmoid(self.proj(self.norm(x)))
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
class _NgramHashMapping:
|
| 447 |
+
"""N-gram hash mapping with CPU offloading (Spider Engram style).
|
| 448 |
+
|
| 449 |
+
Hashes token sequences to fixed-size embedding indices. Hash computation
|
| 450 |
+
runs on CPU via numpy, O(1) per token via precomputed multipliers.
|
| 451 |
+
"""
|
| 452 |
+
|
| 453 |
+
def __init__(self, max_ngram_size, num_heads, table_size_base, layer_seed=0):
|
| 454 |
+
self.max_ngram_size = max_ngram_size
|
| 455 |
+
self.num_heads = num_heads
|
| 456 |
+
self.num_ngram_orders = max_ngram_size - 1
|
| 457 |
+
|
| 458 |
+
import numpy as np
|
| 459 |
+
PRIME_1 = 10007
|
| 460 |
+
g = torch.Generator()
|
| 461 |
+
g.manual_seed(int(layer_seed + PRIME_1 * int(layer_seed)))
|
| 462 |
+
r = torch.randint(0, 1 << 30, (max_ngram_size,), generator=g, dtype=torch.int64)
|
| 463 |
+
self.multipliers = r.numpy() * 2 + 1
|
| 464 |
+
|
| 465 |
+
seen_primes = set()
|
| 466 |
+
self.prime_table_sizes = []
|
| 467 |
+
for _ in range(self.num_ngram_orders):
|
| 468 |
+
head_sizes = []
|
| 469 |
+
ps = table_size_base - 1
|
| 470 |
+
for _ in range(num_heads):
|
| 471 |
+
p = self._next_prime(ps, seen_primes)
|
| 472 |
+
seen_primes.add(p)
|
| 473 |
+
head_sizes.append(p)
|
| 474 |
+
ps = p
|
| 475 |
+
self.prime_table_sizes.append(head_sizes)
|
| 476 |
+
|
| 477 |
+
self.all_head_sizes = [s for sub in self.prime_table_sizes for s in sub]
|
| 478 |
+
offsets = [0]
|
| 479 |
+
for s in self.all_head_sizes[:-1]:
|
| 480 |
+
offsets.append(offsets[-1] + s)
|
| 481 |
+
self.offsets_arr = offsets
|
| 482 |
+
self.total_slots = sum(self.all_head_sizes)
|
| 483 |
+
|
| 484 |
+
@staticmethod
|
| 485 |
+
def _next_prime(n, seen):
|
| 486 |
+
while n in seen or not _is_prime(n):
|
| 487 |
+
n -= 1
|
| 488 |
+
return n
|
| 489 |
+
|
| 490 |
+
def compute_hashes(self, token_ids):
|
| 491 |
+
import numpy as np
|
| 492 |
+
x = token_ids.cpu().numpy().astype(np.int64)
|
| 493 |
+
B, T = x.shape
|
| 494 |
+
|
| 495 |
+
shifts = [x]
|
| 496 |
+
for k in range(1, self.max_ngram_size):
|
| 497 |
+
shifts.append(np.pad(x, ((0, 0), (k, 0)), constant_values=0)[:, :T])
|
| 498 |
+
|
| 499 |
+
all_hashes = []
|
| 500 |
+
for order_idx in range(self.num_ngram_orders):
|
| 501 |
+
n = order_idx + 2
|
| 502 |
+
mix = shifts[0] * self.multipliers[0]
|
| 503 |
+
for k in range(1, n):
|
| 504 |
+
mix = np.bitwise_xor(mix, shifts[k].astype(np.int64) * self.multipliers[k])
|
| 505 |
+
for j, ms in enumerate(self.prime_table_sizes[order_idx]):
|
| 506 |
+
all_hashes.append((mix % ms).astype(np.int64, copy=False))
|
| 507 |
+
|
| 508 |
+
result = np.stack(all_hashes, axis=2)
|
| 509 |
+
return torch.from_numpy(result).to(device=token_ids.device)
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
def _is_prime(n):
|
| 513 |
+
if n < 2:
|
| 514 |
+
return False
|
| 515 |
+
import math
|
| 516 |
+
for i in range(2, int(math.sqrt(n)) + 1):
|
| 517 |
+
if n % i == 0:
|
| 518 |
+
return False
|
| 519 |
+
return True
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
class MemGram(nn.Module):
|
| 523 |
+
"""Engram-style associative memory with O(1) hashed lookup (CPU offloaded).
|
| 524 |
+
|
| 525 |
+
Features:
|
| 526 |
+
- O(1) hash -> index -> embedding lookup (no search, no decay for retrieval)
|
| 527 |
+
- CPU-offloaded hash computation (numpy)
|
| 528 |
+
- Single offset-stacked embedding table (not per-head tables)
|
| 529 |
+
- Gated retrieval: sigmoid(Q*K/sqrt(d)) gates the memory read
|
| 530 |
+
- Depthwise conv1d processes retrieved memory (Engram-style)
|
| 531 |
+
- No strength/decay buffers (decay is handled by GraphMoE usage frequency)
|
| 532 |
+
- MemGram lookups do NOT affect KG decaying (separate mechanisms)
|
| 533 |
+
"""
|
| 534 |
+
|
| 535 |
+
def __init__(self, struct_primes=[64901, 64919, 64921, 64927, 64937, 64951, 64969, 64997,
|
| 536 |
+
65003, 65011, 65027, 65029, 65033, 65053, 65063, 65071],
|
| 537 |
+
conv_primes=[8009, 8011, 8017, 8039],
|
| 538 |
+
embed_dim=64, hidden_dim=HIDDEN_DIM, key_dim=32,
|
| 539 |
+
max_ngram_size=3, num_hash_heads=4, layer_seed=0):
|
| 540 |
+
super().__init__()
|
| 541 |
+
self.embed_dim = embed_dim
|
| 542 |
+
self.key_dim = key_dim
|
| 543 |
+
self.hidden_dim = hidden_dim
|
| 544 |
+
self.n_struct_heads = len(struct_primes)
|
| 545 |
+
self.n_conv_heads = len(conv_primes)
|
| 546 |
+
|
| 547 |
+
self.struct_hash = _NgramHashMapping(
|
| 548 |
+
max_ngram_size=max_ngram_size, num_heads=num_hash_heads,
|
| 549 |
+
table_size_base=struct_primes[0], layer_seed=layer_seed,
|
| 550 |
+
)
|
| 551 |
+
self.conv_hash = _NgramHashMapping(
|
| 552 |
+
max_ngram_size=max_ngram_size, num_heads=num_hash_heads,
|
| 553 |
+
table_size_base=conv_primes[0], layer_seed=layer_seed + 1000,
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
+
total_heads = self.struct_hash.num_ngram_orders * num_hash_heads
|
| 557 |
+
self.total_mem_dim = total_heads * embed_dim
|
| 558 |
+
|
| 559 |
+
total_slots = self.struct_hash.total_slots + self.conv_hash.total_slots
|
| 560 |
+
self.mem_embed = nn.Embedding(total_slots, embed_dim)
|
| 561 |
+
|
| 562 |
+
self.k_proj = nn.Linear(self.total_mem_dim, key_dim, bias=False)
|
| 563 |
+
self.q_proj = nn.Linear(hidden_dim, key_dim, bias=False)
|
| 564 |
+
self.v_proj = nn.Linear(self.total_mem_dim, hidden_dim, bias=False)
|
| 565 |
+
|
| 566 |
+
with torch.no_grad():
|
| 567 |
+
self.v_proj.weight.zero_()
|
| 568 |
+
|
| 569 |
+
self.conv_norm = nn.RMSNorm(hidden_dim)
|
| 570 |
+
self.conv = nn.Conv1d(
|
| 571 |
+
hidden_dim, hidden_dim,
|
| 572 |
+
kernel_size=4, padding=9, dilation=3, groups=hidden_dim,
|
| 573 |
+
)
|
| 574 |
+
with torch.no_grad():
|
| 575 |
+
self.conv.weight.zero_()
|
| 576 |
+
if self.conv.bias is not None:
|
| 577 |
+
self.conv.bias.zero_()
|
| 578 |
+
|
| 579 |
+
def _retrieve(self, token_ids, hash_mapping):
|
| 580 |
+
hash_ids = hash_mapping.compute_hashes(token_ids)
|
| 581 |
+
B, T, H = hash_ids.shape
|
| 582 |
+
flat_ids = hash_ids.reshape(B * T, H)
|
| 583 |
+
offsets = torch.tensor(hash_mapping.offsets_arr, device=flat_ids.device, dtype=torch.long)
|
| 584 |
+
emb = self.mem_embed(flat_ids + offsets)
|
| 585 |
+
return emb.reshape(B, T, H * self.embed_dim)
|
| 586 |
+
|
| 587 |
+
def forward(self, vq_indices, hidden_state):
|
| 588 |
+
B, T, D = hidden_state.shape
|
| 589 |
+
|
| 590 |
+
struct_mem = self._retrieve(vq_indices[:, 1:], self.struct_hash)
|
| 591 |
+
conv_mem = self._retrieve(vq_indices[:, 1:], self.conv_hash)
|
| 592 |
+
mem = struct_mem + conv_mem
|
| 593 |
+
|
| 594 |
+
idx_end = mem.shape[1]
|
| 595 |
+
q_proj = self.q_proj(hidden_state[:, :idx_end])
|
| 596 |
+
k = self.k_proj(mem)
|
| 597 |
+
v = self.v_proj(mem)
|
| 598 |
+
gate = torch.sigmoid((q_proj * k).sum(dim=-1, keepdim=True) / (self.key_dim ** 0.5))
|
| 599 |
+
v_gated = gate * v
|
| 600 |
+
|
| 601 |
+
v_normed = self.conv_norm(v_gated)
|
| 602 |
+
v_t = v_normed.transpose(1, 2)
|
| 603 |
+
conv_out = self.conv(v_t)
|
| 604 |
+
conv_out = conv_out[:, :, :v_t.shape[-1]].transpose(1, 2)
|
| 605 |
+
output = hidden_state[:, :idx_end] + F.silu(conv_out) + v_gated
|
| 606 |
+
|
| 607 |
+
if idx_end < T:
|
| 608 |
+
output = F.pad(output, (0, 0, 0, T - idx_end))
|
| 609 |
+
return output
|
| 610 |
+
|
| 611 |
+
def retrieve_cb(self, vq_indices):
|
| 612 |
+
B, T = vq_indices.shape
|
| 613 |
+
struct_mem = self._retrieve(vq_indices[:, 1:], self.struct_hash)
|
| 614 |
+
conv_mem = self._retrieve(vq_indices[:, 1:], self.conv_hash)
|
| 615 |
+
mem = struct_mem + conv_mem
|
| 616 |
+
idx_end = mem.shape[1]
|
| 617 |
+
pad = torch.zeros(B, T - idx_end, mem.shape[2], device=mem.device)
|
| 618 |
+
mem = torch.cat([mem, pad], dim=1)
|
| 619 |
+
q = mem.mean(dim=-1, keepdim=True)
|
| 620 |
+
gate = torch.sigmoid(q)
|
| 621 |
+
return gate * mem
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
_BOUNDARY_TOKEN_MAP = {
|
| 625 |
+
SPECIAL_VOCAB['BOS']: 0,
|
| 626 |
+
SPECIAL_VOCAB['SYSTEM']: 1,
|
| 627 |
+
SPECIAL_VOCAB['USER']: 2,
|
| 628 |
+
SPECIAL_VOCAB['ASSISTANT']: 3,
|
| 629 |
+
}
|
| 630 |
+
|
| 631 |
+
|
| 632 |
+
class LTIInjection(nn.Module):
|
| 633 |
+
"""LTI state injection: h = A*h + B*e + trans_out.
|
| 634 |
+
|
| 635 |
+
Spectral radius < 1 guaranteed by construction via ZOH discretization.
|
| 636 |
+
Prevents divergence in recurrent/ACT loops at high dimensions.
|
| 637 |
+
"""
|
| 638 |
+
def __init__(self, dim: int):
|
| 639 |
+
super().__init__()
|
| 640 |
+
self.log_A = nn.Parameter(torch.zeros(dim))
|
| 641 |
+
self.log_dt = nn.Parameter(torch.zeros(1))
|
| 642 |
+
self.B = nn.Parameter(torch.ones(dim) * 0.1)
|
| 643 |
+
for p in (self.log_A, self.log_dt, self.B):
|
| 644 |
+
p.requires_grad_(False)
|
| 645 |
+
|
| 646 |
+
def get_A(self):
|
| 647 |
+
return torch.exp(-torch.exp((self.log_dt + self.log_A).clamp(-20, 20)))
|
| 648 |
+
|
| 649 |
+
def forward(self, h, e, trans_out):
|
| 650 |
+
return self.get_A() * h + self.B * e + trans_out
|
| 651 |
+
|
| 652 |
+
|
| 653 |
+
class ByteHead(nn.Module):
|
| 654 |
+
"""Deep 3-layer MLP byte prediction head with ACT loop.
|
| 655 |
+
|
| 656 |
+
Architecture: 8192 β 16384 β 8192 β 16384 β 288
|
| 657 |
+
ACT: up to 3 iterations, halts when argmax stable for 2 consecutive steps.
|
| 658 |
+
"""
|
| 659 |
+
def __init__(self, tscale_type=TScaleType.T32,
|
| 660 |
+
act_max_iters=BYTEHEAD_ACT_MAX_ITERS,
|
| 661 |
+
act_halt_consecutive=BYTEHEAD_ACT_HALT_CONSECUTIVE):
|
| 662 |
+
super().__init__()
|
| 663 |
+
H = HIDDEN_DIM
|
| 664 |
+
W = HIDDEN_DIM * 2
|
| 665 |
+
self.act_max_iters = act_max_iters
|
| 666 |
+
self.act_halt_consecutive = act_halt_consecutive
|
| 667 |
+
self._last_ponder = 0.0
|
| 668 |
+
|
| 669 |
+
self.norm = TernaryRMSNorm(H, tscale_type=tscale_type)
|
| 670 |
+
self.up = TernaryScaleTensor(H, W, tscale_type=tscale_type)
|
| 671 |
+
self.up_norm = TernaryRMSNorm(W, tscale_type=tscale_type)
|
| 672 |
+
self.hidden = TernaryScaleTensor(W, H, tscale_type=tscale_type)
|
| 673 |
+
self.hidden_norm = TernaryRMSNorm(H, tscale_type=tscale_type)
|
| 674 |
+
self.out = TernaryScaleTensor(H, W, tscale_type=tscale_type)
|
| 675 |
+
self.out_norm = TernaryRMSNorm(W, tscale_type=tscale_type)
|
| 676 |
+
self.head = TernaryScaleTensor(W, VOCAB, tscale_type=tscale_type)
|
| 677 |
+
|
| 678 |
+
if act_max_iters > 1:
|
| 679 |
+
self.act_residual = TernaryScaleTensor(VOCAB, H, tscale_type=tscale_type)
|
| 680 |
+
self.lti = LTIInjection(H)
|
| 681 |
+
else:
|
| 682 |
+
self.act_residual = None
|
| 683 |
+
self.lti = None
|
| 684 |
+
|
| 685 |
+
def forward(self, x):
|
| 686 |
+
if self.act_max_iters <= 1 or self.act_residual is None:
|
| 687 |
+
hn = F.silu(self.up(self.norm(x)))
|
| 688 |
+
hn = F.silu(self.hidden(self.up_norm(hn)))
|
| 689 |
+
hn = F.silu(self.out(self.hidden_norm(hn)))
|
| 690 |
+
return self.head(self.out_norm(hn))
|
| 691 |
+
|
| 692 |
+
h = x
|
| 693 |
+
x_initial = x
|
| 694 |
+
prev_argmax = None
|
| 695 |
+
stable_count = 0
|
| 696 |
+
total_iters = 0
|
| 697 |
+
|
| 698 |
+
for i in range(self.act_max_iters):
|
| 699 |
+
hn = F.silu(self.up(self.norm(h)))
|
| 700 |
+
hn = F.silu(self.hidden(self.up_norm(hn)))
|
| 701 |
+
hn = F.silu(self.out(self.hidden_norm(hn)))
|
| 702 |
+
logits = self.head(self.out_norm(hn))
|
| 703 |
+
|
| 704 |
+
curr_argmax = logits.argmax(dim=-1)
|
| 705 |
+
if prev_argmax is not None and (curr_argmax == prev_argmax).all():
|
| 706 |
+
stable_count += 1
|
| 707 |
+
else:
|
| 708 |
+
stable_count = 0
|
| 709 |
+
|
| 710 |
+
total_iters = i + 1
|
| 711 |
+
if stable_count >= self.act_halt_consecutive:
|
| 712 |
+
break
|
| 713 |
+
|
| 714 |
+
prev_argmax = curr_argmax
|
| 715 |
+
trans_out = self.act_residual(logits)
|
| 716 |
+
h = self.lti(h, x_initial, trans_out)
|
| 717 |
+
|
| 718 |
+
self._last_ponder = total_iters / max(self.act_max_iters, 1)
|
| 719 |
+
return logits
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
class OutputRouter(nn.Module):
|
| 723 |
+
"""Routes HIDDEN_DIM relational tokens to ByteHead, VideoHead, or TalkerHead.
|
| 724 |
+
|
| 725 |
+
3-layer MLP when depth=3, 2-layer when depth=2, single projection when depth=1.
|
| 726 |
+
Argmax at inference, soft weighted routing at training.
|
| 727 |
+
"""
|
| 728 |
+
def __init__(self, tscale_type=TScaleType.T32, depth=3):
|
| 729 |
+
super().__init__()
|
| 730 |
+
if depth >= 3:
|
| 731 |
+
self.hidden1 = TernaryScaleTensor(HIDDEN_DIM, HIDDEN_DIM, tscale_type=tscale_type)
|
| 732 |
+
self.hidden1_norm = TernaryRMSNorm(HIDDEN_DIM, tscale_type=tscale_type)
|
| 733 |
+
self.hidden2 = TernaryScaleTensor(HIDDEN_DIM, HIDDEN_DIM // 4, tscale_type=tscale_type)
|
| 734 |
+
self.gate = TernaryScaleTensor(HIDDEN_DIM // 4, 4, tscale_type=tscale_type)
|
| 735 |
+
elif depth == 2:
|
| 736 |
+
self.hidden1 = None
|
| 737 |
+
self.hidden1_norm = None
|
| 738 |
+
self.hidden2 = TernaryScaleTensor(HIDDEN_DIM, HIDDEN_DIM // 4, tscale_type=tscale_type)
|
| 739 |
+
self.gate = TernaryScaleTensor(HIDDEN_DIM // 4, 4, tscale_type=tscale_type)
|
| 740 |
+
else:
|
| 741 |
+
self.hidden1 = None
|
| 742 |
+
self.hidden1_norm = None
|
| 743 |
+
self.hidden2 = None
|
| 744 |
+
self.gate = TernaryScaleTensor(HIDDEN_DIM, 4, tscale_type=tscale_type)
|
| 745 |
+
# 0 = Null (continue), 1 = ByteHead, 2 = VideoHead, 3 = TalkerHead
|
| 746 |
+
|
| 747 |
+
def forward(self, x, training=False):
|
| 748 |
+
h = x
|
| 749 |
+
if self.hidden1 is not None:
|
| 750 |
+
h = F.silu(self.hidden1_norm(self.hidden1(h)))
|
| 751 |
+
if self.hidden2 is not None:
|
| 752 |
+
h = self.hidden2(h)
|
| 753 |
+
logits = self.gate(h) # [B, T, 4]
|
| 754 |
+
logits = torch.nan_to_num(logits, nan=0.0, posinf=30.0, neginf=-30.0).clamp(-30.0, 30.0)
|
| 755 |
+
if training:
|
| 756 |
+
weights = F.softmax(logits, dim=-1)
|
| 757 |
+
return weights, logits
|
| 758 |
+
return logits.argmax(dim=-1)
|
| 759 |
+
|
| 760 |
+
|
| 761 |
+
class KGVQCodebook(TernaryVQCodebook):
|
| 762 |
+
"""Compatibility wrapper for the KG/composite VQ.
|
| 763 |
+
|
| 764 |
+
The old implementation kept float32 `embed` and `embed_avg` buffers. The
|
| 765 |
+
production path now uses the same packed ternary/int8 backing table as the
|
| 766 |
+
shared VQ so default 5M-code KG construction cannot allocate hidden float
|
| 767 |
+
codebook state.
|
| 768 |
+
"""
|
| 769 |
+
def __init__(self, codebook_size=KGVQ_CODEBOOK_SIZE, codebook_dim=KGVQ_CODEBOOK_DIM,
|
| 770 |
+
decay=KGVQ_DECAY, commitment_weight=KGVQ_COMMITMENT_WEIGHT,
|
| 771 |
+
threshold_ema_dead_code=KGVQ_DEAD_CODE_THRESHOLD):
|
| 772 |
+
super().__init__(
|
| 773 |
+
codebook_size=codebook_size,
|
| 774 |
+
codebook_dim=codebook_dim,
|
| 775 |
+
commitment_weight=commitment_weight,
|
| 776 |
+
)
|
| 777 |
+
self.decay = decay
|
| 778 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
| 779 |
+
|
| 780 |
+
@property
|
| 781 |
+
def embed(self):
|
| 782 |
+
if self.codebook_size > self.exact_lookup_max:
|
| 783 |
+
raise RuntimeError(
|
| 784 |
+
"Full KG VQ materialization is disabled for large ternary codebooks; "
|
| 785 |
+
"query rows through `table(indices)` instead."
|
| 786 |
+
)
|
| 787 |
+
return super().embed
|
| 788 |
+
|
| 789 |
+
def _ema_update(self, x_flat, indices):
|
| 790 |
+
unique, counts = torch.unique(indices, return_counts=True)
|
| 791 |
+
current = self.cluster_size[unique].to(torch.int32)
|
| 792 |
+
updated = torch.clamp(
|
| 793 |
+
current + counts.to(device=current.device, dtype=torch.int32),
|
| 794 |
+
0,
|
| 795 |
+
32767,
|
| 796 |
+
).to(torch.int16)
|
| 797 |
+
self.cluster_size[unique] = updated
|
| 798 |
+
|
| 799 |
+
def _dead_code_reset(self, x_flat):
|
| 800 |
+
return None
|
| 801 |
+
|
| 802 |
+
|
| 803 |
+
class CompositeProposalHead(nn.Module):
|
| 804 |
+
"""Multi-proposal head from pooled GNN output (Phase 17).
|
| 805 |
+
|
| 806 |
+
Projects GNN pool output (graph_pool_out [B, D]) to K_MAX composite motif
|
| 807 |
+
proposals, quantizes via KGVQ, and applies ACT-style halting.
|
| 808 |
+
"""
|
| 809 |
+
def __init__(self, dim=HIDDEN_DIM, codebook_dim=KGVQ_CODEBOOK_DIM,
|
| 810 |
+
k_max=K_MAX_COMPOSITES, codebook_size=KGVQ_CODEBOOK_SIZE,
|
| 811 |
+
tscale_type=TScaleType.T32):
|
| 812 |
+
super().__init__()
|
| 813 |
+
self.dim = dim
|
| 814 |
+
self.k_max = k_max
|
| 815 |
+
self.codebook_dim = codebook_dim
|
| 816 |
+
self.proj = TernaryScaleTensor(dim, k_max * codebook_dim, tscale_type=tscale_type)
|
| 817 |
+
self.kgvq = TernaryVQCodebook(codebook_size=codebook_size, codebook_dim=codebook_dim,
|
| 818 |
+
tscale_type=tscale_type)
|
| 819 |
+
self.halt_gate = TernaryScaleTensor(dim, k_max, tscale_type=tscale_type)
|
| 820 |
+
self.diversity_weight = 0.1
|
| 821 |
+
|
| 822 |
+
def forward(self, pool_out):
|
| 823 |
+
B = pool_out.shape[0]
|
| 824 |
+
projections = self.proj(pool_out).view(B, self.k_max, self.codebook_dim)
|
| 825 |
+
quantized, composite_ids, vq_loss = self.kgvq(projections)
|
| 826 |
+
|
| 827 |
+
halt_logits = self.halt_gate(pool_out).clamp(-12.0, 12.0)
|
| 828 |
+
halt = torch.sigmoid(halt_logits) # [B, K_MAX]
|
| 829 |
+
composite_ids = composite_ids.masked_fill(halt < 0.5, -1)
|
| 830 |
+
|
| 831 |
+
normed = F.normalize(projections, dim=-1)
|
| 832 |
+
sim_matrix = normed @ normed.transpose(-1, -2)
|
| 833 |
+
triu = torch.triu(sim_matrix, diagonal=1)
|
| 834 |
+
n_pairs = self.k_max * (self.k_max - 1) / 2
|
| 835 |
+
diversity_loss = triu.sum(dim=(-1, -2)).mean() / max(n_pairs, 1)
|
| 836 |
+
diversity_loss = diversity_loss * self.diversity_weight
|
| 837 |
+
|
| 838 |
+
return composite_ids, vq_loss + diversity_loss, halt
|
| 839 |
+
|
| 840 |
+
|
| 841 |
+
class MoEGraph(nn.Module):
|
| 842 |
+
"""Fused graph traversal + centroid-based MoE routing + ACT halting.
|
| 843 |
+
|
| 844 |
+
Each ACT iteration: traverse KG β aggregate neighbor emb β centroid route β
|
| 845 |
+
run expert β halt check. All operations at MG_WORKSPACE_DIM (1024).
|
| 846 |
+
|
| 847 |
+
Replaces: TernaryGraph + GraphMoEGate + GraphACTCell + SharedProjectionMoE + MoEACTCell.
|
| 848 |
+
"""
|
| 849 |
+
def __init__(self, cb_dim=MG_WORKSPACE_DIM, trigram_dim=HIDDEN_DIM,
|
| 850 |
+
codebook_dim=CODEBOOK_DIM,
|
| 851 |
+
num_experts=MG_N_EXPERTS, core_rank=MG_CORE_RANK,
|
| 852 |
+
shared_inter=MG_SHARED_INTER, max_iters=MG_ACT_ITERS,
|
| 853 |
+
halt_threshold=0.99, tscale_type=TScaleType.T32,
|
| 854 |
+
codebook_size=CODEBOOK_SIZE,
|
| 855 |
+
active_graph_max_nodes=4096,
|
| 856 |
+
top_k=1):
|
| 857 |
+
super().__init__()
|
| 858 |
+
self.cb_dim = cb_dim
|
| 859 |
+
self.trigram_dim = trigram_dim
|
| 860 |
+
self.codebook_dim = codebook_dim
|
| 861 |
+
self.num_experts = num_experts
|
| 862 |
+
self.core_rank = core_rank
|
| 863 |
+
self.shared_inter = shared_inter
|
| 864 |
+
self.max_iters = max_iters
|
| 865 |
+
self.halt_threshold = halt_threshold
|
| 866 |
+
self.codebook_size = codebook_size
|
| 867 |
+
self.active_graph_max_nodes = active_graph_max_nodes
|
| 868 |
+
self.top_k = top_k
|
| 869 |
+
|
| 870 |
+
self.down_proj = TernaryScaleTensor(trigram_dim, cb_dim, tscale_type=tscale_type)
|
| 871 |
+
self.down_norm = TernaryRMSNorm(trigram_dim, tscale_type=tscale_type)
|
| 872 |
+
self.up_proj = TernaryScaleTensor(cb_dim, trigram_dim, tscale_type=tscale_type)
|
| 873 |
+
self.up_norm = TernaryRMSNorm(cb_dim, tscale_type=tscale_type)
|
| 874 |
+
self.attn_down_proj = TernaryScaleTensor(trigram_dim, cb_dim, tscale_type=tscale_type)
|
| 875 |
+
self.codebook_up = TernaryScaleTensor(codebook_dim, cb_dim, tscale_type=tscale_type)
|
| 876 |
+
|
| 877 |
+
self.use_active_edge_store = self.codebook_size > self.active_graph_max_nodes
|
| 878 |
+
self.active_edge_capacity = max(int(self.active_graph_max_nodes) * 16, 65_536)
|
| 879 |
+
if self.use_active_edge_store:
|
| 880 |
+
self.register_buffer("edge_index", torch.zeros(2, 0, dtype=torch.int32))
|
| 881 |
+
self.register_buffer("edge_attr", torch.zeros(0, dtype=torch.int8))
|
| 882 |
+
self.register_buffer("edge_score", torch.zeros(0, dtype=torch.int8))
|
| 883 |
+
self.register_buffer("active_edge_src", torch.full((self.active_edge_capacity,), -1, dtype=torch.int32))
|
| 884 |
+
self.register_buffer("active_edge_dst", torch.full((self.active_edge_capacity,), -1, dtype=torch.int32))
|
| 885 |
+
self.register_buffer("active_edge_attr", torch.zeros(self.active_edge_capacity, dtype=torch.int8))
|
| 886 |
+
self.register_buffer("active_edge_score", torch.zeros(self.active_edge_capacity, dtype=torch.int8))
|
| 887 |
+
self.register_buffer("active_edge_ptr", torch.zeros((), dtype=torch.long))
|
| 888 |
+
else:
|
| 889 |
+
num_edges = self.codebook_size * 10
|
| 890 |
+
src = torch.arange(self.codebook_size, dtype=torch.int32).repeat_interleave(10)
|
| 891 |
+
dst = torch.randint(0, self.codebook_size, (num_edges,), dtype=torch.int32)
|
| 892 |
+
self.register_buffer("edge_index", torch.stack([src, dst], dim=0))
|
| 893 |
+
edge_init = torch.randint(-1, 2, (num_edges,), dtype=torch.int8)
|
| 894 |
+
self.register_buffer("edge_attr", edge_init)
|
| 895 |
+
self.register_buffer("edge_score", torch.zeros(num_edges, dtype=torch.int8))
|
| 896 |
+
self.register_buffer("_steps_since_requant", torch.tensor(0, dtype=torch.long))
|
| 897 |
+
self.requant_every = KG_REQUANT_EVERY
|
| 898 |
+
self.kg_ternary_threshold = KG_TERNARY_THRESHOLD
|
| 899 |
+
self.kg_ema_alpha = KG_EMA_ALPHA
|
| 900 |
+
|
| 901 |
+
self.centroids = TernaryEmbeddingTable(num_experts, cb_dim, tscale_type=tscale_type, normalize=True)
|
| 902 |
+
|
| 903 |
+
self.shared_up_norm = TernaryRMSNorm(cb_dim, tscale_type=tscale_type)
|
| 904 |
+
self.shared_up = TernaryScaleTensor(cb_dim, shared_inter, tscale_type=tscale_type)
|
| 905 |
+
self.shared_down_norm = TernaryRMSNorm(shared_inter, tscale_type=tscale_type)
|
| 906 |
+
self.shared_down = TernaryScaleTensor(shared_inter, cb_dim, tscale_type=tscale_type)
|
| 907 |
+
|
| 908 |
+
self.W_gate = nn.ModuleList([
|
| 909 |
+
TernaryScaleTensor(cb_dim, core_rank, tscale_type=tscale_type)
|
| 910 |
+
for _ in range(num_experts)
|
| 911 |
+
])
|
| 912 |
+
self.W_gate_norms = nn.ModuleList([
|
| 913 |
+
TernaryRMSNorm(cb_dim, tscale_type=tscale_type)
|
| 914 |
+
for _ in range(num_experts)
|
| 915 |
+
])
|
| 916 |
+
self.W_transform = nn.ModuleList([
|
| 917 |
+
TernaryScaleTensor(core_rank, shared_inter, tscale_type=tscale_type)
|
| 918 |
+
for _ in range(num_experts)
|
| 919 |
+
])
|
| 920 |
+
self.W_transform_norms = nn.ModuleList([
|
| 921 |
+
TernaryRMSNorm(core_rank, tscale_type=tscale_type)
|
| 922 |
+
for _ in range(num_experts)
|
| 923 |
+
])
|
| 924 |
+
|
| 925 |
+
self.hop_lora = GNNLoRAAdapter(dim=cb_dim, rank=32, max_hops=max_iters)
|
| 926 |
+
self.halting = HaltingUnit(dim=cb_dim, tscale_type=tscale_type)
|
| 927 |
+
self.lti = LTIInjection(cb_dim)
|
| 928 |
+
|
| 929 |
+
self._codebook_embed = None
|
| 930 |
+
self._codebook_table = None
|
| 931 |
+
|
| 932 |
+
def _codebook_tensor(self, device):
|
| 933 |
+
if self._codebook_table is not None:
|
| 934 |
+
idx = torch.arange(self.codebook_size, device=device)
|
| 935 |
+
codebook = self._codebook_table(idx)
|
| 936 |
+
if codebook.shape[-1] != self.cb_dim:
|
| 937 |
+
codebook = self.codebook_up(codebook)
|
| 938 |
+
return codebook
|
| 939 |
+
if self._codebook_embed is not None:
|
| 940 |
+
codebook = self._codebook_embed.to(device=device).squeeze(0)
|
| 941 |
+
if codebook.shape[-1] != self.cb_dim:
|
| 942 |
+
codebook = self.codebook_up(codebook)
|
| 943 |
+
return codebook
|
| 944 |
+
return torch.zeros(self.codebook_size, self.cb_dim, device=device)
|
| 945 |
+
|
| 946 |
+
def _active_codebook_features(self, vq_indices):
|
| 947 |
+
if self._codebook_table is not None:
|
| 948 |
+
safe_idx = vq_indices.clamp(min=0, max=self.codebook_size - 1)
|
| 949 |
+
active_code = self._codebook_table(safe_idx)
|
| 950 |
+
elif self._codebook_embed is not None:
|
| 951 |
+
codebook = self._codebook_embed.to(device=vq_indices.device).squeeze(0)
|
| 952 |
+
safe_idx = vq_indices.clamp(min=0, max=codebook.shape[0] - 1)
|
| 953 |
+
active_code = codebook[safe_idx]
|
| 954 |
+
else:
|
| 955 |
+
return torch.zeros(*vq_indices.shape, self.cb_dim, device=vq_indices.device)
|
| 956 |
+
if active_code.shape[-1] != self.cb_dim:
|
| 957 |
+
active_code = self.codebook_up(active_code)
|
| 958 |
+
return active_code
|
| 959 |
+
|
| 960 |
+
def _neighbor_aggregate(self, node_features, threshold):
|
| 961 |
+
N, D = node_features.shape
|
| 962 |
+
aggregated = torch.zeros(self.codebook_size, D, device=node_features.device, dtype=node_features.dtype)
|
| 963 |
+
edge_ternary = StickyZoneSTE.apply(self.edge_attr, threshold)
|
| 964 |
+
src_features = node_features[self.edge_index[0]]
|
| 965 |
+
messages = edge_ternary.unsqueeze(1).to(node_features.dtype) * src_features
|
| 966 |
+
dst_idx = self.edge_index[1].unsqueeze(1).expand(-1, D)
|
| 967 |
+
aggregated.scatter_add_(0, dst_idx, messages)
|
| 968 |
+
return aggregated
|
| 969 |
+
|
| 970 |
+
def _run_expert_batch(self, x, expert_idx):
|
| 971 |
+
B, T, D = x.shape
|
| 972 |
+
N = B * T
|
| 973 |
+
x_flat = rearrange(x, 'b t d -> (b t) d')
|
| 974 |
+
exp_flat = rearrange(expert_idx, 'b t -> (b t)')
|
| 975 |
+
shared_hidden = F.silu(self.shared_up(self.shared_up_norm(x_flat)))
|
| 976 |
+
sort_idx = exp_flat.argsort()
|
| 977 |
+
sorted_experts = exp_flat[sort_idx]
|
| 978 |
+
expert_counts = torch.bincount(sorted_experts, minlength=self.num_experts)
|
| 979 |
+
expert_boundaries = torch.cumsum(expert_counts, dim=0)
|
| 980 |
+
out_flat = torch.zeros(N, D, device=x.device, dtype=x.dtype)
|
| 981 |
+
for e in range(self.num_experts):
|
| 982 |
+
start = expert_boundaries[e] - expert_counts[e]
|
| 983 |
+
end = expert_boundaries[e]
|
| 984 |
+
if start == end:
|
| 985 |
+
continue
|
| 986 |
+
tok_idx = sort_idx[start:end]
|
| 987 |
+
inp = x_flat[tok_idx]
|
| 988 |
+
sh = shared_hidden[tok_idx]
|
| 989 |
+
gate = self.W_gate[e](self.W_gate_norms[e](inp))
|
| 990 |
+
core = self.W_transform[e](self.W_transform_norms[e](gate))
|
| 991 |
+
expert_out = self.shared_down(self.shared_down_norm(core * sh))
|
| 992 |
+
out_flat[tok_idx] = expert_out
|
| 993 |
+
return rearrange(out_flat, '(b t) d -> b t d', b=B, t=T)
|
| 994 |
+
|
| 995 |
+
def _run_expert(self, x, expert_idx):
|
| 996 |
+
return self._run_expert_batch(x, expert_idx)
|
| 997 |
+
|
| 998 |
+
def _active_node_add(self, vq_output, vq_indices):
|
| 999 |
+
return vq_output + self._active_codebook_features(vq_indices)
|
| 1000 |
+
|
| 1001 |
+
def forward(self, trigram_input, vq_indices, attention_output=None,
|
| 1002 |
+
memgram_cb_output=None, threshold=0.05):
|
| 1003 |
+
B, T, D = trigram_input.shape
|
| 1004 |
+
device = trigram_input.device
|
| 1005 |
+
|
| 1006 |
+
x = self.down_proj(self.down_norm(trigram_input))
|
| 1007 |
+
|
| 1008 |
+
attn_cb = None
|
| 1009 |
+
if attention_output is not None:
|
| 1010 |
+
attn_cb = self.attn_down_proj(self.down_norm(attention_output))
|
| 1011 |
+
|
| 1012 |
+
halted = torch.zeros(B, T, device=device, dtype=torch.bool)
|
| 1013 |
+
cumulative_p = torch.zeros(B, T, device=device)
|
| 1014 |
+
acc = torch.zeros_like(x)
|
| 1015 |
+
total_ponder = torch.zeros(B, T, device=device)
|
| 1016 |
+
last_x = x
|
| 1017 |
+
initial_x = x
|
| 1018 |
+
|
| 1019 |
+
use_active_graph = self.codebook_size > self.active_graph_max_nodes
|
| 1020 |
+
node_features = None if use_active_graph else self._codebook_tensor(device)
|
| 1021 |
+
|
| 1022 |
+
for iter_t in range(self.max_iters):
|
| 1023 |
+
if use_active_graph:
|
| 1024 |
+
traversal = self._active_node_add(x, vq_indices)
|
| 1025 |
+
else:
|
| 1026 |
+
node_aggregated = self._neighbor_aggregate(node_features, threshold)
|
| 1027 |
+
traversal = x + node_aggregated[vq_indices]
|
| 1028 |
+
|
| 1029 |
+
if attn_cb is not None:
|
| 1030 |
+
traversal = traversal + attn_cb
|
| 1031 |
+
|
| 1032 |
+
if iter_t in [1, 3] and memgram_cb_output is not None:
|
| 1033 |
+
memgram_raw = memgram_cb_output.to(device)
|
| 1034 |
+
if memgram_raw.shape[-1] != self.cb_dim:
|
| 1035 |
+
memgram_raw = memgram_raw.mean(dim=-1, keepdim=True).expand(-1, -1, self.cb_dim)
|
| 1036 |
+
traversal = traversal + memgram_raw
|
| 1037 |
+
|
| 1038 |
+
traversal = traversal + self.hop_lora(traversal, iter_t)
|
| 1039 |
+
|
| 1040 |
+
trav_norm = F.normalize(traversal, dim=-1, eps=1e-8)
|
| 1041 |
+
centroid_ids = torch.arange(self.num_experts, device=device)
|
| 1042 |
+
cent_norm = F.normalize(self.centroids(centroid_ids), dim=-1, eps=1e-8)
|
| 1043 |
+
scores = trav_norm @ cent_norm.T
|
| 1044 |
+
if self.top_k <= 1:
|
| 1045 |
+
_, expert_idx = scores.max(dim=-1)
|
| 1046 |
+
expert_out = self._run_expert(traversal, expert_idx)
|
| 1047 |
+
else:
|
| 1048 |
+
scores_topk, topk_idx = scores.topk(k=self.top_k, dim=-1)
|
| 1049 |
+
weights = F.softmax(scores_topk / 0.1, dim=-1)
|
| 1050 |
+
expert_out = 0
|
| 1051 |
+
for i in range(self.top_k):
|
| 1052 |
+
wi = weights[..., i:i+1]
|
| 1053 |
+
ei = topk_idx[..., i]
|
| 1054 |
+
expert_out = expert_out + wi * self._run_expert(traversal, ei)
|
| 1055 |
+
last_x = expert_out
|
| 1056 |
+
|
| 1057 |
+
p = self.halting(expert_out).squeeze(-1)
|
| 1058 |
+
still_running = ~halted
|
| 1059 |
+
remainder = (1.0 - cumulative_p).clamp(min=0)
|
| 1060 |
+
weight = torch.where(
|
| 1061 |
+
cumulative_p + p >= self.halt_threshold,
|
| 1062 |
+
remainder, p,
|
| 1063 |
+
)
|
| 1064 |
+
weight = weight * still_running.float()
|
| 1065 |
+
acc = acc + weight.unsqueeze(-1) * expert_out
|
| 1066 |
+
cumulative_p = cumulative_p + p * still_running.float()
|
| 1067 |
+
halted = halted | (cumulative_p >= self.halt_threshold)
|
| 1068 |
+
total_ponder = total_ponder + (1.0 - cumulative_p).clamp(min=0)
|
| 1069 |
+
|
| 1070 |
+
x = self.lti(x, initial_x, expert_out)
|
| 1071 |
+
|
| 1072 |
+
if halted.all():
|
| 1073 |
+
break
|
| 1074 |
+
|
| 1075 |
+
never_halted = (~halted).float().unsqueeze(-1)
|
| 1076 |
+
acc = acc + never_halted * last_x
|
| 1077 |
+
|
| 1078 |
+
output = self.up_proj(self.up_norm(acc))
|
| 1079 |
+
ponder_loss = total_ponder.mean() / self.max_iters
|
| 1080 |
+
|
| 1081 |
+
return output, ponder_loss
|
| 1082 |
+
|
| 1083 |
+
@torch.no_grad()
|
| 1084 |
+
def update_kg_edges(self, all_vq_indices):
|
| 1085 |
+
if self.use_active_edge_store:
|
| 1086 |
+
self._update_active_edges(all_vq_indices)
|
| 1087 |
+
return
|
| 1088 |
+
|
| 1089 |
+
unique_ids = torch.unique(all_vq_indices.to(device=self.edge_index.device, dtype=torch.int32))
|
| 1090 |
+
src_in_batch = torch.isin(self.edge_index[0], unique_ids)
|
| 1091 |
+
|
| 1092 |
+
if src_in_batch.any():
|
| 1093 |
+
dst_seen = torch.isin(self.edge_index[1][src_in_batch], unique_ids)
|
| 1094 |
+
delta = torch.where(
|
| 1095 |
+
dst_seen,
|
| 1096 |
+
torch.ones_like(self.edge_score[src_in_batch], dtype=torch.int16),
|
| 1097 |
+
torch.full_like(self.edge_score[src_in_batch], -1, dtype=torch.int16),
|
| 1098 |
+
)
|
| 1099 |
+
score = torch.clamp(self.edge_score[src_in_batch].to(torch.int16) + delta, -128, 127)
|
| 1100 |
+
self.edge_score[src_in_batch] = score.to(torch.int8)
|
| 1101 |
+
|
| 1102 |
+
self._requantize_dense_edges()
|
| 1103 |
+
|
| 1104 |
+
@torch.no_grad()
|
| 1105 |
+
def _update_active_edges(self, all_vq_indices):
|
| 1106 |
+
ids = all_vq_indices.to(device=self.active_edge_src.device, dtype=torch.int32)
|
| 1107 |
+
if ids.numel() < 2:
|
| 1108 |
+
self._steps_since_requant.add_(1)
|
| 1109 |
+
return
|
| 1110 |
+
|
| 1111 |
+
seq = ids.reshape(-1, ids.shape[-1]) if ids.dim() > 1 else ids.reshape(1, -1)
|
| 1112 |
+
src = seq[:, :-1].reshape(-1)
|
| 1113 |
+
dst = seq[:, 1:].reshape(-1)
|
| 1114 |
+
valid = (src >= 0) & (dst >= 0) & (src < self.codebook_size) & (dst < self.codebook_size) & (src != dst)
|
| 1115 |
+
src = src[valid]
|
| 1116 |
+
dst = dst[valid]
|
| 1117 |
+
if src.numel() == 0:
|
| 1118 |
+
self._steps_since_requant.add_(1)
|
| 1119 |
+
return
|
| 1120 |
+
|
| 1121 |
+
n_edges = min(src.numel(), self.active_edge_capacity)
|
| 1122 |
+
src = src[-n_edges:]
|
| 1123 |
+
dst = dst[-n_edges:]
|
| 1124 |
+
ptr = int(self.active_edge_ptr.item())
|
| 1125 |
+
slots = (torch.arange(n_edges, device=src.device, dtype=torch.long) + ptr) % self.active_edge_capacity
|
| 1126 |
+
|
| 1127 |
+
self.active_edge_src[slots] = src
|
| 1128 |
+
self.active_edge_dst[slots] = dst
|
| 1129 |
+
score = torch.clamp(self.active_edge_score[slots].to(torch.int16) + 1, -128, 127)
|
| 1130 |
+
self.active_edge_score[slots] = score.to(torch.int8)
|
| 1131 |
+
self.active_edge_attr[slots] = 1
|
| 1132 |
+
self.active_edge_ptr.fill_((ptr + n_edges) % self.active_edge_capacity)
|
| 1133 |
+
self._requantize_active_edges()
|
| 1134 |
+
|
| 1135 |
+
@torch.no_grad()
|
| 1136 |
+
def _requantize_dense_edges(self):
|
| 1137 |
+
if self._steps_since_requant.item() < self.requant_every:
|
| 1138 |
+
self._steps_since_requant.add_(1)
|
| 1139 |
+
return
|
| 1140 |
+
self.edge_attr = self._score_to_attr(self.edge_score)
|
| 1141 |
+
score = self.edge_score.to(torch.int16)
|
| 1142 |
+
score = torch.where(score > 0, score - 1, torch.where(score < 0, score + 1, score))
|
| 1143 |
+
self.edge_score = score.to(torch.int8)
|
| 1144 |
+
self._steps_since_requant.zero_()
|
| 1145 |
+
|
| 1146 |
+
@torch.no_grad()
|
| 1147 |
+
def _requantize_active_edges(self):
|
| 1148 |
+
if self._steps_since_requant.item() < self.requant_every:
|
| 1149 |
+
self._steps_since_requant.add_(1)
|
| 1150 |
+
return
|
| 1151 |
+
active = self.active_edge_src >= 0
|
| 1152 |
+
if active.any():
|
| 1153 |
+
self.active_edge_attr[active] = self._score_to_attr(self.active_edge_score[active])
|
| 1154 |
+
score = self.active_edge_score[active].to(torch.int16)
|
| 1155 |
+
score = torch.where(score > 0, score - 1, torch.where(score < 0, score + 1, score))
|
| 1156 |
+
self.active_edge_score[active] = score.to(torch.int8)
|
| 1157 |
+
self._steps_since_requant.zero_()
|
| 1158 |
+
|
| 1159 |
+
def _score_to_attr(self, score):
|
| 1160 |
+
threshold = max(1, int(round(float(self.kg_ternary_threshold) * 8)))
|
| 1161 |
+
score_i = score.to(torch.int16)
|
| 1162 |
+
return torch.where(
|
| 1163 |
+
score_i >= threshold,
|
| 1164 |
+
torch.ones_like(score, dtype=torch.int8),
|
| 1165 |
+
torch.where(
|
| 1166 |
+
score_i <= -threshold,
|
| 1167 |
+
torch.full_like(score, -1, dtype=torch.int8),
|
| 1168 |
+
torch.zeros_like(score, dtype=torch.int8),
|
| 1169 |
+
),
|
| 1170 |
+
)
|
| 1171 |
+
|
| 1172 |
+
@torch.no_grad()
|
| 1173 |
+
def monitor_graph_health(self, threshold=0.05):
|
| 1174 |
+
if self.use_active_edge_store:
|
| 1175 |
+
active = self.active_edge_src >= 0
|
| 1176 |
+
if not active.any():
|
| 1177 |
+
return {
|
| 1178 |
+
"sparsity": 1.0, "isolated_nodes": self.codebook_size,
|
| 1179 |
+
"avg_polarity": 0.0, "dead_edges": 0,
|
| 1180 |
+
"score_mean": 0.0, "score_max": 0.0,
|
| 1181 |
+
"active_edges": 0,
|
| 1182 |
+
}
|
| 1183 |
+
edge_attr = self.active_edge_attr[active]
|
| 1184 |
+
edge_score = self.active_edge_score[active]
|
| 1185 |
+
nodes_with_edges = torch.unique(torch.cat([self.active_edge_src[active], self.active_edge_dst[active]]))
|
| 1186 |
+
else:
|
| 1187 |
+
edge_attr = self.edge_attr
|
| 1188 |
+
edge_score = self.edge_score
|
| 1189 |
+
nodes_with_edges = torch.unique(torch.cat([self.edge_index[0], self.edge_index[1]]))
|
| 1190 |
+
|
| 1191 |
+
ternary_edge = edge_attr.sign()
|
| 1192 |
+
sparsity = (ternary_edge == 0).float().mean().item() if ternary_edge.numel() else 1.0
|
| 1193 |
+
n_isolated = max(int(self.codebook_size) - int(nodes_with_edges.numel()), 0)
|
| 1194 |
+
n_pos = (ternary_edge > 0).sum().item()
|
| 1195 |
+
n_neg = (ternary_edge < 0).sum().item()
|
| 1196 |
+
n_nonzero = n_pos + n_neg
|
| 1197 |
+
avg_polarity = (n_pos - n_neg) / max(n_nonzero, 1)
|
| 1198 |
+
dead_edges = ((ternary_edge == 0) & (edge_score != 0)).sum().item()
|
| 1199 |
+
score_mean = edge_score.float().mean().item() if edge_score.numel() else 0.0
|
| 1200 |
+
score_max = edge_score.float().abs().max().item() if edge_score.numel() else 0.0
|
| 1201 |
+
return {
|
| 1202 |
+
"sparsity": sparsity, "isolated_nodes": n_isolated,
|
| 1203 |
+
"avg_polarity": avg_polarity, "dead_edges": dead_edges,
|
| 1204 |
+
"score_mean": score_mean, "score_max": score_max,
|
| 1205 |
+
"active_edges": int(ternary_edge.numel()),
|
| 1206 |
+
}
|
| 1207 |
+
|
| 1208 |
+
def set_adjacency(self, edge_index, edge_attr_init=None):
|
| 1209 |
+
self.use_active_edge_store = False
|
| 1210 |
+
device = self.edge_attr.device
|
| 1211 |
+
self.edge_index = edge_index.to(device=device, dtype=torch.int32)
|
| 1212 |
+
if edge_attr_init is not None:
|
| 1213 |
+
edge_attr = edge_attr_init.sign() * (edge_attr_init.abs() > 0).to(edge_attr_init.dtype)
|
| 1214 |
+
self.edge_attr = edge_attr.to(device=device, dtype=torch.int8)
|
| 1215 |
+
else:
|
| 1216 |
+
self.edge_attr = torch.randint(-1, 2, (edge_index.size(1),),
|
| 1217 |
+
device=device, dtype=torch.int8)
|
| 1218 |
+
self.edge_score = self.edge_attr.clone()
|
arbitor/config.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
VOCAB=288
|
| 2 |
+
AUDIO_VOCAB=288
|
| 3 |
+
AUDIO_SR=16000
|
| 4 |
+
AUDIO_FRAME_RATE=50
|
| 5 |
+
THRESHOLD=0.05
|
| 6 |
+
|
| 7 |
+
# -- 3B Target Dimensions --
|
| 8 |
+
EMBEDDING_DIM=1536
|
| 9 |
+
CODEBOOK_DIM=1024
|
| 10 |
+
CODEBOOK_SIZE=524288 # Base unit
|
| 11 |
+
# Shared multimodal VQ (256K entries Γ 1024-dim)
|
| 12 |
+
SHARED_VQ_SIZE = 262144
|
| 13 |
+
HIDDEN_DIM=8192 # Main hidden dimension
|
| 14 |
+
FFN_HIDDEN=16384 # 2Γ HIDDEN_DIM
|
| 15 |
+
CTX=256
|
| 16 |
+
|
| 17 |
+
# MoEGraph (256 experts, centroid routing, unified ACT)
|
| 18 |
+
MG_N_EXPERTS = 256
|
| 19 |
+
MG_CORE_RANK = 384
|
| 20 |
+
MG_SHARED_INTER = 1536
|
| 21 |
+
MG_ACT_ITERS = 4
|
| 22 |
+
MG_WORKSPACE_DIM = 768
|
| 23 |
+
MG_TOP_K = 2
|
| 24 |
+
|
| 25 |
+
# VQ
|
| 26 |
+
# MemGram (32 heads Γ ~65K slots β 2M total associative slots)
|
| 27 |
+
MEMGRAM_STRUCT_PRIMES = [64901, 64919, 64921, 64927, 64937, 64951, 64969, 64997,
|
| 28 |
+
65003, 65011, 65027, 65029, 65033, 65053, 65063, 65071,
|
| 29 |
+
65101, 65119, 65123, 65129, 65141, 65147, 65167, 65171,
|
| 30 |
+
65173, 65179, 65183, 65203, 65213, 65239, 65257, 65269]
|
| 31 |
+
MEMGRAM_CONV_PRIMES = [8009, 8011, 8017, 8039, 8081, 8087, 8089, 8093]
|
| 32 |
+
MEMGRAM_EMBED_DIM = 64
|
| 33 |
+
MEMGRAM_KEY_DIM = 32
|
| 34 |
+
|
| 35 |
+
# KV Ledger
|
| 36 |
+
KV_LEDGER_SIZE = 262144
|
| 37 |
+
SLIDING_WINDOW_SIZE = 32768
|
| 38 |
+
KQ_CACHE_SIZE = 8192
|
| 39 |
+
|
| 40 |
+
# MLA Attention dimensions
|
| 41 |
+
MLA_N_HEADS = 32
|
| 42 |
+
MLA_QK_NOPE_HEAD_DIM = 96
|
| 43 |
+
MLA_QK_ROPE_HEAD_DIM = 32
|
| 44 |
+
MLA_V_HEAD_DIM = 96
|
| 45 |
+
MLA_SLIDE_DIM = 64
|
| 46 |
+
MLA_FULL_DIM = 32
|
| 47 |
+
MLA_N_LAYERS = 24
|
| 48 |
+
|
| 49 |
+
# RoPE
|
| 50 |
+
MLA_ROPE_THETA = 10000.0
|
| 51 |
+
|
| 52 |
+
# Attention
|
| 53 |
+
ATTENTION_STRIDE = 8
|
| 54 |
+
KV_CONTEXT_LENGTH = 33554432
|
| 55 |
+
|
| 56 |
+
# CSA / HCA compression (DeepSeek V4 hybrid attention)
|
| 57 |
+
MLA_CSA_DIM = 16
|
| 58 |
+
MLA_HCA_DIM = 16
|
| 59 |
+
MLA_HCA_STRIDE = 32
|
| 60 |
+
|
| 61 |
+
# KG EMA β Phase 17
|
| 62 |
+
KG_EMA_ALPHA=0.99
|
| 63 |
+
KG_REQUANT_EVERY=50
|
| 64 |
+
KG_TERNARY_THRESHOLD=0.3
|
| 65 |
+
|
| 66 |
+
# Composite Motif VQ β Phase 17 (64K entries Γ 1024-dim)
|
| 67 |
+
KGVQ_CODEBOOK_SIZE=65536
|
| 68 |
+
KGVQ_CODEBOOK_DIM=1024
|
| 69 |
+
KGVQ_DECAY=0.99
|
| 70 |
+
KGVQ_COMMITMENT_WEIGHT=1.0
|
| 71 |
+
KGVQ_DEAD_CODE_THRESHOLD=2
|
| 72 |
+
K_MAX_COMPOSITES=20
|
| 73 |
+
|
| 74 |
+
# VideoHead (Open-Sora VAE: 4 latent channels, 8Γ spatial + 4Γ temporal compression)
|
| 75 |
+
VIDEO_LATENT_CHANNELS = 4
|
| 76 |
+
VIDEO_MAX_STEPS = 8
|
| 77 |
+
VIDEO_HEIGHT = 64
|
| 78 |
+
VIDEO_WIDTH = 64
|
| 79 |
+
|
| 80 |
+
# -- Open-Sora 3D VAE (Phase 19) --
|
| 81 |
+
OPEN_SORA_VAE_PATH = "arbitor/encoders/models/opensora-vae"
|
| 82 |
+
OPEN_SORA_VAE_REPO = "hpcai-tech/OpenSora-VAE-v1.2"
|
| 83 |
+
OPEN_SORA_LATENT_CHANNELS = 4
|
| 84 |
+
OPEN_SORA_SCALE_FACTOR_SPATIAL = 8
|
| 85 |
+
OPEN_SORA_SCALE_FACTOR_TEMPORAL = 4
|
| 86 |
+
|
| 87 |
+
# -- ACT Loop Parameters (Phase 19) --
|
| 88 |
+
BYTEHEAD_ACT_MAX_ITERS = 3
|
| 89 |
+
BYTEHEAD_ACT_HALT_CONSECUTIVE = 2
|
| 90 |
+
BYTEHEAD_ACT_PONDER_LAMBDA = 0.01
|
| 91 |
+
|
| 92 |
+
VIDEOHEAD_ACT_MIN_FPS = 1
|
| 93 |
+
VIDEOHEAD_ACT_MAX_FPS = 60
|
| 94 |
+
VIDEOHEAD_ACT_FRAME_CHUNK = 8
|
| 95 |
+
|
| 96 |
+
TALKERHEAD_ACT_CHUNK_FRAMES = 500
|
| 97 |
+
|
| 98 |
+
# -- Timestamp Encoding (Phase 19) --
|
| 99 |
+
TIMESTAMP_MAX_PERIOD = 10000.0
|
| 100 |
+
|
| 101 |
+
# -- Temporal Frame Buffer (Phase 19) --
|
| 102 |
+
FRAME_BUFFER_LOCAL_SIZE = 3
|
| 103 |
+
FRAME_BUFFER_CACHE_STRIDE = 4
|
| 104 |
+
|
| 105 |
+
SPECIAL_VOCAB = {
|
| 106 |
+
# Control
|
| 107 |
+
'PAD': 256, 'BOS': 257, 'EOS': 258, 'STOP': 259,
|
| 108 |
+
# Roles
|
| 109 |
+
'SYSTEM': 260, 'USER': 261, 'ASSISTANT': 262,
|
| 110 |
+
# Reasoning
|
| 111 |
+
'SCRATCHPAD': 263, 'PLAN': 264, 'REFLECTION': 265, 'SUMMARY': 266,
|
| 112 |
+
# Tool use
|
| 113 |
+
'ACTION': 267, 'TOOL': 268, 'TOOL_RESULT': 269,
|
| 114 |
+
# Code
|
| 115 |
+
'CODE': 270, 'CODE_BLOCK': 271, 'EXECUTION': 272,
|
| 116 |
+
# RAG
|
| 117 |
+
'SEARCH': 273, 'CONTEXT': 274, 'CITATION': 275,
|
| 118 |
+
# Quality / format
|
| 119 |
+
'ERROR': 276, 'FORMAT': 277,
|
| 120 |
+
# Multimodal
|
| 121 |
+
'IMAGE': 278, 'TEXT': 279, 'AUDIO': 280,
|
| 122 |
+
'VIDEO': 281, 'SPEAK': 282, 'IMG_GEN': 283,
|
| 123 |
+
# Future
|
| 124 |
+
'RES1': 284, 'RES2': 285, 'RES3': 286, 'RESERVED': 287,
|
| 125 |
+
}
|
arbitor/converters/convert_to_ternary2.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def pack_ternary(w):
|
| 9 |
+
q = torch.empty_like(w, dtype=torch.uint8)
|
| 10 |
+
q[w < 0] = 0
|
| 11 |
+
q[w == 0] = 1
|
| 12 |
+
q[w > 0] = 2
|
| 13 |
+
|
| 14 |
+
flat = q.flatten()
|
| 15 |
+
pad = (-len(flat)) % 4
|
| 16 |
+
if pad:
|
| 17 |
+
flat = torch.cat([flat, torch.zeros(pad, dtype=torch.uint8, device=flat.device)])
|
| 18 |
+
|
| 19 |
+
flat = flat.view(-1, 4)
|
| 20 |
+
|
| 21 |
+
packed = (
|
| 22 |
+
flat[:, 0]
|
| 23 |
+
| (flat[:, 1] << 2)
|
| 24 |
+
| (flat[:, 2] << 4)
|
| 25 |
+
| (flat[:, 3] << 6)
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
return packed.cpu(), w.shape
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def save_model(model, path="trigram-morph.pt"):
|
| 32 |
+
ternary_weights = {}
|
| 33 |
+
for name, param in model.named_parameters():
|
| 34 |
+
if "weight" in name and param.ndim >= 2 and "embed" not in name:
|
| 35 |
+
T = StickyZoneSTE.apply(param.data, THRESHOLD)
|
| 36 |
+
packed, shape = pack_ternary(T)
|
| 37 |
+
ternary_weights[name] = {"packed": packed, "shape": shape}
|
| 38 |
+
|
| 39 |
+
torch.save({
|
| 40 |
+
"model_state_dict": model.state_dict(),
|
| 41 |
+
"config": {
|
| 42 |
+
"vocab": VOCAB,
|
| 43 |
+
"embedding_dim": EMBEDDING_DIM,
|
| 44 |
+
"trigram_dim": HIDDEN_DIM,
|
| 45 |
+
"ffn_hidden": FFN_HIDDEN,
|
| 46 |
+
"ctx": CTX,
|
| 47 |
+
"threshold": THRESHOLD,
|
| 48 |
+
},
|
| 49 |
+
"ternary_packed": ternary_weights,
|
| 50 |
+
"format": "factorized_scaled_ternary",
|
| 51 |
+
"bpw": 1.58,
|
| 52 |
+
}, path)
|
| 53 |
+
total = sum(p.numel() for p in model.parameters())
|
| 54 |
+
print(f"Saved {total:,} params to {path}")
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def load_model(path="trigram-morph.pt", device="cpu"):
|
| 58 |
+
checkpoint = torch.load(path, map_location=device, weights_only=False)
|
| 59 |
+
model = ARBModel()
|
| 60 |
+
model.load_state_dict(checkpoint["model_state_dict"], strict=False)
|
| 61 |
+
model.to(device)
|
| 62 |
+
model.eval()
|
| 63 |
+
return model
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
if __name__ == "__main__":
|
| 67 |
+
from ..trigram import ARBModel
|
| 68 |
+
model = ARBModel()
|
| 69 |
+
total = sum(p.numel() for p in model.parameters())
|
| 70 |
+
ternary = sum(
|
| 71 |
+
p.numel() for n, p in model.named_parameters()
|
| 72 |
+
if "weight" in n and p.ndim >= 2 and "embed" not in n
|
| 73 |
+
)
|
| 74 |
+
fp32 = sum(
|
| 75 |
+
p.numel() for n, p in model.named_parameters()
|
| 76 |
+
if not ("weight" in n and p.ndim >= 2 and "embed" not in n)
|
| 77 |
+
)
|
| 78 |
+
print(f"Total params: {total:,}")
|
| 79 |
+
print(f"Ternary params (1.58 BPW): {ternary:,}")
|
| 80 |
+
print(f"FP32 params: {fp32:,}")
|
| 81 |
+
save_model(model)
|
arbitor/converters/convert_to_ternary54.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def pack_ternary_34(w):
|
| 9 |
+
q = torch.empty_like(w, dtype=torch.uint8)
|
| 10 |
+
|
| 11 |
+
q[w < 0] = 0
|
| 12 |
+
q[w == 0] = 1
|
| 13 |
+
q[w > 0] = 2
|
| 14 |
+
|
| 15 |
+
flat = q.flatten()
|
| 16 |
+
|
| 17 |
+
# pad to multiple of 34
|
| 18 |
+
pad = (-len(flat)) % 34
|
| 19 |
+
|
| 20 |
+
if pad:
|
| 21 |
+
flat = torch.cat([
|
| 22 |
+
flat,
|
| 23 |
+
torch.ones(pad, dtype=torch.uint8, device=flat.device)
|
| 24 |
+
])
|
| 25 |
+
|
| 26 |
+
flat = flat.view(-1, 34)
|
| 27 |
+
|
| 28 |
+
packed = torch.zeros(
|
| 29 |
+
flat.shape[0],
|
| 30 |
+
dtype=torch.uint64,
|
| 31 |
+
device=flat.device
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
multiplier = 1
|
| 35 |
+
|
| 36 |
+
for i in range(34):
|
| 37 |
+
packed += flat[:, i].to(torch.uint64) * multiplier
|
| 38 |
+
multiplier *= 3
|
| 39 |
+
|
| 40 |
+
return packed.cpu(), w.shape, pad
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def unpack_ternary_34(packed, shape, pad=0):
|
| 44 |
+
packed = packed.to(torch.uint64)
|
| 45 |
+
|
| 46 |
+
out = []
|
| 47 |
+
|
| 48 |
+
for _ in range(34):
|
| 49 |
+
trit = packed % 3
|
| 50 |
+
packed //= 3
|
| 51 |
+
out.append(trit)
|
| 52 |
+
|
| 53 |
+
out = torch.stack(out, dim=1).flatten()
|
| 54 |
+
|
| 55 |
+
if pad:
|
| 56 |
+
out = out[:-pad]
|
| 57 |
+
|
| 58 |
+
out = out.view(shape)
|
| 59 |
+
|
| 60 |
+
# restore ternary values
|
| 61 |
+
out = out.to(torch.int8)
|
| 62 |
+
|
| 63 |
+
out[out == 0] = -1
|
| 64 |
+
out[out == 1] = 0
|
| 65 |
+
out[out == 2] = 1
|
| 66 |
+
|
| 67 |
+
return out
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def save_model(model, path="trigram-morph.pt"):
|
| 71 |
+
ternary_weights = {}
|
| 72 |
+
for name, param in model.named_parameters():
|
| 73 |
+
if "weight" in name and param.ndim >= 2 and "embed" not in name:
|
| 74 |
+
T = StickyZoneSTE.apply(param.data, THRESHOLD)
|
| 75 |
+
packed, shape = pack_ternary(T)
|
| 76 |
+
ternary_weights[name] = {"packed": packed, "shape": shape}
|
| 77 |
+
|
| 78 |
+
torch.save({
|
| 79 |
+
"model_state_dict": model.state_dict(),
|
| 80 |
+
"config": {
|
| 81 |
+
"vocab": VOCAB,
|
| 82 |
+
"embedding_dim": EMBEDDING_DIM,
|
| 83 |
+
"trigram_dim": HIDDEN_DIM,
|
| 84 |
+
"ffn_hidden": FFN_HIDDEN,
|
| 85 |
+
"ctx": CTX,
|
| 86 |
+
"threshold": THRESHOLD,
|
| 87 |
+
},
|
| 88 |
+
"ternary_packed": ternary_weights,
|
| 89 |
+
"format": "factorized_scaled_ternary",
|
| 90 |
+
"bpw": 1.58,
|
| 91 |
+
}, path)
|
| 92 |
+
total = sum(p.numel() for p in model.parameters())
|
| 93 |
+
print(f"Saved {total:,} params to {path}")
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def load_model(path="trigram-morph.pt", device="cpu"):
|
| 97 |
+
checkpoint = torch.load(path, map_location=device, weights_only=False)
|
| 98 |
+
model = ARBModel()
|
| 99 |
+
model.load_state_dict(checkpoint["model_state_dict"], strict=False)
|
| 100 |
+
model.to(device)
|
| 101 |
+
model.eval()
|
| 102 |
+
return model
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
if __name__ == "__main__":
|
| 106 |
+
from ..trigram import ARBModel
|
| 107 |
+
model = ARBModel()
|
| 108 |
+
total = sum(p.numel() for p in model.parameters())
|
| 109 |
+
ternary = sum(
|
| 110 |
+
p.numel() for n, p in model.named_parameters()
|
| 111 |
+
if "weight" in n and p.ndim >= 2 and "embed" not in n
|
| 112 |
+
)
|
| 113 |
+
fp32 = sum(
|
| 114 |
+
p.numel() for n, p in model.named_parameters()
|
| 115 |
+
if not ("weight" in n and p.ndim >= 2 and "embed" not in n)
|
| 116 |
+
)
|
| 117 |
+
print(f"Total params: {total:,}")
|
| 118 |
+
print(f"Ternary params (1.58 BPW): {ternary:,}")
|
| 119 |
+
print(f"FP32 params: {fp32:,}")
|
| 120 |
+
save_model(model)
|
arbitor/converters/convert_to_ternary64.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def pack_ternary(w):
|
| 9 |
+
q = torch.empty_like(w, dtype=torch.uint8)
|
| 10 |
+
q[w < 0] = 0
|
| 11 |
+
q[w == 0] = 1
|
| 12 |
+
q[w > 0] = 2
|
| 13 |
+
|
| 14 |
+
flat = q.flatten()
|
| 15 |
+
pad = (-len(flat)) % 40 # 40 trit -> 64 bit packing - Higher conversation than 1 trit -> 2 bit + uint64 performance
|
| 16 |
+
if pad:
|
| 17 |
+
flat = torch.cat([
|
| 18 |
+
flat,
|
| 19 |
+
torch.zeros(pad, dtype=torch.uint8, device=flat.device)
|
| 20 |
+
])
|
| 21 |
+
|
| 22 |
+
flat = flat.view(-1, 40)
|
| 23 |
+
|
| 24 |
+
packed = torch.zeros(flat.shape[0], dtype=torch.uint64)
|
| 25 |
+
|
| 26 |
+
multiplier = 1
|
| 27 |
+
|
| 28 |
+
for i in range(40):
|
| 29 |
+
packed += flat[:, i].to(torch.uint64) * multiplier
|
| 30 |
+
multiplier *= 3
|
| 31 |
+
|
| 32 |
+
return packed.cpu(), w.shape, pad
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def unpack_ternary_40(packed, shape, pad=0):
|
| 36 |
+
packed = packed.to(torch.uint64)
|
| 37 |
+
|
| 38 |
+
out = []
|
| 39 |
+
|
| 40 |
+
for _ in range(40):
|
| 41 |
+
trit = packed % 3
|
| 42 |
+
packed //= 3
|
| 43 |
+
out.append(trit)
|
| 44 |
+
|
| 45 |
+
out = torch.stack(out, dim=1).flatten()
|
| 46 |
+
|
| 47 |
+
if pad:
|
| 48 |
+
out = out[:-pad]
|
| 49 |
+
|
| 50 |
+
out = out.view(shape)
|
| 51 |
+
|
| 52 |
+
out = out.to(torch.int8)
|
| 53 |
+
|
| 54 |
+
out[out == 0] = -1
|
| 55 |
+
out[out == 1] = 0
|
| 56 |
+
out[out == 2] = 1
|
| 57 |
+
|
| 58 |
+
return out
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def save_model(model, path="trigram-morph.pt"):
|
| 62 |
+
ternary_weights = {}
|
| 63 |
+
for name, param in model.named_parameters():
|
| 64 |
+
if "weight" in name and param.ndim >= 2 and "embed" not in name:
|
| 65 |
+
T = StickyZoneSTE.apply(param.data, THRESHOLD)
|
| 66 |
+
packed, shape = pack_ternary(T)
|
| 67 |
+
ternary_weights[name] = {"packed": packed, "shape": shape}
|
| 68 |
+
|
| 69 |
+
torch.save({
|
| 70 |
+
"model_state_dict": model.state_dict(),
|
| 71 |
+
"config": {
|
| 72 |
+
"vocab": VOCAB,
|
| 73 |
+
"embedding_dim": EMBEDDING_DIM,
|
| 74 |
+
"trigram_dim": HIDDEN_DIM,
|
| 75 |
+
"ffn_hidden": FFN_HIDDEN,
|
| 76 |
+
"ctx": CTX,
|
| 77 |
+
"threshold": THRESHOLD,
|
| 78 |
+
},
|
| 79 |
+
"ternary_packed": ternary_weights,
|
| 80 |
+
"format": "factorized_scaled_ternary",
|
| 81 |
+
"bpw": 1.58,
|
| 82 |
+
}, path)
|
| 83 |
+
total = sum(p.numel() for p in model.parameters())
|
| 84 |
+
print(f"Saved {total:,} params to {path}")
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def load_model(path="trigram-morph.pt", device="cpu"):
|
| 88 |
+
checkpoint = torch.load(path, map_location=device, weights_only=False)
|
| 89 |
+
model = ARBModel()
|
| 90 |
+
model.load_state_dict(checkpoint["model_state_dict"], strict=False)
|
| 91 |
+
model.to(device)
|
| 92 |
+
model.eval()
|
| 93 |
+
return model
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
if __name__ == "__main__":
|
| 97 |
+
from ..trigram import ARBModel
|
| 98 |
+
model = ARBModel()
|
| 99 |
+
total = sum(p.numel() for p in model.parameters())
|
| 100 |
+
ternary = sum(
|
| 101 |
+
p.numel() for n, p in model.named_parameters()
|
| 102 |
+
if "weight" in n and p.ndim >= 2 and "embed" not in n
|
| 103 |
+
)
|
| 104 |
+
fp32 = sum(
|
| 105 |
+
p.numel() for n, p in model.named_parameters()
|
| 106 |
+
if not ("weight" in n and p.ndim >= 2 and "embed" not in n)
|
| 107 |
+
)
|
| 108 |
+
print(f"Total params: {total:,}")
|
| 109 |
+
print(f"Ternary params (1.58 BPW): {ternary:,}")
|
| 110 |
+
print(f"FP32 params: {fp32:,}")
|
| 111 |
+
save_model(model)
|
arbitor/converters/convert_to_ternary8.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
# Lightweight imports used by pack_ternary/unpack_ternary (called by core system)
|
| 6 |
+
# No circular deps here β these are just type conversions
|
| 7 |
+
|
| 8 |
+
def pack_ternary(w):
|
| 9 |
+
q = torch.empty_like(w, dtype=torch.uint8)
|
| 10 |
+
q[w < 0] = 0
|
| 11 |
+
q[w == 0] = 1
|
| 12 |
+
q[w > 0] = 2
|
| 13 |
+
|
| 14 |
+
flat = q.flatten()
|
| 15 |
+
pad = (-len(flat)) % 5 # 5 trit -> 8 bit packing - Higher conversation than 1 trit -> 2 bit
|
| 16 |
+
if pad:
|
| 17 |
+
flat = torch.cat([
|
| 18 |
+
flat,
|
| 19 |
+
torch.zeros(pad, dtype=torch.uint8, device=flat.device)
|
| 20 |
+
])
|
| 21 |
+
|
| 22 |
+
flat = flat.view(-1, 5)
|
| 23 |
+
|
| 24 |
+
packed = (
|
| 25 |
+
flat[:, 0]
|
| 26 |
+
+ flat[:, 1] * 3
|
| 27 |
+
+ flat[:, 2] * 9
|
| 28 |
+
+ flat[:, 3] * 27
|
| 29 |
+
+ flat[:, 4] * 81
|
| 30 |
+
).to(torch.uint8)
|
| 31 |
+
|
| 32 |
+
return packed.cpu(), w.shape, pad
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def unpack_ternary(packed, shape, pad=0):
|
| 36 |
+
packed = packed.to(torch.int16)
|
| 37 |
+
|
| 38 |
+
t0 = packed % 3
|
| 39 |
+
packed //= 3
|
| 40 |
+
|
| 41 |
+
t1 = packed % 3
|
| 42 |
+
packed //= 3
|
| 43 |
+
|
| 44 |
+
t2 = packed % 3
|
| 45 |
+
packed //= 3
|
| 46 |
+
|
| 47 |
+
t3 = packed % 3
|
| 48 |
+
packed //= 3
|
| 49 |
+
|
| 50 |
+
t4 = packed % 3
|
| 51 |
+
|
| 52 |
+
out = torch.stack([t0, t1, t2, t3, t4], dim=1).flatten()
|
| 53 |
+
|
| 54 |
+
if pad:
|
| 55 |
+
out = out[:-pad]
|
| 56 |
+
|
| 57 |
+
out = out.view(shape)
|
| 58 |
+
|
| 59 |
+
# map back
|
| 60 |
+
out = out.to(torch.int8)
|
| 61 |
+
out[out == 0] = -1
|
| 62 |
+
out[out == 1] = 0
|
| 63 |
+
out[out == 2] = 1
|
| 64 |
+
|
| 65 |
+
return out
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def save_model(model, path="models/conversions/arb-model.pt"):
|
| 69 |
+
import os
|
| 70 |
+
os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
|
| 71 |
+
torch.save({"model_state_dict": model.state_dict()}, path)
|
| 72 |
+
total = sum(p.numel() for p in model.parameters())
|
| 73 |
+
print(f"Saved {total:,} params to {path}")
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def load_model(path="models/conversions/arb-model.pt", device="cpu"):
|
| 77 |
+
from ..trigram import ARBModel
|
| 78 |
+
checkpoint = torch.load(path, map_location=device, weights_only=False)
|
| 79 |
+
model = ARBModel()
|
| 80 |
+
model.load_state_dict(checkpoint["model_state_dict"], strict=False)
|
| 81 |
+
model.to(device)
|
| 82 |
+
model.eval()
|
| 83 |
+
return model
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
if __name__ == "__main__":
|
| 87 |
+
from ..trigram import ARBModel
|
| 88 |
+
model = ARBModel()
|
| 89 |
+
total = sum(p.numel() for p in model.parameters())
|
| 90 |
+
ternary = sum(
|
| 91 |
+
p.numel() for n, p in model.named_parameters()
|
| 92 |
+
if "weight" in n and p.ndim >= 2 and "embed" not in n
|
| 93 |
+
)
|
| 94 |
+
fp32 = sum(
|
| 95 |
+
p.numel() for n, p in model.named_parameters()
|
| 96 |
+
if not ("weight" in n and p.ndim >= 2 and "embed" not in n)
|
| 97 |
+
)
|
| 98 |
+
print(f"Total params: {total:,}")
|
| 99 |
+
print(f"Ternary params (1.58 BPW): {ternary:,}")
|
| 100 |
+
print(f"FP32 params: {fp32:,}")
|
| 101 |
+
save_model(model)
|
arbitor/decoders.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Decoder modules β video diffusion, audio codec, speech generation.
|
| 2 |
+
|
| 3 |
+
These modules convert HIDDEN_DIM relational states into modality-specific outputs:
|
| 4 |
+
video (latent diffusion), audio (codec tokens), and speech (token striding + codec).
|
| 5 |
+
"""
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from .kernel.ternary_scale import TernaryScaleTensor, TScaleType, TernaryRMSNorm
|
| 10 |
+
from .kernel.triton_video import video_denoise_step
|
| 11 |
+
from .config import HIDDEN_DIM, AUDIO_VOCAB, AUDIO_SR, AUDIO_FRAME_RATE, \
|
| 12 |
+
VIDEO_LATENT_CHANNELS, VIDEO_MAX_STEPS, VIDEO_HEIGHT, VIDEO_WIDTH, \
|
| 13 |
+
VIDEOHEAD_ACT_MIN_FPS, VIDEOHEAD_ACT_MAX_FPS, VIDEOHEAD_ACT_FRAME_CHUNK, \
|
| 14 |
+
TALKERHEAD_ACT_CHUNK_FRAMES
|
| 15 |
+
from .components import TernaryEmbeddingTable
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class LTIInjection(nn.Module):
|
| 19 |
+
"""LTI state injection: h = A*h + B*e + trans_out.
|
| 20 |
+
Spectral radius < 1 guaranteed by construction via ZOH discretization.
|
| 21 |
+
"""
|
| 22 |
+
def __init__(self, dim: int):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.log_A = nn.Parameter(torch.zeros(dim))
|
| 25 |
+
self.log_dt = nn.Parameter(torch.zeros(1))
|
| 26 |
+
self.B = nn.Parameter(torch.ones(dim) * 0.1)
|
| 27 |
+
for p in (self.log_A, self.log_dt, self.B):
|
| 28 |
+
p.requires_grad_(False)
|
| 29 |
+
|
| 30 |
+
def get_A(self):
|
| 31 |
+
return torch.exp(-torch.exp((self.log_dt + self.log_A).clamp(-20, 20)))
|
| 32 |
+
|
| 33 |
+
def forward(self, h, e, trans_out):
|
| 34 |
+
return self.get_A() * h + self.B * e + trans_out
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class VideoHead(nn.Module):
|
| 38 |
+
"""Scaled latent diffusion with cross-attention conditioning, frame gate, and 4-frame latent.
|
| 39 |
+
|
| 40 |
+
Produces [B, ch, 4, H', W'] latents (4-frame temporal chunks) per D-102.
|
| 41 |
+
Frame gate controls adaptive fps in [MIN_FPS, MAX_FPS] range.
|
| 42 |
+
"""
|
| 43 |
+
def __init__(self, tscale_type=TScaleType.T32, max_steps=VIDEO_MAX_STEPS,
|
| 44 |
+
latent_channels=VIDEO_LATENT_CHANNELS, height=VIDEO_HEIGHT, width=VIDEO_WIDTH,
|
| 45 |
+
min_fps=VIDEOHEAD_ACT_MIN_FPS, max_fps=VIDEOHEAD_ACT_MAX_FPS,
|
| 46 |
+
frame_chunk=VIDEOHEAD_ACT_FRAME_CHUNK):
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.max_steps = max_steps
|
| 49 |
+
self.latent_channels = latent_channels
|
| 50 |
+
self.height = height
|
| 51 |
+
self.width = width
|
| 52 |
+
self.latent_dim = latent_channels * height * width
|
| 53 |
+
self.halt_threshold = 0.05
|
| 54 |
+
self.min_fps = min_fps
|
| 55 |
+
self.max_fps = max_fps
|
| 56 |
+
self.frame_chunk = frame_chunk
|
| 57 |
+
|
| 58 |
+
self.cross_attn_q = TernaryScaleTensor(self.latent_dim, HIDDEN_DIM, tscale_type=tscale_type)
|
| 59 |
+
self.cross_attn_kv = TernaryScaleTensor(HIDDEN_DIM, HIDDEN_DIM, tscale_type=tscale_type)
|
| 60 |
+
self.diffusion_step = TernaryScaleTensor(HIDDEN_DIM, self.latent_dim, tscale_type=tscale_type)
|
| 61 |
+
self.halt_unit = TernaryScaleTensor(HIDDEN_DIM, 1, tscale_type=tscale_type)
|
| 62 |
+
self.frame_gate = TernaryScaleTensor(HIDDEN_DIM, 1, tscale_type=tscale_type)
|
| 63 |
+
self.noise_embed = TernaryEmbeddingTable(max_steps, HIDDEN_DIM, tscale_type=tscale_type)
|
| 64 |
+
self.lti = LTIInjection(self.latent_dim)
|
| 65 |
+
|
| 66 |
+
@torch.no_grad()
|
| 67 |
+
def _compute_fps(self, cond):
|
| 68 |
+
frame_prob = torch.sigmoid(self.frame_gate(cond))
|
| 69 |
+
fps = self.min_fps + frame_prob * (self.max_fps - self.min_fps)
|
| 70 |
+
return fps.mean().item()
|
| 71 |
+
|
| 72 |
+
def forward(self, relational, max_steps=None, duration_seconds=1.0):
|
| 73 |
+
B, T, D = relational.shape
|
| 74 |
+
max_steps = max_steps or self.max_steps
|
| 75 |
+
cond = relational.mean(dim=1, keepdim=True)
|
| 76 |
+
|
| 77 |
+
fps = self._compute_fps(cond)
|
| 78 |
+
n_frames = max(1, int(fps * duration_seconds))
|
| 79 |
+
n_latents = min((n_frames + self.frame_chunk - 1) // self.frame_chunk, max_steps)
|
| 80 |
+
|
| 81 |
+
all_latents = []
|
| 82 |
+
for chunk_idx in range(n_latents):
|
| 83 |
+
latent = torch.randn(B, 1, self.latent_dim, device=relational.device,
|
| 84 |
+
requires_grad=torch.is_grad_enabled())
|
| 85 |
+
for step in range(max_steps):
|
| 86 |
+
q = self.cross_attn_q(latent)
|
| 87 |
+
kv = self.cross_attn_kv(cond.expand(-1, T, -1))
|
| 88 |
+
context = kv.mean(dim=1, keepdim=True)
|
| 89 |
+
step_embed = self.noise_embed(torch.tensor(step, device=relational.device))
|
| 90 |
+
step_embed = step_embed.expand(B, 1, -1)
|
| 91 |
+
step_input = q + context + step_embed
|
| 92 |
+
pred_noise = self.diffusion_step(step_input)
|
| 93 |
+
alpha = 0.9 ** step
|
| 94 |
+
trans_out = video_denoise_step(latent, pred_noise, alpha)
|
| 95 |
+
h = torch.zeros(B, 1, self.latent_dim, device=context.device)
|
| 96 |
+
h[:, :, :HIDDEN_DIM] = context
|
| 97 |
+
latent = self.lti(latent, h, trans_out)
|
| 98 |
+
halt = torch.sigmoid(self.halt_unit(context))
|
| 99 |
+
if halt.mean() > self.halt_threshold and step > 1:
|
| 100 |
+
break
|
| 101 |
+
all_latents.append(latent.view(B, self.latent_channels, 1, self.height, self.width))
|
| 102 |
+
|
| 103 |
+
return torch.cat(all_latents, dim=2)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class MRFBlock(nn.Module):
|
| 107 |
+
"""Multi-Receptive Field Fusion block from HiFi-GAN."""
|
| 108 |
+
def __init__(self, channels, kernel_sizes=(3, 5, 7)):
|
| 109 |
+
super().__init__()
|
| 110 |
+
self.convs = nn.ModuleList([
|
| 111 |
+
nn.Sequential(
|
| 112 |
+
nn.LeakyReLU(0.1),
|
| 113 |
+
nn.Conv1d(channels, channels, k, padding=k//2, dilation=1),
|
| 114 |
+
)
|
| 115 |
+
for k in kernel_sizes
|
| 116 |
+
])
|
| 117 |
+
|
| 118 |
+
def forward(self, x):
|
| 119 |
+
return sum(conv(x) for conv in self.convs) / len(self.convs)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class TinyNeuralCodec(nn.Module):
|
| 123 |
+
"""Lightweight neural audio decoder (frozen float32 sidecar).
|
| 124 |
+
|
| 125 |
+
Maps byte token sequences to 16 kHz audio waveforms via transposed conv.
|
| 126 |
+
Token rate: 50 Hz β output: [B, 1, T * 320] at 16 kHz.
|
| 127 |
+
"""
|
| 128 |
+
def __init__(self, vocab=AUDIO_VOCAB, embed_dim=512, upsample_ratios=(5, 4, 4, 4)):
|
| 129 |
+
super().__init__()
|
| 130 |
+
self.embed = nn.Embedding(vocab, embed_dim)
|
| 131 |
+
|
| 132 |
+
in_ch = embed_dim
|
| 133 |
+
self.blocks = nn.ModuleList()
|
| 134 |
+
for i, ratio in enumerate(upsample_ratios):
|
| 135 |
+
out_ch = max(1, embed_dim // (2 ** (i + 1)))
|
| 136 |
+
k = ratio * 2
|
| 137 |
+
pad = (ratio + 1) // 2 if ratio % 2 else ratio // 2
|
| 138 |
+
op = max(0, ratio + 2 * pad - k)
|
| 139 |
+
block = nn.Sequential(
|
| 140 |
+
nn.ConvTranspose1d(in_ch, out_ch, k, stride=ratio, padding=pad, output_padding=op),
|
| 141 |
+
MRFBlock(out_ch),
|
| 142 |
+
)
|
| 143 |
+
self.blocks.append(block)
|
| 144 |
+
in_ch = out_ch
|
| 145 |
+
|
| 146 |
+
self.to_audio = nn.Conv1d(in_ch, 1, kernel_size=7, padding=3)
|
| 147 |
+
|
| 148 |
+
def forward(self, tokens):
|
| 149 |
+
x = self.embed(tokens)
|
| 150 |
+
x = x.permute(0, 2, 1)
|
| 151 |
+
for block in self.blocks:
|
| 152 |
+
x = block(x)
|
| 153 |
+
x = self.to_audio(x)
|
| 154 |
+
return torch.tanh(x)
|
| 155 |
+
|
| 156 |
+
def encode_audio(self, audio, frame_rate=AUDIO_FRAME_RATE, sr=AUDIO_SR):
|
| 157 |
+
B, C, T = audio.shape
|
| 158 |
+
frame_len = sr // frame_rate
|
| 159 |
+
pad = (frame_len - T % frame_len) % frame_len
|
| 160 |
+
if pad > 0:
|
| 161 |
+
audio = F.pad(audio, (0, pad))
|
| 162 |
+
frames = audio.unfold(2, frame_len, frame_len)
|
| 163 |
+
frames = frames.mean(dim=1)
|
| 164 |
+
emb = self.embed.weight
|
| 165 |
+
B, NF, FL = frames.shape
|
| 166 |
+
frames_flat = frames.reshape(-1, FL)
|
| 167 |
+
frame_energy = frames_flat.mean(dim=1)
|
| 168 |
+
tokens = torch.clamp(((frame_energy + 1.0) * 127.5).long(), 0, 255)
|
| 169 |
+
tokens = tokens.reshape(B, NF)
|
| 170 |
+
recon = self(tokens)
|
| 171 |
+
if pad > 0:
|
| 172 |
+
recon = recon[:, :, :T]
|
| 173 |
+
return tokens, recon
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class TalkerHead(nn.Module):
|
| 177 |
+
"""Audio generation head with temporal stride and chunked ACT generation.
|
| 178 |
+
|
| 179 |
+
2-layer MLP: 8192 β 8192 β 288.
|
| 180 |
+
Generates byte token predictions at 50 Hz frame rate in 500-frame chunks.
|
| 181 |
+
TinyNeuralCodec decodes the predicted tokens to audio waveform.
|
| 182 |
+
"""
|
| 183 |
+
def __init__(self, tscale_type=TScaleType.T32,
|
| 184 |
+
chunk_frames=TALKERHEAD_ACT_CHUNK_FRAMES):
|
| 185 |
+
super().__init__()
|
| 186 |
+
self.norm = TernaryRMSNorm(HIDDEN_DIM, tscale_type=tscale_type)
|
| 187 |
+
self.hidden = TernaryScaleTensor(HIDDEN_DIM, HIDDEN_DIM, tscale_type=tscale_type)
|
| 188 |
+
self.hidden_norm = TernaryRMSNorm(HIDDEN_DIM, tscale_type=tscale_type)
|
| 189 |
+
self.head = TernaryScaleTensor(HIDDEN_DIM, AUDIO_VOCAB, tscale_type=tscale_type)
|
| 190 |
+
self.codec = None
|
| 191 |
+
self.max_frames = chunk_frames
|
| 192 |
+
self.chunk_frames = chunk_frames
|
| 193 |
+
|
| 194 |
+
def load_codec(self, device='cuda'):
|
| 195 |
+
if self.codec is None:
|
| 196 |
+
self.codec = TinyNeuralCodec().to(device)
|
| 197 |
+
self.codec.eval()
|
| 198 |
+
return self.codec
|
| 199 |
+
|
| 200 |
+
def token_logits(self, x, max_frames=None):
|
| 201 |
+
max_frames = max_frames or self.max_frames
|
| 202 |
+
cond = self.norm(x)
|
| 203 |
+
cond = F.silu(self.hidden_norm(self.hidden(cond)))
|
| 204 |
+
stride = max(1, max_frames // max(1, cond.shape[1]))
|
| 205 |
+
logits = self.head(cond)
|
| 206 |
+
logits = logits.repeat_interleave(stride, dim=1)
|
| 207 |
+
if logits.shape[1] > max_frames:
|
| 208 |
+
logits = logits[:, :max_frames, :]
|
| 209 |
+
elif logits.shape[1] < max_frames:
|
| 210 |
+
pad = logits.new_zeros(logits.shape[0], max_frames - logits.shape[1], logits.shape[2])
|
| 211 |
+
logits = torch.cat([logits, pad], dim=1)
|
| 212 |
+
return logits
|
| 213 |
+
|
| 214 |
+
def forward(self, x, max_frames=None):
|
| 215 |
+
return self.token_logits(x, max_frames=max_frames).argmax(dim=-1)
|
| 216 |
+
|
| 217 |
+
def generate_audio(self, x, max_frames=None, return_all=True):
|
| 218 |
+
if max_frames is None:
|
| 219 |
+
max_frames = self.max_frames
|
| 220 |
+
all_tokens = []
|
| 221 |
+
remaining = max_frames
|
| 222 |
+
while remaining > 0:
|
| 223 |
+
chunk = min(remaining, self.chunk_frames)
|
| 224 |
+
tokens = self.forward(x, max_frames=chunk)
|
| 225 |
+
all_tokens.append(tokens)
|
| 226 |
+
remaining -= chunk
|
| 227 |
+
tokens = torch.cat(all_tokens, dim=1)
|
| 228 |
+
codec = self.load_codec(x.device if hasattr(x, 'device') else 'cuda')
|
| 229 |
+
with torch.no_grad():
|
| 230 |
+
waveform = codec(tokens)
|
| 231 |
+
return waveform, tokens
|
arbitor/encoders/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Encoder sidecar modules for the ARB system.
|
| 2 |
+
|
| 3 |
+
Each module exposes load(), encode(), decode() methods.
|
| 4 |
+
Loaded on-demand as frozen float/int8 sidecars.
|
| 5 |
+
"""
|
| 6 |
+
from ..decoders import TinyNeuralCodec, MRFBlock
|
| 7 |
+
from .audio import AudioVQEncoder
|
| 8 |
+
from .pig_vae import load_vae, VAEWrapper
|
| 9 |
+
from .opensora_vae import load_opensora_vae, OpenSoraVAEWrapper
|
| 10 |
+
from .vae2d import VAE2DEncoder, load_vae2d
|
| 11 |
+
from .mel_frontend import MelSpectrogram3Band
|
arbitor/encoders/audio.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Audio training encoder β VQ encoder for TalkerHead target preparation.
|
| 2 |
+
|
| 3 |
+
Training-only component (~5M float params). Maps audio at 50 Hz to 289-class byte tokens.
|
| 4 |
+
TinyNeuralCodec (the decoder) is in arbitor.components β shared with TalkerHead.
|
| 5 |
+
"""
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from ..components import TernaryEmbeddingTable
|
| 10 |
+
from ..kernel.ternary_scale import TernaryScaleTensor, TScaleType
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class TernaryConv1d(nn.Module):
|
| 14 |
+
"""Conv1d implemented as unfold + ternary linear projection."""
|
| 15 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
|
| 16 |
+
tscale_type=TScaleType.T32, bias=True):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.in_channels = in_channels
|
| 19 |
+
self.out_channels = out_channels
|
| 20 |
+
self.kernel_size = kernel_size
|
| 21 |
+
self.stride = stride
|
| 22 |
+
self.padding = padding
|
| 23 |
+
self.proj = TernaryScaleTensor(
|
| 24 |
+
in_channels * kernel_size,
|
| 25 |
+
out_channels,
|
| 26 |
+
tscale_type=tscale_type,
|
| 27 |
+
bias=bias,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
def forward(self, x):
|
| 31 |
+
if self.padding:
|
| 32 |
+
x = F.pad(x, (self.padding, self.padding))
|
| 33 |
+
windows = x.unfold(2, self.kernel_size, self.stride)
|
| 34 |
+
windows = windows.permute(0, 2, 1, 3).reshape(x.size(0), -1, self.in_channels * self.kernel_size)
|
| 35 |
+
return self.proj(windows).permute(0, 2, 1)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class AudioVQEncoder(nn.Module):
|
| 39 |
+
"""Encodes audio to discrete byte tokens at 50 Hz for TalkerHead training.
|
| 40 |
+
|
| 41 |
+
Input: [B, 1, T] audio waveform at 16 kHz
|
| 42 |
+
Output: [B, T/320, 288] logits over byte vocab (50 Hz frame rate)
|
| 43 |
+
"""
|
| 44 |
+
def __init__(self, vocab=288, codebook_dim=64, downsample_ratios=(4, 4, 4, 5),
|
| 45 |
+
tscale_type=TScaleType.T32):
|
| 46 |
+
super().__init__()
|
| 47 |
+
in_ch = 1
|
| 48 |
+
self.down_blocks = nn.ModuleList()
|
| 49 |
+
for i, ratio in enumerate(downsample_ratios):
|
| 50 |
+
out_ch = min(128, 32 * (2 ** i))
|
| 51 |
+
block = nn.Sequential(
|
| 52 |
+
TernaryConv1d(in_ch, out_ch, kernel_size=ratio * 2, stride=ratio,
|
| 53 |
+
padding=ratio // 2, tscale_type=tscale_type),
|
| 54 |
+
nn.LeakyReLU(0.1),
|
| 55 |
+
TernaryConv1d(out_ch, out_ch, kernel_size=3, padding=1,
|
| 56 |
+
tscale_type=tscale_type),
|
| 57 |
+
nn.LeakyReLU(0.1),
|
| 58 |
+
)
|
| 59 |
+
self.down_blocks.append(block)
|
| 60 |
+
in_ch = out_ch
|
| 61 |
+
self.proj = TernaryScaleTensor(out_ch, codebook_dim, tscale_type=tscale_type, bias=True)
|
| 62 |
+
self.codebook = TernaryEmbeddingTable(vocab, codebook_dim, tscale_type=tscale_type)
|
| 63 |
+
self.out_proj = TernaryScaleTensor(codebook_dim, vocab, tscale_type=tscale_type, bias=True)
|
| 64 |
+
|
| 65 |
+
def forward(self, audio):
|
| 66 |
+
x = audio
|
| 67 |
+
for block in self.down_blocks:
|
| 68 |
+
x = block(x)
|
| 69 |
+
x = x.permute(0, 2, 1)
|
| 70 |
+
x = self.proj(x)
|
| 71 |
+
emb_idx = torch.arange(self.out_proj.out_dim, device=x.device)
|
| 72 |
+
emb = self.codebook(emb_idx).to(device=x.device, dtype=x.dtype)
|
| 73 |
+
dist = torch.cdist(x.float(), emb.unsqueeze(0).float())
|
| 74 |
+
indices = dist.argmin(dim=-1)
|
| 75 |
+
quantized = F.embedding(indices, emb)
|
| 76 |
+
quantized = x + (quantized - x).detach()
|
| 77 |
+
logits = self.out_proj(quantized)
|
| 78 |
+
return logits, indices
|
| 79 |
+
|
| 80 |
+
def encode(self, audio):
|
| 81 |
+
with torch.no_grad():
|
| 82 |
+
_, indices = self.forward(audio)
|
| 83 |
+
return indices
|
arbitor/encoders/mel_frontend.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Mel spectrogram frontend for audio-to-image conversion.
|
| 2 |
+
|
| 3 |
+
Converts [B, T] audio waveform to [B, 3, 64, T_mel] 3-channel mel
|
| 4 |
+
spectrogram (low/mid/high frequency bands β RGB channels) suitable
|
| 5 |
+
for encoding through the 2D VAE encoder.
|
| 6 |
+
|
| 7 |
+
Band split:
|
| 8 |
+
- Channel 0 (low): 0-1000 Hz
|
| 9 |
+
- Channel 1 (mid): 1000-4000 Hz
|
| 10 |
+
- Channel 2 (high): 4000-8000 Hz
|
| 11 |
+
"""
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torchaudio
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class MelSpectrogram3Band(nn.Module):
|
| 18 |
+
"""Audio β 3-channel mel spectrogram (low/mid/high bands β RGB).
|
| 19 |
+
|
| 20 |
+
Splits audio into 3 frequency bands and computes mel spectrograms
|
| 21 |
+
independently, stacked as RGB-like 3-channel image for VAE encoding.
|
| 22 |
+
"""
|
| 23 |
+
def __init__(self, sample_rate=16000, n_fft=1024, hop_length=512,
|
| 24 |
+
n_mels=64, f_min=0, f_max=8000):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.sample_rate = sample_rate
|
| 27 |
+
self.n_fft = n_fft
|
| 28 |
+
self.hop_length = hop_length
|
| 29 |
+
self.n_mels = n_mels
|
| 30 |
+
|
| 31 |
+
self.mel_low = torchaudio.transforms.MelSpectrogram(
|
| 32 |
+
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length,
|
| 33 |
+
n_mels=n_mels, f_min=f_min, f_max=1000,
|
| 34 |
+
)
|
| 35 |
+
self.mel_mid = torchaudio.transforms.MelSpectrogram(
|
| 36 |
+
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length,
|
| 37 |
+
n_mels=n_mels, f_min=1000, f_max=4000,
|
| 38 |
+
)
|
| 39 |
+
self.mel_high = torchaudio.transforms.MelSpectrogram(
|
| 40 |
+
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length,
|
| 41 |
+
n_mels=n_mels, f_min=4000, f_max=f_max,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
def forward(self, waveform):
|
| 45 |
+
if waveform.dim() == 1:
|
| 46 |
+
waveform = waveform.unsqueeze(0)
|
| 47 |
+
elif waveform.dim() == 3:
|
| 48 |
+
if waveform.shape[1] == 1:
|
| 49 |
+
waveform = waveform.squeeze(1)
|
| 50 |
+
else:
|
| 51 |
+
waveform = waveform.mean(dim=1)
|
| 52 |
+
|
| 53 |
+
spec_low = torchaudio.functional.amplitude_to_DB(
|
| 54 |
+
self.mel_low(waveform), multiplier=10.0, amin=1e-10, db_multiplier=0.0, top_db=80.0
|
| 55 |
+
)
|
| 56 |
+
spec_mid = torchaudio.functional.amplitude_to_DB(
|
| 57 |
+
self.mel_mid(waveform), multiplier=10.0, amin=1e-10, db_multiplier=0.0, top_db=80.0
|
| 58 |
+
)
|
| 59 |
+
spec_high = torchaudio.functional.amplitude_to_DB(
|
| 60 |
+
self.mel_high(waveform), multiplier=10.0, amin=1e-10, db_multiplier=0.0, top_db=80.0
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
specs = []
|
| 64 |
+
for spec in [spec_low, spec_mid, spec_high]:
|
| 65 |
+
s_min = spec.amin(dim=(-2, -1), keepdim=True)
|
| 66 |
+
s_max = spec.amax(dim=(-2, -1), keepdim=True)
|
| 67 |
+
s_range = s_max - s_min + 1e-8
|
| 68 |
+
specs.append((spec - s_min) / s_range)
|
| 69 |
+
|
| 70 |
+
return torch.stack(specs, dim=1)
|
arbitor/encoders/models/__init__.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Local model loader β loads encoder models from local cache, falls back to HF.
|
| 2 |
+
|
| 3 |
+
Model directories (saved via model.save_pretrained()):
|
| 4 |
+
dinov2-small/ β facebook/dinov2-small (21M params, 384-dim) vision
|
| 5 |
+
vit-base/ β google/vit-base-patch16-224 (86M, 768-dim) vision fallback
|
| 6 |
+
moonshine-base/ β UsefulSensors/moonshine-base (62M, 416-dim) audio
|
| 7 |
+
pig-vae/ β Wan2.1 VAE checkpoint (84M params) video latent codec
|
| 8 |
+
|
| 9 |
+
Usage:
|
| 10 |
+
from arbitor.encoders.models import load_encoder, load_processor
|
| 11 |
+
|
| 12 |
+
model = load_encoder("dinov2-small")
|
| 13 |
+
processor = load_processor("dinov2-small", "image")
|
| 14 |
+
|
| 15 |
+
Download models:
|
| 16 |
+
python -m arbitor.encoders.models.download
|
| 17 |
+
"""
|
| 18 |
+
import os
|
| 19 |
+
|
| 20 |
+
_MODELS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)))
|
| 21 |
+
|
| 22 |
+
# Map short names to (local_dir, hf_repo, type)
|
| 23 |
+
REGISTRY = {
|
| 24 |
+
"dinov2-small": {
|
| 25 |
+
"local": os.path.join(_MODELS_DIR, "dinov2-small"),
|
| 26 |
+
"hf": "facebook/dinov2-small",
|
| 27 |
+
"type": "auto",
|
| 28 |
+
},
|
| 29 |
+
"vit-base": {
|
| 30 |
+
"local": os.path.join(_MODELS_DIR, "vit-base"),
|
| 31 |
+
"hf": "google/vit-base-patch16-224",
|
| 32 |
+
"type": "auto",
|
| 33 |
+
},
|
| 34 |
+
"moonshine-base": {
|
| 35 |
+
"local": os.path.join(_MODELS_DIR, "moonshine-base"),
|
| 36 |
+
"hf": "UsefulSensors/moonshine-base",
|
| 37 |
+
"type": "auto",
|
| 38 |
+
},
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def resolve_path(name: str) -> tuple[str, dict]:
|
| 43 |
+
"""Return (local_path_or_hf_name, registry_entry)."""
|
| 44 |
+
entry = REGISTRY.get(name)
|
| 45 |
+
if entry is None:
|
| 46 |
+
raise ValueError(f"Unknown model: {name}. Options: {list(REGISTRY.keys())}")
|
| 47 |
+
if os.path.isdir(entry["local"]):
|
| 48 |
+
return entry["local"], entry
|
| 49 |
+
return entry["hf"], entry
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def load_encoder(name: str, device=None, **kwargs):
|
| 53 |
+
"""Load model from local cache, falling back to HuggingFace.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
name: Short name ("dinov2-small", "vit-base", "moonshine-base")
|
| 57 |
+
device: Optional device to move model to (e.g. "cuda", "cpu")
|
| 58 |
+
Returns:
|
| 59 |
+
Loaded model in eval mode
|
| 60 |
+
"""
|
| 61 |
+
from transformers import AutoModel
|
| 62 |
+
|
| 63 |
+
path, entry = resolve_path(name)
|
| 64 |
+
model = AutoModel.from_pretrained(path, low_cpu_mem_usage=True, **kwargs)
|
| 65 |
+
model.eval()
|
| 66 |
+
if device:
|
| 67 |
+
model = model.to(device)
|
| 68 |
+
return model
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def load_processor(name: str, modality: str = "image"):
|
| 72 |
+
"""Load processor (image processor or feature extractor) from local cache.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
name: Short model name
|
| 76 |
+
modality: "image" for AutoImageProcessor, "audio" for AutoFeatureExtractor
|
| 77 |
+
Returns:
|
| 78 |
+
Processor instance
|
| 79 |
+
"""
|
| 80 |
+
path, _ = resolve_path(name)
|
| 81 |
+
if modality == "audio":
|
| 82 |
+
from transformers import AutoFeatureExtractor
|
| 83 |
+
return AutoFeatureExtractor.from_pretrained(path)
|
| 84 |
+
else:
|
| 85 |
+
from transformers import AutoImageProcessor
|
| 86 |
+
return AutoImageProcessor.from_pretrained(path)
|
arbitor/encoders/models/download.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Download encoder models to local cache (arbitor/encoders/models/).
|
| 2 |
+
|
| 3 |
+
Usage:
|
| 4 |
+
python -m arbitor.encoders.models.download # Download all
|
| 5 |
+
python -m arbitor.encoders.models.download --model pig-vae --convert # Also convert GGUFβsafetensors
|
| 6 |
+
|
| 7 |
+
Models are saved to arbitor/encoders/models/{name}/ and loaded from there
|
| 8 |
+
by sequencers and encoders β no HuggingFace download needed at runtime.
|
| 9 |
+
"""
|
| 10 |
+
import os, sys, argparse, importlib
|
| 11 |
+
|
| 12 |
+
MODELS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)))
|
| 13 |
+
|
| 14 |
+
REGISTRY = {
|
| 15 |
+
"pig-vae": {
|
| 16 |
+
"type": "pth",
|
| 17 |
+
"hf_repo": "Wan-AI/Wan2.1-T2V-1.3B",
|
| 18 |
+
"hf_file": "Wan2.1_VAE.pth",
|
| 19 |
+
"desc": "Video VAE (16 latent channels, 84M params)",
|
| 20 |
+
"gguf_repo": "calcuis/pig-vae",
|
| 21 |
+
"gguf_file": "pig_wan_vae_fp32-f16.gguf",
|
| 22 |
+
},
|
| 23 |
+
"opensora-vae": {
|
| 24 |
+
"type": "pipeline",
|
| 25 |
+
"hf_repo": "hpcai-tech/OpenSora-VAE-v1.2",
|
| 26 |
+
"desc": "3D VAE (4 latent channels, 384M params, 8Γ spatial + 4Γ temporal compression)",
|
| 27 |
+
},
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def convert_gguf_to_safetensors(name: str):
|
| 32 |
+
"""Convert GGUF checkpoint to safetensors (pig-vae only)."""
|
| 33 |
+
dest = os.path.join(MODELS_DIR, name)
|
| 34 |
+
gguf_path = os.path.join(dest, f"{name.replace('-', '_')}_fp32-f16.gguf")
|
| 35 |
+
# Try alternate names
|
| 36 |
+
if not os.path.isfile(gguf_path):
|
| 37 |
+
alt = os.path.join(dest, "pig_wan_vae_fp32-f16.gguf")
|
| 38 |
+
if os.path.isfile(alt):
|
| 39 |
+
gguf_path = alt
|
| 40 |
+
if not os.path.isfile(gguf_path):
|
| 41 |
+
print(f" No GGUF file found in {dest}")
|
| 42 |
+
return False
|
| 43 |
+
|
| 44 |
+
print(f" Converting {gguf_path} to safetensors...", flush=True)
|
| 45 |
+
import gguf
|
| 46 |
+
import safetensors.torch
|
| 47 |
+
|
| 48 |
+
reader = gguf.GGUFReader(gguf_path)
|
| 49 |
+
state_dict = {t.name: __import__('torch').tensor(t.data) for t in reader.tensors}
|
| 50 |
+
|
| 51 |
+
safetensors_path = os.path.join(dest, "model.safetensors")
|
| 52 |
+
safetensors.torch.save_file(state_dict, safetensors_path)
|
| 53 |
+
size = os.path.getsize(safetensors_path)
|
| 54 |
+
print(f" β Saved {safetensors_path} ({size/1e6:.0f} MB, {len(state_dict)} tensors)", flush=True)
|
| 55 |
+
return True
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def download_model(name: str, convert: bool = False):
|
| 59 |
+
"""Download a single model to local cache."""
|
| 60 |
+
entry = REGISTRY.get(name)
|
| 61 |
+
if entry is None:
|
| 62 |
+
print(f"Unknown model: {name}. Options: {list(REGISTRY.keys())}")
|
| 63 |
+
return False
|
| 64 |
+
|
| 65 |
+
dest = os.path.join(MODELS_DIR, name)
|
| 66 |
+
os.makedirs(dest, exist_ok=True)
|
| 67 |
+
|
| 68 |
+
if entry["type"] == "auto":
|
| 69 |
+
from transformers import AutoModel
|
| 70 |
+
print(f"Downloading {name} ({entry['desc']})...", flush=True)
|
| 71 |
+
model = AutoModel.from_pretrained(entry["hf_repo"], low_cpu_mem_usage=True)
|
| 72 |
+
model.save_pretrained(dest)
|
| 73 |
+
print(f" β {name} saved to {dest}", flush=True)
|
| 74 |
+
|
| 75 |
+
if "dinov2" in name or "vit" in name:
|
| 76 |
+
from transformers import AutoImageProcessor
|
| 77 |
+
proc = AutoImageProcessor.from_pretrained(entry["hf_repo"])
|
| 78 |
+
proc.save_pretrained(dest)
|
| 79 |
+
elif "moonshine" in name:
|
| 80 |
+
from transformers import AutoFeatureExtractor
|
| 81 |
+
proc = AutoFeatureExtractor.from_pretrained(entry["hf_repo"])
|
| 82 |
+
proc.save_pretrained(dest)
|
| 83 |
+
|
| 84 |
+
elif entry["type"] == "pth":
|
| 85 |
+
from huggingface_hub import hf_hub_download
|
| 86 |
+
print(f"Downloading {name} ({entry['desc']})...", flush=True)
|
| 87 |
+
hf_hub_download(entry["hf_repo"], entry["hf_file"],
|
| 88 |
+
local_dir=dest, local_dir_use_symlinks=False)
|
| 89 |
+
print(f" β {name} .pth saved to {dest}", flush=True)
|
| 90 |
+
|
| 91 |
+
if convert:
|
| 92 |
+
convert_gguf_to_safetensors(name)
|
| 93 |
+
|
| 94 |
+
return True
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def download_all(convert: bool = False):
|
| 98 |
+
success = 0
|
| 99 |
+
for name in REGISTRY:
|
| 100 |
+
if os.path.isdir(os.path.join(MODELS_DIR, name)):
|
| 101 |
+
existing = [f for f in os.listdir(os.path.join(MODELS_DIR, name))
|
| 102 |
+
if f.endswith(('.safetensors', '.pt', '.pth', '.gguf'))]
|
| 103 |
+
if existing:
|
| 104 |
+
print(f" β {name} already exists β skipping")
|
| 105 |
+
success += 1
|
| 106 |
+
continue
|
| 107 |
+
if download_model(name, convert=convert):
|
| 108 |
+
success += 1
|
| 109 |
+
print(f"\nDownloaded {success}/{len(REGISTRY)} models to {MODELS_DIR}")
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
if __name__ == "__main__":
|
| 113 |
+
parser = argparse.ArgumentParser(description="Download encoder models for ARB")
|
| 114 |
+
parser.add_argument("--model", type=str, default=None,
|
| 115 |
+
help=f"Model ({', '.join(REGISTRY.keys())})")
|
| 116 |
+
parser.add_argument("--convert", action="store_true",
|
| 117 |
+
help="Convert pig-vae GGUFβsafetensors after download")
|
| 118 |
+
parser.add_argument("--list", action="store_true", help="List available models")
|
| 119 |
+
args = parser.parse_args()
|
| 120 |
+
|
| 121 |
+
if args.list:
|
| 122 |
+
for name, info in REGISTRY.items():
|
| 123 |
+
d = os.path.join(MODELS_DIR, name)
|
| 124 |
+
files = os.listdir(d) if os.path.isdir(d) else []
|
| 125 |
+
status = "β" if any(f.endswith(('.safetensors', '.pt', '.pth')) for f in files) else "β"
|
| 126 |
+
print(f" {status} {name:<20} {info['desc']}")
|
| 127 |
+
sys.exit(0)
|
| 128 |
+
|
| 129 |
+
if args.model:
|
| 130 |
+
download_model(args.model, convert=args.convert)
|
| 131 |
+
else:
|
| 132 |
+
download_all(convert=args.convert)
|
arbitor/encoders/models/opensora-vae/config.json
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"VideoAutoencoderPipeline"
|
| 4 |
+
],
|
| 5 |
+
"cal_loss": false,
|
| 6 |
+
"freeze_vae_2d": false,
|
| 7 |
+
"from_pretrained": null,
|
| 8 |
+
"micro_frame_size": 17,
|
| 9 |
+
"model_type": "VideoAutoencoderPipeline",
|
| 10 |
+
"scale": [
|
| 11 |
+
3.85,
|
| 12 |
+
2.32,
|
| 13 |
+
2.33,
|
| 14 |
+
3.06
|
| 15 |
+
],
|
| 16 |
+
"shift": [
|
| 17 |
+
-0.1,
|
| 18 |
+
0.34,
|
| 19 |
+
0.27,
|
| 20 |
+
0.98
|
| 21 |
+
],
|
| 22 |
+
"torch_dtype": "float32",
|
| 23 |
+
"transformers_version": "4.36.2",
|
| 24 |
+
"vae_2d": {
|
| 25 |
+
"from_pretrained": "PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
|
| 26 |
+
"local_files_only": false,
|
| 27 |
+
"micro_batch_size": 4,
|
| 28 |
+
"subfolder": "vae",
|
| 29 |
+
"type": "VideoAutoencoderKL"
|
| 30 |
+
},
|
| 31 |
+
"vae_temporal": {
|
| 32 |
+
"from_pretrained": null,
|
| 33 |
+
"type": "VAE_Temporal_SD"
|
| 34 |
+
}
|
| 35 |
+
}
|
arbitor/encoders/models/opensora-vae/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:057f368f538ca04540c0728a6b8ef80ff529077c5a1c4ba810eb8ba017b8d7c9
|
| 3 |
+
size 1573430548
|
arbitor/encoders/models/pig-vae/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f5a9bb06d188abdf1585142159b84e2c6c8aaa3e64bb5c5792b7316ee2b44785
|
| 3 |
+
size 253879612
|
arbitor/encoders/opensora_vae.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Open-Sora 3D VAE v1.2 sidecar module.
|
| 2 |
+
|
| 3 |
+
Latent: [B, 4, T/4, H/8, W/8]
|
| 4 |
+
8Γ spatial compression, 4Γ temporal compression.
|
| 5 |
+
Frozen float32 sidecar (no gradients).
|
| 6 |
+
|
| 7 |
+
Uses PixArt SDXL VAE (from diffusers) for spatial encoding/decoding.
|
| 8 |
+
Temporal VAE requires opensora package or custom module loading.
|
| 9 |
+
"""
|
| 10 |
+
import os
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
from safetensors import safe_open
|
| 14 |
+
|
| 15 |
+
_LOCAL_VAE_DIR = os.path.join(os.path.dirname(__file__), "models", "opensora-vae")
|
| 16 |
+
_VAE_CONFIG = {
|
| 17 |
+
"scale": (3.85, 2.32, 2.33, 3.06),
|
| 18 |
+
"shift": (-0.10, 0.34, 0.27, 0.98),
|
| 19 |
+
"micro_frame_size": 17,
|
| 20 |
+
}
|
| 21 |
+
_QUANTO_CLASS_MARKERS = ("Q", "Quanto", "Quantized", "WeightQ")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _mark_quantized_sidecar(module, quant_type, applied):
|
| 25 |
+
module._arb_quantize_requested = quant_type
|
| 26 |
+
module._arb_quantized_int8 = bool(applied and quant_type == "int8")
|
| 27 |
+
module._arb_quantized = bool(applied)
|
| 28 |
+
for p in module.parameters():
|
| 29 |
+
p.requires_grad = False
|
| 30 |
+
return module
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _has_quantized_modules(module):
|
| 34 |
+
return any(
|
| 35 |
+
any(marker in type(child).__name__ for marker in _QUANTO_CLASS_MARKERS)
|
| 36 |
+
for child in module.modules()
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _freeze_sidecar(model, quantize_requested=None, quantized=False):
|
| 41 |
+
_mark_quantized_sidecar(model, quantize_requested, quantized)
|
| 42 |
+
return model
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _quantize_int8_if_requested(model, quantize):
|
| 46 |
+
if quantize is None:
|
| 47 |
+
model = model.to(torch.bfloat16)
|
| 48 |
+
_mark_quantized_sidecar(model, quantize, False)
|
| 49 |
+
return model
|
| 50 |
+
try:
|
| 51 |
+
from optimum.quanto import quantize, freeze
|
| 52 |
+
qtype = {"int8": qint8}.get(quantize)
|
| 53 |
+
if qtype is None:
|
| 54 |
+
model = model.to(torch.bfloat16)
|
| 55 |
+
_mark_quantized_sidecar(model, quantize, False)
|
| 56 |
+
return model
|
| 57 |
+
quantize(model, weights=qtype)
|
| 58 |
+
freeze(model)
|
| 59 |
+
_mark_quantized_sidecar(model, quantize, _has_quantized_modules(model))
|
| 60 |
+
except ImportError:
|
| 61 |
+
model = model.to(torch.bfloat16)
|
| 62 |
+
_mark_quantized_sidecar(model, quantize, False)
|
| 63 |
+
return model
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def load_opensora_vae(device="cuda", quantize=None):
|
| 67 |
+
"""Load Open-Sora 3D VAE as frozen float32 sidecar.
|
| 68 |
+
|
| 69 |
+
Loads the spatial VAE from PixArt SDXL (diffusers) and the temporal
|
| 70 |
+
VAE from local safetensors. Falls back to spatial-only if temporal
|
| 71 |
+
module can't be loaded.
|
| 72 |
+
"""
|
| 73 |
+
try:
|
| 74 |
+
from diffusers import AutoencoderKL
|
| 75 |
+
except ImportError:
|
| 76 |
+
raise RuntimeError("need diffusers for Open-Sora VAE spatial component")
|
| 77 |
+
|
| 78 |
+
# Load spatial VAE
|
| 79 |
+
spatial_vae = AutoencoderKL.from_pretrained(
|
| 80 |
+
"PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
|
| 81 |
+
subfolder="vae",
|
| 82 |
+
torch_dtype=torch.float32,
|
| 83 |
+
).to(device)
|
| 84 |
+
spatial_vae.eval()
|
| 85 |
+
|
| 86 |
+
# Try to load temporal VAE weights
|
| 87 |
+
temporal_state = {}
|
| 88 |
+
safetensors_path = os.path.join(_LOCAL_VAE_DIR, "model.safetensors")
|
| 89 |
+
if os.path.isfile(safetensors_path):
|
| 90 |
+
with safe_open(safetensors_path, framework="pt") as f:
|
| 91 |
+
for k in f.keys():
|
| 92 |
+
if k.startswith("temporal_vae."):
|
| 93 |
+
temporal_state[k] = f.get_tensor(k)
|
| 94 |
+
if k.startswith("scale"):
|
| 95 |
+
temporal_state["scale"] = f.get_tensor(k)
|
| 96 |
+
if k.startswith("shift"):
|
| 97 |
+
temporal_state["shift"] = f.get_tensor(k)
|
| 98 |
+
|
| 99 |
+
_freeze_sidecar(spatial_vae, quantize, False)
|
| 100 |
+
return OpenSoraVAEWrapper(spatial_vae, temporal_state)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class OpenSoraVAEWrapper(nn.Module):
|
| 104 |
+
def __init__(self, spatial_vae, temporal_state=None):
|
| 105 |
+
super().__init__()
|
| 106 |
+
self.spatial = spatial_vae
|
| 107 |
+
self.latent_channels = 4
|
| 108 |
+
self.scale_factor_spatial = 8
|
| 109 |
+
self.scale_factor_temporal = 4
|
| 110 |
+
self.temporal_state = temporal_state
|
| 111 |
+
self.temporal_loaded = temporal_state is not None and len(temporal_state) > 0
|
| 112 |
+
|
| 113 |
+
@torch.no_grad()
|
| 114 |
+
def encode(self, video_tensor):
|
| 115 |
+
"""Encode video tensor: [B,3,T,H,W] β [B,4,T/4,H/8,W/8]."""
|
| 116 |
+
B, C, T, H, W = video_tensor.shape
|
| 117 |
+
# Process frame-by-frame through spatial VAE
|
| 118 |
+
latents = []
|
| 119 |
+
for t in range(T):
|
| 120 |
+
frame = video_tensor[:, :, t, :, :]
|
| 121 |
+
latent = self.spatial.encode(frame).latent_dist.sample()
|
| 122 |
+
latents.append(latent)
|
| 123 |
+
latent = torch.stack(latents, dim=2)
|
| 124 |
+
# Scale
|
| 125 |
+
latent = latent * 0.18215
|
| 126 |
+
# Temporal downsample (simple: take every 4th)
|
| 127 |
+
if latent.shape[2] >= 4:
|
| 128 |
+
latent = latent[:, :, ::4, :, :]
|
| 129 |
+
return latent
|
| 130 |
+
|
| 131 |
+
@torch.no_grad()
|
| 132 |
+
def decode(self, latents, num_frames=None):
|
| 133 |
+
"""Decode latents: [B,4,T/4,H/8,W/8] β [B,3,T,H,W]."""
|
| 134 |
+
B, C, T, H, W = latents.shape
|
| 135 |
+
# Temporal upsample (repeat each latent 4Γ)
|
| 136 |
+
latents = latents.repeat_interleave(4, dim=2)
|
| 137 |
+
# Unscale
|
| 138 |
+
latents = latents / 0.18215
|
| 139 |
+
# Decode frame-by-frame
|
| 140 |
+
frames = []
|
| 141 |
+
for t in range(latents.shape[2]):
|
| 142 |
+
frame = latents[:, :, t, :, :]
|
| 143 |
+
decoded = self.spatial.decode(frame).sample
|
| 144 |
+
frames.append(decoded)
|
| 145 |
+
return torch.stack(frames, dim=2)
|
arbitor/encoders/opensora_vae_modules/autoencoder_2d.py
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from Flux
|
| 2 |
+
#
|
| 3 |
+
# Copyright 2024 Black Forest Labs
|
| 4 |
+
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
#
|
| 17 |
+
# This source code is licensed under the license found in the
|
| 18 |
+
# LICENSE file in the root directory of this source tree.
|
| 19 |
+
|
| 20 |
+
from dataclasses import dataclass
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
from einops import rearrange
|
| 24 |
+
from torch import Tensor, nn
|
| 25 |
+
from torch.nn.functional import silu as swish
|
| 26 |
+
|
| 27 |
+
from opensora.registry import MODELS
|
| 28 |
+
from opensora.utils.ckpt import load_checkpoint
|
| 29 |
+
|
| 30 |
+
from .utils import DiagonalGaussianDistribution
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@dataclass
|
| 34 |
+
class AutoEncoderConfig:
|
| 35 |
+
from_pretrained: str | None
|
| 36 |
+
cache_dir: str | None
|
| 37 |
+
resolution: int
|
| 38 |
+
in_channels: int
|
| 39 |
+
ch: int
|
| 40 |
+
out_ch: int
|
| 41 |
+
ch_mult: list[int]
|
| 42 |
+
num_res_blocks: int
|
| 43 |
+
z_channels: int
|
| 44 |
+
scale_factor: float
|
| 45 |
+
shift_factor: float
|
| 46 |
+
sample: bool = True
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class AttnBlock(nn.Module):
|
| 50 |
+
def __init__(self, in_channels: int):
|
| 51 |
+
super().__init__()
|
| 52 |
+
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
| 53 |
+
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 54 |
+
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 55 |
+
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 56 |
+
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 57 |
+
|
| 58 |
+
def attention(self, h_: Tensor) -> Tensor:
|
| 59 |
+
h_ = self.norm(h_)
|
| 60 |
+
q = self.q(h_)
|
| 61 |
+
k = self.k(h_)
|
| 62 |
+
v = self.v(h_)
|
| 63 |
+
b, c, h, w = q.shape
|
| 64 |
+
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
|
| 65 |
+
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
|
| 66 |
+
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
|
| 67 |
+
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
|
| 68 |
+
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w)
|
| 69 |
+
|
| 70 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 71 |
+
return x + self.proj_out(self.attention(x))
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class ResnetBlock(nn.Module):
|
| 75 |
+
def __init__(self, in_channels: int, out_channels: int):
|
| 76 |
+
super().__init__()
|
| 77 |
+
self.in_channels = in_channels
|
| 78 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 79 |
+
self.out_channels = out_channels
|
| 80 |
+
|
| 81 |
+
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
| 82 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 83 |
+
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
| 84 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 85 |
+
if self.in_channels != self.out_channels:
|
| 86 |
+
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
| 87 |
+
|
| 88 |
+
def forward(self, x):
|
| 89 |
+
h = x
|
| 90 |
+
h = self.norm1(h)
|
| 91 |
+
h = swish(h)
|
| 92 |
+
h = self.conv1(h)
|
| 93 |
+
|
| 94 |
+
h = self.norm2(h)
|
| 95 |
+
h = swish(h)
|
| 96 |
+
h = self.conv2(h)
|
| 97 |
+
|
| 98 |
+
if self.in_channels != self.out_channels:
|
| 99 |
+
x = self.nin_shortcut(x)
|
| 100 |
+
|
| 101 |
+
return x + h
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class Downsample(nn.Module):
|
| 105 |
+
def __init__(self, in_channels: int):
|
| 106 |
+
super().__init__()
|
| 107 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
| 108 |
+
|
| 109 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 110 |
+
pad = (0, 1, 0, 1)
|
| 111 |
+
x = nn.functional.pad(x, pad, mode="constant", value=0)
|
| 112 |
+
return self.conv(x)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class Upsample(nn.Module):
|
| 116 |
+
def __init__(self, in_channels: int):
|
| 117 |
+
super().__init__()
|
| 118 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
| 119 |
+
|
| 120 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 121 |
+
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
| 122 |
+
return self.conv(x)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class Encoder(nn.Module):
|
| 126 |
+
def __init__(self, config: AutoEncoderConfig):
|
| 127 |
+
super().__init__()
|
| 128 |
+
self.ch = config.ch
|
| 129 |
+
self.num_resolutions = len(config.ch_mult)
|
| 130 |
+
self.num_res_blocks = config.num_res_blocks
|
| 131 |
+
self.resolution = config.resolution
|
| 132 |
+
self.in_channels = config.in_channels
|
| 133 |
+
|
| 134 |
+
# downsampling
|
| 135 |
+
self.conv_in = nn.Conv2d(config.in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
| 136 |
+
|
| 137 |
+
curr_res = config.resolution
|
| 138 |
+
in_ch_mult = (1,) + tuple(config.ch_mult)
|
| 139 |
+
self.in_ch_mult = in_ch_mult
|
| 140 |
+
self.down = nn.ModuleList()
|
| 141 |
+
block_in = self.ch
|
| 142 |
+
for i_level in range(self.num_resolutions):
|
| 143 |
+
block = nn.ModuleList()
|
| 144 |
+
attn = nn.ModuleList()
|
| 145 |
+
block_in = config.ch * in_ch_mult[i_level]
|
| 146 |
+
block_out = config.ch * config.ch_mult[i_level]
|
| 147 |
+
for _ in range(self.num_res_blocks):
|
| 148 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
| 149 |
+
block_in = block_out
|
| 150 |
+
down = nn.Module()
|
| 151 |
+
down.block = block
|
| 152 |
+
down.attn = attn
|
| 153 |
+
if i_level != self.num_resolutions - 1:
|
| 154 |
+
down.downsample = Downsample(block_in)
|
| 155 |
+
curr_res = curr_res // 2
|
| 156 |
+
self.down.append(down)
|
| 157 |
+
|
| 158 |
+
# middle
|
| 159 |
+
self.mid = nn.Module()
|
| 160 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 161 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
| 162 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 163 |
+
|
| 164 |
+
# end
|
| 165 |
+
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
| 166 |
+
self.conv_out = nn.Conv2d(block_in, 2 * config.z_channels, kernel_size=3, stride=1, padding=1)
|
| 167 |
+
|
| 168 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 169 |
+
# downsampling
|
| 170 |
+
hs = [self.conv_in(x)]
|
| 171 |
+
for i_level in range(self.num_resolutions):
|
| 172 |
+
for i_block in range(self.num_res_blocks):
|
| 173 |
+
h = self.down[i_level].block[i_block](hs[-1])
|
| 174 |
+
if len(self.down[i_level].attn) > 0:
|
| 175 |
+
h = self.down[i_level].attn[i_block](h)
|
| 176 |
+
hs.append(h)
|
| 177 |
+
if i_level != self.num_resolutions - 1:
|
| 178 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
| 179 |
+
|
| 180 |
+
# middle
|
| 181 |
+
h = hs[-1]
|
| 182 |
+
h = self.mid.block_1(h)
|
| 183 |
+
h = self.mid.attn_1(h)
|
| 184 |
+
h = self.mid.block_2(h)
|
| 185 |
+
# end
|
| 186 |
+
h = self.norm_out(h)
|
| 187 |
+
h = swish(h)
|
| 188 |
+
h = self.conv_out(h)
|
| 189 |
+
return h
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class Decoder(nn.Module):
|
| 193 |
+
def __init__(self, config: AutoEncoderConfig):
|
| 194 |
+
super().__init__()
|
| 195 |
+
self.ch = config.ch
|
| 196 |
+
self.num_resolutions = len(config.ch_mult)
|
| 197 |
+
self.num_res_blocks = config.num_res_blocks
|
| 198 |
+
self.resolution = config.resolution
|
| 199 |
+
self.in_channels = config.in_channels
|
| 200 |
+
self.ffactor = 2 ** (self.num_resolutions - 1)
|
| 201 |
+
|
| 202 |
+
block_in = config.ch * config.ch_mult[self.num_resolutions - 1]
|
| 203 |
+
curr_res = config.resolution // 2 ** (self.num_resolutions - 1)
|
| 204 |
+
self.z_shape = (1, config.z_channels, curr_res, curr_res)
|
| 205 |
+
|
| 206 |
+
# z to block_in
|
| 207 |
+
self.conv_in = nn.Conv2d(config.z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
| 208 |
+
|
| 209 |
+
# middle
|
| 210 |
+
self.mid = nn.Module()
|
| 211 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 212 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
| 213 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 214 |
+
|
| 215 |
+
# upsampling
|
| 216 |
+
self.up = nn.ModuleList()
|
| 217 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 218 |
+
block = nn.ModuleList()
|
| 219 |
+
attn = nn.ModuleList()
|
| 220 |
+
block_out = config.ch * config.ch_mult[i_level]
|
| 221 |
+
for _ in range(self.num_res_blocks + 1):
|
| 222 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
| 223 |
+
block_in = block_out
|
| 224 |
+
up = nn.Module()
|
| 225 |
+
up.block = block
|
| 226 |
+
up.attn = attn
|
| 227 |
+
if i_level != 0:
|
| 228 |
+
up.upsample = Upsample(block_in)
|
| 229 |
+
curr_res = curr_res * 2
|
| 230 |
+
self.up.insert(0, up) # prepend to get consistent order
|
| 231 |
+
|
| 232 |
+
# end
|
| 233 |
+
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
| 234 |
+
self.conv_out = nn.Conv2d(block_in, config.out_ch, kernel_size=3, stride=1, padding=1)
|
| 235 |
+
|
| 236 |
+
def forward(self, z: Tensor) -> Tensor:
|
| 237 |
+
# z to block_in
|
| 238 |
+
h = self.conv_in(z)
|
| 239 |
+
|
| 240 |
+
# middle
|
| 241 |
+
h = self.mid.block_1(h)
|
| 242 |
+
h = self.mid.attn_1(h)
|
| 243 |
+
h = self.mid.block_2(h)
|
| 244 |
+
|
| 245 |
+
# upsampling
|
| 246 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 247 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 248 |
+
h = self.up[i_level].block[i_block](h)
|
| 249 |
+
if len(self.up[i_level].attn) > 0:
|
| 250 |
+
h = self.up[i_level].attn[i_block](h)
|
| 251 |
+
if i_level != 0:
|
| 252 |
+
h = self.up[i_level].upsample(h)
|
| 253 |
+
|
| 254 |
+
# end
|
| 255 |
+
h = self.norm_out(h)
|
| 256 |
+
h = swish(h)
|
| 257 |
+
return self.conv_out(h)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
class AutoEncoder(nn.Module):
|
| 261 |
+
def __init__(self, config: AutoEncoderConfig):
|
| 262 |
+
super().__init__()
|
| 263 |
+
self.encoder = Encoder(config)
|
| 264 |
+
self.decoder = Decoder(config)
|
| 265 |
+
self.scale_factor = config.scale_factor
|
| 266 |
+
self.shift_factor = config.shift_factor
|
| 267 |
+
self.sample = config.sample
|
| 268 |
+
|
| 269 |
+
def encode_(self, x: Tensor) -> tuple[Tensor, DiagonalGaussianDistribution]:
|
| 270 |
+
T = x.shape[2]
|
| 271 |
+
x = rearrange(x, "b c t h w -> (b t) c h w")
|
| 272 |
+
params = self.encoder(x)
|
| 273 |
+
params = rearrange(params, "(b t) c h w -> b c t h w", t=T)
|
| 274 |
+
posterior = DiagonalGaussianDistribution(params)
|
| 275 |
+
if self.sample:
|
| 276 |
+
z = posterior.sample()
|
| 277 |
+
else:
|
| 278 |
+
z = posterior.mode()
|
| 279 |
+
z = self.scale_factor * (z - self.shift_factor)
|
| 280 |
+
return z, posterior
|
| 281 |
+
|
| 282 |
+
def encode(self, x: Tensor) -> Tensor:
|
| 283 |
+
return self.encode_(x)[0]
|
| 284 |
+
|
| 285 |
+
def decode(self, z: Tensor) -> Tensor:
|
| 286 |
+
T = z.shape[2]
|
| 287 |
+
z = rearrange(z, "b c t h w -> (b t) c h w")
|
| 288 |
+
z = z / self.scale_factor + self.shift_factor
|
| 289 |
+
x = self.decoder(z)
|
| 290 |
+
x = rearrange(x, "(b t) c h w -> b c t h w", t=T)
|
| 291 |
+
return x
|
| 292 |
+
|
| 293 |
+
def forward(self, x: Tensor) -> tuple[Tensor, DiagonalGaussianDistribution, Tensor]:
|
| 294 |
+
# encode
|
| 295 |
+
x.shape[2]
|
| 296 |
+
z, posterior = self.encode_(x)
|
| 297 |
+
# decode
|
| 298 |
+
x_rec = self.decode(z)
|
| 299 |
+
|
| 300 |
+
return x_rec, posterior, z
|
| 301 |
+
|
| 302 |
+
def get_last_layer(self):
|
| 303 |
+
return self.decoder.conv_out.weight
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
@MODELS.register_module("autoencoder_2d")
|
| 307 |
+
def AutoEncoderFlux(
|
| 308 |
+
from_pretrained: str,
|
| 309 |
+
cache_dir=None,
|
| 310 |
+
resolution=256,
|
| 311 |
+
in_channels=3,
|
| 312 |
+
ch=128,
|
| 313 |
+
out_ch=3,
|
| 314 |
+
ch_mult=[1, 2, 4, 4],
|
| 315 |
+
num_res_blocks=2,
|
| 316 |
+
z_channels=16,
|
| 317 |
+
scale_factor=0.3611,
|
| 318 |
+
shift_factor=0.1159,
|
| 319 |
+
device_map: str | torch.device = "cuda",
|
| 320 |
+
torch_dtype: torch.dtype = torch.bfloat16,
|
| 321 |
+
) -> AutoEncoder:
|
| 322 |
+
config = AutoEncoderConfig(
|
| 323 |
+
from_pretrained=from_pretrained,
|
| 324 |
+
cache_dir=cache_dir,
|
| 325 |
+
resolution=resolution,
|
| 326 |
+
in_channels=in_channels,
|
| 327 |
+
ch=ch,
|
| 328 |
+
out_ch=out_ch,
|
| 329 |
+
ch_mult=ch_mult,
|
| 330 |
+
num_res_blocks=num_res_blocks,
|
| 331 |
+
z_channels=z_channels,
|
| 332 |
+
scale_factor=scale_factor,
|
| 333 |
+
shift_factor=shift_factor,
|
| 334 |
+
)
|
| 335 |
+
with torch.device(device_map):
|
| 336 |
+
model = AutoEncoder(config).to(torch_dtype)
|
| 337 |
+
if from_pretrained:
|
| 338 |
+
model = load_checkpoint(model, from_pretrained, cache_dir=cache_dir, device_map=device_map)
|
| 339 |
+
return model
|
arbitor/encoders/opensora_vae_modules/autoencoder_kl_causal_3d.py
ADDED
|
@@ -0,0 +1,638 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from diffusers==0.29.2 and HunyuanVideo
|
| 2 |
+
#
|
| 3 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
#
|
| 17 |
+
# Copyright 2024 HunyuanVideo
|
| 18 |
+
#
|
| 19 |
+
# This source code is licensed under the license found in the
|
| 20 |
+
# LICENSE file in the root directory of this source tree.
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
from dataclasses import dataclass
|
| 24 |
+
from typing import Dict, Optional, Tuple, Union
|
| 25 |
+
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn as nn
|
| 28 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 29 |
+
|
| 30 |
+
from opensora.registry import MODELS
|
| 31 |
+
from opensora.utils.ckpt import load_checkpoint
|
| 32 |
+
|
| 33 |
+
try:
|
| 34 |
+
# This diffusers is modified and packed in the mirror.
|
| 35 |
+
from diffusers.loaders import FromOriginalVAEMixin
|
| 36 |
+
except ImportError:
|
| 37 |
+
# Use this to be compatible with the original diffusers.
|
| 38 |
+
from diffusers.loaders.single_file_model import FromOriginalModelMixin as FromOriginalVAEMixin
|
| 39 |
+
|
| 40 |
+
from diffusers.models.attention_processor import (
|
| 41 |
+
ADDED_KV_ATTENTION_PROCESSORS,
|
| 42 |
+
CROSS_ATTENTION_PROCESSORS,
|
| 43 |
+
Attention,
|
| 44 |
+
AttentionProcessor,
|
| 45 |
+
AttnAddedKVProcessor,
|
| 46 |
+
AttnProcessor,
|
| 47 |
+
)
|
| 48 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 49 |
+
from diffusers.utils.accelerate_utils import apply_forward_hook
|
| 50 |
+
|
| 51 |
+
from opensora.models.hunyuan_vae.vae import (
|
| 52 |
+
DecoderCausal3D,
|
| 53 |
+
DecoderOutput,
|
| 54 |
+
DiagonalGaussianDistribution,
|
| 55 |
+
EncoderCausal3D,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@dataclass
|
| 60 |
+
class AutoEncoder3DConfig:
|
| 61 |
+
from_pretrained: str | None
|
| 62 |
+
act_fn: str = "silu"
|
| 63 |
+
in_channels: int = 3
|
| 64 |
+
out_channels: int = 3
|
| 65 |
+
latent_channels: int = 16
|
| 66 |
+
layers_per_block: int = 2
|
| 67 |
+
norm_num_groups: int = 32
|
| 68 |
+
scale_factor: float = 0.476986
|
| 69 |
+
shift_factor: float = 0
|
| 70 |
+
time_compression_ratio: int = 4
|
| 71 |
+
spatial_compression_ratio: int = 8
|
| 72 |
+
mid_block_add_attention: bool = True
|
| 73 |
+
block_out_channels: tuple[int] = (128, 256, 512, 512)
|
| 74 |
+
sample_size: int = 256
|
| 75 |
+
sample_tsize: int = 64
|
| 76 |
+
use_slicing: bool = False
|
| 77 |
+
use_spatial_tiling: bool = False
|
| 78 |
+
use_temporal_tiling: bool = False
|
| 79 |
+
tile_overlap_factor: float = 0.25
|
| 80 |
+
dropout: float = 0.0
|
| 81 |
+
channel: bool = False
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
| 85 |
+
r"""
|
| 86 |
+
A VAE model with KL loss for encoding images/videos into latents and decoding latent representations into images/videos.
|
| 87 |
+
|
| 88 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
| 89 |
+
for all models (such as downloading or saving).
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
_supports_gradient_checkpointing = True
|
| 93 |
+
|
| 94 |
+
@register_to_config
|
| 95 |
+
def __init__(self, config: AutoEncoder3DConfig):
|
| 96 |
+
super().__init__()
|
| 97 |
+
|
| 98 |
+
self.scale_factor = config.scale_factor
|
| 99 |
+
self.shift_factor = config.shift_factor
|
| 100 |
+
|
| 101 |
+
self.time_compression_ratio = config.time_compression_ratio
|
| 102 |
+
self.spatial_compression_ratio = config.spatial_compression_ratio
|
| 103 |
+
self.z_channels = config.latent_channels
|
| 104 |
+
|
| 105 |
+
self.encoder = EncoderCausal3D(
|
| 106 |
+
in_channels=config.in_channels,
|
| 107 |
+
out_channels=config.latent_channels,
|
| 108 |
+
block_out_channels=config.block_out_channels,
|
| 109 |
+
layers_per_block=config.layers_per_block,
|
| 110 |
+
act_fn=config.act_fn,
|
| 111 |
+
norm_num_groups=config.norm_num_groups,
|
| 112 |
+
double_z=True,
|
| 113 |
+
time_compression_ratio=config.time_compression_ratio,
|
| 114 |
+
spatial_compression_ratio=config.spatial_compression_ratio,
|
| 115 |
+
mid_block_add_attention=config.mid_block_add_attention,
|
| 116 |
+
dropout=config.dropout,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
self.decoder = DecoderCausal3D(
|
| 120 |
+
in_channels=config.latent_channels,
|
| 121 |
+
out_channels=config.out_channels,
|
| 122 |
+
block_out_channels=config.block_out_channels,
|
| 123 |
+
layers_per_block=config.layers_per_block,
|
| 124 |
+
norm_num_groups=config.norm_num_groups,
|
| 125 |
+
act_fn=config.act_fn,
|
| 126 |
+
time_compression_ratio=config.time_compression_ratio,
|
| 127 |
+
spatial_compression_ratio=config.spatial_compression_ratio,
|
| 128 |
+
mid_block_add_attention=config.mid_block_add_attention,
|
| 129 |
+
dropout=config.dropout,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
self.quant_conv = nn.Conv3d(2 * config.latent_channels, 2 * config.latent_channels, kernel_size=1)
|
| 133 |
+
self.post_quant_conv = nn.Conv3d(config.latent_channels, config.latent_channels, kernel_size=1)
|
| 134 |
+
|
| 135 |
+
self.use_slicing = config.use_slicing
|
| 136 |
+
self.use_spatial_tiling = config.use_spatial_tiling
|
| 137 |
+
self.use_temporal_tiling = config.use_temporal_tiling
|
| 138 |
+
|
| 139 |
+
# only relevant if vae tiling is enabled
|
| 140 |
+
self.tile_sample_min_tsize = config.sample_tsize
|
| 141 |
+
self.tile_latent_min_tsize = config.sample_tsize // config.time_compression_ratio
|
| 142 |
+
|
| 143 |
+
self.tile_sample_min_size = config.sample_size
|
| 144 |
+
sample_size = config.sample_size[0] if isinstance(config.sample_size, (list, tuple)) else config.sample_size
|
| 145 |
+
self.tile_latent_min_size = int(sample_size / (2 ** (len(config.block_out_channels) - 1)))
|
| 146 |
+
self.tile_overlap_factor = config.tile_overlap_factor
|
| 147 |
+
|
| 148 |
+
def enable_temporal_tiling(self, use_tiling: bool = True):
|
| 149 |
+
self.use_temporal_tiling = use_tiling
|
| 150 |
+
|
| 151 |
+
def disable_temporal_tiling(self):
|
| 152 |
+
self.enable_temporal_tiling(False)
|
| 153 |
+
|
| 154 |
+
def enable_spatial_tiling(self, use_tiling: bool = True):
|
| 155 |
+
self.use_spatial_tiling = use_tiling
|
| 156 |
+
|
| 157 |
+
def disable_spatial_tiling(self):
|
| 158 |
+
self.enable_spatial_tiling(False)
|
| 159 |
+
|
| 160 |
+
def enable_tiling(self, use_tiling: bool = True):
|
| 161 |
+
r"""
|
| 162 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
| 163 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
| 164 |
+
processing larger videos.
|
| 165 |
+
"""
|
| 166 |
+
self.enable_spatial_tiling(use_tiling)
|
| 167 |
+
self.enable_temporal_tiling(use_tiling)
|
| 168 |
+
|
| 169 |
+
def disable_tiling(self):
|
| 170 |
+
r"""
|
| 171 |
+
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
| 172 |
+
decoding in one step.
|
| 173 |
+
"""
|
| 174 |
+
self.disable_spatial_tiling()
|
| 175 |
+
self.disable_temporal_tiling()
|
| 176 |
+
|
| 177 |
+
def enable_slicing(self):
|
| 178 |
+
r"""
|
| 179 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
| 180 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
| 181 |
+
"""
|
| 182 |
+
self.use_slicing = True
|
| 183 |
+
|
| 184 |
+
def disable_slicing(self):
|
| 185 |
+
r"""
|
| 186 |
+
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
| 187 |
+
decoding in one step.
|
| 188 |
+
"""
|
| 189 |
+
self.use_slicing = False
|
| 190 |
+
|
| 191 |
+
@property
|
| 192 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
| 193 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
| 194 |
+
r"""
|
| 195 |
+
Returns:
|
| 196 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
| 197 |
+
indexed by its weight name.
|
| 198 |
+
"""
|
| 199 |
+
# set recursively
|
| 200 |
+
processors = {}
|
| 201 |
+
|
| 202 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
| 203 |
+
if hasattr(module, "get_processor"):
|
| 204 |
+
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
| 205 |
+
|
| 206 |
+
for sub_name, child in module.named_children():
|
| 207 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
| 208 |
+
|
| 209 |
+
return processors
|
| 210 |
+
|
| 211 |
+
for name, module in self.named_children():
|
| 212 |
+
fn_recursive_add_processors(name, module, processors)
|
| 213 |
+
|
| 214 |
+
return processors
|
| 215 |
+
|
| 216 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
| 217 |
+
def set_attn_processor(
|
| 218 |
+
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
| 219 |
+
):
|
| 220 |
+
r"""
|
| 221 |
+
Sets the attention processor to use to compute attention.
|
| 222 |
+
|
| 223 |
+
Parameters:
|
| 224 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
| 225 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
| 226 |
+
for **all** `Attention` layers.
|
| 227 |
+
|
| 228 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
| 229 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
| 230 |
+
|
| 231 |
+
"""
|
| 232 |
+
count = len(self.attn_processors.keys())
|
| 233 |
+
|
| 234 |
+
if isinstance(processor, dict) and len(processor) != count:
|
| 235 |
+
raise ValueError(
|
| 236 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
| 237 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
| 241 |
+
if hasattr(module, "set_processor"):
|
| 242 |
+
if not isinstance(processor, dict):
|
| 243 |
+
module.set_processor(processor, _remove_lora=_remove_lora)
|
| 244 |
+
else:
|
| 245 |
+
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
| 246 |
+
|
| 247 |
+
for sub_name, child in module.named_children():
|
| 248 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
| 249 |
+
|
| 250 |
+
for name, module in self.named_children():
|
| 251 |
+
fn_recursive_attn_processor(name, module, processor)
|
| 252 |
+
|
| 253 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
| 254 |
+
def set_default_attn_processor(self):
|
| 255 |
+
"""
|
| 256 |
+
Disables custom attention processors and sets the default attention implementation.
|
| 257 |
+
"""
|
| 258 |
+
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
| 259 |
+
processor = AttnAddedKVProcessor()
|
| 260 |
+
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
| 261 |
+
processor = AttnProcessor()
|
| 262 |
+
else:
|
| 263 |
+
raise ValueError(
|
| 264 |
+
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
self.set_attn_processor(processor, _remove_lora=True)
|
| 268 |
+
|
| 269 |
+
@apply_forward_hook
|
| 270 |
+
def encode(
|
| 271 |
+
self,
|
| 272 |
+
x: torch.FloatTensor,
|
| 273 |
+
sample_posterior: bool = True,
|
| 274 |
+
return_posterior: bool = False,
|
| 275 |
+
generator: Optional[torch.Generator] = None,
|
| 276 |
+
) -> Union[torch.FloatTensor, Tuple[DiagonalGaussianDistribution]]:
|
| 277 |
+
"""
|
| 278 |
+
Encode a batch of images/videos into latents.
|
| 279 |
+
|
| 280 |
+
Args:
|
| 281 |
+
x (`torch.FloatTensor`): Input batch of images/videos.
|
| 282 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 283 |
+
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
| 284 |
+
|
| 285 |
+
Returns:
|
| 286 |
+
The latent representations of the encoded images/videos. If `return_dict` is True, a
|
| 287 |
+
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
| 288 |
+
"""
|
| 289 |
+
assert len(x.shape) == 5, "The input tensor should have 5 dimensions."
|
| 290 |
+
|
| 291 |
+
if self.use_temporal_tiling and x.shape[2] > self.tile_sample_min_tsize:
|
| 292 |
+
posterior = self.temporal_tiled_encode(x)
|
| 293 |
+
elif self.use_spatial_tiling and (
|
| 294 |
+
x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size
|
| 295 |
+
):
|
| 296 |
+
posterior = self.spatial_tiled_encode(x)
|
| 297 |
+
else:
|
| 298 |
+
if self.use_slicing and x.shape[0] > 1:
|
| 299 |
+
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
|
| 300 |
+
h = torch.cat(encoded_slices)
|
| 301 |
+
else:
|
| 302 |
+
h = self.encoder(x)
|
| 303 |
+
moments = self.quant_conv(h)
|
| 304 |
+
posterior = DiagonalGaussianDistribution(moments)
|
| 305 |
+
|
| 306 |
+
if sample_posterior:
|
| 307 |
+
z = posterior.sample(generator=generator)
|
| 308 |
+
else:
|
| 309 |
+
z = posterior.mode()
|
| 310 |
+
|
| 311 |
+
z = self.scale_factor * (z - self.shift_factor) # shift & scale
|
| 312 |
+
|
| 313 |
+
if return_posterior:
|
| 314 |
+
return z, posterior
|
| 315 |
+
else:
|
| 316 |
+
return z
|
| 317 |
+
|
| 318 |
+
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
| 319 |
+
assert len(z.shape) == 5, "The input tensor should have 5 dimensions."
|
| 320 |
+
|
| 321 |
+
if self.use_temporal_tiling and z.shape[2] > self.tile_latent_min_tsize:
|
| 322 |
+
return self.temporal_tiled_decode(z, return_dict=return_dict)
|
| 323 |
+
|
| 324 |
+
if self.use_spatial_tiling and (
|
| 325 |
+
z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size
|
| 326 |
+
):
|
| 327 |
+
return self.spatial_tiled_decode(z, return_dict=return_dict)
|
| 328 |
+
|
| 329 |
+
z = self.post_quant_conv(z)
|
| 330 |
+
dec = self.decoder(z)
|
| 331 |
+
|
| 332 |
+
if not return_dict:
|
| 333 |
+
return (dec,)
|
| 334 |
+
|
| 335 |
+
return DecoderOutput(sample=dec)
|
| 336 |
+
|
| 337 |
+
@apply_forward_hook
|
| 338 |
+
def decode(self, z: torch.FloatTensor) -> torch.FloatTensor:
|
| 339 |
+
"""
|
| 340 |
+
Decode a batch of images/videos.
|
| 341 |
+
|
| 342 |
+
Args:
|
| 343 |
+
z (`torch.FloatTensor`): Input batch of latent vectors.
|
| 344 |
+
|
| 345 |
+
Returns:
|
| 346 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
| 347 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
| 348 |
+
returned.
|
| 349 |
+
|
| 350 |
+
"""
|
| 351 |
+
z = z / self.scale_factor + self.shift_factor # scale & shift
|
| 352 |
+
|
| 353 |
+
if self.use_slicing and z.shape[0] > 1:
|
| 354 |
+
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
| 355 |
+
decoded = torch.cat(decoded_slices)
|
| 356 |
+
else:
|
| 357 |
+
decoded = self._decode(z).sample
|
| 358 |
+
return decoded
|
| 359 |
+
|
| 360 |
+
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
| 361 |
+
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
|
| 362 |
+
for y in range(blend_extent):
|
| 363 |
+
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
|
| 364 |
+
y / blend_extent
|
| 365 |
+
)
|
| 366 |
+
return b
|
| 367 |
+
|
| 368 |
+
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
| 369 |
+
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
|
| 370 |
+
for x in range(blend_extent):
|
| 371 |
+
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
|
| 372 |
+
x / blend_extent
|
| 373 |
+
)
|
| 374 |
+
return b
|
| 375 |
+
|
| 376 |
+
def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
| 377 |
+
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
|
| 378 |
+
for x in range(blend_extent):
|
| 379 |
+
b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (
|
| 380 |
+
x / blend_extent
|
| 381 |
+
)
|
| 382 |
+
return b
|
| 383 |
+
|
| 384 |
+
def spatial_tiled_encode(self, x: torch.FloatTensor, return_moments: bool = False) -> DiagonalGaussianDistribution:
|
| 385 |
+
r"""Encode a batch of images/videos using a tiled encoder.
|
| 386 |
+
|
| 387 |
+
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
| 388 |
+
steps. This is useful to keep memory use constant regardless of image/videos size. The end result of tiled encoding is
|
| 389 |
+
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
|
| 390 |
+
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
| 391 |
+
output, but they should be much less noticeable.
|
| 392 |
+
|
| 393 |
+
Args:
|
| 394 |
+
x (`torch.FloatTensor`): Input batch of images/videos.
|
| 395 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 396 |
+
Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
| 397 |
+
|
| 398 |
+
Returns:
|
| 399 |
+
[`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
|
| 400 |
+
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
|
| 401 |
+
`tuple` is returned.
|
| 402 |
+
"""
|
| 403 |
+
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
| 404 |
+
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
| 405 |
+
row_limit = self.tile_latent_min_size - blend_extent
|
| 406 |
+
|
| 407 |
+
# Split video into tiles and encode them separately.
|
| 408 |
+
rows = []
|
| 409 |
+
for i in range(0, x.shape[-2], overlap_size):
|
| 410 |
+
row = []
|
| 411 |
+
for j in range(0, x.shape[-1], overlap_size):
|
| 412 |
+
tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
| 413 |
+
tile = self.encoder(tile)
|
| 414 |
+
tile = self.quant_conv(tile)
|
| 415 |
+
row.append(tile)
|
| 416 |
+
rows.append(row)
|
| 417 |
+
result_rows = []
|
| 418 |
+
for i, row in enumerate(rows):
|
| 419 |
+
result_row = []
|
| 420 |
+
for j, tile in enumerate(row):
|
| 421 |
+
# blend the above tile and the left tile
|
| 422 |
+
# to the current tile and add the current tile to the result row
|
| 423 |
+
if i > 0:
|
| 424 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
| 425 |
+
if j > 0:
|
| 426 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
| 427 |
+
result_row.append(tile[:, :, :, :row_limit, :row_limit])
|
| 428 |
+
result_rows.append(torch.cat(result_row, dim=-1))
|
| 429 |
+
|
| 430 |
+
moments = torch.cat(result_rows, dim=-2)
|
| 431 |
+
if return_moments:
|
| 432 |
+
return moments
|
| 433 |
+
posterior = DiagonalGaussianDistribution(moments)
|
| 434 |
+
return posterior
|
| 435 |
+
|
| 436 |
+
def spatial_tiled_decode(
|
| 437 |
+
self, z: torch.FloatTensor, return_dict: bool = True
|
| 438 |
+
) -> Union[DecoderOutput, torch.FloatTensor]:
|
| 439 |
+
r"""
|
| 440 |
+
Decode a batch of images/videos using a tiled decoder.
|
| 441 |
+
|
| 442 |
+
Args:
|
| 443 |
+
z (`torch.FloatTensor`): Input batch of latent vectors.
|
| 444 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 445 |
+
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
| 446 |
+
|
| 447 |
+
Returns:
|
| 448 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
| 449 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
| 450 |
+
returned.
|
| 451 |
+
"""
|
| 452 |
+
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
|
| 453 |
+
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
|
| 454 |
+
row_limit = self.tile_sample_min_size - blend_extent
|
| 455 |
+
|
| 456 |
+
# Split z into overlapping tiles and decode them separately.
|
| 457 |
+
# The tiles have an overlap to avoid seams between tiles.
|
| 458 |
+
rows = []
|
| 459 |
+
for i in range(0, z.shape[-2], overlap_size):
|
| 460 |
+
row = []
|
| 461 |
+
for j in range(0, z.shape[-1], overlap_size):
|
| 462 |
+
tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
|
| 463 |
+
tile = self.post_quant_conv(tile)
|
| 464 |
+
decoded = self.decoder(tile)
|
| 465 |
+
row.append(decoded)
|
| 466 |
+
rows.append(row)
|
| 467 |
+
result_rows = []
|
| 468 |
+
for i, row in enumerate(rows):
|
| 469 |
+
result_row = []
|
| 470 |
+
for j, tile in enumerate(row):
|
| 471 |
+
# blend the above tile and the left tile
|
| 472 |
+
# to the current tile and add the current tile to the result row
|
| 473 |
+
if i > 0:
|
| 474 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
| 475 |
+
if j > 0:
|
| 476 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
| 477 |
+
result_row.append(tile[:, :, :, :row_limit, :row_limit])
|
| 478 |
+
result_rows.append(torch.cat(result_row, dim=-1))
|
| 479 |
+
|
| 480 |
+
dec = torch.cat(result_rows, dim=-2)
|
| 481 |
+
if not return_dict:
|
| 482 |
+
return (dec,)
|
| 483 |
+
|
| 484 |
+
return DecoderOutput(sample=dec)
|
| 485 |
+
|
| 486 |
+
def temporal_tiled_encode(self, x: torch.FloatTensor) -> DiagonalGaussianDistribution:
|
| 487 |
+
B, C, T, H, W = x.shape
|
| 488 |
+
overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor))
|
| 489 |
+
blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor)
|
| 490 |
+
t_limit = self.tile_latent_min_tsize - blend_extent
|
| 491 |
+
|
| 492 |
+
# Split the video into tiles and encode them separately.
|
| 493 |
+
row = []
|
| 494 |
+
for i in range(0, T, overlap_size):
|
| 495 |
+
tile = x[:, :, i : i + self.tile_sample_min_tsize + 1, :, :]
|
| 496 |
+
if self.use_spatial_tiling and (
|
| 497 |
+
tile.shape[-1] > self.tile_sample_min_size or tile.shape[-2] > self.tile_sample_min_size
|
| 498 |
+
):
|
| 499 |
+
tile = self.spatial_tiled_encode(tile, return_moments=True)
|
| 500 |
+
else:
|
| 501 |
+
tile = self.encoder(tile)
|
| 502 |
+
tile = self.quant_conv(tile)
|
| 503 |
+
if i > 0:
|
| 504 |
+
tile = tile[:, :, 1:, :, :]
|
| 505 |
+
row.append(tile)
|
| 506 |
+
result_row = []
|
| 507 |
+
for i, tile in enumerate(row):
|
| 508 |
+
if i > 0:
|
| 509 |
+
tile = self.blend_t(row[i - 1], tile, blend_extent)
|
| 510 |
+
result_row.append(tile[:, :, :t_limit, :, :])
|
| 511 |
+
else:
|
| 512 |
+
result_row.append(tile[:, :, : t_limit + 1, :, :])
|
| 513 |
+
moments = torch.cat(result_row, dim=2)
|
| 514 |
+
posterior = DiagonalGaussianDistribution(moments)
|
| 515 |
+
return posterior
|
| 516 |
+
|
| 517 |
+
def temporal_tiled_decode(
|
| 518 |
+
self, z: torch.FloatTensor, return_dict: bool = True
|
| 519 |
+
) -> Union[DecoderOutput, torch.FloatTensor]:
|
| 520 |
+
# Split z into overlapping tiles and decode them separately.
|
| 521 |
+
|
| 522 |
+
B, C, T, H, W = z.shape
|
| 523 |
+
overlap_size = int(self.tile_latent_min_tsize * (1 - self.tile_overlap_factor))
|
| 524 |
+
blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor)
|
| 525 |
+
t_limit = self.tile_sample_min_tsize - blend_extent
|
| 526 |
+
|
| 527 |
+
row = []
|
| 528 |
+
for i in range(0, T, overlap_size):
|
| 529 |
+
tile = z[:, :, i : i + self.tile_latent_min_tsize + 1, :, :]
|
| 530 |
+
if self.use_spatial_tiling and (
|
| 531 |
+
tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size
|
| 532 |
+
):
|
| 533 |
+
decoded = self.spatial_tiled_decode(tile, return_dict=True).sample
|
| 534 |
+
else:
|
| 535 |
+
tile = self.post_quant_conv(tile)
|
| 536 |
+
decoded = self.decoder(tile)
|
| 537 |
+
if i > 0:
|
| 538 |
+
decoded = decoded[:, :, 1:, :, :]
|
| 539 |
+
row.append(decoded)
|
| 540 |
+
result_row = []
|
| 541 |
+
for i, tile in enumerate(row):
|
| 542 |
+
if i > 0:
|
| 543 |
+
tile = self.blend_t(row[i - 1], tile, blend_extent)
|
| 544 |
+
result_row.append(tile[:, :, :t_limit, :, :])
|
| 545 |
+
else:
|
| 546 |
+
result_row.append(tile[:, :, : t_limit + 1, :, :])
|
| 547 |
+
|
| 548 |
+
dec = torch.cat(result_row, dim=2)
|
| 549 |
+
if not return_dict:
|
| 550 |
+
return (dec,)
|
| 551 |
+
|
| 552 |
+
return DecoderOutput(sample=dec)
|
| 553 |
+
|
| 554 |
+
def forward(
|
| 555 |
+
self,
|
| 556 |
+
sample: torch.FloatTensor,
|
| 557 |
+
sample_posterior: bool = True,
|
| 558 |
+
generator: Optional[torch.Generator] = None,
|
| 559 |
+
) -> Tuple[torch.FloatTensor, DiagonalGaussianDistribution, torch.FloatTensor]:
|
| 560 |
+
r"""
|
| 561 |
+
Args:
|
| 562 |
+
sample (`torch.FloatTensor`): Input sample.
|
| 563 |
+
sample_posterior (`bool`, *optional*, defaults to `False`):
|
| 564 |
+
Whether to sample from the posterior.
|
| 565 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 566 |
+
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
| 567 |
+
"""
|
| 568 |
+
x = sample
|
| 569 |
+
z, posterior = self.encode(x, return_posterior=True, sample_posterior=sample_posterior, generator=generator)
|
| 570 |
+
dec = self.decode(z)
|
| 571 |
+
|
| 572 |
+
return (dec, posterior, z)
|
| 573 |
+
|
| 574 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
| 575 |
+
def fuse_qkv_projections(self):
|
| 576 |
+
"""
|
| 577 |
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
| 578 |
+
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
| 579 |
+
|
| 580 |
+
<Tip warning={true}>
|
| 581 |
+
|
| 582 |
+
This API is π§ͺ experimental.
|
| 583 |
+
|
| 584 |
+
</Tip>
|
| 585 |
+
"""
|
| 586 |
+
self.original_attn_processors = None
|
| 587 |
+
|
| 588 |
+
for _, attn_processor in self.attn_processors.items():
|
| 589 |
+
if "Added" in str(attn_processor.__class__.__name__):
|
| 590 |
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
| 591 |
+
|
| 592 |
+
self.original_attn_processors = self.attn_processors
|
| 593 |
+
|
| 594 |
+
for module in self.modules():
|
| 595 |
+
if isinstance(module, Attention):
|
| 596 |
+
module.fuse_projections(fuse=True)
|
| 597 |
+
|
| 598 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
| 599 |
+
def unfuse_qkv_projections(self):
|
| 600 |
+
"""Disables the fused QKV projection if enabled.
|
| 601 |
+
|
| 602 |
+
<Tip warning={true}>
|
| 603 |
+
|
| 604 |
+
This API is π§ͺ experimental.
|
| 605 |
+
|
| 606 |
+
</Tip>
|
| 607 |
+
|
| 608 |
+
"""
|
| 609 |
+
if self.original_attn_processors is not None:
|
| 610 |
+
self.set_attn_processor(self.original_attn_processors)
|
| 611 |
+
|
| 612 |
+
def get_last_layer(self):
|
| 613 |
+
return self.decoder.conv_out.conv.weight
|
| 614 |
+
|
| 615 |
+
def get_latent_size(self, input_size: list[int]) -> list[int]:
|
| 616 |
+
latent_size = []
|
| 617 |
+
# T
|
| 618 |
+
latent_size.append((input_size[0] - 1) // self.time_compression_ratio + 1)
|
| 619 |
+
# H, w
|
| 620 |
+
for i in range(1, 3):
|
| 621 |
+
latent_size.append((input_size[i] - 1) // self.spatial_compression_ratio + 1)
|
| 622 |
+
return latent_size
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
@MODELS.register_module("hunyuan_vae")
|
| 626 |
+
def CausalVAE3D_HUNYUAN(
|
| 627 |
+
from_pretrained: str = None,
|
| 628 |
+
device_map: str | torch.device = "cuda",
|
| 629 |
+
torch_dtype: torch.dtype = torch.bfloat16,
|
| 630 |
+
**kwargs,
|
| 631 |
+
) -> AutoencoderKLCausal3D:
|
| 632 |
+
config = AutoEncoder3DConfig(from_pretrained=from_pretrained, **kwargs)
|
| 633 |
+
with torch.device(device_map):
|
| 634 |
+
model = AutoencoderKLCausal3D(config).to(torch_dtype)
|
| 635 |
+
if from_pretrained:
|
| 636 |
+
model = load_checkpoint(model, from_pretrained, device_map=device_map, strict=True)
|
| 637 |
+
|
| 638 |
+
return model
|
arbitor/encoders/opensora_vae_modules/registry.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from copy import deepcopy
|
| 2 |
+
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from mmengine.registry import Registry
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def build_module(module: dict | nn.Module, builder: Registry, **kwargs) -> nn.Module | None:
|
| 8 |
+
"""Build module from config or return the module itself.
|
| 9 |
+
|
| 10 |
+
Args:
|
| 11 |
+
module (dict | nn.Module): The module to build.
|
| 12 |
+
builder (Registry): The registry to build module.
|
| 13 |
+
*args, **kwargs: Arguments passed to build function.
|
| 14 |
+
|
| 15 |
+
Returns:
|
| 16 |
+
(None | nn.Module): The created model.
|
| 17 |
+
"""
|
| 18 |
+
if module is None:
|
| 19 |
+
return None
|
| 20 |
+
if isinstance(module, dict):
|
| 21 |
+
cfg = deepcopy(module)
|
| 22 |
+
for k, v in kwargs.items():
|
| 23 |
+
cfg[k] = v
|
| 24 |
+
return builder.build(cfg)
|
| 25 |
+
elif isinstance(module, nn.Module):
|
| 26 |
+
return module
|
| 27 |
+
elif module is None:
|
| 28 |
+
return None
|
| 29 |
+
else:
|
| 30 |
+
raise TypeError(f"Only support dict and nn.Module, but got {type(module)}.")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
MODELS = Registry(
|
| 34 |
+
"model",
|
| 35 |
+
locations=["opensora.models"],
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
DATASETS = Registry(
|
| 39 |
+
"dataset",
|
| 40 |
+
locations=["opensora.datasets"],
|
| 41 |
+
)
|
arbitor/encoders/opensora_vae_modules/unet_causal_3d_blocks.py
ADDED
|
@@ -0,0 +1,476 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from diffusers==0.29.2 and HunyuanVideo
|
| 2 |
+
#
|
| 3 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
# #
|
| 17 |
+
# Copyright 2024 HunyuanVideo
|
| 18 |
+
#
|
| 19 |
+
# This source code is licensed under the license found in the
|
| 20 |
+
# LICENSE file in the root directory of this source tree.
|
| 21 |
+
|
| 22 |
+
from typing import Optional, Tuple, Union
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
import torch
|
| 26 |
+
import torch.nn.functional as F
|
| 27 |
+
from diffusers.models.activations import get_activation
|
| 28 |
+
from diffusers.models.attention_processor import Attention
|
| 29 |
+
from diffusers.utils import logging
|
| 30 |
+
from einops import rearrange
|
| 31 |
+
from torch import nn
|
| 32 |
+
|
| 33 |
+
from opensora.acceleration.checkpoint import auto_grad_checkpoint
|
| 34 |
+
from opensora.models.vae.utils import ChannelChunkConv3d, get_conv3d_n_chunks
|
| 35 |
+
|
| 36 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 37 |
+
|
| 38 |
+
INTERPOLATE_NUMEL_LIMIT = 2**31 - 1
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def chunk_nearest_interpolate(
|
| 42 |
+
x: torch.Tensor,
|
| 43 |
+
scale_factor,
|
| 44 |
+
):
|
| 45 |
+
limit = INTERPOLATE_NUMEL_LIMIT // np.prod(scale_factor)
|
| 46 |
+
n_chunks = get_conv3d_n_chunks(x.numel(), x.size(1), limit)
|
| 47 |
+
x_chunks = x.chunk(n_chunks, dim=1)
|
| 48 |
+
x_chunks = [F.interpolate(x_chunk, scale_factor=scale_factor, mode="nearest") for x_chunk in x_chunks]
|
| 49 |
+
return torch.cat(x_chunks, dim=1)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, device, batch_size: int = None):
|
| 53 |
+
seq_len = n_frame * n_hw
|
| 54 |
+
mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
|
| 55 |
+
for i in range(seq_len):
|
| 56 |
+
i_frame = i // n_hw
|
| 57 |
+
mask[i, : (i_frame + 1) * n_hw] = 0
|
| 58 |
+
if batch_size is not None:
|
| 59 |
+
mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
|
| 60 |
+
return mask
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class CausalConv3d(nn.Module):
|
| 64 |
+
"""
|
| 65 |
+
Implements a causal 3D convolution layer where each position only depends on previous timesteps and current spatial locations.
|
| 66 |
+
This maintains temporal causality in video generation tasks.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
def __init__(
|
| 70 |
+
self,
|
| 71 |
+
chan_in,
|
| 72 |
+
chan_out,
|
| 73 |
+
kernel_size: Union[int, Tuple[int, int, int]],
|
| 74 |
+
stride: Union[int, Tuple[int, int, int]] = 1,
|
| 75 |
+
dilation: Union[int, Tuple[int, int, int]] = 1,
|
| 76 |
+
pad_mode="replicate",
|
| 77 |
+
**kwargs,
|
| 78 |
+
):
|
| 79 |
+
super().__init__()
|
| 80 |
+
|
| 81 |
+
self.pad_mode = pad_mode
|
| 82 |
+
padding = (
|
| 83 |
+
kernel_size // 2,
|
| 84 |
+
kernel_size // 2,
|
| 85 |
+
kernel_size // 2,
|
| 86 |
+
kernel_size // 2,
|
| 87 |
+
kernel_size - 1,
|
| 88 |
+
0,
|
| 89 |
+
) # W, H, T
|
| 90 |
+
self.time_causal_padding = padding
|
| 91 |
+
|
| 92 |
+
self.conv = ChannelChunkConv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
| 93 |
+
|
| 94 |
+
def forward(self, x):
|
| 95 |
+
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
|
| 96 |
+
return self.conv(x)
|
| 97 |
+
|
| 98 |
+
class UpsampleCausal3D(nn.Module):
|
| 99 |
+
"""
|
| 100 |
+
A 3D upsampling layer with an optional convolution.
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
def __init__(
|
| 104 |
+
self,
|
| 105 |
+
channels: int,
|
| 106 |
+
out_channels: Optional[int] = None,
|
| 107 |
+
kernel_size: int = 3,
|
| 108 |
+
bias=True,
|
| 109 |
+
upsample_factor=(2, 2, 2),
|
| 110 |
+
):
|
| 111 |
+
super().__init__()
|
| 112 |
+
self.channels = channels
|
| 113 |
+
self.out_channels = out_channels or channels
|
| 114 |
+
self.upsample_factor = upsample_factor
|
| 115 |
+
self.conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, bias=bias)
|
| 116 |
+
|
| 117 |
+
def forward(
|
| 118 |
+
self,
|
| 119 |
+
input_tensor: torch.FloatTensor,
|
| 120 |
+
) -> torch.FloatTensor:
|
| 121 |
+
assert input_tensor.shape[1] == self.channels
|
| 122 |
+
|
| 123 |
+
#######################
|
| 124 |
+
# handle hidden states
|
| 125 |
+
#######################
|
| 126 |
+
hidden_states = input_tensor
|
| 127 |
+
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
| 128 |
+
# dtype = hidden_states.dtype
|
| 129 |
+
# if dtype == torch.bfloat16:
|
| 130 |
+
# hidden_states = hidden_states.to(torch.float32)
|
| 131 |
+
|
| 132 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
| 133 |
+
if hidden_states.shape[0] >= 64:
|
| 134 |
+
hidden_states = hidden_states.contiguous()
|
| 135 |
+
|
| 136 |
+
# interpolate H & W only for the first frame; interpolate T & H & W for the rest
|
| 137 |
+
T = hidden_states.size(2)
|
| 138 |
+
first_h, other_h = hidden_states.split((1, T - 1), dim=2)
|
| 139 |
+
# process non-1st frames
|
| 140 |
+
if T > 1:
|
| 141 |
+
other_h = chunk_nearest_interpolate(other_h, scale_factor=self.upsample_factor)
|
| 142 |
+
# proess 1st fram
|
| 143 |
+
first_h = first_h.squeeze(2)
|
| 144 |
+
first_h = chunk_nearest_interpolate(first_h, scale_factor=self.upsample_factor[1:])
|
| 145 |
+
first_h = first_h.unsqueeze(2)
|
| 146 |
+
# concat together
|
| 147 |
+
if T > 1:
|
| 148 |
+
hidden_states = torch.cat((first_h, other_h), dim=2)
|
| 149 |
+
else:
|
| 150 |
+
hidden_states = first_h
|
| 151 |
+
|
| 152 |
+
# If the input is bfloat16, we cast back to bfloat16
|
| 153 |
+
# if dtype == torch.bfloat16:
|
| 154 |
+
# hidden_states = hidden_states.to(dtype)
|
| 155 |
+
|
| 156 |
+
hidden_states = self.conv(hidden_states)
|
| 157 |
+
|
| 158 |
+
return hidden_states
|
| 159 |
+
|
| 160 |
+
class DownsampleCausal3D(nn.Module):
|
| 161 |
+
"""
|
| 162 |
+
A 3D downsampling layer with an optional convolution.
|
| 163 |
+
"""
|
| 164 |
+
|
| 165 |
+
def __init__(
|
| 166 |
+
self,
|
| 167 |
+
channels: int,
|
| 168 |
+
kernel_size=3,
|
| 169 |
+
bias=True,
|
| 170 |
+
stride=2,
|
| 171 |
+
):
|
| 172 |
+
super().__init__()
|
| 173 |
+
self.channels = channels
|
| 174 |
+
self.out_channels = channels
|
| 175 |
+
self.conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, bias=bias)
|
| 176 |
+
|
| 177 |
+
def forward(self, input_tensor: torch.FloatTensor) -> torch.FloatTensor:
|
| 178 |
+
assert input_tensor.shape[1] == self.channels
|
| 179 |
+
hidden_states = self.conv(input_tensor)
|
| 180 |
+
|
| 181 |
+
return hidden_states
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class ResnetBlockCausal3D(nn.Module):
|
| 185 |
+
r"""
|
| 186 |
+
A Resnet block.
|
| 187 |
+
"""
|
| 188 |
+
|
| 189 |
+
def __init__(
|
| 190 |
+
self,
|
| 191 |
+
*,
|
| 192 |
+
in_channels: int,
|
| 193 |
+
out_channels: Optional[int] = None,
|
| 194 |
+
dropout: float = 0.0,
|
| 195 |
+
groups: int = 32,
|
| 196 |
+
groups_out: Optional[int] = None,
|
| 197 |
+
pre_norm: bool = True,
|
| 198 |
+
eps: float = 1e-6,
|
| 199 |
+
non_linearity: str = "swish",
|
| 200 |
+
output_scale_factor: float = 1.0,
|
| 201 |
+
use_in_shortcut: Optional[bool] = None,
|
| 202 |
+
conv_shortcut_bias: bool = True,
|
| 203 |
+
conv_3d_out_channels: Optional[int] = None,
|
| 204 |
+
):
|
| 205 |
+
super().__init__()
|
| 206 |
+
self.pre_norm = pre_norm
|
| 207 |
+
self.pre_norm = True
|
| 208 |
+
self.in_channels = in_channels
|
| 209 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 210 |
+
self.out_channels = out_channels
|
| 211 |
+
self.output_scale_factor = output_scale_factor
|
| 212 |
+
|
| 213 |
+
if groups_out is None:
|
| 214 |
+
groups_out = groups
|
| 215 |
+
|
| 216 |
+
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
| 217 |
+
self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1)
|
| 218 |
+
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
| 219 |
+
|
| 220 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 221 |
+
conv_3d_out_channels = conv_3d_out_channels or out_channels
|
| 222 |
+
self.conv2 = CausalConv3d(out_channels, conv_3d_out_channels, kernel_size=3, stride=1)
|
| 223 |
+
|
| 224 |
+
self.nonlinearity = get_activation(non_linearity)
|
| 225 |
+
|
| 226 |
+
self.upsample = self.downsample = None
|
| 227 |
+
|
| 228 |
+
self.use_in_shortcut = self.in_channels != conv_3d_out_channels if use_in_shortcut is None else use_in_shortcut
|
| 229 |
+
|
| 230 |
+
self.conv_shortcut = None
|
| 231 |
+
if self.use_in_shortcut:
|
| 232 |
+
self.conv_shortcut = CausalConv3d(
|
| 233 |
+
in_channels,
|
| 234 |
+
conv_3d_out_channels,
|
| 235 |
+
kernel_size=1,
|
| 236 |
+
stride=1,
|
| 237 |
+
bias=conv_shortcut_bias,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
def forward(
|
| 241 |
+
self,
|
| 242 |
+
input_tensor: torch.FloatTensor,
|
| 243 |
+
) -> torch.FloatTensor:
|
| 244 |
+
hidden_states = input_tensor
|
| 245 |
+
|
| 246 |
+
hidden_states = self.norm1(hidden_states)
|
| 247 |
+
hidden_states = self.nonlinearity(hidden_states)
|
| 248 |
+
hidden_states = self.conv1(hidden_states)
|
| 249 |
+
hidden_states = self.norm2(hidden_states)
|
| 250 |
+
hidden_states = self.nonlinearity(hidden_states)
|
| 251 |
+
hidden_states = self.dropout(hidden_states)
|
| 252 |
+
hidden_states = self.conv2(hidden_states)
|
| 253 |
+
|
| 254 |
+
if self.conv_shortcut is not None:
|
| 255 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
| 256 |
+
|
| 257 |
+
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
| 258 |
+
|
| 259 |
+
return output_tensor
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
class UNetMidBlockCausal3D(nn.Module):
|
| 263 |
+
"""
|
| 264 |
+
A 3D UNet mid-block [`UNetMidBlockCausal3D`] with multiple residual blocks and optional attention blocks.
|
| 265 |
+
"""
|
| 266 |
+
|
| 267 |
+
def __init__(
|
| 268 |
+
self,
|
| 269 |
+
in_channels: int,
|
| 270 |
+
dropout: float = 0.0,
|
| 271 |
+
num_layers: int = 1,
|
| 272 |
+
resnet_eps: float = 1e-6,
|
| 273 |
+
resnet_act_fn: str = "swish",
|
| 274 |
+
resnet_groups: int = 32,
|
| 275 |
+
attn_groups: Optional[int] = None,
|
| 276 |
+
resnet_pre_norm: bool = True,
|
| 277 |
+
add_attention: bool = True,
|
| 278 |
+
attention_head_dim: int = 1,
|
| 279 |
+
output_scale_factor: float = 1.0,
|
| 280 |
+
):
|
| 281 |
+
super().__init__()
|
| 282 |
+
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
| 283 |
+
self.add_attention = add_attention
|
| 284 |
+
|
| 285 |
+
if attn_groups is None:
|
| 286 |
+
attn_groups = resnet_groups
|
| 287 |
+
|
| 288 |
+
# there is always at least one resnet
|
| 289 |
+
resnets = [
|
| 290 |
+
ResnetBlockCausal3D(
|
| 291 |
+
in_channels=in_channels,
|
| 292 |
+
out_channels=in_channels,
|
| 293 |
+
eps=resnet_eps,
|
| 294 |
+
groups=resnet_groups,
|
| 295 |
+
dropout=dropout,
|
| 296 |
+
non_linearity=resnet_act_fn,
|
| 297 |
+
output_scale_factor=output_scale_factor,
|
| 298 |
+
pre_norm=resnet_pre_norm,
|
| 299 |
+
)
|
| 300 |
+
]
|
| 301 |
+
attentions = []
|
| 302 |
+
|
| 303 |
+
if attention_head_dim is None:
|
| 304 |
+
logger.warn(
|
| 305 |
+
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
|
| 306 |
+
)
|
| 307 |
+
attention_head_dim = in_channels
|
| 308 |
+
|
| 309 |
+
for _ in range(num_layers):
|
| 310 |
+
if self.add_attention:
|
| 311 |
+
attentions.append(
|
| 312 |
+
Attention(
|
| 313 |
+
in_channels,
|
| 314 |
+
heads=in_channels // attention_head_dim,
|
| 315 |
+
dim_head=attention_head_dim,
|
| 316 |
+
rescale_output_factor=output_scale_factor,
|
| 317 |
+
eps=resnet_eps,
|
| 318 |
+
norm_num_groups=attn_groups,
|
| 319 |
+
spatial_norm_dim=None,
|
| 320 |
+
residual_connection=True,
|
| 321 |
+
bias=True,
|
| 322 |
+
upcast_softmax=True,
|
| 323 |
+
_from_deprecated_attn_block=True,
|
| 324 |
+
)
|
| 325 |
+
)
|
| 326 |
+
else:
|
| 327 |
+
attentions.append(None)
|
| 328 |
+
|
| 329 |
+
resnets.append(
|
| 330 |
+
ResnetBlockCausal3D(
|
| 331 |
+
in_channels=in_channels,
|
| 332 |
+
out_channels=in_channels,
|
| 333 |
+
eps=resnet_eps,
|
| 334 |
+
groups=resnet_groups,
|
| 335 |
+
dropout=dropout,
|
| 336 |
+
non_linearity=resnet_act_fn,
|
| 337 |
+
output_scale_factor=output_scale_factor,
|
| 338 |
+
pre_norm=resnet_pre_norm,
|
| 339 |
+
)
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
self.attentions = nn.ModuleList(attentions)
|
| 343 |
+
self.resnets = nn.ModuleList(resnets)
|
| 344 |
+
|
| 345 |
+
def forward(self, hidden_states: torch.FloatTensor, attention_mask: Optional[torch.Tensor]) -> torch.FloatTensor:
|
| 346 |
+
hidden_states = self.resnets[0](hidden_states)
|
| 347 |
+
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
| 348 |
+
if attn is not None:
|
| 349 |
+
B, C, T, H, W = hidden_states.shape
|
| 350 |
+
hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c")
|
| 351 |
+
hidden_states = attn(hidden_states, attention_mask=attention_mask)
|
| 352 |
+
hidden_states = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=T, h=H, w=W)
|
| 353 |
+
hidden_states = resnet(hidden_states)
|
| 354 |
+
|
| 355 |
+
return hidden_states
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
class DownEncoderBlockCausal3D(nn.Module):
|
| 359 |
+
def __init__(
|
| 360 |
+
self,
|
| 361 |
+
in_channels: int,
|
| 362 |
+
out_channels: int,
|
| 363 |
+
dropout: float = 0.0,
|
| 364 |
+
num_layers: int = 1,
|
| 365 |
+
resnet_eps: float = 1e-6,
|
| 366 |
+
resnet_act_fn: str = "swish",
|
| 367 |
+
resnet_groups: int = 32,
|
| 368 |
+
resnet_pre_norm: bool = True,
|
| 369 |
+
output_scale_factor: float = 1.0,
|
| 370 |
+
add_downsample: bool = True,
|
| 371 |
+
downsample_stride: int = 2,
|
| 372 |
+
):
|
| 373 |
+
super().__init__()
|
| 374 |
+
resnets = []
|
| 375 |
+
|
| 376 |
+
for i in range(num_layers):
|
| 377 |
+
in_channels = in_channels if i == 0 else out_channels
|
| 378 |
+
resnets.append(
|
| 379 |
+
ResnetBlockCausal3D(
|
| 380 |
+
in_channels=in_channels,
|
| 381 |
+
out_channels=out_channels,
|
| 382 |
+
eps=resnet_eps,
|
| 383 |
+
groups=resnet_groups,
|
| 384 |
+
dropout=dropout,
|
| 385 |
+
non_linearity=resnet_act_fn,
|
| 386 |
+
output_scale_factor=output_scale_factor,
|
| 387 |
+
pre_norm=resnet_pre_norm,
|
| 388 |
+
)
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
self.resnets = nn.ModuleList(resnets)
|
| 392 |
+
|
| 393 |
+
if add_downsample:
|
| 394 |
+
self.downsamplers = nn.ModuleList(
|
| 395 |
+
[
|
| 396 |
+
DownsampleCausal3D(
|
| 397 |
+
out_channels,
|
| 398 |
+
stride=downsample_stride,
|
| 399 |
+
)
|
| 400 |
+
]
|
| 401 |
+
)
|
| 402 |
+
else:
|
| 403 |
+
self.downsamplers = None
|
| 404 |
+
|
| 405 |
+
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
| 406 |
+
for resnet in self.resnets:
|
| 407 |
+
hidden_states = auto_grad_checkpoint(resnet, hidden_states)
|
| 408 |
+
|
| 409 |
+
if self.downsamplers is not None:
|
| 410 |
+
for downsampler in self.downsamplers:
|
| 411 |
+
hidden_states = auto_grad_checkpoint(downsampler, hidden_states)
|
| 412 |
+
|
| 413 |
+
return hidden_states
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
class UpDecoderBlockCausal3D(nn.Module):
|
| 417 |
+
def __init__(
|
| 418 |
+
self,
|
| 419 |
+
in_channels: int,
|
| 420 |
+
out_channels: int,
|
| 421 |
+
resolution_idx: Optional[int] = None,
|
| 422 |
+
dropout: float = 0.0,
|
| 423 |
+
num_layers: int = 1,
|
| 424 |
+
resnet_eps: float = 1e-6,
|
| 425 |
+
resnet_act_fn: str = "swish",
|
| 426 |
+
resnet_groups: int = 32,
|
| 427 |
+
resnet_pre_norm: bool = True,
|
| 428 |
+
output_scale_factor: float = 1.0,
|
| 429 |
+
add_upsample: bool = True,
|
| 430 |
+
upsample_scale_factor=(2, 2, 2),
|
| 431 |
+
):
|
| 432 |
+
super().__init__()
|
| 433 |
+
resnets = []
|
| 434 |
+
|
| 435 |
+
for i in range(num_layers):
|
| 436 |
+
input_channels = in_channels if i == 0 else out_channels
|
| 437 |
+
|
| 438 |
+
resnets.append(
|
| 439 |
+
ResnetBlockCausal3D(
|
| 440 |
+
in_channels=input_channels,
|
| 441 |
+
out_channels=out_channels,
|
| 442 |
+
eps=resnet_eps,
|
| 443 |
+
groups=resnet_groups,
|
| 444 |
+
dropout=dropout,
|
| 445 |
+
non_linearity=resnet_act_fn,
|
| 446 |
+
output_scale_factor=output_scale_factor,
|
| 447 |
+
pre_norm=resnet_pre_norm,
|
| 448 |
+
)
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
self.resnets = nn.ModuleList(resnets)
|
| 452 |
+
|
| 453 |
+
if add_upsample:
|
| 454 |
+
self.upsamplers = nn.ModuleList(
|
| 455 |
+
[
|
| 456 |
+
UpsampleCausal3D(
|
| 457 |
+
out_channels,
|
| 458 |
+
out_channels=out_channels,
|
| 459 |
+
upsample_factor=upsample_scale_factor,
|
| 460 |
+
)
|
| 461 |
+
]
|
| 462 |
+
)
|
| 463 |
+
else:
|
| 464 |
+
self.upsamplers = None
|
| 465 |
+
|
| 466 |
+
self.resolution_idx = resolution_idx
|
| 467 |
+
|
| 468 |
+
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
| 469 |
+
for resnet in self.resnets:
|
| 470 |
+
hidden_states = auto_grad_checkpoint(resnet, hidden_states)
|
| 471 |
+
|
| 472 |
+
if self.upsamplers is not None:
|
| 473 |
+
for upsampler in self.upsamplers:
|
| 474 |
+
hidden_states = auto_grad_checkpoint(upsampler, hidden_states)
|
| 475 |
+
|
| 476 |
+
return hidden_states
|
arbitor/encoders/opensora_vae_modules/vae.py
ADDED
|
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from HunyuanVideo
|
| 2 |
+
#
|
| 3 |
+
# Copyright 2024 HunyuanVideo
|
| 4 |
+
#
|
| 5 |
+
# This source code is licensed under the license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import Optional, Tuple
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from diffusers.utils import BaseOutput
|
| 15 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 16 |
+
|
| 17 |
+
from opensora.acceleration.checkpoint import auto_grad_checkpoint, checkpoint
|
| 18 |
+
from opensora.models.hunyuan_vae.unet_causal_3d_blocks import (
|
| 19 |
+
CausalConv3d,
|
| 20 |
+
DownEncoderBlockCausal3D,
|
| 21 |
+
UNetMidBlockCausal3D,
|
| 22 |
+
UpDecoderBlockCausal3D,
|
| 23 |
+
prepare_causal_attention_mask,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class DecoderOutput(BaseOutput):
|
| 29 |
+
r"""
|
| 30 |
+
Output of decoding method.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 34 |
+
The decoded output sample from the last layer of the model.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
sample: torch.FloatTensor
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class EncoderCausal3D(nn.Module):
|
| 41 |
+
r"""
|
| 42 |
+
The `EncoderCausal3D` layer of a variational autoencoder that encodes its input into a latent representation.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
in_channels: int = 3,
|
| 48 |
+
out_channels: int = 3,
|
| 49 |
+
block_out_channels: Tuple[int, ...] = (64,),
|
| 50 |
+
layers_per_block: int = 2,
|
| 51 |
+
norm_num_groups: int = 32,
|
| 52 |
+
act_fn: str = "silu",
|
| 53 |
+
double_z: bool = True,
|
| 54 |
+
mid_block_add_attention=True,
|
| 55 |
+
time_compression_ratio: int = 4,
|
| 56 |
+
spatial_compression_ratio: int = 8,
|
| 57 |
+
dropout: float = 0.0,
|
| 58 |
+
):
|
| 59 |
+
super().__init__()
|
| 60 |
+
self.layers_per_block = layers_per_block
|
| 61 |
+
|
| 62 |
+
self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
|
| 63 |
+
self.mid_block = None
|
| 64 |
+
self.down_blocks = nn.ModuleList([])
|
| 65 |
+
|
| 66 |
+
# down
|
| 67 |
+
output_channel = block_out_channels[0]
|
| 68 |
+
for i, _ in enumerate(block_out_channels):
|
| 69 |
+
input_channel = output_channel
|
| 70 |
+
output_channel = block_out_channels[i]
|
| 71 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 72 |
+
num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
|
| 73 |
+
num_time_downsample_layers = int(np.log2(time_compression_ratio))
|
| 74 |
+
|
| 75 |
+
if time_compression_ratio == 4:
|
| 76 |
+
add_spatial_downsample = bool(i < num_spatial_downsample_layers)
|
| 77 |
+
add_time_downsample = bool(
|
| 78 |
+
i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block
|
| 79 |
+
)
|
| 80 |
+
elif time_compression_ratio == 8:
|
| 81 |
+
add_spatial_downsample = bool(i < num_spatial_downsample_layers)
|
| 82 |
+
add_time_downsample = bool(i < num_spatial_downsample_layers)
|
| 83 |
+
else:
|
| 84 |
+
raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
|
| 85 |
+
|
| 86 |
+
downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
|
| 87 |
+
downsample_stride_T = (2,) if add_time_downsample else (1,)
|
| 88 |
+
downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
|
| 89 |
+
down_block = DownEncoderBlockCausal3D(
|
| 90 |
+
num_layers=self.layers_per_block,
|
| 91 |
+
in_channels=input_channel,
|
| 92 |
+
out_channels=output_channel,
|
| 93 |
+
dropout=dropout,
|
| 94 |
+
add_downsample=bool(add_spatial_downsample or add_time_downsample),
|
| 95 |
+
downsample_stride=downsample_stride,
|
| 96 |
+
resnet_eps=1e-6,
|
| 97 |
+
resnet_act_fn=act_fn,
|
| 98 |
+
resnet_groups=norm_num_groups,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
self.down_blocks.append(down_block)
|
| 102 |
+
|
| 103 |
+
# mid
|
| 104 |
+
self.mid_block = UNetMidBlockCausal3D(
|
| 105 |
+
in_channels=block_out_channels[-1],
|
| 106 |
+
resnet_eps=1e-6,
|
| 107 |
+
resnet_act_fn=act_fn,
|
| 108 |
+
output_scale_factor=1,
|
| 109 |
+
attention_head_dim=block_out_channels[-1],
|
| 110 |
+
resnet_groups=norm_num_groups,
|
| 111 |
+
add_attention=mid_block_add_attention,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
# out
|
| 115 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
|
| 116 |
+
self.conv_act = nn.SiLU()
|
| 117 |
+
|
| 118 |
+
conv_out_channels = 2 * out_channels if double_z else out_channels
|
| 119 |
+
self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3)
|
| 120 |
+
|
| 121 |
+
def prepare_attention_mask(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 122 |
+
B, C, T, H, W = hidden_states.shape
|
| 123 |
+
attention_mask = prepare_causal_attention_mask(
|
| 124 |
+
T, H * W, hidden_states.dtype, hidden_states.device, batch_size=B
|
| 125 |
+
)
|
| 126 |
+
return attention_mask
|
| 127 |
+
|
| 128 |
+
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
| 129 |
+
r"""The forward method of the `EncoderCausal3D` class."""
|
| 130 |
+
assert len(sample.shape) == 5, "The input tensor should have 5 dimensions"
|
| 131 |
+
|
| 132 |
+
sample = self.conv_in(sample)
|
| 133 |
+
|
| 134 |
+
# down
|
| 135 |
+
for down_block in self.down_blocks:
|
| 136 |
+
sample = down_block(sample)
|
| 137 |
+
|
| 138 |
+
# middle
|
| 139 |
+
if self.mid_block.add_attention:
|
| 140 |
+
attention_mask = self.prepare_attention_mask(sample)
|
| 141 |
+
else:
|
| 142 |
+
attention_mask = None
|
| 143 |
+
sample = auto_grad_checkpoint(self.mid_block, sample, attention_mask)
|
| 144 |
+
|
| 145 |
+
# post-process
|
| 146 |
+
sample = self.conv_norm_out(sample)
|
| 147 |
+
sample = self.conv_act(sample)
|
| 148 |
+
sample = self.conv_out(sample)
|
| 149 |
+
|
| 150 |
+
return sample
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class DecoderCausal3D(nn.Module):
|
| 154 |
+
r"""
|
| 155 |
+
The `DecoderCausal3D` layer of a variational autoencoder that decodes its latent representation into an output sample.
|
| 156 |
+
"""
|
| 157 |
+
|
| 158 |
+
def __init__(
|
| 159 |
+
self,
|
| 160 |
+
in_channels: int = 3,
|
| 161 |
+
out_channels: int = 3,
|
| 162 |
+
block_out_channels: Tuple[int, ...] = (64,),
|
| 163 |
+
layers_per_block: int = 2,
|
| 164 |
+
norm_num_groups: int = 32,
|
| 165 |
+
act_fn: str = "silu",
|
| 166 |
+
mid_block_add_attention=True,
|
| 167 |
+
time_compression_ratio: int = 4,
|
| 168 |
+
spatial_compression_ratio: int = 8,
|
| 169 |
+
dropout: float = 0.0,
|
| 170 |
+
):
|
| 171 |
+
super().__init__()
|
| 172 |
+
self.layers_per_block = layers_per_block
|
| 173 |
+
|
| 174 |
+
self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1)
|
| 175 |
+
self.mid_block = None
|
| 176 |
+
self.up_blocks = nn.ModuleList([])
|
| 177 |
+
|
| 178 |
+
# mid
|
| 179 |
+
self.mid_block = UNetMidBlockCausal3D(
|
| 180 |
+
in_channels=block_out_channels[-1],
|
| 181 |
+
resnet_eps=1e-6,
|
| 182 |
+
resnet_act_fn=act_fn,
|
| 183 |
+
output_scale_factor=1,
|
| 184 |
+
attention_head_dim=block_out_channels[-1],
|
| 185 |
+
resnet_groups=norm_num_groups,
|
| 186 |
+
add_attention=mid_block_add_attention,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# up
|
| 190 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
| 191 |
+
output_channel = reversed_block_out_channels[0]
|
| 192 |
+
for i, _ in enumerate(block_out_channels):
|
| 193 |
+
prev_output_channel = output_channel
|
| 194 |
+
output_channel = reversed_block_out_channels[i]
|
| 195 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 196 |
+
num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
|
| 197 |
+
num_time_upsample_layers = int(np.log2(time_compression_ratio))
|
| 198 |
+
|
| 199 |
+
if time_compression_ratio == 4:
|
| 200 |
+
add_spatial_upsample = bool(i < num_spatial_upsample_layers)
|
| 201 |
+
add_time_upsample = bool(
|
| 202 |
+
i >= len(block_out_channels) - 1 - num_time_upsample_layers and not is_final_block
|
| 203 |
+
)
|
| 204 |
+
elif time_compression_ratio == 8:
|
| 205 |
+
add_spatial_upsample = bool(i < num_spatial_upsample_layers)
|
| 206 |
+
add_time_upsample = bool(i < num_spatial_upsample_layers)
|
| 207 |
+
else:
|
| 208 |
+
raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
|
| 209 |
+
|
| 210 |
+
upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
|
| 211 |
+
upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
|
| 212 |
+
upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
|
| 213 |
+
up_block = UpDecoderBlockCausal3D(
|
| 214 |
+
num_layers=self.layers_per_block + 1,
|
| 215 |
+
in_channels=prev_output_channel,
|
| 216 |
+
out_channels=output_channel,
|
| 217 |
+
resolution_idx=None,
|
| 218 |
+
dropout=dropout,
|
| 219 |
+
add_upsample=bool(add_spatial_upsample or add_time_upsample),
|
| 220 |
+
upsample_scale_factor=upsample_scale_factor,
|
| 221 |
+
resnet_eps=1e-6,
|
| 222 |
+
resnet_act_fn=act_fn,
|
| 223 |
+
resnet_groups=norm_num_groups,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
self.up_blocks.append(up_block)
|
| 227 |
+
prev_output_channel = output_channel
|
| 228 |
+
|
| 229 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
|
| 230 |
+
self.conv_act = nn.SiLU()
|
| 231 |
+
self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
|
| 232 |
+
|
| 233 |
+
def post_process(self, sample: torch.Tensor) -> torch.Tensor:
|
| 234 |
+
sample = self.conv_norm_out(sample)
|
| 235 |
+
sample = self.conv_act(sample)
|
| 236 |
+
return sample
|
| 237 |
+
|
| 238 |
+
def prepare_attention_mask(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 239 |
+
B, C, T, H, W = hidden_states.shape
|
| 240 |
+
attention_mask = prepare_causal_attention_mask(
|
| 241 |
+
T, H * W, hidden_states.dtype, hidden_states.device, batch_size=B
|
| 242 |
+
)
|
| 243 |
+
return attention_mask
|
| 244 |
+
|
| 245 |
+
def forward(
|
| 246 |
+
self,
|
| 247 |
+
sample: torch.FloatTensor,
|
| 248 |
+
) -> torch.FloatTensor:
|
| 249 |
+
r"""The forward method of the `DecoderCausal3D` class."""
|
| 250 |
+
assert len(sample.shape) == 5, "The input tensor should have 5 dimensions."
|
| 251 |
+
|
| 252 |
+
sample = self.conv_in(sample)
|
| 253 |
+
|
| 254 |
+
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
| 255 |
+
|
| 256 |
+
# middle
|
| 257 |
+
if self.mid_block.add_attention:
|
| 258 |
+
attention_mask = self.prepare_attention_mask(sample)
|
| 259 |
+
else:
|
| 260 |
+
attention_mask = None
|
| 261 |
+
|
| 262 |
+
sample = auto_grad_checkpoint(self.mid_block, sample, attention_mask)
|
| 263 |
+
sample = sample.to(upscale_dtype)
|
| 264 |
+
|
| 265 |
+
# up
|
| 266 |
+
for up_block in self.up_blocks:
|
| 267 |
+
sample = up_block(sample)
|
| 268 |
+
|
| 269 |
+
# post-process
|
| 270 |
+
if getattr(self, "grad_checkpointing", False):
|
| 271 |
+
sample = checkpoint(self.post_process, sample, use_reentrant=True)
|
| 272 |
+
else:
|
| 273 |
+
sample = self.post_process(sample)
|
| 274 |
+
|
| 275 |
+
sample = self.conv_out(sample)
|
| 276 |
+
|
| 277 |
+
return sample
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
class DiagonalGaussianDistribution(object):
|
| 281 |
+
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
|
| 282 |
+
if parameters.ndim == 3:
|
| 283 |
+
dim = 2 # (B, L, C)
|
| 284 |
+
elif parameters.ndim == 5 or parameters.ndim == 4:
|
| 285 |
+
dim = 1 # (B, C, T, H ,W) / (B, C, H, W)
|
| 286 |
+
else:
|
| 287 |
+
raise NotImplementedError
|
| 288 |
+
self.parameters = parameters
|
| 289 |
+
self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim)
|
| 290 |
+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
| 291 |
+
self.deterministic = deterministic
|
| 292 |
+
self.std = torch.exp(0.5 * self.logvar)
|
| 293 |
+
self.var = torch.exp(self.logvar)
|
| 294 |
+
if self.deterministic:
|
| 295 |
+
self.var = self.std = torch.zeros_like(
|
| 296 |
+
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
|
| 300 |
+
# make sure sample is on the same device as the parameters and has same dtype
|
| 301 |
+
sample = randn_tensor(
|
| 302 |
+
self.mean.shape,
|
| 303 |
+
generator=generator,
|
| 304 |
+
device=self.parameters.device,
|
| 305 |
+
dtype=self.parameters.dtype,
|
| 306 |
+
)
|
| 307 |
+
x = self.mean + self.std * sample
|
| 308 |
+
return x
|
| 309 |
+
|
| 310 |
+
def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
|
| 311 |
+
if self.deterministic:
|
| 312 |
+
return torch.Tensor([0.0])
|
| 313 |
+
else:
|
| 314 |
+
reduce_dim = list(range(1, self.mean.ndim))
|
| 315 |
+
if other is None:
|
| 316 |
+
return 0.5 * torch.sum(
|
| 317 |
+
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
| 318 |
+
dim=reduce_dim,
|
| 319 |
+
)
|
| 320 |
+
else:
|
| 321 |
+
return 0.5 * torch.sum(
|
| 322 |
+
torch.pow(self.mean - other.mean, 2) / other.var
|
| 323 |
+
+ self.var / other.var
|
| 324 |
+
- 1.0
|
| 325 |
+
- self.logvar
|
| 326 |
+
+ other.logvar,
|
| 327 |
+
dim=reduce_dim,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
|
| 331 |
+
if self.deterministic:
|
| 332 |
+
return torch.Tensor([0.0])
|
| 333 |
+
logtwopi = np.log(2.0 * np.pi)
|
| 334 |
+
return 0.5 * torch.sum(
|
| 335 |
+
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
| 336 |
+
dim=dims,
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
def mode(self) -> torch.Tensor:
|
| 340 |
+
return self.mean
|
arbitor/encoders/pig_vae.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""pig-vae (WanVAE) sidecar module.
|
| 2 |
+
|
| 3 |
+
Loads from local safetensors, .pth, or diffusers AutoencoderKLWan.
|
| 4 |
+
Exposes encode() and decode() for the VideoHead training pipeline.
|
| 5 |
+
|
| 6 |
+
Latent shape: [B, 16, T/4, H/8, W/8] for input video of T frames at HxW.
|
| 7 |
+
"""
|
| 8 |
+
import os, torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
|
| 11 |
+
_LOCAL_VAE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models", "pig-vae")
|
| 12 |
+
_VAE_CONFIG = {
|
| 13 |
+
"base_dim": 96, "z_dim": 16, "dim_mult": [1, 2, 4, 4],
|
| 14 |
+
"num_res_blocks": 2, "dropout": 0.0,
|
| 15 |
+
"temperal_downsample": [False, True, True],
|
| 16 |
+
"in_channels": 3, "out_channels": 3,
|
| 17 |
+
"scale_factor_temporal": 4, "scale_factor_spatial": 8,
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _freeze_sidecar(model, quantize_requested=None, quantized=False):
|
| 22 |
+
model._arb_quantize_requested = quantize_requested
|
| 23 |
+
model._arb_quantized_int8 = bool(quantized and quantize_requested == "int8")
|
| 24 |
+
model._arb_quantized = bool(quantized)
|
| 25 |
+
for p in model.parameters():
|
| 26 |
+
p.requires_grad = False
|
| 27 |
+
return model
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _has_quantized_modules(model):
|
| 31 |
+
markers = ("Q", "Quanto", "Quantized", "WeightQ")
|
| 32 |
+
return any(any(marker in type(module).__name__ for marker in markers) for module in model.modules())
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _quantize_int8_if_requested(model, quantize):
|
| 36 |
+
if quantize == 'int8':
|
| 37 |
+
from optimum.quanto import quantize as quanto_quantize, freeze, qint8
|
| 38 |
+
quanto_quantize(model, weights=qint8)
|
| 39 |
+
freeze(model)
|
| 40 |
+
return _freeze_sidecar(model, quantize_requested=quantize, quantized=_has_quantized_modules(model))
|
| 41 |
+
return _freeze_sidecar(model, quantize_requested=quantize, quantized=False)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _wan_vae_cls():
|
| 45 |
+
try:
|
| 46 |
+
from diffusers import AutoencoderKLWan
|
| 47 |
+
except ModuleNotFoundError as exc:
|
| 48 |
+
raise RuntimeError(
|
| 49 |
+
"pig-vae requires the optional diffusers dependency. "
|
| 50 |
+
"Install the project with `pip install -e .[diffusers]` in a venv "
|
| 51 |
+
"before loading or verifying pig-vae int8 quantization."
|
| 52 |
+
) from exc
|
| 53 |
+
return AutoencoderKLWan
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def load_vae(device='cuda', quantize='int8'):
|
| 57 |
+
"""Load pig-vae from local cache or diffusers. Optionally int8 quantize."""
|
| 58 |
+
safetensors_path = os.path.join(_LOCAL_VAE_DIR, "model.safetensors")
|
| 59 |
+
gguf_path = os.path.join(_LOCAL_VAE_DIR, "pig_wan_vae_fp32-f16.gguf")
|
| 60 |
+
|
| 61 |
+
if os.path.isfile(safetensors_path):
|
| 62 |
+
return _load_local(safetensors_path, device, quantize, is_safetensors=True)
|
| 63 |
+
if os.path.isfile(gguf_path):
|
| 64 |
+
return _load_gguf(gguf_path, device, quantize)
|
| 65 |
+
return _load_from_hf(device, quantize)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _build_vae():
|
| 69 |
+
AutoencoderKLWan = _wan_vae_cls()
|
| 70 |
+
return AutoencoderKLWan(
|
| 71 |
+
**_VAE_CONFIG,
|
| 72 |
+
latents_mean=[-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653,
|
| 73 |
+
-0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632,
|
| 74 |
+
-0.1922, -0.9497, 0.2503, -0.2921],
|
| 75 |
+
latents_std=[2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708,
|
| 76 |
+
2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579,
|
| 77 |
+
1.6382, 1.1253, 2.8251, 1.916],
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def _load_local(path, device, quantize, is_safetensors=False):
|
| 82 |
+
if is_safetensors:
|
| 83 |
+
AutoencoderKLWan = _wan_vae_cls()
|
| 84 |
+
model = AutoencoderKLWan.from_single_file(path)
|
| 85 |
+
else:
|
| 86 |
+
model = _build_vae()
|
| 87 |
+
ckpt = torch.load(path, map_location="cpu", weights_only=True)
|
| 88 |
+
missing, unexpected = model.load_state_dict(ckpt, strict=False)
|
| 89 |
+
if missing or unexpected:
|
| 90 |
+
raise RuntimeError(
|
| 91 |
+
"pig-vae local .pth checkpoint does not match AutoencoderKLWan "
|
| 92 |
+
f"(missing={len(missing)}, unexpected={len(unexpected)})."
|
| 93 |
+
)
|
| 94 |
+
model = model.to(device)
|
| 95 |
+
model.eval()
|
| 96 |
+
model = _quantize_int8_if_requested(model, quantize)
|
| 97 |
+
return VAEWrapper(model)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def _load_gguf(path, device, quantize):
|
| 101 |
+
import gguf
|
| 102 |
+
reader = gguf.GGUFReader(path)
|
| 103 |
+
state_dict = {t.name: torch.tensor(t.data) for t in reader.tensors}
|
| 104 |
+
model = _build_vae()
|
| 105 |
+
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
| 106 |
+
if missing or unexpected:
|
| 107 |
+
raise RuntimeError(
|
| 108 |
+
"pig-vae local GGUF checkpoint does not match AutoencoderKLWan "
|
| 109 |
+
f"(missing={len(missing)}, unexpected={len(unexpected)})."
|
| 110 |
+
)
|
| 111 |
+
model = model.to(device)
|
| 112 |
+
model.eval()
|
| 113 |
+
model = _quantize_int8_if_requested(model, quantize)
|
| 114 |
+
return VAEWrapper(model)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def _load_from_hf(device, quantize):
|
| 118 |
+
AutoencoderKLWan = _wan_vae_cls()
|
| 119 |
+
model = AutoencoderKLWan.from_pretrained(
|
| 120 |
+
"Wan-AI/Wan2.1-T2V-1.3B", subfolder="vae",
|
| 121 |
+
torch_dtype=torch.bfloat16,
|
| 122 |
+
)
|
| 123 |
+
model = model.to(device)
|
| 124 |
+
model.eval()
|
| 125 |
+
model = _quantize_int8_if_requested(model, quantize)
|
| 126 |
+
return VAEWrapper(model)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class VAEWrapper(nn.Module):
|
| 130 |
+
def __init__(self, vae):
|
| 131 |
+
super().__init__()
|
| 132 |
+
self.vae = vae
|
| 133 |
+
self.latent_channels = _VAE_CONFIG["z_dim"]
|
| 134 |
+
self.scale_factor = 0.476986
|
| 135 |
+
|
| 136 |
+
def encode(self, video_tensor):
|
| 137 |
+
with torch.no_grad():
|
| 138 |
+
dist = self.vae.encode(video_tensor)
|
| 139 |
+
latents = dist.latent_dist.sample() if hasattr(dist, 'latent_dist') else dist
|
| 140 |
+
latents = latents * self.scale_factor
|
| 141 |
+
return latents
|
| 142 |
+
|
| 143 |
+
def decode(self, latents):
|
| 144 |
+
with torch.no_grad():
|
| 145 |
+
latents = latents / self.scale_factor
|
| 146 |
+
video = self.vae.decode(latents)
|
| 147 |
+
video = video.sample if hasattr(video, 'sample') else video
|
| 148 |
+
return video
|
arbitor/encoders/vae2d.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""2D VAE encoder β wraps PixArt SDXL AutoencoderKL encoder half.
|
| 2 |
+
|
| 3 |
+
Encodes images or mel spectrograms to [B, 4, H/8, W/8] latents.
|
| 4 |
+
Same encoder used for images AND audio spectrograms (via MelSpectrogram3Band).
|
| 5 |
+
Frozen float32 sidecar (no gradients).
|
| 6 |
+
"""
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def load_vae2d(device="cuda", quantize=None):
|
| 13 |
+
from diffusers import AutoencoderKL
|
| 14 |
+
|
| 15 |
+
vae = AutoencoderKL.from_pretrained(
|
| 16 |
+
"PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
|
| 17 |
+
subfolder="vae",
|
| 18 |
+
torch_dtype=torch.float32,
|
| 19 |
+
).to(device)
|
| 20 |
+
vae.eval()
|
| 21 |
+
for p in vae.parameters():
|
| 22 |
+
p.requires_grad = False
|
| 23 |
+
|
| 24 |
+
return VAE2DEncoder(vae)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class VAE2DEncoder(nn.Module):
|
| 28 |
+
def __init__(self, vae):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.encoder = vae.encoder
|
| 31 |
+
self.quant_conv = vae.quant_conv
|
| 32 |
+
self.latent_channels = 4
|
| 33 |
+
self.input_scale = 0.18215
|
| 34 |
+
|
| 35 |
+
def forward(self, x):
|
| 36 |
+
H, W = x.shape[-2], x.shape[-1]
|
| 37 |
+
pad_h = (8 - H % 8) % 8
|
| 38 |
+
pad_w = (8 - W % 8) % 8
|
| 39 |
+
if pad_h > 0 or pad_w > 0:
|
| 40 |
+
x = F.pad(x, (0, pad_w, 0, pad_h))
|
| 41 |
+
|
| 42 |
+
h = self.encoder(x)
|
| 43 |
+
moments = self.quant_conv(h)
|
| 44 |
+
posterior = torch.distributions.Normal(
|
| 45 |
+
moments[:, :self.latent_channels],
|
| 46 |
+
torch.nn.functional.softplus(moments[:, self.latent_channels:])
|
| 47 |
+
)
|
| 48 |
+
latent = posterior.rsample()
|
| 49 |
+
latent = latent * self.input_scale
|
| 50 |
+
|
| 51 |
+
if pad_h > 0 or pad_w > 0:
|
| 52 |
+
out_h = H // 8 if H >= 8 else 1
|
| 53 |
+
out_w = W // 8 if W >= 8 else 1
|
| 54 |
+
latent = latent[:, :, :out_h, :out_w]
|
| 55 |
+
|
| 56 |
+
return latent
|
arbitor/kernel/flash_vq.py
ADDED
|
@@ -0,0 +1,510 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FlashVQ: Custom Vector Quantization with dual Triton GPU + PyTorch CPU path.
|
| 3 |
+
|
| 4 |
+
Replaces vector_quantize_pytorch entirely (D-100). FlashVQCodebook is a standalone
|
| 5 |
+
nn.Module implementing all VQ operations:
|
| 6 |
+
- Cosine similarity codebook lookup
|
| 7 |
+
- EMA codebook update
|
| 8 |
+
- Dead code reset
|
| 9 |
+
- Rotation trick (gradient through quantization)
|
| 10 |
+
- Commitment loss
|
| 11 |
+
|
| 12 |
+
Dispatch pattern (following tscale.py):
|
| 13 |
+
if x.is_cuda and _HAS_TRITON β _TritonFlashVQFn.apply()
|
| 14 |
+
else β self._cpu_forward()
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
|
| 21 |
+
_HAS_TRITON = False
|
| 22 |
+
try:
|
| 23 |
+
import triton
|
| 24 |
+
import triton.language as tl
|
| 25 |
+
_HAS_TRITON = True
|
| 26 |
+
except ImportError:
|
| 27 |
+
pass
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class _RotationTrickFn(torch.autograd.Function):
|
| 31 |
+
"""
|
| 32 |
+
Rotation trick gradient through vector quantization.
|
| 33 |
+
|
| 34 |
+
Instead of straight-through estimator (STE), rotate the encoder output
|
| 35 |
+
gradient toward the quantized vector direction. This helps the encoder
|
| 36 |
+
learn to produce outputs that align with codebook entries.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
@staticmethod
|
| 40 |
+
def forward(ctx, x, quantized):
|
| 41 |
+
ctx.save_for_backward(x.detach(), quantized.detach())
|
| 42 |
+
return quantized
|
| 43 |
+
|
| 44 |
+
@staticmethod
|
| 45 |
+
def backward(ctx, grad_output):
|
| 46 |
+
x, quantized = ctx.saved_tensors
|
| 47 |
+
# Normalize in fp32 for numerical stability
|
| 48 |
+
x_norm = F.normalize(x.float(), dim=-1)
|
| 49 |
+
q_norm = F.normalize(quantized.float(), dim=-1)
|
| 50 |
+
# Gradient deflection: subtract projection onto (x_norm - q_norm)
|
| 51 |
+
# This rotates the gradient toward the quantized direction
|
| 52 |
+
diff = x_norm - q_norm
|
| 53 |
+
proj = (grad_output.float() * x_norm).sum(dim=-1, keepdim=True)
|
| 54 |
+
grad_x = grad_output.float() - proj * diff
|
| 55 |
+
return grad_x.to(grad_output.dtype), None
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class FlashVQCodebook(nn.Module):
|
| 59 |
+
"""
|
| 60 |
+
Vector quantization codebook with dual GPU (Triton) / CPU (PyTorch) paths.
|
| 61 |
+
|
| 62 |
+
Interface matches vector_quantize_pytorch.VectorQuantize:
|
| 63 |
+
forward(x) β (quantized, indices, commitment_loss)
|
| 64 |
+
|
| 65 |
+
All VQ operations are self-contained:
|
| 66 |
+
- Cosine similarity codebook lookup
|
| 67 |
+
- Straight-through estimator (STE) with optional rotation trick
|
| 68 |
+
- EMA codebook update (decay=0.99)
|
| 69 |
+
- Dead code reset (threshold_ema_dead_code=2)
|
| 70 |
+
- Commitment loss
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
def __init__(
|
| 74 |
+
self,
|
| 75 |
+
codebook_size: int = 8192,
|
| 76 |
+
codebook_dim: int = 32,
|
| 77 |
+
decay: float = 0.99,
|
| 78 |
+
commitment_weight: float = 1.0,
|
| 79 |
+
threshold_ema_dead_code: int = 2,
|
| 80 |
+
kmeans_init: bool = True,
|
| 81 |
+
kmeans_iters: int = 10,
|
| 82 |
+
rotation_trick: bool = True,
|
| 83 |
+
):
|
| 84 |
+
super().__init__()
|
| 85 |
+
self.codebook_size = codebook_size
|
| 86 |
+
self.codebook_dim = codebook_dim
|
| 87 |
+
self.decay = decay
|
| 88 |
+
self.commitment_weight = commitment_weight
|
| 89 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
| 90 |
+
self.kmeans_init = kmeans_init
|
| 91 |
+
self.kmeans_iters = kmeans_iters
|
| 92 |
+
self.rotation_trick = rotation_trick
|
| 93 |
+
|
| 94 |
+
# Codebook buffers
|
| 95 |
+
self.register_buffer('embed', torch.randn(codebook_size, codebook_dim) * 0.02)
|
| 96 |
+
self.register_buffer('cluster_size', torch.zeros(codebook_size))
|
| 97 |
+
self.register_buffer('embed_avg', torch.zeros(codebook_size, codebook_dim))
|
| 98 |
+
|
| 99 |
+
# Tile sizes for Triton kernel (set on first GPU forward)
|
| 100 |
+
self._triton_block_bt = 16
|
| 101 |
+
self._triton_tile_k = 1024
|
| 102 |
+
|
| 103 |
+
def _compute_tile_sizes(self):
|
| 104 |
+
"""
|
| 105 |
+
Dynamic tile sizing per D-102.
|
| 106 |
+
|
| 107 |
+
Queries GPU device properties to determine SRAM budget, then computes
|
| 108 |
+
BLOCK_BT and TILE_K such that:
|
| 109 |
+
BLOCK_BT * codebook_dim * 2 + TILE_K * codebook_dim * 2 < SRAM * 0.9
|
| 110 |
+
|
| 111 |
+
For sm_89 (RTX 4060, 99KB SRAM per SM):
|
| 112 |
+
codebook_size=8192, codebook_dim=32 β BLOCK_BT=16, TILE_K=1024 (65KB)
|
| 113 |
+
codebook_size=4096, codebook_dim=32 β BLOCK_BT=16, TILE_K=512 (33KB)
|
| 114 |
+
"""
|
| 115 |
+
if not torch.cuda.is_available():
|
| 116 |
+
return
|
| 117 |
+
try:
|
| 118 |
+
props = torch.cuda.get_device_properties(0)
|
| 119 |
+
sram_budget = 99 * 1024 # SM 8.9: 99KB per SM
|
| 120 |
+
|
| 121 |
+
# Conservative estimate: each element is 2 bytes (bf16) in SRAM
|
| 122 |
+
elem_bytes = 2
|
| 123 |
+
|
| 124 |
+
# Find largest TILE_K that fits with BLOCK_BT=16
|
| 125 |
+
bt = 16
|
| 126 |
+
for tk in [2048, 1024, 512, 256, 128]:
|
| 127 |
+
sram_usage = bt * self.codebook_dim * elem_bytes + tk * self.codebook_dim * elem_bytes
|
| 128 |
+
if sram_usage < sram_budget * 0.9:
|
| 129 |
+
self._triton_block_bt = bt
|
| 130 |
+
self._triton_tile_k = tk
|
| 131 |
+
return
|
| 132 |
+
|
| 133 |
+
# Fallback for very constrained SRAM or large codebook_dim
|
| 134 |
+
self._triton_block_bt = 8
|
| 135 |
+
self._triton_tile_k = 256
|
| 136 |
+
except Exception:
|
| 137 |
+
# Default values
|
| 138 |
+
self._triton_block_bt = 16
|
| 139 |
+
self._triton_tile_k = 1024
|
| 140 |
+
|
| 141 |
+
def forward(self, x: torch.Tensor):
|
| 142 |
+
"""
|
| 143 |
+
Args:
|
| 144 |
+
x: Input tensor of shape [*, codebook_dim]
|
| 145 |
+
Returns:
|
| 146 |
+
quantized: Tensor of same shape as x
|
| 147 |
+
indices: Tensor of shape [*] with codebook indices
|
| 148 |
+
commitment_loss: Scalar tensor
|
| 149 |
+
"""
|
| 150 |
+
orig_shape = x.shape
|
| 151 |
+
x_flat = x.reshape(-1, self.codebook_dim)
|
| 152 |
+
|
| 153 |
+
if x.is_cuda and _HAS_TRITON:
|
| 154 |
+
quantized, indices, commitment_loss = self._triton_forward(x_flat)
|
| 155 |
+
else:
|
| 156 |
+
quantized, indices, commitment_loss = self._cpu_forward(x_flat)
|
| 157 |
+
|
| 158 |
+
quantized = quantized.reshape(orig_shape)
|
| 159 |
+
indices = indices.reshape(orig_shape[:-1])
|
| 160 |
+
return quantized, indices, commitment_loss
|
| 161 |
+
|
| 162 |
+
def _triton_forward(self, x_flat: torch.Tensor):
|
| 163 |
+
"""Triton GPU path β dispatched when CUDA + Triton available."""
|
| 164 |
+
# Use _TritonFlashVQFn for forward + backward via autograd
|
| 165 |
+
quantized, indices, commitment_loss = _TritonFlashVQFn.apply(
|
| 166 |
+
x_flat, self.embed, self.cluster_size, self.embed_avg,
|
| 167 |
+
self.codebook_size, self.codebook_dim,
|
| 168 |
+
self.commitment_weight, self.rotation_trick,
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
# EMA update and dead code reset (under torch.no_grad)
|
| 172 |
+
with torch.no_grad():
|
| 173 |
+
self._ema_update(x_flat, indices)
|
| 174 |
+
self._dead_code_reset(x_flat)
|
| 175 |
+
|
| 176 |
+
return quantized, indices, commitment_loss
|
| 177 |
+
|
| 178 |
+
def _cpu_forward(self, x_flat: torch.Tensor):
|
| 179 |
+
"""
|
| 180 |
+
Pure PyTorch CPU path β implements all VQ operations.
|
| 181 |
+
|
| 182 |
+
Steps:
|
| 183 |
+
1. Cosine similarity lookup β nearest codebook entry indices
|
| 184 |
+
2. Quantize via straight-through estimator (or rotation trick)
|
| 185 |
+
3. Compute commitment loss
|
| 186 |
+
4. EMA update codebook (under torch.no_grad)
|
| 187 |
+
5. Dead code reset (under torch.no_grad)
|
| 188 |
+
"""
|
| 189 |
+
# ββ Step 1: Cosine similarity lookup ββ
|
| 190 |
+
x_norm = F.normalize(x_flat.float(), dim=-1)
|
| 191 |
+
embed_norm = F.normalize(self.embed.float(), dim=-1)
|
| 192 |
+
sim = x_norm @ embed_norm.T # [N, codebook_size]
|
| 193 |
+
indices = sim.argmax(dim=-1) # [N]
|
| 194 |
+
|
| 195 |
+
# ββ Step 2: Quantize with STE or rotation trick ββ
|
| 196 |
+
with torch.no_grad():
|
| 197 |
+
quantized = self.embed[indices] # [N, D]
|
| 198 |
+
|
| 199 |
+
if self.rotation_trick:
|
| 200 |
+
quantized = _RotationTrickFn.apply(x_flat, quantized)
|
| 201 |
+
else:
|
| 202 |
+
# Straight-through estimator
|
| 203 |
+
quantized = x_flat + (quantized - x_flat).detach()
|
| 204 |
+
|
| 205 |
+
# ββ Step 3: Commitment loss ββ
|
| 206 |
+
commitment_loss = self.commitment_weight * F.mse_loss(
|
| 207 |
+
x_flat.float(), quantized.detach().float()
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
# ββ Step 4: EMA update ββ
|
| 211 |
+
with torch.no_grad():
|
| 212 |
+
self._ema_update(x_flat, indices)
|
| 213 |
+
|
| 214 |
+
# ββ Step 5: Dead code reset ββ
|
| 215 |
+
self._dead_code_reset(x_flat)
|
| 216 |
+
|
| 217 |
+
return quantized, indices, commitment_loss
|
| 218 |
+
|
| 219 |
+
def _ema_update(self, x_flat: torch.Tensor, indices: torch.Tensor):
|
| 220 |
+
"""
|
| 221 |
+
Exponential moving average codebook update.
|
| 222 |
+
|
| 223 |
+
Args:
|
| 224 |
+
x_flat: [N, D] input vectors
|
| 225 |
+
indices: [N] codebook indices for each input vector
|
| 226 |
+
"""
|
| 227 |
+
one_hot = F.one_hot(indices, num_classes=self.codebook_size).float() # [N, codebook_size]
|
| 228 |
+
n_assign = one_hot.sum(dim=0) # [codebook_size]
|
| 229 |
+
|
| 230 |
+
# EMA on cluster_size (how many inputs assigned to each code)
|
| 231 |
+
self.cluster_size.mul_(self.decay).add_(n_assign * (1 - self.decay))
|
| 232 |
+
|
| 233 |
+
# EMA on embed_avg: weighted sum of assigned inputs
|
| 234 |
+
# embed_avg[c] = decay * embed_avg[c] + (1 - decay) * sum(x assigned to c)
|
| 235 |
+
x_float = x_flat.float()
|
| 236 |
+
for c in range(self.codebook_size):
|
| 237 |
+
mask = indices == c
|
| 238 |
+
count = mask.sum().item()
|
| 239 |
+
if count > 0:
|
| 240 |
+
assigned_sum = x_float[mask].sum(dim=0)
|
| 241 |
+
self.embed_avg[c].mul_(self.decay).add_(assigned_sum * (1 - self.decay))
|
| 242 |
+
|
| 243 |
+
# Normalize: embed = embed_avg / cluster_size (with epsilon)
|
| 244 |
+
cluster_size_safe = self.cluster_size.clamp(min=1e-5)
|
| 245 |
+
self.embed.copy_(self.embed_avg / cluster_size_safe.unsqueeze(1))
|
| 246 |
+
|
| 247 |
+
def _dead_code_reset(self, x_flat: torch.Tensor):
|
| 248 |
+
"""
|
| 249 |
+
Replace dead codebook entries (cluster_size < threshold) with
|
| 250 |
+
random vectors from the current input batch.
|
| 251 |
+
"""
|
| 252 |
+
dead_mask = self.cluster_size < self.threshold_ema_dead_code
|
| 253 |
+
n_dead = dead_mask.sum().item()
|
| 254 |
+
if n_dead == 0:
|
| 255 |
+
return
|
| 256 |
+
dead_indices = torch.where(dead_mask)[0]
|
| 257 |
+
# Replace with random input vectors
|
| 258 |
+
rand_idx = torch.randint(0, x_flat.shape[0], (n_dead,), device=x_flat.device)
|
| 259 |
+
self.embed[dead_indices] = x_flat[rand_idx].detach()
|
| 260 |
+
self.cluster_size[dead_indices] = 0.0
|
| 261 |
+
self.embed_avg[dead_indices] = 0.0
|
| 262 |
+
|
| 263 |
+
@torch.no_grad()
|
| 264 |
+
def kmeans_init_codebook(self, x: torch.Tensor):
|
| 265 |
+
"""Initialize codebook via k-means on first batch."""
|
| 266 |
+
x_flat = x.reshape(-1, self.codebook_dim).float()
|
| 267 |
+
centroids = x_flat[torch.randperm(x_flat.shape[0])[:self.codebook_size]].clone()
|
| 268 |
+
for _ in range(self.kmeans_iters):
|
| 269 |
+
dist = torch.cdist(x_flat, centroids)
|
| 270 |
+
assign = dist.argmin(dim=-1)
|
| 271 |
+
for i in range(self.codebook_size):
|
| 272 |
+
mask = assign == i
|
| 273 |
+
if mask.sum() > 0:
|
| 274 |
+
centroids[i] = x_flat[mask].mean(dim=0)
|
| 275 |
+
self.embed.copy_(centroids)
|
| 276 |
+
|
| 277 |
+
@torch.no_grad()
|
| 278 |
+
def get_codebook_utilization(self) -> float:
|
| 279 |
+
"""Fraction of codebook entries with any usage."""
|
| 280 |
+
return (self.cluster_size > 0).float().mean().item()
|
| 281 |
+
|
| 282 |
+
@torch.no_grad()
|
| 283 |
+
def get_dead_code_count(self) -> int:
|
| 284 |
+
"""Number of codebook entries below EMA dead threshold."""
|
| 285 |
+
return (self.cluster_size < self.threshold_ema_dead_code).sum().item()
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
# βββ Triton GPU Kernels βββ
|
| 289 |
+
# Only defined when Triton is available
|
| 290 |
+
|
| 291 |
+
if _HAS_TRITON:
|
| 292 |
+
|
| 293 |
+
@triton.jit
|
| 294 |
+
def _triton_flash_vq_lookup_kernel(
|
| 295 |
+
x_ptr, codebook_ptr, indices_ptr,
|
| 296 |
+
stride_xb, stride_xd,
|
| 297 |
+
stride_cb, stride_cd,
|
| 298 |
+
N_CTX: tl.constexpr,
|
| 299 |
+
CODEBOOK_SIZE: tl.constexpr,
|
| 300 |
+
CODEBOOK_DIM: tl.constexpr,
|
| 301 |
+
BLOCK_BT: tl.constexpr,
|
| 302 |
+
TILE_K: tl.constexpr,
|
| 303 |
+
):
|
| 304 |
+
"""
|
| 305 |
+
Tiled cosine similarity + argmax lookup for VQ codebook.
|
| 306 |
+
|
| 307 |
+
Architecture:
|
| 308 |
+
pid = batch tile index
|
| 309 |
+
Load input tile [BLOCK_BT, CODEBOOK_DIM]
|
| 310 |
+
Normalize in fp32
|
| 311 |
+
Tile over codebook in TILE_K chunks:
|
| 312 |
+
Load codebook tile [TILE_K, CODEBOOK_DIM]
|
| 313 |
+
Normalize in fp32
|
| 314 |
+
Compute dot product via tl.dot β [BLOCK_BT, TILE_K]
|
| 315 |
+
Update running argmax
|
| 316 |
+
Store best indices
|
| 317 |
+
|
| 318 |
+
SRAM: all arithmetic in fp32 with small tiles to fit 99KB budget.
|
| 319 |
+
"""
|
| 320 |
+
pid = tl.program_id(0)
|
| 321 |
+
offs_bt = pid * BLOCK_BT + tl.arange(0, BLOCK_BT)
|
| 322 |
+
offs_d = tl.arange(0, CODEBOOK_DIM)
|
| 323 |
+
|
| 324 |
+
# ββ Load input tile ββ
|
| 325 |
+
x_ptrs = x_ptr + offs_bt[:, None] * stride_xb + offs_d[None, :] * stride_xd
|
| 326 |
+
x = tl.load(x_ptrs, mask=offs_bt[:, None] < N_CTX, other=0.0)
|
| 327 |
+
|
| 328 |
+
# ββ Normalize input in fp32 (no keepdims in Triton tl.sum) ββ
|
| 329 |
+
x_f32 = x.to(tl.float32)
|
| 330 |
+
x_sq = tl.sum(x_f32 * x_f32, axis=1) # [BLOCK_BT]
|
| 331 |
+
x_norm_f32 = x_f32 / tl.sqrt(x_sq[:, None] + 1e-8)
|
| 332 |
+
|
| 333 |
+
# ββ Running argmax over tiled codebook ββ
|
| 334 |
+
best_sim = tl.full([BLOCK_BT], -float('inf'), dtype=tl.float32)
|
| 335 |
+
best_idx = tl.zeros([BLOCK_BT], dtype=tl.int32)
|
| 336 |
+
|
| 337 |
+
for k_start in range(0, CODEBOOK_SIZE, TILE_K):
|
| 338 |
+
offs_k = k_start + tl.arange(0, TILE_K)
|
| 339 |
+
k_mask = offs_k < CODEBOOK_SIZE
|
| 340 |
+
|
| 341 |
+
# Load codebook tile into fp32 directly for normalization
|
| 342 |
+
cb_ptrs = (codebook_ptr
|
| 343 |
+
+ offs_k[:, None] * stride_cb
|
| 344 |
+
+ offs_d[None, :] * stride_cd)
|
| 345 |
+
cb = tl.load(cb_ptrs, mask=k_mask[:, None], other=0.0)
|
| 346 |
+
|
| 347 |
+
# Normalize codebook tile in fp32
|
| 348 |
+
cb_f32 = cb.to(tl.float32)
|
| 349 |
+
cb_sq = tl.sum(cb_f32 * cb_f32, axis=1) # [TILE_K]
|
| 350 |
+
cb_norm_f32 = cb_f32 / tl.sqrt(cb_sq[:, None] + 1e-8)
|
| 351 |
+
|
| 352 |
+
# Cosine similarity via tl.dot (tf32 on sm_89)
|
| 353 |
+
sim = tl.dot(x_norm_f32, tl.trans(cb_norm_f32)) # [BLOCK_BT, TILE_K]
|
| 354 |
+
|
| 355 |
+
# Running argmax within this tile
|
| 356 |
+
tile_max = tl.max(sim, axis=1)
|
| 357 |
+
tile_argmax = tl.argmax(sim, axis=1)
|
| 358 |
+
tile_idx = k_start + tile_argmax
|
| 359 |
+
|
| 360 |
+
# Merge with best across tiles using element-wise mask
|
| 361 |
+
update_mask = tile_max > best_sim
|
| 362 |
+
best_sim = tl.where(update_mask, tile_max, best_sim)
|
| 363 |
+
best_idx = tl.where(update_mask, tile_idx, best_idx)
|
| 364 |
+
|
| 365 |
+
# ββ Store results ββ
|
| 366 |
+
tl.store(indices_ptr + offs_bt, best_idx, mask=offs_bt < N_CTX)
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
@triton.jit
|
| 370 |
+
def _triton_flash_vq_quantize_kernel(
|
| 371 |
+
codebook_ptr, indices_ptr, quantized_ptr,
|
| 372 |
+
stride_cb, stride_cd,
|
| 373 |
+
stride_qb, stride_qd,
|
| 374 |
+
N_CTX: tl.constexpr,
|
| 375 |
+
CODEBOOK_DIM: tl.constexpr,
|
| 376 |
+
BLOCK_BT: tl.constexpr,
|
| 377 |
+
):
|
| 378 |
+
"""
|
| 379 |
+
Gather quantized vectors from codebook at given indices.
|
| 380 |
+
Kernel form of: quantized[i] = codebook[indices[i]]
|
| 381 |
+
"""
|
| 382 |
+
pid = tl.program_id(0)
|
| 383 |
+
offs_bt = pid * BLOCK_BT + tl.arange(0, BLOCK_BT)
|
| 384 |
+
offs_d = tl.arange(0, CODEBOOK_DIM)
|
| 385 |
+
|
| 386 |
+
# Load indices for this batch tile
|
| 387 |
+
idx = tl.load(indices_ptr + offs_bt, mask=offs_bt < N_CTX, other=0)
|
| 388 |
+
|
| 389 |
+
# Gather: for each i in BLOCK_BT, load codebook[idx[i], :]
|
| 390 |
+
# Pointer arithmetic with broadcasting
|
| 391 |
+
gather_ptrs = (codebook_ptr
|
| 392 |
+
+ idx[:, None] * stride_cb
|
| 393 |
+
+ offs_d[None, :] * stride_cd)
|
| 394 |
+
quantized = tl.load(gather_ptrs,
|
| 395 |
+
mask=offs_bt[:, None] < N_CTX,
|
| 396 |
+
other=0.0)
|
| 397 |
+
|
| 398 |
+
# Store quantized output
|
| 399 |
+
out_ptrs = (quantized_ptr
|
| 400 |
+
+ offs_bt[:, None] * stride_qb
|
| 401 |
+
+ offs_d[None, :] * stride_qd)
|
| 402 |
+
tl.store(out_ptrs, quantized, mask=offs_bt[:, None] < N_CTX)
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
def _triton_lookup(x, embed, block_bt=None, tile_k=None):
|
| 406 |
+
"""
|
| 407 |
+
Launch Triton VQ lookup kernel with SRAM-safe tile sizes.
|
| 408 |
+
|
| 409 |
+
Args:
|
| 410 |
+
x: [N, D] input tensor (cuda, contiguous)
|
| 411 |
+
embed: [codebook_size, D] codebook (cuda, contiguous)
|
| 412 |
+
block_bt: BLOCK_BT tile size (auto-computed if None)
|
| 413 |
+
tile_k: TILE_K tile size (auto-computed if None)
|
| 414 |
+
Returns:
|
| 415 |
+
indices: [N] int64 tensor of argmax indices
|
| 416 |
+
"""
|
| 417 |
+
N, D = x.shape
|
| 418 |
+
codebook_size = embed.shape[0]
|
| 419 |
+
assert embed.shape[1] == D, f"Codebook dim {embed.shape[1]} != input dim {D}"
|
| 420 |
+
|
| 421 |
+
# SRAM-safe tile sizes: kernel uses tf32 (fp32 math), and Triton
|
| 422 |
+
# pipelines data through shared memory. Conservative sizing ensures
|
| 423 |
+
# fits within ~99KB (sm_89) even with default num_stages=3.
|
| 424 |
+
#
|
| 425 |
+
# fp32 codebook tile: TILE_K * D * 4 β 128*32*4 = 16KB
|
| 426 |
+
# fp32 input tile: BLOCK_BT * D * 4 β 8*32*4 = 1KB
|
| 427 |
+
# Accumulator: BLOCK_BT*TILE_K*4 β 8*128*4 = 4KB
|
| 428 |
+
# Per stage: ~21KB. With 3 pipeline stages: ~63KB (fits in 99KB).
|
| 429 |
+
#
|
| 430 |
+
# Larger tiles oversubscribe SRAM (tested: TILE_K=1024 β 321KB needed).
|
| 431 |
+
if block_bt is None or tile_k is None:
|
| 432 |
+
BLOCK_BT = 8
|
| 433 |
+
TILE_K = 128
|
| 434 |
+
else:
|
| 435 |
+
BLOCK_BT, TILE_K = block_bt, tile_k
|
| 436 |
+
|
| 437 |
+
grid = (triton.cdiv(N, BLOCK_BT),)
|
| 438 |
+
|
| 439 |
+
indices = torch.empty(N, dtype=torch.int32, device=x.device)
|
| 440 |
+
|
| 441 |
+
_triton_flash_vq_lookup_kernel[grid](
|
| 442 |
+
x, embed, indices,
|
| 443 |
+
x.stride(0), x.stride(1),
|
| 444 |
+
embed.stride(0), embed.stride(1),
|
| 445 |
+
N, codebook_size, D,
|
| 446 |
+
BLOCK_BT=BLOCK_BT, TILE_K=TILE_K,
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
return indices.long()
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
class _TritonFlashVQFn(torch.autograd.Function):
|
| 453 |
+
"""
|
| 454 |
+
Custom autograd Function wrapping Triton VQ kernels.
|
| 455 |
+
|
| 456 |
+
Forward: Triton tiled cosine similarity + argmax lookup
|
| 457 |
+
Backward: Rotation trick gradient or straight-through estimator
|
| 458 |
+
"""
|
| 459 |
+
|
| 460 |
+
@staticmethod
|
| 461 |
+
def forward(ctx, x_flat, embed, cluster_size, embed_avg,
|
| 462 |
+
codebook_size, codebook_dim,
|
| 463 |
+
commitment_weight, rotation_trick):
|
| 464 |
+
# Triton tiled lookup for indices
|
| 465 |
+
with torch.no_grad():
|
| 466 |
+
indices = _triton_lookup(x_flat.contiguous(), embed.contiguous())
|
| 467 |
+
|
| 468 |
+
quantized = embed[indices]
|
| 469 |
+
commitment_loss = commitment_weight * F.mse_loss(x_flat.float(), quantized.detach().float())
|
| 470 |
+
|
| 471 |
+
# Clone saved tensors to avoid version conflicts with in-place EMA updates
|
| 472 |
+
ctx.save_for_backward(
|
| 473 |
+
x_flat.detach().clone(),
|
| 474 |
+
quantized.detach().clone(),
|
| 475 |
+
embed.detach().clone(),
|
| 476 |
+
)
|
| 477 |
+
ctx.codebook_dim = codebook_dim
|
| 478 |
+
ctx.rotation_trick = rotation_trick
|
| 479 |
+
|
| 480 |
+
return quantized, indices, commitment_loss
|
| 481 |
+
|
| 482 |
+
@staticmethod
|
| 483 |
+
def backward(ctx, grad_quantized, grad_indices, grad_commitment):
|
| 484 |
+
x_flat, quantized, embed = ctx.saved_tensors
|
| 485 |
+
|
| 486 |
+
if ctx.rotation_trick:
|
| 487 |
+
# Rotation trick gradient
|
| 488 |
+
x_norm = F.normalize(x_flat.float(), dim=-1)
|
| 489 |
+
q_norm = F.normalize(quantized.float(), dim=-1)
|
| 490 |
+
diff = x_norm - q_norm
|
| 491 |
+
proj = (grad_quantized.float() * x_norm).sum(dim=-1, keepdim=True)
|
| 492 |
+
grad_x = grad_quantized.float() - proj * diff
|
| 493 |
+
else:
|
| 494 |
+
# Straight-through estimator
|
| 495 |
+
grad_x = grad_quantized.float()
|
| 496 |
+
|
| 497 |
+
return grad_x.to(grad_quantized.dtype), None, None, None, None, None, None, None
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
# When Triton is not available, define a fallback lookup
|
| 501 |
+
if not _HAS_TRITON:
|
| 502 |
+
|
| 503 |
+
def _triton_lookup(x, embed):
|
| 504 |
+
"""Fallback: torch-based cosine similarity lookup (CPU or CUDA without Triton)."""
|
| 505 |
+
with torch.no_grad():
|
| 506 |
+
x_norm = F.normalize(x.float(), dim=-1)
|
| 507 |
+
embed_norm = F.normalize(embed.float(), dim=-1)
|
| 508 |
+
sim = x_norm @ embed_norm.T
|
| 509 |
+
indices = sim.argmax(dim=-1)
|
| 510 |
+
return indices
|
arbitor/kernel/ternary_audit.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Iterable
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class TensorState:
|
| 11 |
+
name: str
|
| 12 |
+
shape: tuple[int, ...]
|
| 13 |
+
dtype: str
|
| 14 |
+
bytes: int
|
| 15 |
+
trainable: bool = False
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class TernaryAudit:
|
| 20 |
+
logical_ternary_weights: int
|
| 21 |
+
ternary_packed_bytes: int
|
| 22 |
+
ternary_scale_bytes: int
|
| 23 |
+
ternary_scale_accum_bytes: int
|
| 24 |
+
ternary_accum_bytes: int
|
| 25 |
+
ternary_corr_accum_bytes: int
|
| 26 |
+
ternary_step_counter_bytes: int
|
| 27 |
+
trainable_float_params: list[TensorState]
|
| 28 |
+
frozen_float_params: list[TensorState]
|
| 29 |
+
float_buffers: list[TensorState]
|
| 30 |
+
|
| 31 |
+
@property
|
| 32 |
+
def ternary_training_bytes(self) -> int:
|
| 33 |
+
return (
|
| 34 |
+
self.ternary_packed_bytes
|
| 35 |
+
+ self.ternary_scale_bytes
|
| 36 |
+
+ self.ternary_scale_accum_bytes
|
| 37 |
+
+ self.ternary_accum_bytes
|
| 38 |
+
+ self.ternary_corr_accum_bytes
|
| 39 |
+
+ self.ternary_step_counter_bytes
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
@property
|
| 43 |
+
def trainable_float_bytes(self) -> int:
|
| 44 |
+
return sum(item.bytes for item in self.trainable_float_params)
|
| 45 |
+
|
| 46 |
+
@property
|
| 47 |
+
def frozen_float_bytes(self) -> int:
|
| 48 |
+
return sum(item.bytes for item in self.frozen_float_params)
|
| 49 |
+
|
| 50 |
+
@property
|
| 51 |
+
def float_buffer_bytes(self) -> int:
|
| 52 |
+
return sum(item.bytes for item in self.float_buffers)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _tensor_bytes(t: torch.Tensor) -> int:
|
| 56 |
+
return t.numel() * t.element_size()
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _tensor_state(name: str, t: torch.Tensor, trainable: bool = False) -> TensorState:
|
| 60 |
+
return TensorState(
|
| 61 |
+
name=name,
|
| 62 |
+
shape=tuple(t.shape),
|
| 63 |
+
dtype=str(t.dtype).replace("torch.", ""),
|
| 64 |
+
bytes=_tensor_bytes(t),
|
| 65 |
+
trainable=trainable,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def _mb(n_bytes: int) -> float:
|
| 70 |
+
return n_bytes / (1024 * 1024)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def audit_model(model: torch.nn.Module) -> TernaryAudit:
|
| 74 |
+
logical_ternary_weights = 0
|
| 75 |
+
ternary_packed_bytes = 0
|
| 76 |
+
ternary_scale_bytes = 0
|
| 77 |
+
ternary_scale_accum_bytes = 0
|
| 78 |
+
ternary_accum_bytes = 0
|
| 79 |
+
ternary_corr_accum_bytes = 0
|
| 80 |
+
ternary_step_counter_bytes = 0
|
| 81 |
+
|
| 82 |
+
for module in model.modules():
|
| 83 |
+
if hasattr(module, "T_packed") and hasattr(module, "_T_shape"):
|
| 84 |
+
shape = tuple(int(x) for x in module._T_shape.tolist())
|
| 85 |
+
n_weights = 1
|
| 86 |
+
for dim in shape:
|
| 87 |
+
n_weights *= dim
|
| 88 |
+
logical_ternary_weights += n_weights
|
| 89 |
+
ternary_packed_bytes += _tensor_bytes(module.T_packed)
|
| 90 |
+
if hasattr(module, "E"):
|
| 91 |
+
ternary_scale_bytes += _tensor_bytes(module.E)
|
| 92 |
+
if hasattr(module, "E_accum"):
|
| 93 |
+
ternary_scale_accum_bytes += _tensor_bytes(module.E_accum)
|
| 94 |
+
if hasattr(module, "T_accum"):
|
| 95 |
+
ternary_accum_bytes += _tensor_bytes(module.T_accum)
|
| 96 |
+
if hasattr(module, "corr_accum"):
|
| 97 |
+
ternary_corr_accum_bytes += _tensor_bytes(module.corr_accum)
|
| 98 |
+
if hasattr(module, "step_counter"):
|
| 99 |
+
ternary_step_counter_bytes += _tensor_bytes(module.step_counter)
|
| 100 |
+
|
| 101 |
+
trainable_float_params: list[TensorState] = []
|
| 102 |
+
frozen_float_params: list[TensorState] = []
|
| 103 |
+
for name, param in model.named_parameters():
|
| 104 |
+
if not param.dtype.is_floating_point:
|
| 105 |
+
continue
|
| 106 |
+
state = _tensor_state(name, param, trainable=param.requires_grad)
|
| 107 |
+
if param.requires_grad:
|
| 108 |
+
trainable_float_params.append(state)
|
| 109 |
+
else:
|
| 110 |
+
frozen_float_params.append(state)
|
| 111 |
+
|
| 112 |
+
float_buffers = [
|
| 113 |
+
_tensor_state(name, buf)
|
| 114 |
+
for name, buf in model.named_buffers()
|
| 115 |
+
if buf.dtype.is_floating_point
|
| 116 |
+
]
|
| 117 |
+
|
| 118 |
+
return TernaryAudit(
|
| 119 |
+
logical_ternary_weights=logical_ternary_weights,
|
| 120 |
+
ternary_packed_bytes=ternary_packed_bytes,
|
| 121 |
+
ternary_scale_bytes=ternary_scale_bytes,
|
| 122 |
+
ternary_scale_accum_bytes=ternary_scale_accum_bytes,
|
| 123 |
+
ternary_accum_bytes=ternary_accum_bytes,
|
| 124 |
+
ternary_corr_accum_bytes=ternary_corr_accum_bytes,
|
| 125 |
+
ternary_step_counter_bytes=ternary_step_counter_bytes,
|
| 126 |
+
trainable_float_params=trainable_float_params,
|
| 127 |
+
frozen_float_params=frozen_float_params,
|
| 128 |
+
float_buffers=float_buffers,
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def format_audit(audit: TernaryAudit, limit: int = 12) -> str:
|
| 133 |
+
lines = [
|
| 134 |
+
"Ternary state audit:",
|
| 135 |
+
f" logical ternary weights: {audit.logical_ternary_weights:,}",
|
| 136 |
+
(
|
| 137 |
+
" ternary training state: "
|
| 138 |
+
f"{_mb(audit.ternary_training_bytes):.2f} MB "
|
| 139 |
+
f"(T={_mb(audit.ternary_packed_bytes):.2f}, "
|
| 140 |
+
f"E={_mb(audit.ternary_scale_bytes):.2f}, "
|
| 141 |
+
f"E_accum={_mb(audit.ternary_scale_accum_bytes):.2f}, "
|
| 142 |
+
f"T_accum={_mb(audit.ternary_accum_bytes):.2f}, "
|
| 143 |
+
f"corr_accum={_mb(audit.ternary_corr_accum_bytes):.2f}, "
|
| 144 |
+
f"steps={_mb(audit.ternary_step_counter_bytes):.4f})"
|
| 145 |
+
),
|
| 146 |
+
(
|
| 147 |
+
" trainable float params: "
|
| 148 |
+
f"{len(audit.trainable_float_params)} tensors, "
|
| 149 |
+
f"{_mb(audit.trainable_float_bytes):.2f} MB"
|
| 150 |
+
),
|
| 151 |
+
(
|
| 152 |
+
" frozen float params: "
|
| 153 |
+
f"{len(audit.frozen_float_params)} tensors, "
|
| 154 |
+
f"{_mb(audit.frozen_float_bytes):.2f} MB"
|
| 155 |
+
),
|
| 156 |
+
(
|
| 157 |
+
" float buffers: "
|
| 158 |
+
f"{len(audit.float_buffers)} tensors, "
|
| 159 |
+
f"{_mb(audit.float_buffer_bytes):.2f} MB"
|
| 160 |
+
),
|
| 161 |
+
]
|
| 162 |
+
|
| 163 |
+
if audit.trainable_float_params:
|
| 164 |
+
lines.append(" largest trainable float params:")
|
| 165 |
+
for item in sorted(audit.trainable_float_params, key=lambda x: x.bytes, reverse=True)[:limit]:
|
| 166 |
+
lines.append(f" {item.name}: {item.shape} {item.dtype} {_mb(item.bytes):.2f} MB")
|
| 167 |
+
|
| 168 |
+
if audit.float_buffers:
|
| 169 |
+
lines.append(" largest float buffers:")
|
| 170 |
+
for item in sorted(audit.float_buffers, key=lambda x: x.bytes, reverse=True)[:limit]:
|
| 171 |
+
lines.append(f" {item.name}: {item.shape} {item.dtype} {_mb(item.bytes):.2f} MB")
|
| 172 |
+
|
| 173 |
+
return "\n".join(lines)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def freeze_float_parameters(
|
| 177 |
+
model: torch.nn.Module,
|
| 178 |
+
allow_prefixes: Iterable[str] = (),
|
| 179 |
+
) -> list[TensorState]:
|
| 180 |
+
allow = tuple(allow_prefixes)
|
| 181 |
+
frozen: list[TensorState] = []
|
| 182 |
+
for name, param in model.named_parameters():
|
| 183 |
+
if allow and name.startswith(allow):
|
| 184 |
+
continue
|
| 185 |
+
if param.dtype.is_floating_point and param.requires_grad:
|
| 186 |
+
frozen.append(_tensor_state(name, param, trainable=True))
|
| 187 |
+
param.requires_grad_(False)
|
| 188 |
+
return frozen
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def trainable_parameters(model: torch.nn.Module) -> list[torch.nn.Parameter]:
|
| 192 |
+
return [p for p in model.parameters() if p.requires_grad]
|
arbitor/kernel/ternary_scale.py
ADDED
|
@@ -0,0 +1,1811 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import threading
|
| 3 |
+
import warnings
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from enum import IntEnum
|
| 9 |
+
from math import ceil
|
| 10 |
+
|
| 11 |
+
from ..converters.convert_to_ternary8 import pack_ternary, unpack_ternary
|
| 12 |
+
|
| 13 |
+
_HAS_TILELANG = False
|
| 14 |
+
try:
|
| 15 |
+
import tilelang
|
| 16 |
+
import tilelang.language as T
|
| 17 |
+
_HAS_TILELANG = True
|
| 18 |
+
except ImportError:
|
| 19 |
+
pass
|
| 20 |
+
|
| 21 |
+
_HAS_TRITON = False
|
| 22 |
+
try:
|
| 23 |
+
import triton
|
| 24 |
+
import triton.language as tl
|
| 25 |
+
_HAS_TRITON = True
|
| 26 |
+
except ImportError:
|
| 27 |
+
pass
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _backend_preference() -> str:
|
| 31 |
+
backend = os.environ.get("ARB_TERNARY_BACKEND", "auto").strip().lower()
|
| 32 |
+
if backend not in {"auto", "tilelang", "triton", "torch"}:
|
| 33 |
+
warnings.warn(
|
| 34 |
+
f"Unknown ARB_TERNARY_BACKEND={backend!r}; falling back to auto.",
|
| 35 |
+
RuntimeWarning,
|
| 36 |
+
stacklevel=2,
|
| 37 |
+
)
|
| 38 |
+
return "auto"
|
| 39 |
+
return backend
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _rmsnorm_triton_max_dim() -> int:
|
| 43 |
+
raw = os.environ.get("ARB_RMSNORM_TRITON_MAX_DIM", "4096").strip()
|
| 44 |
+
try:
|
| 45 |
+
return max(0, int(raw))
|
| 46 |
+
except ValueError:
|
| 47 |
+
warnings.warn(
|
| 48 |
+
f"Invalid ARB_RMSNORM_TRITON_MAX_DIM={raw!r}; using 4096.",
|
| 49 |
+
RuntimeWarning,
|
| 50 |
+
stacklevel=2,
|
| 51 |
+
)
|
| 52 |
+
return 4096
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _bigint_corr_strength() -> float:
|
| 56 |
+
raw = os.environ.get("ARB_BIGINT_CORR_STRENGTH", "4.0").strip()
|
| 57 |
+
try:
|
| 58 |
+
return float(raw)
|
| 59 |
+
except ValueError:
|
| 60 |
+
warnings.warn(
|
| 61 |
+
f"Invalid ARB_BIGINT_CORR_STRENGTH={raw!r}; using 4.0.",
|
| 62 |
+
RuntimeWarning,
|
| 63 |
+
stacklevel=2,
|
| 64 |
+
)
|
| 65 |
+
return 4.0
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class _ComponentContext:
|
| 69 |
+
_local = threading.local()
|
| 70 |
+
|
| 71 |
+
@classmethod
|
| 72 |
+
def get(cls):
|
| 73 |
+
val = getattr(cls._local, "current", None)
|
| 74 |
+
if val is None:
|
| 75 |
+
return None, 1.0
|
| 76 |
+
return val
|
| 77 |
+
|
| 78 |
+
@classmethod
|
| 79 |
+
def set(cls, name, weight=1.0):
|
| 80 |
+
if name is None:
|
| 81 |
+
cls._local.current = None
|
| 82 |
+
else:
|
| 83 |
+
cls._local.current = (name, weight)
|
| 84 |
+
|
| 85 |
+
@classmethod
|
| 86 |
+
def clear(cls):
|
| 87 |
+
cls._local.current = None
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
_COMPONENT_CONTEXT = _ComponentContext
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _tilelang_training_enabled() -> bool:
|
| 94 |
+
return os.environ.get("ARB_TILELANG_TRAINING", "0").strip().lower() in {"1", "true", "yes"}
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
if _HAS_TILELANG:
|
| 98 |
+
|
| 99 |
+
tilelang_jit = tilelang.jit(pass_configs={"tl.disable_warp_specialized": True})
|
| 100 |
+
|
| 101 |
+
def _ternary_fwd_kernel(
|
| 102 |
+
M: int, N: int, K: int, group_size: int = 12,
|
| 103 |
+
corr_strength: float = 4.0,
|
| 104 |
+
block_M: int = 64, block_N: int = 64, block_K: int = 32,
|
| 105 |
+
threads: int = 128, num_stages: int = 2,
|
| 106 |
+
):
|
| 107 |
+
gpr = (K + group_size - 1) // group_size
|
| 108 |
+
cs = corr_strength
|
| 109 |
+
|
| 110 |
+
@T.prim_func
|
| 111 |
+
def kernel(
|
| 112 |
+
x: T.Tensor((M, K), "float16"),
|
| 113 |
+
T_packed: T.Tensor((N * K + 4) // 5, "uint8"),
|
| 114 |
+
E: T.Tensor((N * gpr), "int8"),
|
| 115 |
+
corr_accum: T.Tensor((N * gpr), "int64"),
|
| 116 |
+
step_counter: T.Tensor((1,), "int64"),
|
| 117 |
+
output: T.Tensor((M, N), "float32"),
|
| 118 |
+
):
|
| 119 |
+
steps = T.cast(step_counter[0], "int32")
|
| 120 |
+
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=threads) as (bx, by):
|
| 121 |
+
x_shared = T.alloc_shared((block_M, block_K), dtype="float16")
|
| 122 |
+
dq_shared = T.alloc_shared((block_N, block_K), dtype="float16")
|
| 123 |
+
acc = T.alloc_fragment((block_M, block_N), dtype="float32")
|
| 124 |
+
T.use_swizzle(10)
|
| 125 |
+
T.clear(acc)
|
| 126 |
+
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
|
| 127 |
+
T.copy(x[bx * block_M, k * block_K], x_shared)
|
| 128 |
+
for i, j in T.Parallel(block_N, block_K):
|
| 129 |
+
i_glob = by * block_N + i
|
| 130 |
+
j_glob = k * block_K + j
|
| 131 |
+
if i_glob < N and j_glob < K:
|
| 132 |
+
lin_idx = i_glob * K + j_glob
|
| 133 |
+
pack_idx = lin_idx // 5
|
| 134 |
+
trit_pos = lin_idx % 5
|
| 135 |
+
packed_val = T.cast(T_packed[pack_idx], "int32")
|
| 136 |
+
trit = T.if_then_else(
|
| 137 |
+
trit_pos == 0, packed_val % 3,
|
| 138 |
+
T.if_then_else(trit_pos == 1, (packed_val // 3) % 3,
|
| 139 |
+
T.if_then_else(trit_pos == 2, (packed_val // 9) % 3,
|
| 140 |
+
T.if_then_else(trit_pos == 3, (packed_val // 27) % 3,
|
| 141 |
+
(packed_val // 81) % 3))))
|
| 142 |
+
sign_val = T.cast(trit, "int32") - 1
|
| 143 |
+
exp_idx = i_glob * gpr + j_glob // group_size
|
| 144 |
+
exp_val = T.cast(E[exp_idx], "int32")
|
| 145 |
+
ca = T.cast(corr_accum[exp_idx], "int32")
|
| 146 |
+
den = T.max(steps * group_size, 1)
|
| 147 |
+
mc = T.cast(ca, "float32") / T.cast(den, "float32")
|
| 148 |
+
e_adj = T.cast(exp_val, "float32") + mc * cs
|
| 149 |
+
ecl = T.min(T.max(e_adj, -14.0), 15.0)
|
| 150 |
+
dq_shared[i, j] = T.cast(T.exp2(ecl) * T.cast(sign_val, "float32"), "float16")
|
| 151 |
+
T.gemm(x_shared, dq_shared, acc, transpose_B=True)
|
| 152 |
+
T.copy(acc, output[bx * block_M, by * block_N])
|
| 153 |
+
return tilelang_jit(kernel)
|
| 154 |
+
|
| 155 |
+
def _ternary_grad_x_kernel(
|
| 156 |
+
M: int, N: int, K: int, group_size: int = 12,
|
| 157 |
+
corr_strength: float = 4.0,
|
| 158 |
+
block_M: int = 64, block_N: int = 64, block_K: int = 32,
|
| 159 |
+
threads: int = 128, num_stages: int = 2,
|
| 160 |
+
):
|
| 161 |
+
gpr = (K + group_size - 1) // group_size
|
| 162 |
+
cs = corr_strength
|
| 163 |
+
|
| 164 |
+
@T.prim_func
|
| 165 |
+
def kernel(
|
| 166 |
+
grad_y: T.Tensor((M, N), "float16"),
|
| 167 |
+
T_packed: T.Tensor((N * K + 4) // 5, "uint8"),
|
| 168 |
+
E: T.Tensor((N * gpr), "int8"),
|
| 169 |
+
corr_accum: T.Tensor((N * gpr), "int64"),
|
| 170 |
+
step_counter: T.Tensor((1,), "int64"),
|
| 171 |
+
output: T.Tensor((M, K), "float32"),
|
| 172 |
+
):
|
| 173 |
+
steps = T.cast(step_counter[0], "int32")
|
| 174 |
+
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(K, block_K), threads=threads) as (bx, by):
|
| 175 |
+
gy_shared = T.alloc_shared((block_M, block_N), dtype="float16")
|
| 176 |
+
dq_shared = T.alloc_shared((block_N, block_K), dtype="float16")
|
| 177 |
+
acc = T.alloc_fragment((block_M, block_K), dtype="float32")
|
| 178 |
+
T.use_swizzle(10)
|
| 179 |
+
T.clear(acc)
|
| 180 |
+
for n in T.Pipelined(T.ceildiv(N, block_N), num_stages=num_stages):
|
| 181 |
+
T.copy(grad_y[bx * block_M, n * block_N], gy_shared)
|
| 182 |
+
for i, j in T.Parallel(block_N, block_K):
|
| 183 |
+
i_glob = n * block_N + i
|
| 184 |
+
j_glob = by * block_K + j
|
| 185 |
+
if i_glob < N and j_glob < K:
|
| 186 |
+
lin_idx = i_glob * K + j_glob
|
| 187 |
+
pack_idx = lin_idx // 5
|
| 188 |
+
trit_pos = lin_idx % 5
|
| 189 |
+
packed_val = T.cast(T_packed[pack_idx], "int32")
|
| 190 |
+
trit = T.if_then_else(
|
| 191 |
+
trit_pos == 0, packed_val % 3,
|
| 192 |
+
T.if_then_else(trit_pos == 1, (packed_val // 3) % 3,
|
| 193 |
+
T.if_then_else(trit_pos == 2, (packed_val // 9) % 3,
|
| 194 |
+
T.if_then_else(trit_pos == 3, (packed_val // 27) % 3,
|
| 195 |
+
(packed_val // 81) % 3))))
|
| 196 |
+
sign_val = T.cast(trit, "int32") - 1
|
| 197 |
+
exp_idx = i_glob * gpr + j_glob // group_size
|
| 198 |
+
exp_val = T.cast(E[exp_idx], "int32")
|
| 199 |
+
ca = T.cast(corr_accum[exp_idx], "int32")
|
| 200 |
+
den = T.max(steps * group_size, 1)
|
| 201 |
+
mc = T.cast(ca, "float32") / T.cast(den, "float32")
|
| 202 |
+
e_adj = T.cast(exp_val, "float32") + mc * cs
|
| 203 |
+
ecl = T.min(T.max(e_adj, -14.0), 15.0)
|
| 204 |
+
dq_shared[i, j] = T.cast(T.exp2(ecl) * T.cast(sign_val, "float32"), "float16")
|
| 205 |
+
T.gemm(gy_shared, dq_shared, acc)
|
| 206 |
+
T.copy(acc, output[bx * block_M, by * block_K])
|
| 207 |
+
return tilelang_jit(kernel)
|
| 208 |
+
|
| 209 |
+
_KERNEL_CACHE_FWD = {}
|
| 210 |
+
_KERNEL_CACHE_GX = {}
|
| 211 |
+
|
| 212 |
+
def _get_kernel(M, N, K, group_size, mode, corr_strength=4.0):
|
| 213 |
+
cs = corr_strength
|
| 214 |
+
if mode == "fwd":
|
| 215 |
+
cache = _KERNEL_CACHE_FWD
|
| 216 |
+
key = (M, N, K, group_size, cs)
|
| 217 |
+
if key not in cache:
|
| 218 |
+
cache[key] = _ternary_fwd_kernel(M, N, K, group_size, corr_strength=cs)
|
| 219 |
+
return cache[key]
|
| 220 |
+
elif mode == "grad_x":
|
| 221 |
+
cache = _KERNEL_CACHE_GX
|
| 222 |
+
key = (M, N, K, group_size)
|
| 223 |
+
if key not in cache:
|
| 224 |
+
cache[key] = _ternary_grad_x_kernel(M, N, K, group_size)
|
| 225 |
+
return cache[key]
|
| 226 |
+
raise ValueError(f"Unknown TileLang kernel mode: {mode}")
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def _get_grad_kernels(M, N, K, group_size):
|
| 230 |
+
return _get_kernel(M, N, K, group_size, "grad_x")
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class _TernaryLinearFn(torch.autograd.Function):
|
| 234 |
+
@staticmethod
|
| 235 |
+
def forward(ctx, x, module, fwd_kernel):
|
| 236 |
+
ctx.module = module
|
| 237 |
+
T_packed = module.T_packed
|
| 238 |
+
E = module.E
|
| 239 |
+
shape = tuple(module._T_shape.tolist())
|
| 240 |
+
N, K = shape
|
| 241 |
+
x_2d = x.reshape(-1, K).contiguous()
|
| 242 |
+
ctx.group_size = module.group_size
|
| 243 |
+
ctx.shape = shape
|
| 244 |
+
ctx.x_shape = x.shape
|
| 245 |
+
comp_name, _ = _COMPONENT_CONTEXT.get()
|
| 246 |
+
ctx.comp_name = comp_name
|
| 247 |
+
ctx.x_dtype = x.dtype
|
| 248 |
+
has_corr = hasattr(module, "corr_accum") and hasattr(module, "step_counter")
|
| 249 |
+
ctx.save_for_backward(x_2d, T_packed, E)
|
| 250 |
+
ctx.has_corr = has_corr
|
| 251 |
+
ctx.step_snapshot = int(module.step_counter.item()) if has_corr else 0
|
| 252 |
+
with torch.no_grad():
|
| 253 |
+
M = x_2d.shape[0]
|
| 254 |
+
output = torch.empty(M, N, device=x.device, dtype=torch.float32)
|
| 255 |
+
if has_corr:
|
| 256 |
+
fwd_kernel(x_2d.half(), T_packed, E,
|
| 257 |
+
module.corr_accum.contiguous(),
|
| 258 |
+
module.step_counter.contiguous(), output)
|
| 259 |
+
else:
|
| 260 |
+
fwd_kernel(x_2d.half(), T_packed, E,
|
| 261 |
+
torch.zeros(N * ((K + module.group_size - 1) // module.group_size),
|
| 262 |
+
dtype=torch.int64, device=x.device),
|
| 263 |
+
torch.zeros(1, dtype=torch.int64, device=x.device), output)
|
| 264 |
+
return output.reshape(*x.shape[:-1], N)
|
| 265 |
+
|
| 266 |
+
@staticmethod
|
| 267 |
+
def backward(ctx, grad_output):
|
| 268 |
+
x_2d, T_packed, E = ctx.saved_tensors
|
| 269 |
+
group_size = ctx.group_size
|
| 270 |
+
N, K = ctx.shape
|
| 271 |
+
M = x_2d.shape[0]
|
| 272 |
+
grad_2d = grad_output.reshape(-1, N).contiguous()
|
| 273 |
+
if ctx.has_corr:
|
| 274 |
+
corr_accum = ctx.module.corr_accum.contiguous()
|
| 275 |
+
step_counter = torch.tensor([ctx.step_snapshot], dtype=torch.int64, device=x_2d.device)
|
| 276 |
+
else:
|
| 277 |
+
corr_accum = torch.zeros(N * ((K + group_size - 1) // group_size),
|
| 278 |
+
dtype=torch.int64, device=x_2d.device)
|
| 279 |
+
step_counter = torch.zeros(1, dtype=torch.int64, device=x_2d.device)
|
| 280 |
+
grad_x_kernel = _get_grad_kernels(M, N, K, group_size)
|
| 281 |
+
with torch.no_grad():
|
| 282 |
+
grad_x = torch.empty(M, K, device=x_2d.device, dtype=torch.float32)
|
| 283 |
+
grad_x_kernel(grad_2d.half(), T_packed, E, corr_accum, step_counter, grad_x)
|
| 284 |
+
comp_name = ctx.comp_name
|
| 285 |
+
if _HAS_TRITON and ctx.has_corr and getattr(ctx.module, "_stream_backward_updates", True):
|
| 286 |
+
bwd_name, bwd_weight = _COMPONENT_CONTEXT.get()
|
| 287 |
+
if bwd_name is None:
|
| 288 |
+
bwd_weight = 1.0
|
| 289 |
+
base_step = int(getattr(ctx.module, "_backward_t_accum_step", 1))
|
| 290 |
+
corr_step = max(1, int(round(abs(float(bwd_weight)) * base_step)))
|
| 291 |
+
if bwd_weight < 0:
|
| 292 |
+
corr_step = -corr_step
|
| 293 |
+
_triton_accumulate_corr_direct(
|
| 294 |
+
T_packed, grad_2d, x_2d, ctx.module.corr_accum,
|
| 295 |
+
N, K, group_size, corr_step=corr_step,
|
| 296 |
+
)
|
| 297 |
+
ctx.module.step_counter.add_(abs(corr_step))
|
| 298 |
+
ctx.module._streamed_bigint_backward = True
|
| 299 |
+
elif _HAS_TRITON:
|
| 300 |
+
grad_sign = _triton_ternary_grad_sign(grad_2d, x_2d, N, K)
|
| 301 |
+
if comp_name is not None:
|
| 302 |
+
setattr(ctx.module, f"_hook_grad_T_sign_{comp_name}", grad_sign.detach())
|
| 303 |
+
else:
|
| 304 |
+
ctx.module._hook_grad_T_sign = grad_sign.detach()
|
| 305 |
+
elif comp_name is not None:
|
| 306 |
+
setattr(ctx.module, f"_hook_grad_2d_{comp_name}", grad_2d.detach())
|
| 307 |
+
setattr(ctx.module, f"_hook_x_2d_{comp_name}", x_2d.detach())
|
| 308 |
+
else:
|
| 309 |
+
ctx.module._hook_grad_2d = grad_2d.detach()
|
| 310 |
+
ctx.module._hook_x_2d = x_2d.detach()
|
| 311 |
+
grad_x_reshaped = grad_x.reshape(*ctx.x_shape).to(dtype=ctx.x_dtype)
|
| 312 |
+
return grad_x_reshaped, None, None
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
if _HAS_TRITON:
|
| 316 |
+
|
| 317 |
+
@triton.jit
|
| 318 |
+
def _triton_ternary_fwd_kernel(
|
| 319 |
+
x_ptr, packed_ptr, e_ptr, corr_ptr, step_ptr, out_ptr,
|
| 320 |
+
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
|
| 321 |
+
GPR: tl.constexpr, GROUP_SIZE: tl.constexpr,
|
| 322 |
+
CORR_STRENGTH: tl.constexpr,
|
| 323 |
+
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
| 324 |
+
):
|
| 325 |
+
pid_m = tl.program_id(0)
|
| 326 |
+
pid_n = tl.program_id(1)
|
| 327 |
+
|
| 328 |
+
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 329 |
+
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 330 |
+
offs_k = tl.arange(0, BLOCK_K)
|
| 331 |
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
| 332 |
+
|
| 333 |
+
for k0 in range(0, K, BLOCK_K):
|
| 334 |
+
k = k0 + offs_k
|
| 335 |
+
x = tl.load(
|
| 336 |
+
x_ptr + offs_m[:, None] * K + k[None, :],
|
| 337 |
+
mask=(offs_m[:, None] < M) & (k[None, :] < K),
|
| 338 |
+
other=0.0,
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
lin = offs_n[:, None] * K + k[None, :]
|
| 342 |
+
pack_idx = lin // 5
|
| 343 |
+
trit_pos = lin - pack_idx * 5
|
| 344 |
+
packed = tl.load(
|
| 345 |
+
packed_ptr + pack_idx,
|
| 346 |
+
mask=(offs_n[:, None] < N) & (k[None, :] < K),
|
| 347 |
+
other=0,
|
| 348 |
+
).to(tl.int32)
|
| 349 |
+
divisor = tl.where(
|
| 350 |
+
trit_pos == 0, 1,
|
| 351 |
+
tl.where(trit_pos == 1, 3,
|
| 352 |
+
tl.where(trit_pos == 2, 9,
|
| 353 |
+
tl.where(trit_pos == 3, 27, 81))),
|
| 354 |
+
)
|
| 355 |
+
trit = (packed // divisor) % 3
|
| 356 |
+
sign = trit.to(tl.int32) - 1
|
| 357 |
+
|
| 358 |
+
e_idx = offs_n[:, None] * GPR + k[None, :] // GROUP_SIZE
|
| 359 |
+
e_val = tl.load(
|
| 360 |
+
e_ptr + e_idx,
|
| 361 |
+
mask=(offs_n[:, None] < N) & (k[None, :] < K),
|
| 362 |
+
other=0,
|
| 363 |
+
).to(tl.float32)
|
| 364 |
+
corr_val = tl.load(
|
| 365 |
+
corr_ptr + e_idx,
|
| 366 |
+
mask=(offs_n[:, None] < N) & (k[None, :] < K),
|
| 367 |
+
other=0,
|
| 368 |
+
).to(tl.float32)
|
| 369 |
+
step_val = tl.load(step_ptr).to(tl.float32)
|
| 370 |
+
denom = tl.maximum(step_val * GROUP_SIZE, 1.0)
|
| 371 |
+
e_adj = e_val + (corr_val / denom) * CORR_STRENGTH
|
| 372 |
+
w = sign.to(tl.float32) * tl.exp2(e_adj)
|
| 373 |
+
w = tl.where((offs_n[:, None] < N) & (k[None, :] < K), w, 0.0)
|
| 374 |
+
acc += tl.dot(x, tl.trans(w))
|
| 375 |
+
|
| 376 |
+
tl.store(
|
| 377 |
+
out_ptr + offs_m[:, None] * N + offs_n[None, :],
|
| 378 |
+
acc,
|
| 379 |
+
mask=(offs_m[:, None] < M) & (offs_n[None, :] < N),
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
@triton.jit
|
| 384 |
+
def _triton_ternary_grad_x_kernel(
|
| 385 |
+
grad_ptr, packed_ptr, e_ptr, corr_ptr, step_ptr, out_ptr,
|
| 386 |
+
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
|
| 387 |
+
GPR: tl.constexpr, GROUP_SIZE: tl.constexpr,
|
| 388 |
+
CORR_STRENGTH: tl.constexpr,
|
| 389 |
+
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
| 390 |
+
):
|
| 391 |
+
pid_m = tl.program_id(0)
|
| 392 |
+
pid_k = tl.program_id(1)
|
| 393 |
+
|
| 394 |
+
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 395 |
+
offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
|
| 396 |
+
offs_n = tl.arange(0, BLOCK_N)
|
| 397 |
+
acc = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32)
|
| 398 |
+
|
| 399 |
+
for n0 in range(0, N, BLOCK_N):
|
| 400 |
+
n = n0 + offs_n
|
| 401 |
+
grad = tl.load(
|
| 402 |
+
grad_ptr + offs_m[:, None] * N + n[None, :],
|
| 403 |
+
mask=(offs_m[:, None] < M) & (n[None, :] < N),
|
| 404 |
+
other=0.0,
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
lin = n[:, None] * K + offs_k[None, :]
|
| 408 |
+
pack_idx = lin // 5
|
| 409 |
+
trit_pos = lin - pack_idx * 5
|
| 410 |
+
packed = tl.load(
|
| 411 |
+
packed_ptr + pack_idx,
|
| 412 |
+
mask=(n[:, None] < N) & (offs_k[None, :] < K),
|
| 413 |
+
other=0,
|
| 414 |
+
).to(tl.int32)
|
| 415 |
+
divisor = tl.where(
|
| 416 |
+
trit_pos == 0, 1,
|
| 417 |
+
tl.where(trit_pos == 1, 3,
|
| 418 |
+
tl.where(trit_pos == 2, 9,
|
| 419 |
+
tl.where(trit_pos == 3, 27, 81))),
|
| 420 |
+
)
|
| 421 |
+
trit = (packed // divisor) % 3
|
| 422 |
+
sign = trit.to(tl.int32) - 1
|
| 423 |
+
|
| 424 |
+
e_idx = n[:, None] * GPR + offs_k[None, :] // GROUP_SIZE
|
| 425 |
+
e_val = tl.load(
|
| 426 |
+
e_ptr + e_idx,
|
| 427 |
+
mask=(n[:, None] < N) & (offs_k[None, :] < K),
|
| 428 |
+
other=0,
|
| 429 |
+
).to(tl.float32)
|
| 430 |
+
corr_val = tl.load(
|
| 431 |
+
corr_ptr + e_idx,
|
| 432 |
+
mask=(n[:, None] < N) & (offs_k[None, :] < K),
|
| 433 |
+
other=0,
|
| 434 |
+
).to(tl.float32)
|
| 435 |
+
step_val = tl.load(step_ptr).to(tl.float32)
|
| 436 |
+
denom = tl.maximum(step_val * GROUP_SIZE, 1.0)
|
| 437 |
+
e_adj = e_val + (corr_val / denom) * CORR_STRENGTH
|
| 438 |
+
w = sign.to(tl.float32) * tl.exp2(e_adj)
|
| 439 |
+
w = tl.where((n[:, None] < N) & (offs_k[None, :] < K), w, 0.0)
|
| 440 |
+
acc += tl.dot(grad, w)
|
| 441 |
+
|
| 442 |
+
tl.store(
|
| 443 |
+
out_ptr + offs_m[:, None] * K + offs_k[None, :],
|
| 444 |
+
acc,
|
| 445 |
+
mask=(offs_m[:, None] < M) & (offs_k[None, :] < K),
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
@triton.jit
|
| 450 |
+
def _triton_ternary_grad_sign_kernel(
|
| 451 |
+
grad_ptr, x_ptr, sign_ptr,
|
| 452 |
+
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
|
| 453 |
+
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
| 454 |
+
):
|
| 455 |
+
pid_n = tl.program_id(0)
|
| 456 |
+
pid_k = tl.program_id(1)
|
| 457 |
+
|
| 458 |
+
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 459 |
+
offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
|
| 460 |
+
offs_m = tl.arange(0, BLOCK_M)
|
| 461 |
+
acc = tl.zeros((BLOCK_N, BLOCK_K), dtype=tl.float32)
|
| 462 |
+
|
| 463 |
+
for m0 in range(0, M, BLOCK_M):
|
| 464 |
+
m = m0 + offs_m
|
| 465 |
+
grad = tl.load(
|
| 466 |
+
grad_ptr + m[:, None] * N + offs_n[None, :],
|
| 467 |
+
mask=(m[:, None] < M) & (offs_n[None, :] < N),
|
| 468 |
+
other=0.0,
|
| 469 |
+
)
|
| 470 |
+
x = tl.load(
|
| 471 |
+
x_ptr + m[:, None] * K + offs_k[None, :],
|
| 472 |
+
mask=(m[:, None] < M) & (offs_k[None, :] < K),
|
| 473 |
+
other=0.0,
|
| 474 |
+
)
|
| 475 |
+
acc += tl.dot(tl.trans(grad), x, input_precision="ieee")
|
| 476 |
+
|
| 477 |
+
sign = tl.where(acc > 0.0, 1, tl.where(acc < 0.0, -1, 0))
|
| 478 |
+
tl.store(
|
| 479 |
+
sign_ptr + offs_n[:, None] * K + offs_k[None, :],
|
| 480 |
+
sign.to(tl.int8),
|
| 481 |
+
mask=(offs_n[:, None] < N) & (offs_k[None, :] < K),
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
@triton.jit
|
| 486 |
+
def _triton_update_e_kernel(
|
| 487 |
+
packed_ptr, grad_sign_ptr, e_ptr, e_accum_ptr,
|
| 488 |
+
N: tl.constexpr, K: tl.constexpr,
|
| 489 |
+
GROUP_SIZE: tl.constexpr, GPR: tl.constexpr,
|
| 490 |
+
E_ACCUM_THRESHOLD: tl.constexpr,
|
| 491 |
+
BLOCK_N: tl.constexpr, BLOCK_G: tl.constexpr, BLOCK_K: tl.constexpr,
|
| 492 |
+
):
|
| 493 |
+
pid_n = tl.program_id(0)
|
| 494 |
+
pid_g = tl.program_id(1)
|
| 495 |
+
|
| 496 |
+
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 497 |
+
offs_g = pid_g * BLOCK_G + tl.arange(0, BLOCK_G)
|
| 498 |
+
offs_r = tl.arange(0, BLOCK_K)
|
| 499 |
+
k = offs_g[:, None] * GROUP_SIZE + offs_r[None, :]
|
| 500 |
+
valid_group = offs_g < GPR
|
| 501 |
+
|
| 502 |
+
lin = offs_n[:, None, None] * K + k[None, :, :]
|
| 503 |
+
pack_idx = lin // 5
|
| 504 |
+
trit_pos = lin - pack_idx * 5
|
| 505 |
+
packed = tl.load(
|
| 506 |
+
packed_ptr + pack_idx,
|
| 507 |
+
mask=(offs_n[:, None, None] < N) & valid_group[None, :, None] & (offs_r[None, None, :] < GROUP_SIZE) & (k[None, :, :] < K),
|
| 508 |
+
other=0,
|
| 509 |
+
).to(tl.int32)
|
| 510 |
+
divisor = tl.where(
|
| 511 |
+
trit_pos == 0, 1,
|
| 512 |
+
tl.where(trit_pos == 1, 3,
|
| 513 |
+
tl.where(trit_pos == 2, 9,
|
| 514 |
+
tl.where(trit_pos == 3, 27, 81))),
|
| 515 |
+
)
|
| 516 |
+
trit = (packed // divisor) % 3
|
| 517 |
+
ternary = trit.to(tl.int32) - 1
|
| 518 |
+
|
| 519 |
+
grad_sign = tl.load(
|
| 520 |
+
grad_sign_ptr + offs_n[:, None, None] * K + k[None, :, :],
|
| 521 |
+
mask=(offs_n[:, None, None] < N) & valid_group[None, :, None] & (offs_r[None, None, :] < GROUP_SIZE) & (k[None, :, :] < K),
|
| 522 |
+
other=0,
|
| 523 |
+
).to(tl.int32)
|
| 524 |
+
contrib = grad_sign * ternary
|
| 525 |
+
score = tl.sum(contrib, axis=2)
|
| 526 |
+
delta = tl.where(score > 0, -1, tl.where(score < 0, 1, 0))
|
| 527 |
+
|
| 528 |
+
e_idx = offs_n[:, None] * GPR + offs_g[None, :]
|
| 529 |
+
old_accum = tl.load(
|
| 530 |
+
e_accum_ptr + e_idx,
|
| 531 |
+
mask=(offs_n[:, None] < N) & valid_group[None, :],
|
| 532 |
+
other=0,
|
| 533 |
+
).to(tl.int32)
|
| 534 |
+
new_accum = tl.minimum(127, tl.maximum(-128, old_accum + delta))
|
| 535 |
+
step_up = new_accum >= E_ACCUM_THRESHOLD
|
| 536 |
+
step_down = new_accum <= -E_ACCUM_THRESHOLD
|
| 537 |
+
e_step = tl.where(step_up, 1, tl.where(step_down, -1, 0))
|
| 538 |
+
stored_accum = new_accum - e_step * E_ACCUM_THRESHOLD
|
| 539 |
+
|
| 540 |
+
old_e = tl.load(
|
| 541 |
+
e_ptr + e_idx,
|
| 542 |
+
mask=(offs_n[:, None] < N) & valid_group[None, :],
|
| 543 |
+
other=0,
|
| 544 |
+
).to(tl.int32)
|
| 545 |
+
new_e = tl.minimum(127, tl.maximum(-128, old_e + e_step))
|
| 546 |
+
tl.store(
|
| 547 |
+
e_ptr + e_idx,
|
| 548 |
+
new_e.to(tl.int8),
|
| 549 |
+
mask=(offs_n[:, None] < N) & valid_group[None, :],
|
| 550 |
+
)
|
| 551 |
+
tl.store(
|
| 552 |
+
e_accum_ptr + e_idx,
|
| 553 |
+
stored_accum.to(tl.int8),
|
| 554 |
+
mask=(offs_n[:, None] < N) & valid_group[None, :],
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
@triton.jit
|
| 559 |
+
def _triton_ternary_step_kernel(
|
| 560 |
+
packed_ptr, grad_sign_ptr, accum_ptr, per_group_threshold_ptr,
|
| 561 |
+
TOTAL: tl.constexpr, ACCUM_THRESHOLD: tl.constexpr,
|
| 562 |
+
T_ACCUM_STEP: tl.constexpr,
|
| 563 |
+
K: tl.constexpr, GPR: tl.constexpr, GROUP_SIZE: tl.constexpr,
|
| 564 |
+
HAS_PER_GROUP_THRESHOLD: tl.constexpr,
|
| 565 |
+
BLOCK_T: tl.constexpr,
|
| 566 |
+
):
|
| 567 |
+
pack_idx = tl.program_id(0)
|
| 568 |
+
offs_t = tl.arange(0, BLOCK_T)
|
| 569 |
+
valid_trit = offs_t < 5
|
| 570 |
+
lin = pack_idx * 5 + offs_t
|
| 571 |
+
valid = valid_trit & (lin < TOTAL)
|
| 572 |
+
|
| 573 |
+
old_packed = tl.load(packed_ptr + pack_idx).to(tl.int32)
|
| 574 |
+
divisor = tl.where(
|
| 575 |
+
offs_t == 0, 1,
|
| 576 |
+
tl.where(offs_t == 1, 3,
|
| 577 |
+
tl.where(offs_t == 2, 9,
|
| 578 |
+
tl.where(offs_t == 3, 27, 81))),
|
| 579 |
+
)
|
| 580 |
+
old_code = (old_packed // divisor) % 3
|
| 581 |
+
old_sign = old_code.to(tl.int32) - 1
|
| 582 |
+
|
| 583 |
+
grad_sign = tl.load(grad_sign_ptr + lin, mask=valid, other=0).to(tl.int32)
|
| 584 |
+
old_accum = tl.load(accum_ptr + lin, mask=valid, other=0).to(tl.int32)
|
| 585 |
+
new_accum = tl.minimum(127, tl.maximum(-128, old_accum - grad_sign * T_ACCUM_STEP))
|
| 586 |
+
|
| 587 |
+
if HAS_PER_GROUP_THRESHOLD:
|
| 588 |
+
n = lin // K
|
| 589 |
+
k = lin - n * K
|
| 590 |
+
g_idx = n * GPR + k // GROUP_SIZE
|
| 591 |
+
threshold = tl.load(per_group_threshold_ptr + g_idx, mask=valid, other=ACCUM_THRESHOLD).to(tl.int32)
|
| 592 |
+
else:
|
| 593 |
+
threshold = ACCUM_THRESHOLD
|
| 594 |
+
|
| 595 |
+
flip_up = new_accum > threshold
|
| 596 |
+
flip_down = new_accum < -threshold
|
| 597 |
+
did_flip = valid & (flip_up | flip_down)
|
| 598 |
+
new_sign = tl.where(flip_up, 1, tl.where(flip_down, -1, old_sign))
|
| 599 |
+
stored_accum = tl.where(did_flip, 0, new_accum)
|
| 600 |
+
tl.store(accum_ptr + lin, stored_accum.to(tl.int8), mask=valid)
|
| 601 |
+
|
| 602 |
+
new_code = tl.where(valid, new_sign + 1, 0)
|
| 603 |
+
packed_val = tl.sum(new_code * divisor, axis=0)
|
| 604 |
+
tl.store(packed_ptr + pack_idx, packed_val.to(tl.uint8))
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
@triton.jit
|
| 608 |
+
def _triton_update_e_direct_kernel(
|
| 609 |
+
packed_ptr, grad_ptr, x_ptr, e_ptr, e_accum_ptr,
|
| 610 |
+
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
|
| 611 |
+
GROUP_SIZE: tl.constexpr, GPR: tl.constexpr,
|
| 612 |
+
E_ACCUM_THRESHOLD: tl.constexpr,
|
| 613 |
+
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
| 614 |
+
):
|
| 615 |
+
pid_n = tl.program_id(0)
|
| 616 |
+
pid_g = tl.program_id(1)
|
| 617 |
+
|
| 618 |
+
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 619 |
+
offs_r = tl.arange(0, BLOCK_K)
|
| 620 |
+
k = pid_g * GROUP_SIZE + offs_r
|
| 621 |
+
offs_m = tl.arange(0, BLOCK_M)
|
| 622 |
+
acc = tl.zeros((BLOCK_N, BLOCK_K), dtype=tl.float32)
|
| 623 |
+
|
| 624 |
+
for m0 in range(0, M, BLOCK_M):
|
| 625 |
+
m = m0 + offs_m
|
| 626 |
+
grad = tl.load(
|
| 627 |
+
grad_ptr + m[:, None] * N + offs_n[None, :],
|
| 628 |
+
mask=(m[:, None] < M) & (offs_n[None, :] < N),
|
| 629 |
+
other=0.0,
|
| 630 |
+
)
|
| 631 |
+
x = tl.load(
|
| 632 |
+
x_ptr + m[:, None] * K + k[None, :],
|
| 633 |
+
mask=(m[:, None] < M) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K),
|
| 634 |
+
other=0.0,
|
| 635 |
+
)
|
| 636 |
+
acc += tl.dot(tl.trans(grad), x, input_precision="ieee")
|
| 637 |
+
|
| 638 |
+
grad_sign = tl.where(acc > 0.0, 1, tl.where(acc < 0.0, -1, 0)).to(tl.int32)
|
| 639 |
+
lin = offs_n[:, None] * K + k[None, :]
|
| 640 |
+
pack_idx = lin // 5
|
| 641 |
+
trit_pos = lin - pack_idx * 5
|
| 642 |
+
packed = tl.load(
|
| 643 |
+
packed_ptr + pack_idx,
|
| 644 |
+
mask=(offs_n[:, None] < N) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K),
|
| 645 |
+
other=0,
|
| 646 |
+
).to(tl.int32)
|
| 647 |
+
divisor = tl.where(
|
| 648 |
+
trit_pos == 0, 1,
|
| 649 |
+
tl.where(trit_pos == 1, 3,
|
| 650 |
+
tl.where(trit_pos == 2, 9,
|
| 651 |
+
tl.where(trit_pos == 3, 27, 81))),
|
| 652 |
+
)
|
| 653 |
+
trit = (packed // divisor) % 3
|
| 654 |
+
ternary = trit.to(tl.int32) - 1
|
| 655 |
+
contrib = tl.where(
|
| 656 |
+
(offs_n[:, None] < N) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K),
|
| 657 |
+
grad_sign * ternary,
|
| 658 |
+
0,
|
| 659 |
+
)
|
| 660 |
+
score = tl.sum(contrib, axis=1)
|
| 661 |
+
delta = tl.where(score > 0, -1, tl.where(score < 0, 1, 0))
|
| 662 |
+
|
| 663 |
+
e_idx = offs_n * GPR + pid_g
|
| 664 |
+
old_accum = tl.load(e_accum_ptr + e_idx, mask=offs_n < N, other=0).to(tl.int32)
|
| 665 |
+
new_accum = tl.minimum(127, tl.maximum(-128, old_accum + delta))
|
| 666 |
+
step_up = new_accum >= E_ACCUM_THRESHOLD
|
| 667 |
+
step_down = new_accum <= -E_ACCUM_THRESHOLD
|
| 668 |
+
e_step = tl.where(step_up, 1, tl.where(step_down, -1, 0))
|
| 669 |
+
stored_accum = new_accum - e_step * E_ACCUM_THRESHOLD
|
| 670 |
+
|
| 671 |
+
old_e = tl.load(e_ptr + e_idx, mask=offs_n < N, other=0).to(tl.int32)
|
| 672 |
+
new_e = tl.minimum(127, tl.maximum(-128, old_e + e_step))
|
| 673 |
+
tl.store(e_ptr + e_idx, new_e.to(tl.int8), mask=offs_n < N)
|
| 674 |
+
tl.store(e_accum_ptr + e_idx, stored_accum.to(tl.int8), mask=offs_n < N)
|
| 675 |
+
|
| 676 |
+
|
| 677 |
+
@triton.jit
|
| 678 |
+
def _triton_ternary_step_direct_kernel(
|
| 679 |
+
packed_ptr, grad_ptr, x_ptr, accum_ptr, per_group_threshold_ptr,
|
| 680 |
+
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
|
| 681 |
+
TOTAL: tl.constexpr, ACCUM_THRESHOLD: tl.constexpr,
|
| 682 |
+
T_ACCUM_STEP: tl.constexpr,
|
| 683 |
+
GPR: tl.constexpr, GROUP_SIZE: tl.constexpr,
|
| 684 |
+
HAS_PER_GROUP_THRESHOLD: tl.constexpr,
|
| 685 |
+
BLOCK_M: tl.constexpr, BLOCK_T: tl.constexpr,
|
| 686 |
+
):
|
| 687 |
+
pack_idx = tl.program_id(0)
|
| 688 |
+
offs_t = tl.arange(0, BLOCK_T)
|
| 689 |
+
lin = pack_idx * 5 + offs_t
|
| 690 |
+
valid_trit = offs_t < 5
|
| 691 |
+
valid = valid_trit & (lin < TOTAL)
|
| 692 |
+
n = lin // K
|
| 693 |
+
k = lin - n * K
|
| 694 |
+
|
| 695 |
+
offs_m = tl.arange(0, BLOCK_M)
|
| 696 |
+
acc = tl.zeros((BLOCK_T,), dtype=tl.float32)
|
| 697 |
+
for m0 in range(0, M, BLOCK_M):
|
| 698 |
+
m = m0 + offs_m
|
| 699 |
+
grad = tl.load(
|
| 700 |
+
grad_ptr + m[:, None] * N + n[None, :],
|
| 701 |
+
mask=(m[:, None] < M) & valid[None, :],
|
| 702 |
+
other=0.0,
|
| 703 |
+
)
|
| 704 |
+
x = tl.load(
|
| 705 |
+
x_ptr + m[:, None] * K + k[None, :],
|
| 706 |
+
mask=(m[:, None] < M) & valid[None, :],
|
| 707 |
+
other=0.0,
|
| 708 |
+
)
|
| 709 |
+
acc += tl.sum(grad * x, axis=0)
|
| 710 |
+
|
| 711 |
+
grad_sign = tl.where(acc > 0.0, 1, tl.where(acc < 0.0, -1, 0)).to(tl.int32)
|
| 712 |
+
|
| 713 |
+
old_packed = tl.load(packed_ptr + pack_idx).to(tl.int32)
|
| 714 |
+
divisor = tl.where(
|
| 715 |
+
offs_t == 0, 1,
|
| 716 |
+
tl.where(offs_t == 1, 3,
|
| 717 |
+
tl.where(offs_t == 2, 9,
|
| 718 |
+
tl.where(offs_t == 3, 27, 81))),
|
| 719 |
+
)
|
| 720 |
+
old_code = (old_packed // divisor) % 3
|
| 721 |
+
old_sign = old_code.to(tl.int32) - 1
|
| 722 |
+
|
| 723 |
+
old_accum = tl.load(accum_ptr + lin, mask=valid, other=0).to(tl.int32)
|
| 724 |
+
new_accum = tl.minimum(127, tl.maximum(-128, old_accum - grad_sign * T_ACCUM_STEP))
|
| 725 |
+
|
| 726 |
+
if HAS_PER_GROUP_THRESHOLD:
|
| 727 |
+
g_idx = n * GPR + k // GROUP_SIZE
|
| 728 |
+
threshold = tl.load(per_group_threshold_ptr + g_idx, mask=valid, other=ACCUM_THRESHOLD).to(tl.int32)
|
| 729 |
+
else:
|
| 730 |
+
threshold = ACCUM_THRESHOLD
|
| 731 |
+
|
| 732 |
+
flip_up = new_accum > threshold
|
| 733 |
+
flip_down = new_accum < -threshold
|
| 734 |
+
did_flip = valid & (flip_up | flip_down)
|
| 735 |
+
new_sign = tl.where(flip_up, 1, tl.where(flip_down, -1, old_sign))
|
| 736 |
+
stored_accum = tl.where(did_flip, 0, new_accum)
|
| 737 |
+
tl.store(accum_ptr + lin, stored_accum.to(tl.int8), mask=valid)
|
| 738 |
+
|
| 739 |
+
new_code = tl.where(valid, new_sign + 1, 0)
|
| 740 |
+
packed_val = tl.sum(new_code * divisor, axis=0)
|
| 741 |
+
tl.store(packed_ptr + pack_idx, packed_val.to(tl.uint8))
|
| 742 |
+
|
| 743 |
+
|
| 744 |
+
@triton.jit
|
| 745 |
+
def _triton_accumulate_t_direct_kernel(
|
| 746 |
+
grad_ptr, x_ptr, accum_ptr,
|
| 747 |
+
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
|
| 748 |
+
TOTAL: tl.constexpr, T_ACCUM_STEP: tl.constexpr,
|
| 749 |
+
BLOCK_M: tl.constexpr, BLOCK_T: tl.constexpr,
|
| 750 |
+
):
|
| 751 |
+
pack_idx = tl.program_id(0)
|
| 752 |
+
offs_t = tl.arange(0, BLOCK_T)
|
| 753 |
+
lin = pack_idx * 5 + offs_t
|
| 754 |
+
valid_trit = offs_t < 5
|
| 755 |
+
valid = valid_trit & (lin < TOTAL)
|
| 756 |
+
n = lin // K
|
| 757 |
+
k = lin - n * K
|
| 758 |
+
|
| 759 |
+
offs_m = tl.arange(0, BLOCK_M)
|
| 760 |
+
acc = tl.zeros((BLOCK_T,), dtype=tl.float32)
|
| 761 |
+
for m0 in range(0, M, BLOCK_M):
|
| 762 |
+
m = m0 + offs_m
|
| 763 |
+
grad = tl.load(
|
| 764 |
+
grad_ptr + m[:, None] * N + n[None, :],
|
| 765 |
+
mask=(m[:, None] < M) & valid[None, :],
|
| 766 |
+
other=0.0,
|
| 767 |
+
)
|
| 768 |
+
x = tl.load(
|
| 769 |
+
x_ptr + m[:, None] * K + k[None, :],
|
| 770 |
+
mask=(m[:, None] < M) & valid[None, :],
|
| 771 |
+
other=0.0,
|
| 772 |
+
)
|
| 773 |
+
acc += tl.sum(grad * x, axis=0)
|
| 774 |
+
|
| 775 |
+
grad_sign = tl.where(acc > 0.0, 1, tl.where(acc < 0.0, -1, 0)).to(tl.int32)
|
| 776 |
+
old_accum = tl.load(accum_ptr + lin, mask=valid, other=0).to(tl.int32)
|
| 777 |
+
new_accum = tl.minimum(127, tl.maximum(-128, old_accum - grad_sign * T_ACCUM_STEP))
|
| 778 |
+
tl.store(accum_ptr + lin, new_accum.to(tl.int8), mask=valid)
|
| 779 |
+
|
| 780 |
+
|
| 781 |
+
@triton.jit
|
| 782 |
+
def _triton_accumulate_e_direct_kernel(
|
| 783 |
+
packed_ptr, grad_ptr, x_ptr, e_accum_ptr,
|
| 784 |
+
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
|
| 785 |
+
GROUP_SIZE: tl.constexpr, GPR: tl.constexpr,
|
| 786 |
+
E_ACCUM_STEP: tl.constexpr,
|
| 787 |
+
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
| 788 |
+
):
|
| 789 |
+
pid_n = tl.program_id(0)
|
| 790 |
+
pid_g = tl.program_id(1)
|
| 791 |
+
|
| 792 |
+
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 793 |
+
offs_r = tl.arange(0, BLOCK_K)
|
| 794 |
+
k = pid_g * GROUP_SIZE + offs_r
|
| 795 |
+
offs_m = tl.arange(0, BLOCK_M)
|
| 796 |
+
acc = tl.zeros((BLOCK_N, BLOCK_K), dtype=tl.float32)
|
| 797 |
+
|
| 798 |
+
for m0 in range(0, M, BLOCK_M):
|
| 799 |
+
m = m0 + offs_m
|
| 800 |
+
grad = tl.load(
|
| 801 |
+
grad_ptr + m[:, None] * N + offs_n[None, :],
|
| 802 |
+
mask=(m[:, None] < M) & (offs_n[None, :] < N),
|
| 803 |
+
other=0.0,
|
| 804 |
+
)
|
| 805 |
+
x = tl.load(
|
| 806 |
+
x_ptr + m[:, None] * K + k[None, :],
|
| 807 |
+
mask=(m[:, None] < M) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K),
|
| 808 |
+
other=0.0,
|
| 809 |
+
)
|
| 810 |
+
acc += tl.dot(tl.trans(grad), x, input_precision="ieee")
|
| 811 |
+
|
| 812 |
+
grad_sign = tl.where(acc > 0.0, 1, tl.where(acc < 0.0, -1, 0)).to(tl.int32)
|
| 813 |
+
lin = offs_n[:, None] * K + k[None, :]
|
| 814 |
+
pack_idx = lin // 5
|
| 815 |
+
trit_pos = lin - pack_idx * 5
|
| 816 |
+
packed = tl.load(
|
| 817 |
+
packed_ptr + pack_idx,
|
| 818 |
+
mask=(offs_n[:, None] < N) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K),
|
| 819 |
+
other=0,
|
| 820 |
+
).to(tl.int32)
|
| 821 |
+
divisor = tl.where(
|
| 822 |
+
trit_pos == 0, 1,
|
| 823 |
+
tl.where(trit_pos == 1, 3,
|
| 824 |
+
tl.where(trit_pos == 2, 9,
|
| 825 |
+
tl.where(trit_pos == 3, 27, 81))),
|
| 826 |
+
)
|
| 827 |
+
trit = (packed // divisor) % 3
|
| 828 |
+
ternary = trit.to(tl.int32) - 1
|
| 829 |
+
contrib = tl.where(
|
| 830 |
+
(offs_n[:, None] < N) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K),
|
| 831 |
+
grad_sign * ternary,
|
| 832 |
+
0,
|
| 833 |
+
)
|
| 834 |
+
score = tl.sum(contrib, axis=1)
|
| 835 |
+
delta = tl.where(score > 0, -1, tl.where(score < 0, 1, 0))
|
| 836 |
+
|
| 837 |
+
e_idx = offs_n * GPR + pid_g
|
| 838 |
+
old_accum = tl.load(e_accum_ptr + e_idx, mask=offs_n < N, other=0).to(tl.int32)
|
| 839 |
+
new_accum = tl.minimum(127, tl.maximum(-128, old_accum + delta * E_ACCUM_STEP))
|
| 840 |
+
tl.store(e_accum_ptr + e_idx, new_accum.to(tl.int8), mask=offs_n < N)
|
| 841 |
+
|
| 842 |
+
|
| 843 |
+
@triton.jit
|
| 844 |
+
def _triton_accumulate_corr_direct_kernel(
|
| 845 |
+
packed_ptr, grad_ptr, x_ptr, corr_ptr,
|
| 846 |
+
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
|
| 847 |
+
GROUP_SIZE: tl.constexpr, GPR: tl.constexpr,
|
| 848 |
+
CORR_STEP: tl.constexpr,
|
| 849 |
+
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
| 850 |
+
):
|
| 851 |
+
pid_n = tl.program_id(0)
|
| 852 |
+
pid_g = tl.program_id(1)
|
| 853 |
+
|
| 854 |
+
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 855 |
+
offs_r = tl.arange(0, BLOCK_K)
|
| 856 |
+
k = pid_g * GROUP_SIZE + offs_r
|
| 857 |
+
offs_m = tl.arange(0, BLOCK_M)
|
| 858 |
+
acc = tl.zeros((BLOCK_N, BLOCK_K), dtype=tl.float32)
|
| 859 |
+
|
| 860 |
+
for m0 in range(0, M, BLOCK_M):
|
| 861 |
+
m = m0 + offs_m
|
| 862 |
+
grad = tl.load(
|
| 863 |
+
grad_ptr + m[:, None] * N + offs_n[None, :],
|
| 864 |
+
mask=(m[:, None] < M) & (offs_n[None, :] < N),
|
| 865 |
+
other=0.0,
|
| 866 |
+
)
|
| 867 |
+
x = tl.load(
|
| 868 |
+
x_ptr + m[:, None] * K + k[None, :],
|
| 869 |
+
mask=(m[:, None] < M) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K),
|
| 870 |
+
other=0.0,
|
| 871 |
+
)
|
| 872 |
+
acc += tl.dot(tl.trans(grad), x, input_precision="ieee")
|
| 873 |
+
|
| 874 |
+
grad_sign = tl.where(acc > 0.0, 1, tl.where(acc < 0.0, -1, 0)).to(tl.int32)
|
| 875 |
+
lin = offs_n[:, None] * K + k[None, :]
|
| 876 |
+
pack_idx = lin // 5
|
| 877 |
+
trit_pos = lin - pack_idx * 5
|
| 878 |
+
packed = tl.load(
|
| 879 |
+
packed_ptr + pack_idx,
|
| 880 |
+
mask=(offs_n[:, None] < N) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K),
|
| 881 |
+
other=0,
|
| 882 |
+
).to(tl.int32)
|
| 883 |
+
divisor = tl.where(
|
| 884 |
+
trit_pos == 0, 1,
|
| 885 |
+
tl.where(trit_pos == 1, 3,
|
| 886 |
+
tl.where(trit_pos == 2, 9,
|
| 887 |
+
tl.where(trit_pos == 3, 27, 81))),
|
| 888 |
+
)
|
| 889 |
+
trit = (packed // divisor) % 3
|
| 890 |
+
ternary = trit.to(tl.int32) - 1
|
| 891 |
+
contrib = tl.where(
|
| 892 |
+
(offs_n[:, None] < N) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K),
|
| 893 |
+
grad_sign * ternary,
|
| 894 |
+
0,
|
| 895 |
+
)
|
| 896 |
+
score = tl.sum(contrib, axis=1)
|
| 897 |
+
|
| 898 |
+
corr_idx = offs_n * GPR + pid_g
|
| 899 |
+
old_corr = tl.load(corr_ptr + corr_idx, mask=offs_n < N, other=0).to(tl.int64)
|
| 900 |
+
new_corr = old_corr - score.to(tl.int64) * CORR_STEP
|
| 901 |
+
tl.store(corr_ptr + corr_idx, new_corr, mask=offs_n < N)
|
| 902 |
+
|
| 903 |
+
|
| 904 |
+
@triton.jit
|
| 905 |
+
def _triton_apply_accumulated_flips_kernel(
|
| 906 |
+
packed_ptr, accum_ptr, per_group_threshold_ptr,
|
| 907 |
+
TOTAL: tl.constexpr, ACCUM_THRESHOLD: tl.constexpr,
|
| 908 |
+
K: tl.constexpr, GPR: tl.constexpr, GROUP_SIZE: tl.constexpr,
|
| 909 |
+
HAS_PER_GROUP_THRESHOLD: tl.constexpr,
|
| 910 |
+
BLOCK_T: tl.constexpr,
|
| 911 |
+
):
|
| 912 |
+
pack_idx = tl.program_id(0)
|
| 913 |
+
offs_t = tl.arange(0, BLOCK_T)
|
| 914 |
+
valid_trit = offs_t < 5
|
| 915 |
+
lin = pack_idx * 5 + offs_t
|
| 916 |
+
valid = valid_trit & (lin < TOTAL)
|
| 917 |
+
|
| 918 |
+
old_packed = tl.load(packed_ptr + pack_idx).to(tl.int32)
|
| 919 |
+
divisor = tl.where(
|
| 920 |
+
offs_t == 0, 1,
|
| 921 |
+
tl.where(offs_t == 1, 3,
|
| 922 |
+
tl.where(offs_t == 2, 9,
|
| 923 |
+
tl.where(offs_t == 3, 27, 81))),
|
| 924 |
+
)
|
| 925 |
+
old_code = (old_packed // divisor) % 3
|
| 926 |
+
old_sign = old_code.to(tl.int32) - 1
|
| 927 |
+
|
| 928 |
+
old_accum = tl.load(accum_ptr + lin, mask=valid, other=0).to(tl.int32)
|
| 929 |
+
if HAS_PER_GROUP_THRESHOLD:
|
| 930 |
+
n = lin // K
|
| 931 |
+
k = lin - n * K
|
| 932 |
+
g_idx = n * GPR + k // GROUP_SIZE
|
| 933 |
+
threshold = tl.load(per_group_threshold_ptr + g_idx, mask=valid, other=ACCUM_THRESHOLD).to(tl.int32)
|
| 934 |
+
else:
|
| 935 |
+
threshold = ACCUM_THRESHOLD
|
| 936 |
+
|
| 937 |
+
flip_up = old_accum > threshold
|
| 938 |
+
flip_down = old_accum < -threshold
|
| 939 |
+
did_flip = valid & (flip_up | flip_down)
|
| 940 |
+
new_sign = tl.where(flip_up, 1, tl.where(flip_down, -1, old_sign))
|
| 941 |
+
stored_accum = tl.where(did_flip, 0, old_accum)
|
| 942 |
+
tl.store(accum_ptr + lin, stored_accum.to(tl.int8), mask=valid)
|
| 943 |
+
|
| 944 |
+
new_code = tl.where(valid, new_sign + 1, 0)
|
| 945 |
+
packed_val = tl.sum(new_code * divisor, axis=0)
|
| 946 |
+
tl.store(packed_ptr + pack_idx, packed_val.to(tl.uint8))
|
| 947 |
+
|
| 948 |
+
|
| 949 |
+
def _triton_ternary_forward(x_2d, packed, e, corr_accum, step_counter, n_out, k_in, group_size):
|
| 950 |
+
block_m, block_n, block_k = 16, 16, 32
|
| 951 |
+
out = torch.empty((x_2d.shape[0], n_out), device=x_2d.device, dtype=torch.float32)
|
| 952 |
+
grid = (triton.cdiv(x_2d.shape[0], block_m), triton.cdiv(n_out, block_n))
|
| 953 |
+
_triton_ternary_fwd_kernel[grid](
|
| 954 |
+
x_2d, packed, e, corr_accum, step_counter, out,
|
| 955 |
+
x_2d.shape[0], n_out, k_in, ceil(k_in / group_size), group_size,
|
| 956 |
+
_bigint_corr_strength(),
|
| 957 |
+
BLOCK_M=block_m, BLOCK_N=block_n, BLOCK_K=block_k,
|
| 958 |
+
)
|
| 959 |
+
return out
|
| 960 |
+
|
| 961 |
+
|
| 962 |
+
def _triton_ternary_grad_x(grad_2d, packed, e, corr_accum, step_counter, m_rows, n_out, k_in, group_size):
|
| 963 |
+
block_m, block_n, block_k = 16, 16, 32
|
| 964 |
+
out = torch.empty((m_rows, k_in), device=grad_2d.device, dtype=torch.float32)
|
| 965 |
+
grid = (triton.cdiv(m_rows, block_m), triton.cdiv(k_in, block_k))
|
| 966 |
+
_triton_ternary_grad_x_kernel[grid](
|
| 967 |
+
grad_2d, packed, e, corr_accum, step_counter, out,
|
| 968 |
+
m_rows, n_out, k_in, ceil(k_in / group_size), group_size,
|
| 969 |
+
_bigint_corr_strength(),
|
| 970 |
+
BLOCK_M=block_m, BLOCK_N=block_n, BLOCK_K=block_k,
|
| 971 |
+
)
|
| 972 |
+
return out
|
| 973 |
+
|
| 974 |
+
|
| 975 |
+
def _triton_ternary_grad_sign(grad_2d, x_2d, n_out, k_in):
|
| 976 |
+
block_m, block_n, block_k = 32, 16, 32
|
| 977 |
+
out = torch.empty((n_out, k_in), device=grad_2d.device, dtype=torch.int8)
|
| 978 |
+
grid = (triton.cdiv(n_out, block_n), triton.cdiv(k_in, block_k))
|
| 979 |
+
_triton_ternary_grad_sign_kernel[grid](
|
| 980 |
+
grad_2d, x_2d, out,
|
| 981 |
+
x_2d.shape[0], n_out, k_in,
|
| 982 |
+
BLOCK_M=block_m, BLOCK_N=block_n, BLOCK_K=block_k,
|
| 983 |
+
)
|
| 984 |
+
return out
|
| 985 |
+
|
| 986 |
+
|
| 987 |
+
def _triton_update_e(packed, grad_sign, e, e_accum, n_out, k_in, group_size, e_accum_threshold=4):
|
| 988 |
+
block_n, block_g = 8, 4
|
| 989 |
+
gpr = ceil(k_in / group_size)
|
| 990 |
+
block_k = 1 << (group_size - 1).bit_length()
|
| 991 |
+
grid = (triton.cdiv(n_out, block_n), triton.cdiv(gpr, block_g))
|
| 992 |
+
_triton_update_e_kernel[grid](
|
| 993 |
+
packed, grad_sign, e, e_accum,
|
| 994 |
+
n_out, k_in, group_size, gpr, int(e_accum_threshold),
|
| 995 |
+
BLOCK_N=block_n, BLOCK_G=block_g, BLOCK_K=block_k,
|
| 996 |
+
)
|
| 997 |
+
|
| 998 |
+
|
| 999 |
+
def _triton_update_e_direct(packed, grad_2d, x_2d, e, e_accum, n_out, k_in, group_size, e_accum_threshold=4):
|
| 1000 |
+
block_m, block_n = 32, 8
|
| 1001 |
+
block_k = 1 << (group_size - 1).bit_length()
|
| 1002 |
+
gpr = ceil(k_in / group_size)
|
| 1003 |
+
grid = (triton.cdiv(n_out, block_n), gpr)
|
| 1004 |
+
_triton_update_e_direct_kernel[grid](
|
| 1005 |
+
packed, grad_2d, x_2d, e, e_accum,
|
| 1006 |
+
x_2d.shape[0], n_out, k_in, group_size, gpr, int(e_accum_threshold),
|
| 1007 |
+
BLOCK_M=block_m, BLOCK_N=block_n, BLOCK_K=block_k,
|
| 1008 |
+
)
|
| 1009 |
+
|
| 1010 |
+
|
| 1011 |
+
def _triton_ternary_step(packed, grad_sign, accum, total, accum_threshold, t_accum_step=1,
|
| 1012 |
+
per_group_threshold=None, n_out=0, k_in=0, group_size=0):
|
| 1013 |
+
block_t = 8
|
| 1014 |
+
grid = (triton.cdiv(total, 5),)
|
| 1015 |
+
has_pgt = per_group_threshold is not None
|
| 1016 |
+
dummy = torch.empty(1, device=accum.device, dtype=torch.int8)
|
| 1017 |
+
gpr = (k_in + group_size - 1) // group_size if has_pgt else 0
|
| 1018 |
+
_triton_ternary_step_kernel[grid](
|
| 1019 |
+
packed, grad_sign, accum,
|
| 1020 |
+
per_group_threshold if has_pgt else dummy,
|
| 1021 |
+
total, accum_threshold, int(t_accum_step),
|
| 1022 |
+
k_in if has_pgt else 0, gpr, group_size if has_pgt else 0,
|
| 1023 |
+
has_pgt,
|
| 1024 |
+
BLOCK_T=block_t,
|
| 1025 |
+
)
|
| 1026 |
+
|
| 1027 |
+
|
| 1028 |
+
def _triton_ternary_step_direct(packed, grad_2d, x_2d, accum, n_out, k_in, total, accum_threshold, t_accum_step=1,
|
| 1029 |
+
per_group_threshold=None, group_size=0):
|
| 1030 |
+
block_m, block_t = 32, 8
|
| 1031 |
+
grid = (triton.cdiv(total, 5),)
|
| 1032 |
+
has_pgt = per_group_threshold is not None
|
| 1033 |
+
dummy = torch.empty(1, device=accum.device, dtype=torch.int8)
|
| 1034 |
+
gpr = (k_in + group_size - 1) // group_size if has_pgt else 0
|
| 1035 |
+
_triton_ternary_step_direct_kernel[grid](
|
| 1036 |
+
packed, grad_2d, x_2d, accum,
|
| 1037 |
+
per_group_threshold if has_pgt else dummy,
|
| 1038 |
+
x_2d.shape[0], n_out, k_in,
|
| 1039 |
+
total, accum_threshold, int(t_accum_step),
|
| 1040 |
+
gpr, group_size if has_pgt else 0,
|
| 1041 |
+
has_pgt,
|
| 1042 |
+
BLOCK_M=block_m, BLOCK_T=block_t,
|
| 1043 |
+
)
|
| 1044 |
+
|
| 1045 |
+
|
| 1046 |
+
def _triton_accumulate_direct(packed, grad_2d, x_2d, t_accum, e_accum,
|
| 1047 |
+
n_out, k_in, group_size,
|
| 1048 |
+
t_accum_step=1, e_accum_step=1,
|
| 1049 |
+
update_scales=True):
|
| 1050 |
+
block_m, block_t = 32, 8
|
| 1051 |
+
total = n_out * k_in
|
| 1052 |
+
grid = (triton.cdiv(total, 5),)
|
| 1053 |
+
_triton_accumulate_t_direct_kernel[grid](
|
| 1054 |
+
grad_2d, x_2d, t_accum,
|
| 1055 |
+
grad_2d.shape[0], n_out, k_in, total, int(t_accum_step),
|
| 1056 |
+
BLOCK_M=block_m, BLOCK_T=block_t,
|
| 1057 |
+
)
|
| 1058 |
+
if update_scales and e_accum is not None:
|
| 1059 |
+
block_n = 8
|
| 1060 |
+
block_k = 1 << (group_size - 1).bit_length()
|
| 1061 |
+
gpr = ceil(k_in / group_size)
|
| 1062 |
+
grid_e = (triton.cdiv(n_out, block_n), gpr)
|
| 1063 |
+
_triton_accumulate_e_direct_kernel[grid_e](
|
| 1064 |
+
packed, grad_2d, x_2d, e_accum,
|
| 1065 |
+
grad_2d.shape[0], n_out, k_in, group_size, gpr, int(e_accum_step),
|
| 1066 |
+
BLOCK_M=block_m, BLOCK_N=block_n, BLOCK_K=block_k,
|
| 1067 |
+
)
|
| 1068 |
+
|
| 1069 |
+
|
| 1070 |
+
def _triton_accumulate_corr_direct(packed, grad_2d, x_2d, corr_accum,
|
| 1071 |
+
n_out, k_in, group_size, corr_step=1):
|
| 1072 |
+
block_m, block_n = 32, 8
|
| 1073 |
+
block_k = 1 << (group_size - 1).bit_length()
|
| 1074 |
+
gpr = ceil(k_in / group_size)
|
| 1075 |
+
grid = (triton.cdiv(n_out, block_n), gpr)
|
| 1076 |
+
_triton_accumulate_corr_direct_kernel[grid](
|
| 1077 |
+
packed, grad_2d, x_2d, corr_accum,
|
| 1078 |
+
grad_2d.shape[0], n_out, k_in, group_size, gpr, int(corr_step),
|
| 1079 |
+
BLOCK_M=block_m, BLOCK_N=block_n, BLOCK_K=block_k,
|
| 1080 |
+
)
|
| 1081 |
+
|
| 1082 |
+
|
| 1083 |
+
def _triton_apply_accumulated_flips(packed, accum, total, accum_threshold,
|
| 1084 |
+
per_group_threshold=None,
|
| 1085 |
+
k_in=0, group_size=0):
|
| 1086 |
+
block_t = 8
|
| 1087 |
+
grid = (triton.cdiv(total, 5),)
|
| 1088 |
+
has_pgt = per_group_threshold is not None
|
| 1089 |
+
dummy = torch.empty(1, device=accum.device, dtype=torch.int8)
|
| 1090 |
+
gpr = (k_in + group_size - 1) // group_size if has_pgt else 0
|
| 1091 |
+
_triton_apply_accumulated_flips_kernel[grid](
|
| 1092 |
+
packed, accum,
|
| 1093 |
+
per_group_threshold if has_pgt else dummy,
|
| 1094 |
+
total, accum_threshold,
|
| 1095 |
+
k_in if has_pgt else 0, gpr, group_size if has_pgt else 0,
|
| 1096 |
+
has_pgt,
|
| 1097 |
+
BLOCK_T=block_t,
|
| 1098 |
+
)
|
| 1099 |
+
|
| 1100 |
+
|
| 1101 |
+
@triton.jit
|
| 1102 |
+
def _triton_ternary_embed_fwd_kernel(
|
| 1103 |
+
idx_ptr, packed_ptr, e_ptr, out_ptr,
|
| 1104 |
+
NUM_IDX: tl.constexpr, DIM: tl.constexpr,
|
| 1105 |
+
VOCAB: tl.constexpr, GPR: tl.constexpr, GROUP_SIZE: tl.constexpr,
|
| 1106 |
+
BLOCK_B: tl.constexpr, BLOCK_D: tl.constexpr,
|
| 1107 |
+
):
|
| 1108 |
+
pid = tl.program_id(0)
|
| 1109 |
+
offs_b = pid * BLOCK_B + tl.arange(0, BLOCK_B)
|
| 1110 |
+
offs_d = tl.arange(0, BLOCK_D)
|
| 1111 |
+
|
| 1112 |
+
idx = tl.load(idx_ptr + offs_b, mask=offs_b < NUM_IDX, other=0).to(tl.int32)
|
| 1113 |
+
|
| 1114 |
+
lin = idx[:, None] * DIM + offs_d[None, :]
|
| 1115 |
+
pack_idx = lin // 5
|
| 1116 |
+
trit_pos = lin - pack_idx * 5
|
| 1117 |
+
packed = tl.load(packed_ptr + pack_idx, mask=(offs_b[:, None] < NUM_IDX) & (offs_d[None, :] < DIM), other=0).to(tl.int32)
|
| 1118 |
+
divisor = tl.where(
|
| 1119 |
+
trit_pos == 0, 1,
|
| 1120 |
+
tl.where(trit_pos == 1, 3,
|
| 1121 |
+
tl.where(trit_pos == 2, 9,
|
| 1122 |
+
tl.where(trit_pos == 3, 27, 81))),
|
| 1123 |
+
)
|
| 1124 |
+
trit = (packed // divisor) % 3
|
| 1125 |
+
sign = trit.to(tl.int32) - 1
|
| 1126 |
+
|
| 1127 |
+
e_idx = idx[:, None] * GPR + offs_d[None, :] // GROUP_SIZE
|
| 1128 |
+
e_val = tl.load(e_ptr + e_idx, mask=(offs_b[:, None] < NUM_IDX) & (offs_d[None, :] < DIM), other=0).to(tl.float32)
|
| 1129 |
+
w = sign.to(tl.float32) * tl.exp2(e_val)
|
| 1130 |
+
w = tl.where((offs_b[:, None] < NUM_IDX) & (offs_d[None, :] < DIM), w, 0.0)
|
| 1131 |
+
|
| 1132 |
+
tl.store(
|
| 1133 |
+
out_ptr + offs_b[:, None] * DIM + offs_d[None, :],
|
| 1134 |
+
w,
|
| 1135 |
+
mask=(offs_b[:, None] < NUM_IDX) & (offs_d[None, :] < DIM),
|
| 1136 |
+
)
|
| 1137 |
+
|
| 1138 |
+
|
| 1139 |
+
@triton.jit
|
| 1140 |
+
def _triton_ternary_embed_bwd_accum_kernel(
|
| 1141 |
+
idx_ptr, grad_ptr, accum_ptr,
|
| 1142 |
+
NUM_IDX: tl.constexpr, DIM: tl.constexpr,
|
| 1143 |
+
BLOCK_B: tl.constexpr, BLOCK_D: tl.constexpr,
|
| 1144 |
+
):
|
| 1145 |
+
pid = tl.program_id(0)
|
| 1146 |
+
offs_b = pid * BLOCK_B + tl.arange(0, BLOCK_B)
|
| 1147 |
+
offs_d = tl.arange(0, BLOCK_D)
|
| 1148 |
+
valid = (offs_b[:, None] < NUM_IDX) & (offs_d[None, :] < DIM)
|
| 1149 |
+
idx = tl.load(idx_ptr + offs_b, mask=offs_b < NUM_IDX, other=0).to(tl.int32)
|
| 1150 |
+
g = tl.load(grad_ptr + offs_b[:, None] * DIM + offs_d[None, :], mask=valid, other=0.0)
|
| 1151 |
+
dst = idx[:, None] * DIM + offs_d[None, :]
|
| 1152 |
+
tl.atomic_add(accum_ptr + dst, g, mask=valid)
|
| 1153 |
+
|
| 1154 |
+
|
| 1155 |
+
@triton.jit
|
| 1156 |
+
def _triton_ternary_embed_bwd_sign_kernel(
|
| 1157 |
+
accum_ptr, sign_ptr,
|
| 1158 |
+
VOCAB: tl.constexpr, DIM: tl.constexpr,
|
| 1159 |
+
BLOCK_V: tl.constexpr, BLOCK_D: tl.constexpr,
|
| 1160 |
+
):
|
| 1161 |
+
pid_v = tl.program_id(0)
|
| 1162 |
+
offs_v = pid_v * BLOCK_V + tl.arange(0, BLOCK_V)
|
| 1163 |
+
offs_d = tl.arange(0, BLOCK_D)
|
| 1164 |
+
valid = (offs_v[:, None] < VOCAB) & (offs_d[None, :] < DIM)
|
| 1165 |
+
acc = tl.load(accum_ptr + offs_v[:, None] * DIM + offs_d[None, :], mask=valid, other=0.0)
|
| 1166 |
+
sign_val = tl.where(acc > 0.0, 1, tl.where(acc < 0.0, -1, 0)).to(tl.int8)
|
| 1167 |
+
tl.store(sign_ptr + offs_v[:, None] * DIM + offs_d[None, :], sign_val, mask=valid)
|
| 1168 |
+
|
| 1169 |
+
|
| 1170 |
+
def _triton_ternary_embed_grad_sign(indices, grad_output, vocab, dim):
|
| 1171 |
+
flat_idx = indices.reshape(-1).contiguous().to(torch.int32)
|
| 1172 |
+
grad_2d = grad_output.reshape(-1, dim).contiguous()
|
| 1173 |
+
num_idx = flat_idx.shape[0]
|
| 1174 |
+
accum = torch.zeros(vocab, dim, device=grad_output.device, dtype=torch.float32)
|
| 1175 |
+
block_b = 64
|
| 1176 |
+
grid = (triton.cdiv(num_idx, block_b),)
|
| 1177 |
+
_triton_ternary_embed_bwd_accum_kernel[grid](
|
| 1178 |
+
flat_idx, grad_2d, accum,
|
| 1179 |
+
num_idx, dim,
|
| 1180 |
+
BLOCK_B=block_b, BLOCK_D=triton.next_power_of_2(dim),
|
| 1181 |
+
)
|
| 1182 |
+
sign_out = torch.empty(vocab, dim, device=grad_output.device, dtype=torch.int8)
|
| 1183 |
+
block_v = 32
|
| 1184 |
+
grid2 = (triton.cdiv(vocab, block_v),)
|
| 1185 |
+
_triton_ternary_embed_bwd_sign_kernel[grid2](
|
| 1186 |
+
accum, sign_out,
|
| 1187 |
+
vocab, dim,
|
| 1188 |
+
BLOCK_V=block_v, BLOCK_D=triton.next_power_of_2(dim),
|
| 1189 |
+
)
|
| 1190 |
+
return sign_out
|
| 1191 |
+
|
| 1192 |
+
|
| 1193 |
+
def _triton_ternary_embed(indices, packed, e, vocab, dim, group_size):
|
| 1194 |
+
flat_idx = indices.reshape(-1).contiguous().to(torch.int32)
|
| 1195 |
+
num_idx = flat_idx.shape[0]
|
| 1196 |
+
out = torch.empty((num_idx, dim), device=indices.device, dtype=torch.float32)
|
| 1197 |
+
block_b, block_d = 32, triton.next_power_of_2(dim)
|
| 1198 |
+
gpr = ceil(dim / group_size)
|
| 1199 |
+
grid = (triton.cdiv(num_idx, block_b),)
|
| 1200 |
+
_triton_ternary_embed_fwd_kernel[grid](
|
| 1201 |
+
flat_idx, packed, e, out,
|
| 1202 |
+
num_idx, dim, vocab, gpr, group_size,
|
| 1203 |
+
BLOCK_B=block_b, BLOCK_D=block_d,
|
| 1204 |
+
)
|
| 1205 |
+
return out.reshape(*indices.shape, dim)
|
| 1206 |
+
|
| 1207 |
+
|
| 1208 |
+
class _TritonTernaryEmbedFn(torch.autograd.Function):
|
| 1209 |
+
@staticmethod
|
| 1210 |
+
def forward(ctx, indices, _dummy, module):
|
| 1211 |
+
shape = tuple(module._T_shape.tolist())
|
| 1212 |
+
vocab, dim = shape
|
| 1213 |
+
packed = module.T_packed.contiguous()
|
| 1214 |
+
e = module.E.contiguous()
|
| 1215 |
+
ctx.save_for_backward(indices, packed, e)
|
| 1216 |
+
ctx.module = module
|
| 1217 |
+
ctx.shape = shape
|
| 1218 |
+
ctx.group_size = module.group_size
|
| 1219 |
+
comp_name, _ = _COMPONENT_CONTEXT.get()
|
| 1220 |
+
ctx.comp_name = comp_name
|
| 1221 |
+
return _triton_ternary_embed(indices, packed, e, vocab, dim, module.group_size)
|
| 1222 |
+
|
| 1223 |
+
@staticmethod
|
| 1224 |
+
def backward(ctx, grad_output):
|
| 1225 |
+
indices, packed, e = ctx.saved_tensors
|
| 1226 |
+
vocab, dim = ctx.shape
|
| 1227 |
+
grad_2d = grad_output.reshape(-1, dim).contiguous()
|
| 1228 |
+
comp_name = ctx.comp_name
|
| 1229 |
+
has_corr = hasattr(ctx.module, "corr_accum") and hasattr(ctx.module, "_accumulate_corr_from_grad_sign")
|
| 1230 |
+
if getattr(ctx.module, "_stream_backward_updates", True) and has_corr:
|
| 1231 |
+
# BigInt streaming: accumulate correlation directly
|
| 1232 |
+
grad_sign = _triton_ternary_embed_grad_sign(indices, grad_2d, vocab, dim)
|
| 1233 |
+
T = unpack_ternary(packed, tuple(ctx.module._T_shape.tolist()), int(ctx.module._T_pad.item())).to(device=grad_sign.device)
|
| 1234 |
+
signed = grad_sign.to(torch.int16) * T.to(torch.int16)
|
| 1235 |
+
ctx.module._accumulate_corr_from_grad_sign(grad_sign)
|
| 1236 |
+
ctx.module._streamed_bigint_backward = True
|
| 1237 |
+
elif comp_name is not None:
|
| 1238 |
+
setattr(ctx.module, f"_hook_grad_T_sign_{comp_name}", _triton_ternary_embed_grad_sign(indices, grad_2d, vocab, dim))
|
| 1239 |
+
T = unpack_ternary(packed, tuple(ctx.module._T_shape.tolist()), int(ctx.module._T_pad.item()))
|
| 1240 |
+
setattr(ctx.module, f"_hook_T_{comp_name}", T.to(device=grad_2d.device))
|
| 1241 |
+
else:
|
| 1242 |
+
ctx.module._hook_grad_T_sign = _triton_ternary_embed_grad_sign(indices, grad_2d, vocab, dim)
|
| 1243 |
+
T = unpack_ternary(packed, tuple(ctx.module._T_shape.tolist()), int(ctx.module._T_pad.item()))
|
| 1244 |
+
ctx.module._hook_T = T.to(device=grad_2d.device)
|
| 1245 |
+
return None, None, None
|
| 1246 |
+
|
| 1247 |
+
|
| 1248 |
+
class _TritonTernaryLinearFn(torch.autograd.Function):
|
| 1249 |
+
@staticmethod
|
| 1250 |
+
def forward(ctx, x, module):
|
| 1251 |
+
shape = tuple(module._T_shape.tolist())
|
| 1252 |
+
n_out, k_in = shape
|
| 1253 |
+
x_2d = x.reshape(-1, k_in).contiguous()
|
| 1254 |
+
packed = module.T_packed.contiguous()
|
| 1255 |
+
e = module.E.contiguous()
|
| 1256 |
+
ctx.save_for_backward(x_2d, packed, e)
|
| 1257 |
+
ctx.step_snapshot = int(module.step_counter.item())
|
| 1258 |
+
ctx.x_shape = x.shape
|
| 1259 |
+
ctx.shape = shape
|
| 1260 |
+
ctx.group_size = module.group_size
|
| 1261 |
+
ctx.module = module
|
| 1262 |
+
comp_name, _ = _COMPONENT_CONTEXT.get()
|
| 1263 |
+
ctx.comp_name = comp_name
|
| 1264 |
+
corr = module.corr_accum.contiguous()
|
| 1265 |
+
step = module.step_counter.contiguous()
|
| 1266 |
+
out = _triton_ternary_forward(x_2d, packed, e, corr, step, n_out, k_in, module.group_size)
|
| 1267 |
+
return out.reshape(*x.shape[:-1], n_out)
|
| 1268 |
+
|
| 1269 |
+
@staticmethod
|
| 1270 |
+
def backward(ctx, grad_output):
|
| 1271 |
+
x_2d, packed, e = ctx.saved_tensors
|
| 1272 |
+
n_out, k_in = ctx.shape
|
| 1273 |
+
grad_2d = grad_output.reshape(-1, n_out).contiguous()
|
| 1274 |
+
corr = ctx.module.corr_accum.contiguous()
|
| 1275 |
+
step = torch.tensor([ctx.step_snapshot], device=e.device, dtype=torch.int64)
|
| 1276 |
+
grad_x = _triton_ternary_grad_x(
|
| 1277 |
+
grad_2d, packed, e, corr, step, x_2d.shape[0], n_out, k_in, ctx.group_size
|
| 1278 |
+
)
|
| 1279 |
+
with torch.no_grad():
|
| 1280 |
+
if getattr(ctx.module, "_stream_backward_updates", True):
|
| 1281 |
+
_, bwd_weight = _COMPONENT_CONTEXT.get()
|
| 1282 |
+
corr_step = max(1, int(round(abs(float(bwd_weight)))))
|
| 1283 |
+
if bwd_weight < 0:
|
| 1284 |
+
corr_step = -corr_step
|
| 1285 |
+
_triton_accumulate_corr_direct(
|
| 1286 |
+
packed, grad_2d, x_2d, ctx.module.corr_accum,
|
| 1287 |
+
n_out, k_in, ctx.group_size, corr_step=corr_step,
|
| 1288 |
+
)
|
| 1289 |
+
ctx.module.step_counter.add_(abs(corr_step))
|
| 1290 |
+
ctx.module._streamed_bigint_backward = True
|
| 1291 |
+
else:
|
| 1292 |
+
grad_sign = _triton_ternary_grad_sign(grad_2d, x_2d, n_out, k_in)
|
| 1293 |
+
comp_name = ctx.comp_name
|
| 1294 |
+
if comp_name is not None:
|
| 1295 |
+
setattr(ctx.module, f"_hook_grad_T_sign_{comp_name}", grad_sign.detach())
|
| 1296 |
+
else:
|
| 1297 |
+
ctx.module._hook_grad_T_sign = grad_sign.detach()
|
| 1298 |
+
return grad_x.reshape(*ctx.x_shape), None
|
| 1299 |
+
|
| 1300 |
+
|
| 1301 |
+
class _BigIntTernaryLinearFn(torch.autograd.Function):
|
| 1302 |
+
@staticmethod
|
| 1303 |
+
def forward(ctx, x, module):
|
| 1304 |
+
shape = tuple(module._T_shape.tolist())
|
| 1305 |
+
n_out, k_in = shape
|
| 1306 |
+
x_2d = x.reshape(-1, k_in).contiguous()
|
| 1307 |
+
ctx.module = module
|
| 1308 |
+
ctx.x_shape = x.shape
|
| 1309 |
+
ctx.shape = shape
|
| 1310 |
+
ctx.x_dtype = x.dtype
|
| 1311 |
+
ctx.save_for_backward(x_2d)
|
| 1312 |
+
with torch.no_grad():
|
| 1313 |
+
w_eff = module.dequantize().to(device=x.device, dtype=torch.float32)
|
| 1314 |
+
out = F.linear(x_2d.float(), w_eff, module.bias.float() if module.bias is not None else None)
|
| 1315 |
+
return out.reshape(*x.shape[:-1], n_out)
|
| 1316 |
+
|
| 1317 |
+
@staticmethod
|
| 1318 |
+
def backward(ctx, grad_output):
|
| 1319 |
+
(x_2d,) = ctx.saved_tensors
|
| 1320 |
+
module = ctx.module
|
| 1321 |
+
n_out, k_in = ctx.shape
|
| 1322 |
+
grad_2d = grad_output.reshape(-1, n_out).contiguous()
|
| 1323 |
+
with torch.no_grad():
|
| 1324 |
+
w_eff = module.dequantize().to(device=grad_2d.device, dtype=torch.float32)
|
| 1325 |
+
grad_x = grad_2d.float() @ w_eff
|
| 1326 |
+
grad_sign = (grad_2d.float().transpose(0, 1) @ x_2d.float()).sign().to(torch.int8)
|
| 1327 |
+
module._accumulate_corr_from_grad_sign(grad_sign)
|
| 1328 |
+
module._streamed_bigint_backward = True
|
| 1329 |
+
return grad_x.reshape(*ctx.x_shape).to(dtype=ctx.x_dtype), None
|
| 1330 |
+
|
| 1331 |
+
|
| 1332 |
+
"""
|
| 1333 |
+
Log-Space Group Scale Representation
|
| 1334 |
+
|
| 1335 |
+
Convention (matching agents' Option B recommendation):
|
| 1336 |
+
S = 2^E where S = scale, E = int8 log-space exponent
|
| 1337 |
+
W_eff = T * 2^E
|
| 1338 |
+
|
| 1339 |
+
Key log-space properties exploited:
|
| 1340 |
+
Multiplication β addition: S1 * S2 = 2^(E1 + E2)
|
| 1341 |
+
Division β subtraction: S1 / S2 = 2^(E1 - E2)
|
| 1342 |
+
Dequant β integer shift: 2^E * T = T << E (for E >= 0)
|
| 1343 |
+
|
| 1344 |
+
No IEEE floats in persistent state. E is stored as int8.
|
| 1345 |
+
Ephemeral float only exists in autograd's computation graph.
|
| 1346 |
+
"""
|
| 1347 |
+
|
| 1348 |
+
class TScaleType(IntEnum):
|
| 1349 |
+
T4 = 4
|
| 1350 |
+
T6 = 6
|
| 1351 |
+
T8 = 8
|
| 1352 |
+
T16 = 16
|
| 1353 |
+
T32 = 32
|
| 1354 |
+
T64 = 64
|
| 1355 |
+
T96 = 96
|
| 1356 |
+
|
| 1357 |
+
GROUP_SIZES = {
|
| 1358 |
+
TScaleType.T4: 4,
|
| 1359 |
+
TScaleType.T6: 6,
|
| 1360 |
+
TScaleType.T8: 8,
|
| 1361 |
+
TScaleType.T16: 16,
|
| 1362 |
+
TScaleType.T32: 32,
|
| 1363 |
+
TScaleType.T64: 64,
|
| 1364 |
+
TScaleType.T96: 96,
|
| 1365 |
+
}
|
| 1366 |
+
TILE_SIZE = 384
|
| 1367 |
+
|
| 1368 |
+
def _n_groups(shape, group_size):
|
| 1369 |
+
out_dim, in_dim = shape
|
| 1370 |
+
return out_dim * ceil(in_dim / group_size)
|
| 1371 |
+
|
| 1372 |
+
def _expand_E(E, shape, group_size):
|
| 1373 |
+
out_dim, in_dim = shape
|
| 1374 |
+
gpr = ceil(in_dim / group_size)
|
| 1375 |
+
E_2d = E.view(out_dim, gpr)
|
| 1376 |
+
E_exp = E_2d.repeat_interleave(group_size, dim=1)
|
| 1377 |
+
if E_exp.shape[1] > in_dim:
|
| 1378 |
+
E_exp = E_exp[:, :in_dim]
|
| 1379 |
+
return E_exp
|
| 1380 |
+
|
| 1381 |
+
def _ternarize(x, threshold=0.05):
|
| 1382 |
+
return x.sign() * (x.abs() > threshold).to(x.dtype)
|
| 1383 |
+
|
| 1384 |
+
|
| 1385 |
+
def _scaled_init_threshold(threshold: float, init_std: float) -> float:
|
| 1386 |
+
if init_std <= 0:
|
| 1387 |
+
return threshold
|
| 1388 |
+
return min(float(threshold), 0.5 * float(init_std))
|
| 1389 |
+
|
| 1390 |
+
class TernaryScaleTensor(nn.Module):
|
| 1391 |
+
def __init__(
|
| 1392 |
+
self,
|
| 1393 |
+
in_dim: int,
|
| 1394 |
+
out_dim: int,
|
| 1395 |
+
threshold: float = 0.05,
|
| 1396 |
+
weight_init_std: float | None = None,
|
| 1397 |
+
tscale_type: TScaleType = TScaleType.T32,
|
| 1398 |
+
bias: bool = False,
|
| 1399 |
+
):
|
| 1400 |
+
super().__init__()
|
| 1401 |
+
self.in_dim = in_dim
|
| 1402 |
+
self.out_dim = out_dim
|
| 1403 |
+
init_std = min(0.1, in_dim ** -0.5) if weight_init_std is None else float(weight_init_std)
|
| 1404 |
+
init_threshold = _scaled_init_threshold(threshold, init_std)
|
| 1405 |
+
self.threshold = init_threshold
|
| 1406 |
+
self.tscale_type = tscale_type
|
| 1407 |
+
self.group_size = GROUP_SIZES[tscale_type]
|
| 1408 |
+
shape = (out_dim, in_dim)
|
| 1409 |
+
n_grp = _n_groups(shape, self.group_size)
|
| 1410 |
+
|
| 1411 |
+
w_init = torch.randn(out_dim, in_dim) * init_std
|
| 1412 |
+
T_init = _ternarize(w_init, init_threshold)
|
| 1413 |
+
packed_T, T_shape, T_pad = pack_ternary(T_init)
|
| 1414 |
+
|
| 1415 |
+
self.register_buffer("T_packed", packed_T)
|
| 1416 |
+
self.register_buffer("_T_shape", torch.tensor([out_dim, in_dim], dtype=torch.long))
|
| 1417 |
+
self.register_buffer("_T_pad", torch.tensor(T_pad, dtype=torch.long))
|
| 1418 |
+
|
| 1419 |
+
gpr = ceil(in_dim / self.group_size)
|
| 1420 |
+
total_in = gpr * self.group_size
|
| 1421 |
+
padded = torch.zeros(out_dim, total_in)
|
| 1422 |
+
abs_w = w_init.abs()
|
| 1423 |
+
padded[:, :in_dim] = abs_w
|
| 1424 |
+
grouped = padded.view(out_dim, gpr, self.group_size)
|
| 1425 |
+
grp_means = grouped.mean(dim=2)
|
| 1426 |
+
E_vals = torch.where(grp_means > 0, grp_means, torch.ones_like(grp_means))
|
| 1427 |
+
E_int = E_vals.log2().clamp(-128, 127).to(torch.int8)
|
| 1428 |
+
self.register_buffer("E", E_int.flatten())
|
| 1429 |
+
self.register_buffer("corr_accum", torch.zeros_like(self.E, dtype=torch.int64))
|
| 1430 |
+
self.register_buffer("step_counter", torch.zeros(1, dtype=torch.int64))
|
| 1431 |
+
|
| 1432 |
+
if bias:
|
| 1433 |
+
self.register_buffer("bias", torch.zeros(out_dim, dtype=torch.int32))
|
| 1434 |
+
else:
|
| 1435 |
+
self.bias = None
|
| 1436 |
+
|
| 1437 |
+
def _get_T(self):
|
| 1438 |
+
return unpack_ternary(self.T_packed, tuple(self._T_shape.tolist()), int(self._T_pad.item()))
|
| 1439 |
+
|
| 1440 |
+
def _get_S(self):
|
| 1441 |
+
gpr = ceil(self.in_dim / self.group_size)
|
| 1442 |
+
e_adj = self.E.float()
|
| 1443 |
+
if hasattr(self, "corr_accum") and hasattr(self, "step_counter"):
|
| 1444 |
+
step = int(self.step_counter.item())
|
| 1445 |
+
if step > 0:
|
| 1446 |
+
denom = max(step * self.group_size, 1)
|
| 1447 |
+
e_adj = e_adj + (self.corr_accum.float() / denom) * _bigint_corr_strength()
|
| 1448 |
+
E_exp = _expand_E(e_adj, (self.out_dim, self.in_dim), self.group_size)
|
| 1449 |
+
return torch.exp2(E_exp)
|
| 1450 |
+
|
| 1451 |
+
def _ensure_group_lr(self):
|
| 1452 |
+
if not hasattr(self, "group_lr"):
|
| 1453 |
+
self.register_buffer("group_lr", torch.ones_like(self.E, dtype=torch.int8))
|
| 1454 |
+
elif self.group_lr.shape != self.E.shape or self.group_lr.device != self.E.device:
|
| 1455 |
+
self.group_lr = torch.ones_like(self.E, dtype=torch.int8)
|
| 1456 |
+
return self.group_lr
|
| 1457 |
+
|
| 1458 |
+
def precompile_kernels(self, M: int):
|
| 1459 |
+
pass
|
| 1460 |
+
|
| 1461 |
+
def forward(self, x):
|
| 1462 |
+
backend = _backend_preference()
|
| 1463 |
+
if backend == "tilelang" and _HAS_TILELANG:
|
| 1464 |
+
if torch.is_grad_enabled() and not _tilelang_training_enabled():
|
| 1465 |
+
raise RuntimeError(
|
| 1466 |
+
"ARB_TERNARY_BACKEND='tilelang' is inference-only by default. "
|
| 1467 |
+
"BigInt ternary training should use ARB_TERNARY_BACKEND='triton'. "
|
| 1468 |
+
"Set ARB_TILELANG_TRAINING=1 only for experimental TileLang training."
|
| 1469 |
+
)
|
| 1470 |
+
x_for_grad = x
|
| 1471 |
+
if torch.is_grad_enabled() and not x.requires_grad:
|
| 1472 |
+
x_for_grad = x.detach().requires_grad_(True)
|
| 1473 |
+
N, K = tuple(self._T_shape.tolist())
|
| 1474 |
+
x_2d = x_for_grad.reshape(-1, K)
|
| 1475 |
+
M = x_2d.shape[0]
|
| 1476 |
+
try:
|
| 1477 |
+
fwd_kernel = _get_kernel(M, N, K, self.group_size, "fwd")
|
| 1478 |
+
y = _TernaryLinearFn.apply(x_for_grad, self, fwd_kernel)
|
| 1479 |
+
if self.bias is not None:
|
| 1480 |
+
y = y + self.bias.float()
|
| 1481 |
+
return y
|
| 1482 |
+
except Exception as e:
|
| 1483 |
+
warnings.warn(f"TileLang forward failed for {self._T_shape.tolist()}: {e}")
|
| 1484 |
+
if _HAS_TRITON:
|
| 1485 |
+
backend = "triton"
|
| 1486 |
+
else:
|
| 1487 |
+
backend = "torch"
|
| 1488 |
+
if x.is_cuda and _HAS_TRITON and backend in {"auto", "triton"}:
|
| 1489 |
+
x_for_grad = x
|
| 1490 |
+
if torch.is_grad_enabled() and not x.requires_grad:
|
| 1491 |
+
x_for_grad = x.detach().requires_grad_(True)
|
| 1492 |
+
y = _TritonTernaryLinearFn.apply(x_for_grad, self)
|
| 1493 |
+
if self.bias is not None:
|
| 1494 |
+
y = y + self.bias.float()
|
| 1495 |
+
return y
|
| 1496 |
+
if backend == "triton":
|
| 1497 |
+
raise RuntimeError("ARB_TERNARY_BACKEND='triton' requested, but Triton is unavailable for this input.")
|
| 1498 |
+
x_for_grad = x
|
| 1499 |
+
if torch.is_grad_enabled() and not x.requires_grad:
|
| 1500 |
+
x_for_grad = x.detach().requires_grad_(True)
|
| 1501 |
+
return _BigIntTernaryLinearFn.apply(x_for_grad, self)
|
| 1502 |
+
|
| 1503 |
+
@torch.no_grad()
|
| 1504 |
+
def _accumulate_corr_from_grad_sign(self, grad_sign, corr_step=1):
|
| 1505 |
+
shape = tuple(self._T_shape.tolist())
|
| 1506 |
+
out_dim, in_dim = shape
|
| 1507 |
+
if tuple(grad_sign.shape) != shape:
|
| 1508 |
+
return
|
| 1509 |
+
T = self._get_T().to(device=grad_sign.device, dtype=torch.int16)
|
| 1510 |
+
signed = grad_sign.to(torch.int16) * T
|
| 1511 |
+
gpr = ceil(in_dim / self.group_size)
|
| 1512 |
+
total_in = gpr * self.group_size
|
| 1513 |
+
if total_in > in_dim:
|
| 1514 |
+
signed = F.pad(signed, (0, total_in - in_dim))
|
| 1515 |
+
score = signed.view(out_dim, gpr, self.group_size).sum(dim=2, dtype=torch.int16)
|
| 1516 |
+
self.corr_accum -= score.flatten().to(device=self.corr_accum.device, dtype=torch.int64) * int(corr_step)
|
| 1517 |
+
self.step_counter += abs(int(corr_step))
|
| 1518 |
+
|
| 1519 |
+
def ternary_step(self, lr=1, accum_threshold=None):
|
| 1520 |
+
self._had_flip = False
|
| 1521 |
+
if hasattr(self, "_hook_grad_T_sign"):
|
| 1522 |
+
self._accumulate_corr_from_grad_sign(self._hook_grad_T_sign)
|
| 1523 |
+
del self._hook_grad_T_sign
|
| 1524 |
+
|
| 1525 |
+
def update_E(self, lr=1, loss_signal=None):
|
| 1526 |
+
has_dense_grad = hasattr(self, "_hook_grad_T_sign")
|
| 1527 |
+
has_direct_grad = hasattr(self, "_hook_grad_2d") and hasattr(self, "_hook_x_2d")
|
| 1528 |
+
if not has_dense_grad and not has_direct_grad:
|
| 1529 |
+
return
|
| 1530 |
+
if has_dense_grad:
|
| 1531 |
+
self._accumulate_corr_from_grad_sign(self._hook_grad_T_sign)
|
| 1532 |
+
del self._hook_grad_T_sign
|
| 1533 |
+
else:
|
| 1534 |
+
grad = self._hook_grad_2d.to(device=self.E.device, dtype=torch.float32)
|
| 1535 |
+
x = self._hook_x_2d.to(device=self.E.device, dtype=torch.float32)
|
| 1536 |
+
grad_sign = (grad.transpose(0, 1) @ x).sign().to(torch.int8)
|
| 1537 |
+
self._accumulate_corr_from_grad_sign(grad_sign)
|
| 1538 |
+
del self._hook_grad_2d
|
| 1539 |
+
del self._hook_x_2d
|
| 1540 |
+
if hasattr(self, "_hook_T"):
|
| 1541 |
+
del self._hook_T
|
| 1542 |
+
|
| 1543 |
+
@property
|
| 1544 |
+
def effective_bpw(self) -> float:
|
| 1545 |
+
group_size = self.group_size
|
| 1546 |
+
total = self._T_shape[0].item() * self._T_shape[1].item()
|
| 1547 |
+
n_grp = _n_groups(tuple(self._T_shape.tolist()), group_size)
|
| 1548 |
+
sign_bits = total * (8 / 5)
|
| 1549 |
+
scale_bits = n_grp * 8.0
|
| 1550 |
+
corr_bits = n_grp * 64.0
|
| 1551 |
+
bias_bits = self.bias.numel() * 32.0 if self.bias is not None else 0.0
|
| 1552 |
+
return (sign_bits + scale_bits + corr_bits + bias_bits) / total
|
| 1553 |
+
|
| 1554 |
+
def dequantize(self) -> torch.Tensor:
|
| 1555 |
+
T = self._get_T().float()
|
| 1556 |
+
S = self._get_S()
|
| 1557 |
+
return S * T
|
| 1558 |
+
|
| 1559 |
+
def tscale_to(self, tscale_type: TScaleType):
|
| 1560 |
+
self.tscale_type = tscale_type
|
| 1561 |
+
old_group_size = self.group_size
|
| 1562 |
+
self.group_size = GROUP_SIZES[tscale_type]
|
| 1563 |
+
shape = tuple(self._T_shape.tolist())
|
| 1564 |
+
out_dim, in_dim = shape
|
| 1565 |
+
new_gpr = ceil(in_dim / self.group_size)
|
| 1566 |
+
new_n_grp = out_dim * new_gpr
|
| 1567 |
+
if self.E.shape[0] != new_n_grp:
|
| 1568 |
+
T = self._get_T().float()
|
| 1569 |
+
total_in = new_gpr * self.group_size
|
| 1570 |
+
padded = torch.zeros(out_dim, total_in, device=self.T_packed.device)
|
| 1571 |
+
abs_w = T.abs()
|
| 1572 |
+
padded[:, :in_dim] = abs_w
|
| 1573 |
+
grouped = padded.view(out_dim, new_gpr, self.group_size)
|
| 1574 |
+
grp_means = grouped.mean(dim=2)
|
| 1575 |
+
E_new = torch.where(grp_means > 0, grp_means, torch.ones_like(grp_means))
|
| 1576 |
+
E_int = E_new.log2().clamp(-128, 127).to(torch.int8)
|
| 1577 |
+
self.E = E_int.flatten()
|
| 1578 |
+
self.corr_accum = torch.zeros_like(self.E, dtype=torch.int64)
|
| 1579 |
+
self.step_counter = torch.zeros(1, dtype=torch.int64, device=self.E.device)
|
| 1580 |
+
return self
|
| 1581 |
+
|
| 1582 |
+
tscale_cast = tscale_to
|
| 1583 |
+
|
| 1584 |
+
def extra_repr(self) -> str:
|
| 1585 |
+
return (
|
| 1586 |
+
f"in_dim={self.in_dim}, out_dim={self.out_dim}, "
|
| 1587 |
+
f"tscale_type={self.tscale_type.name}, group_size={self.group_size}, "
|
| 1588 |
+
f"effective_bpw={self.effective_bpw:.2f}"
|
| 1589 |
+
)
|
| 1590 |
+
|
| 1591 |
+
if _HAS_TRITON:
|
| 1592 |
+
|
| 1593 |
+
@triton.jit
|
| 1594 |
+
def _triton_rmsnorm_fwd_kernel(
|
| 1595 |
+
x_ptr, packed_ptr, e_ptr, out_ptr,
|
| 1596 |
+
BATCH: tl.constexpr, DIM: tl.constexpr,
|
| 1597 |
+
GPR: tl.constexpr, GROUP_SIZE: tl.constexpr,
|
| 1598 |
+
BLOCK_B: tl.constexpr, BLOCK_D: tl.constexpr,
|
| 1599 |
+
):
|
| 1600 |
+
pid_b = tl.program_id(0)
|
| 1601 |
+
offs_b = pid_b * BLOCK_B + tl.arange(0, BLOCK_B)
|
| 1602 |
+
offs_d = tl.arange(0, BLOCK_D)
|
| 1603 |
+
|
| 1604 |
+
x = tl.load(
|
| 1605 |
+
x_ptr + offs_b[:, None] * DIM + offs_d[None, :],
|
| 1606 |
+
mask=(offs_b[:, None] < BATCH) & (offs_d[None, :] < DIM),
|
| 1607 |
+
other=0.0,
|
| 1608 |
+
)
|
| 1609 |
+
sq = x * x
|
| 1610 |
+
msq = tl.sum(sq, axis=1, keep_dims=True) / DIM
|
| 1611 |
+
rms = tl.sqrt(msq + 1e-5)
|
| 1612 |
+
x_norm = x / rms
|
| 1613 |
+
|
| 1614 |
+
pack_idx = offs_d // 5
|
| 1615 |
+
trit_pos = offs_d - pack_idx * 5
|
| 1616 |
+
packed = tl.load(packed_ptr + pack_idx, mask=offs_d < DIM, other=0).to(tl.int32)
|
| 1617 |
+
divisor = tl.where(
|
| 1618 |
+
trit_pos == 0, 1,
|
| 1619 |
+
tl.where(trit_pos == 1, 3,
|
| 1620 |
+
tl.where(trit_pos == 2, 9,
|
| 1621 |
+
tl.where(trit_pos == 3, 27, 81))),
|
| 1622 |
+
)
|
| 1623 |
+
trit = (packed // divisor) % 3
|
| 1624 |
+
sign = trit.to(tl.int32) - 1
|
| 1625 |
+
|
| 1626 |
+
e_idx = offs_d // GROUP_SIZE
|
| 1627 |
+
e_val = tl.load(e_ptr + e_idx, mask=offs_d < DIM, other=0).to(tl.float32)
|
| 1628 |
+
w = sign.to(tl.float32) * tl.exp2(e_val)
|
| 1629 |
+
w = tl.where(offs_d < DIM, w, 0.0)
|
| 1630 |
+
|
| 1631 |
+
out = x_norm * w[None, :]
|
| 1632 |
+
tl.store(
|
| 1633 |
+
out_ptr + offs_b[:, None] * DIM + offs_d[None, :],
|
| 1634 |
+
out,
|
| 1635 |
+
mask=(offs_b[:, None] < BATCH) & (offs_d[None, :] < DIM),
|
| 1636 |
+
)
|
| 1637 |
+
|
| 1638 |
+
|
| 1639 |
+
@triton.jit
|
| 1640 |
+
def _triton_rmsnorm_bwd_kernel(
|
| 1641 |
+
grad_out_ptr, x_ptr, packed_ptr, e_ptr,
|
| 1642 |
+
grad_x_ptr,
|
| 1643 |
+
BATCH: tl.constexpr, DIM: tl.constexpr,
|
| 1644 |
+
GPR: tl.constexpr, GROUP_SIZE: tl.constexpr,
|
| 1645 |
+
BLOCK_B: tl.constexpr, BLOCK_D: tl.constexpr,
|
| 1646 |
+
):
|
| 1647 |
+
pid_b = tl.program_id(0)
|
| 1648 |
+
offs_b = pid_b * BLOCK_B + tl.arange(0, BLOCK_B)
|
| 1649 |
+
offs_d = tl.arange(0, BLOCK_D)
|
| 1650 |
+
|
| 1651 |
+
x = tl.load(
|
| 1652 |
+
x_ptr + offs_b[:, None] * DIM + offs_d[None, :],
|
| 1653 |
+
mask=(offs_b[:, None] < BATCH) & (offs_d[None, :] < DIM),
|
| 1654 |
+
other=0.0,
|
| 1655 |
+
)
|
| 1656 |
+
sq = x * x
|
| 1657 |
+
msq = tl.sum(sq, axis=1, keep_dims=True) / DIM
|
| 1658 |
+
rms = tl.sqrt(msq + 1e-5)
|
| 1659 |
+
x_norm = x / rms
|
| 1660 |
+
|
| 1661 |
+
pack_idx = offs_d // 5
|
| 1662 |
+
trit_pos = offs_d - pack_idx * 5
|
| 1663 |
+
packed = tl.load(packed_ptr + pack_idx, mask=offs_d < DIM, other=0).to(tl.int32)
|
| 1664 |
+
divisor = tl.where(
|
| 1665 |
+
trit_pos == 0, 1,
|
| 1666 |
+
tl.where(trit_pos == 1, 3,
|
| 1667 |
+
tl.where(trit_pos == 2, 9,
|
| 1668 |
+
tl.where(trit_pos == 3, 27, 81))),
|
| 1669 |
+
)
|
| 1670 |
+
trit = (packed // divisor) % 3
|
| 1671 |
+
sign = trit.to(tl.int32) - 1
|
| 1672 |
+
|
| 1673 |
+
e_idx = offs_d // GROUP_SIZE
|
| 1674 |
+
e_val = tl.load(e_ptr + e_idx, mask=offs_d < DIM, other=0).to(tl.float32)
|
| 1675 |
+
w = sign.to(tl.float32) * tl.exp2(e_val)
|
| 1676 |
+
w = tl.where(offs_d < DIM, w, 0.0)
|
| 1677 |
+
|
| 1678 |
+
dy = tl.load(
|
| 1679 |
+
grad_out_ptr + offs_b[:, None] * DIM + offs_d[None, :],
|
| 1680 |
+
mask=(offs_b[:, None] < BATCH) & (offs_d[None, :] < DIM),
|
| 1681 |
+
other=0.0,
|
| 1682 |
+
)
|
| 1683 |
+
dyw = dy * w[None, :]
|
| 1684 |
+
|
| 1685 |
+
c1 = tl.sum(x_norm * dyw, axis=1, keep_dims=True) / DIM
|
| 1686 |
+
dx = (dyw - x_norm * c1) / rms
|
| 1687 |
+
|
| 1688 |
+
tl.store(
|
| 1689 |
+
grad_x_ptr + offs_b[:, None] * DIM + offs_d[None, :],
|
| 1690 |
+
dx,
|
| 1691 |
+
mask=(offs_b[:, None] < BATCH) & (offs_d[None, :] < DIM),
|
| 1692 |
+
)
|
| 1693 |
+
|
| 1694 |
+
|
| 1695 |
+
class _TritonRMSNormFn(torch.autograd.Function):
|
| 1696 |
+
@staticmethod
|
| 1697 |
+
def forward(ctx, x, module, packed, e, dim, group_size):
|
| 1698 |
+
ctx.module = module
|
| 1699 |
+
x_2d = x.reshape(-1, dim).contiguous()
|
| 1700 |
+
batch = x_2d.shape[0]
|
| 1701 |
+
out = torch.empty_like(x_2d)
|
| 1702 |
+
block_b = 16
|
| 1703 |
+
grid = (triton.cdiv(batch, block_b),)
|
| 1704 |
+
_triton_rmsnorm_fwd_kernel[grid](
|
| 1705 |
+
x_2d, packed, e, out,
|
| 1706 |
+
batch, dim, ceil(dim / group_size), group_size,
|
| 1707 |
+
BLOCK_B=block_b, BLOCK_D=triton.next_power_of_2(dim),
|
| 1708 |
+
)
|
| 1709 |
+
ctx.save_for_backward(x_2d, packed, e)
|
| 1710 |
+
ctx.dim = dim
|
| 1711 |
+
ctx.group_size = group_size
|
| 1712 |
+
comp_name, _ = _COMPONENT_CONTEXT.get()
|
| 1713 |
+
ctx.comp_name = comp_name
|
| 1714 |
+
return out.reshape(*x.shape)
|
| 1715 |
+
|
| 1716 |
+
@staticmethod
|
| 1717 |
+
def backward(ctx, grad_output):
|
| 1718 |
+
x_2d, packed, e = ctx.saved_tensors
|
| 1719 |
+
dim = ctx.dim
|
| 1720 |
+
group_size = ctx.group_size
|
| 1721 |
+
grad_2d = grad_output.reshape(-1, dim).contiguous()
|
| 1722 |
+
batch = grad_2d.shape[0]
|
| 1723 |
+
grad_x = torch.empty_like(x_2d)
|
| 1724 |
+
block_b = 16
|
| 1725 |
+
grid = (triton.cdiv(batch, block_b),)
|
| 1726 |
+
_triton_rmsnorm_bwd_kernel[grid](
|
| 1727 |
+
grad_2d, x_2d, packed, e, grad_x,
|
| 1728 |
+
batch, dim, ceil(dim / group_size), group_size,
|
| 1729 |
+
BLOCK_B=block_b, BLOCK_D=triton.next_power_of_2(dim),
|
| 1730 |
+
)
|
| 1731 |
+
return grad_x.reshape(*grad_output.shape), None, None, None, None, None
|
| 1732 |
+
|
| 1733 |
+
|
| 1734 |
+
class TernaryRMSNorm(nn.Module):
|
| 1735 |
+
def __init__(self, dim, eps=1e-5, threshold=0.05, tscale_type=TScaleType.T64):
|
| 1736 |
+
super().__init__()
|
| 1737 |
+
self.dim = dim
|
| 1738 |
+
self.eps = eps
|
| 1739 |
+
self.threshold = threshold
|
| 1740 |
+
self.tscale_type = tscale_type
|
| 1741 |
+
self.group_size = GROUP_SIZES[tscale_type]
|
| 1742 |
+
shape = (1, dim)
|
| 1743 |
+
n_grp = _n_groups(shape, self.group_size)
|
| 1744 |
+
|
| 1745 |
+
w_init = torch.ones(1, dim)
|
| 1746 |
+
T_init = _ternarize(w_init, threshold)
|
| 1747 |
+
packed_T, T_shape, T_pad = pack_ternary(T_init)
|
| 1748 |
+
|
| 1749 |
+
self.register_buffer("T_packed", packed_T)
|
| 1750 |
+
self.register_buffer("_T_shape", torch.tensor([1, dim], dtype=torch.long))
|
| 1751 |
+
self.register_buffer("_T_pad", torch.tensor(T_pad, dtype=torch.long))
|
| 1752 |
+
|
| 1753 |
+
gpr = ceil(dim / self.group_size)
|
| 1754 |
+
total_in = gpr * self.group_size
|
| 1755 |
+
padded = torch.zeros(1, total_in)
|
| 1756 |
+
abs_w = w_init.abs()
|
| 1757 |
+
padded[:, :dim] = abs_w
|
| 1758 |
+
grouped = padded.view(1, gpr, self.group_size)
|
| 1759 |
+
grp_means = grouped.mean(dim=2)
|
| 1760 |
+
E_vals = torch.where(grp_means > 0, grp_means, torch.ones_like(grp_means))
|
| 1761 |
+
self.register_buffer("E", E_vals.flatten().log2().clamp(-128, 127).to(torch.int8))
|
| 1762 |
+
self.register_buffer("E_accum", torch.zeros_like(self.E, dtype=torch.int8))
|
| 1763 |
+
self.register_buffer("group_lr", torch.ones_like(self.E, dtype=torch.int8))
|
| 1764 |
+
|
| 1765 |
+
self.register_buffer("T_accum", torch.zeros(1, dim, dtype=torch.int8))
|
| 1766 |
+
|
| 1767 |
+
def _ensure_E_accum(self):
|
| 1768 |
+
if not hasattr(self, "E_accum"):
|
| 1769 |
+
self.register_buffer("E_accum", torch.zeros_like(self.E, dtype=torch.int8))
|
| 1770 |
+
elif self.E_accum.shape != self.E.shape or self.E_accum.device != self.E.device:
|
| 1771 |
+
self.E_accum = torch.zeros_like(self.E, dtype=torch.int8)
|
| 1772 |
+
return self.E_accum
|
| 1773 |
+
|
| 1774 |
+
def _ensure_group_lr(self):
|
| 1775 |
+
if not hasattr(self, "group_lr"):
|
| 1776 |
+
self.register_buffer("group_lr", torch.ones_like(self.E, dtype=torch.int8))
|
| 1777 |
+
elif self.group_lr.shape != self.E.shape or self.group_lr.device != self.E.device:
|
| 1778 |
+
self.group_lr = torch.ones_like(self.E, dtype=torch.int8)
|
| 1779 |
+
return self.group_lr
|
| 1780 |
+
|
| 1781 |
+
def _get_T(self):
|
| 1782 |
+
return unpack_ternary(self.T_packed, tuple(self._T_shape.tolist()), int(self._T_pad.item())).squeeze(0)
|
| 1783 |
+
|
| 1784 |
+
def forward(self, x):
|
| 1785 |
+
if x.is_cuda and _HAS_TRITON and self.dim <= _rmsnorm_triton_max_dim():
|
| 1786 |
+
return _TritonRMSNormFn.apply(
|
| 1787 |
+
x, self, self.T_packed.contiguous(), self.E.contiguous(),
|
| 1788 |
+
self.dim, self.group_size,
|
| 1789 |
+
)
|
| 1790 |
+
|
| 1791 |
+
inv_rms = torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
|
| 1792 |
+
if x.is_cuda:
|
| 1793 |
+
# TernaryRMSNorm is initialized as an identity scale and does not
|
| 1794 |
+
# train E/T. Avoid unpacking a full large-dim weight or launching
|
| 1795 |
+
# the high-register Triton backward kernel on 8GB GPUs.
|
| 1796 |
+
return x * inv_rms
|
| 1797 |
+
|
| 1798 |
+
T = self._get_T()
|
| 1799 |
+
E_exp = _expand_E(self.E, tuple(self._T_shape.tolist()), self.group_size).squeeze(0)
|
| 1800 |
+
S = torch.exp2(E_exp.float())
|
| 1801 |
+
weight = S * T.float()
|
| 1802 |
+
return weight * (x * inv_rms)
|
| 1803 |
+
|
| 1804 |
+
def ternary_step(self, lr=1, accum_threshold=3):
|
| 1805 |
+
pass
|
| 1806 |
+
|
| 1807 |
+
def update_E(self, lr=1, loss_signal=None):
|
| 1808 |
+
pass
|
| 1809 |
+
|
| 1810 |
+
def extra_repr(self):
|
| 1811 |
+
return f"dim={self.dim}, tscale_type={self.tscale_type.name}"
|
arbitor/kernel/triton_video.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Triton kernels for video denoising (used by VideoHead)."""
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from math import ceil as _ceil
|
| 5 |
+
|
| 6 |
+
from .ternary_scale import _HAS_TRITON
|
| 7 |
+
|
| 8 |
+
if _HAS_TRITON:
|
| 9 |
+
import triton
|
| 10 |
+
import triton.language as tl
|
| 11 |
+
|
| 12 |
+
@triton.jit
|
| 13 |
+
def _triton_video_denoise_fwd_kernel(
|
| 14 |
+
latent, pred_noise, out,
|
| 15 |
+
TOTAL: tl.constexpr, ALPHA: tl.constexpr, BLOCK: tl.constexpr,
|
| 16 |
+
):
|
| 17 |
+
offsets = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
|
| 18 |
+
mask = offsets < TOTAL
|
| 19 |
+
l = tl.load(latent + offsets, mask=mask, other=0.0)
|
| 20 |
+
p = tl.load(pred_noise + offsets, mask=mask, other=0.0)
|
| 21 |
+
beta = 1.0 - ALPHA
|
| 22 |
+
inv_sqrt = 1.0 / tl.sqrt(ALPHA + 0.00000001)
|
| 23 |
+
tl.store(out + offsets, (l - beta * p) * inv_sqrt, mask=mask)
|
| 24 |
+
|
| 25 |
+
@triton.jit
|
| 26 |
+
def _triton_video_denoise_bwd_kernel(
|
| 27 |
+
grad_out, grad_latent, grad_pred,
|
| 28 |
+
TOTAL: tl.constexpr, ALPHA: tl.constexpr, BLOCK: tl.constexpr,
|
| 29 |
+
):
|
| 30 |
+
offsets = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
|
| 31 |
+
mask = offsets < TOTAL
|
| 32 |
+
g = tl.load(grad_out + offsets, mask=mask, other=0.0)
|
| 33 |
+
beta = 1.0 - ALPHA
|
| 34 |
+
inv_sqrt = 1.0 / tl.sqrt(ALPHA + 0.00000001)
|
| 35 |
+
tl.store(grad_latent + offsets, g * inv_sqrt, mask=mask)
|
| 36 |
+
tl.store(grad_pred + offsets, -beta * g * inv_sqrt, mask=mask)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class _TritonVideoDenoiseFn(torch.autograd.Function):
|
| 40 |
+
@staticmethod
|
| 41 |
+
def forward(ctx, latent, pred_noise, alpha):
|
| 42 |
+
latent_c = latent.contiguous()
|
| 43 |
+
pred_c = pred_noise.contiguous()
|
| 44 |
+
out = torch.empty_like(latent_c)
|
| 45 |
+
total = latent_c.numel()
|
| 46 |
+
block = 256
|
| 47 |
+
grid = (_ceil_div(total, block),)
|
| 48 |
+
alpha_f = float(alpha)
|
| 49 |
+
_triton_video_denoise_fwd_kernel[grid](
|
| 50 |
+
latent_c, pred_c, out,
|
| 51 |
+
total, alpha_f, BLOCK=block,
|
| 52 |
+
)
|
| 53 |
+
ctx.alpha = alpha_f
|
| 54 |
+
ctx.shape = latent.shape
|
| 55 |
+
return out.reshape_as(latent)
|
| 56 |
+
|
| 57 |
+
@staticmethod
|
| 58 |
+
def backward(ctx, grad_out):
|
| 59 |
+
grad_c = grad_out.contiguous()
|
| 60 |
+
grad_latent = torch.empty_like(grad_c)
|
| 61 |
+
grad_pred = torch.empty_like(grad_c)
|
| 62 |
+
total = grad_c.numel()
|
| 63 |
+
block = 256
|
| 64 |
+
grid = (_ceil_div(total, block),)
|
| 65 |
+
_triton_video_denoise_bwd_kernel[grid](
|
| 66 |
+
grad_c, grad_latent, grad_pred,
|
| 67 |
+
total, ctx.alpha, BLOCK=block,
|
| 68 |
+
)
|
| 69 |
+
return grad_latent.reshape(ctx.shape), grad_pred.reshape(ctx.shape), None
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def video_denoise_step(latent, pred_noise, alpha):
|
| 73 |
+
if _HAS_TRITON and latent.is_cuda and pred_noise.is_cuda and _TritonVideoDenoiseFn is not None:
|
| 74 |
+
return _TritonVideoDenoiseFn.apply(latent, pred_noise, alpha)
|
| 75 |
+
return (latent - (1 - alpha) * pred_noise) / (alpha ** 0.5 + 1e-8)
|
arbitor/main.py
ADDED
|
@@ -0,0 +1,585 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ARB β Any Relational Bit. Core model assembly."""
|
| 2 |
+
import warnings
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from math import ceil as _ceil
|
| 7 |
+
|
| 8 |
+
_ceil_div = lambda a, b: _ceil(a / b) if b > 0 else 0
|
| 9 |
+
|
| 10 |
+
from .config import VOCAB, HIDDEN_DIM, SPECIAL_VOCAB, CTX, THRESHOLD, CODEBOOK_DIM, CODEBOOK_SIZE, KV_LEDGER_SIZE, KQ_CACHE_SIZE, MEMGRAM_STRUCT_PRIMES, MEMGRAM_CONV_PRIMES, MEMGRAM_EMBED_DIM, MEMGRAM_KEY_DIM, KGVQ_CODEBOOK_SIZE, KGVQ_CODEBOOK_DIM, K_MAX_COMPOSITES, MG_TOP_K
|
| 11 |
+
from .kernel.ternary_scale import TScaleType, TernaryScaleTensor, TernaryRMSNorm, _HAS_TRITON
|
| 12 |
+
try:
|
| 13 |
+
from .kernel.ternary_scale import _triton_apply_accumulated_flips
|
| 14 |
+
except ImportError:
|
| 15 |
+
_triton_apply_accumulated_flips = None
|
| 16 |
+
from .converters.convert_to_ternary8 import pack_ternary
|
| 17 |
+
try:
|
| 18 |
+
from .kernel.ternary_scale import _TritonTernaryEmbedFn
|
| 19 |
+
except ImportError:
|
| 20 |
+
_TritonTernaryEmbedFn = None
|
| 21 |
+
from .sequencers import ByteEmbedding, MultimodalSequencer
|
| 22 |
+
from .vq import SharedVQ
|
| 23 |
+
from .components import (
|
| 24 |
+
ByteHead, OutputRouter,
|
| 25 |
+
MemGram, LossComponents, LossWeights,
|
| 26 |
+
CompositeProposalHead, MoEGraph,
|
| 27 |
+
)
|
| 28 |
+
from .decoders import VideoHead, TalkerHead
|
| 29 |
+
from .components import _BOUNDARY_TOKEN_MAP as _BOUNDARY_MAP
|
| 30 |
+
from .attention import KVLedger, KQCache, ContextAttentionScheduler
|
| 31 |
+
from .kernel.flash_vq import FlashVQCodebook
|
| 32 |
+
def _extract_boundary_from_input(x):
|
| 33 |
+
if x.dim() != 2:
|
| 34 |
+
return None
|
| 35 |
+
first_token = x[0, 0].item()
|
| 36 |
+
if first_token in _BOUNDARY_MAP:
|
| 37 |
+
return first_token
|
| 38 |
+
for tok in x[0].tolist():
|
| 39 |
+
if tok in _BOUNDARY_MAP:
|
| 40 |
+
return tok
|
| 41 |
+
return None
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class ARBModel(nn.Module):
|
| 45 |
+
def __init__(self, tscale_type=TScaleType.T32, threshold=THRESHOLD,
|
| 46 |
+
max_graph_hops=4, max_moe_iters=4, halt_threshold=0.99,
|
| 47 |
+
enable_image=False, enable_audio=False, enable_vq=True, enable_graph=True,
|
| 48 |
+
enable_memory_modules=False, enable_moe=True,
|
| 49 |
+
shared_vq_size=None, kgvq_codebook_size=None,
|
| 50 |
+
enable_attention=True, enable_output_router=True,
|
| 51 |
+
enable_video_output=True, enable_talker_output=True):
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.image_enabled = enable_image
|
| 54 |
+
self.audio_enabled = enable_audio
|
| 55 |
+
self.embedding = ByteEmbedding(tscale_type=tscale_type)
|
| 56 |
+
self.multimodal_sequencer = MultimodalSequencer(
|
| 57 |
+
tscale_type=tscale_type,
|
| 58 |
+
enable_text=True, enable_image=enable_image, enable_audio=enable_audio,
|
| 59 |
+
)
|
| 60 |
+
self.text_sequencer = self.multimodal_sequencer.text
|
| 61 |
+
self.image_sequencer = self.multimodal_sequencer.image
|
| 62 |
+
self.audio_sequencer = self.multimodal_sequencer.audio
|
| 63 |
+
self.vq_enabled = enable_vq
|
| 64 |
+
self.bridge = SharedVQ(
|
| 65 |
+
codebook_size=shared_vq_size,
|
| 66 |
+
tscale_type=tscale_type, enable_image=enable_image, enable_audio=enable_audio,
|
| 67 |
+
) if enable_vq else None
|
| 68 |
+
self.vq_to_trigram = TernaryScaleTensor(CODEBOOK_DIM, HIDDEN_DIM, tscale_type=tscale_type) if enable_vq else None
|
| 69 |
+
self.vq_to_trigram_norm = TernaryRMSNorm(HIDDEN_DIM, tscale_type=tscale_type) if enable_vq else None
|
| 70 |
+
self.graph_enabled = enable_graph and enable_vq
|
| 71 |
+
graph_vocab_size = self.bridge.total_codebook_size if self.graph_enabled else None
|
| 72 |
+
self.threshold = threshold
|
| 73 |
+
self.moegraph = MoEGraph(
|
| 74 |
+
trigram_dim=HIDDEN_DIM, codebook_size=graph_vocab_size or CODEBOOK_SIZE,
|
| 75 |
+
max_iters=max_moe_iters, halt_threshold=halt_threshold,
|
| 76 |
+
top_k=MG_TOP_K,
|
| 77 |
+
) if self.graph_enabled else None
|
| 78 |
+
self.byte_head = ByteHead(tscale_type=tscale_type)
|
| 79 |
+
# Composite motif generation (Phase 17)
|
| 80 |
+
self.composite_head = CompositeProposalHead(
|
| 81 |
+
dim=HIDDEN_DIM, codebook_dim=KGVQ_CODEBOOK_DIM,
|
| 82 |
+
k_max=K_MAX_COMPOSITES, codebook_size=kgvq_codebook_size or KGVQ_CODEBOOK_SIZE,
|
| 83 |
+
tscale_type=tscale_type,
|
| 84 |
+
) if self.graph_enabled else None
|
| 85 |
+
self.output_router = OutputRouter(tscale_type=tscale_type, depth=3) if enable_output_router else None
|
| 86 |
+
self.video_head = VideoHead(tscale_type=tscale_type) if enable_video_output else None
|
| 87 |
+
self.talker_head = TalkerHead(tscale_type=tscale_type) if enable_talker_output else None
|
| 88 |
+
self.memgram = MemGram(
|
| 89 |
+
struct_primes=MEMGRAM_STRUCT_PRIMES,
|
| 90 |
+
conv_primes=MEMGRAM_CONV_PRIMES,
|
| 91 |
+
embed_dim=MEMGRAM_EMBED_DIM, key_dim=MEMGRAM_KEY_DIM, hidden_dim=HIDDEN_DIM,
|
| 92 |
+
) if enable_memory_modules else None
|
| 93 |
+
self.memgram_enabled = self.memgram is not None
|
| 94 |
+
|
| 95 |
+
# KV Ledger + Attention (Phase 16 β replaces LSTM)
|
| 96 |
+
self.kv_ledger = KVLedger(max_size=KV_LEDGER_SIZE) if enable_attention else None
|
| 97 |
+
self.kq_cache = KQCache(max_size=KQ_CACHE_SIZE) if enable_attention else None
|
| 98 |
+
self.attention = ContextAttentionScheduler(dim=HIDDEN_DIM) if enable_attention else None
|
| 99 |
+
self.attention_enabled = bool(enable_attention)
|
| 100 |
+
|
| 101 |
+
def forward(self, x, targets=None, commitment_warmup_weight=1.0,
|
| 102 |
+
act_warmup_mode=False, ponder_lambda=0.01, images=None,
|
| 103 |
+
audio=None, timestep=0, loss_weights=None, output_mode=None):
|
| 104 |
+
has_image = images is not None
|
| 105 |
+
has_audio = audio is not None
|
| 106 |
+
if has_image and (not self.image_enabled or self.image_sequencer is None):
|
| 107 |
+
raise ValueError("images provided but model has enable_image=False")
|
| 108 |
+
if has_audio and (not self.audio_enabled or self.audio_sequencer is None):
|
| 109 |
+
raise ValueError("audio provided but model has enable_audio=False")
|
| 110 |
+
|
| 111 |
+
embedded = self.embedding(x)
|
| 112 |
+
seq_inputs = {'text': embedded}
|
| 113 |
+
if has_image:
|
| 114 |
+
seq_inputs['image'] = images
|
| 115 |
+
if has_audio:
|
| 116 |
+
seq_inputs['audio'] = audio
|
| 117 |
+
seq_outputs = self.multimodal_sequencer(seq_inputs)
|
| 118 |
+
relational = seq_outputs['text']
|
| 119 |
+
|
| 120 |
+
indices_dict = {}
|
| 121 |
+
if self.vq_enabled:
|
| 122 |
+
bridge_inputs = {'text': relational}
|
| 123 |
+
if 'image' in seq_outputs:
|
| 124 |
+
bridge_inputs['image'] = seq_outputs['image']
|
| 125 |
+
if 'audio' in seq_outputs:
|
| 126 |
+
bridge_inputs['audio'] = seq_outputs['audio']
|
| 127 |
+
|
| 128 |
+
combined, vq_losses, indices_dict = self.bridge(bridge_inputs, timestep=timestep)
|
| 129 |
+
if combined is None:
|
| 130 |
+
combined = relational
|
| 131 |
+
elif combined.shape[-1] == CODEBOOK_DIM:
|
| 132 |
+
combined = self.vq_to_trigram_norm(self.vq_to_trigram(combined))
|
| 133 |
+
vq_loss = vq_losses.get('text_vq', torch.zeros((), device=x.device))
|
| 134 |
+
if 'image_vq' in vq_losses:
|
| 135 |
+
vq_loss = vq_loss + vq_losses['image_vq']
|
| 136 |
+
if 'audio_vq' in vq_losses:
|
| 137 |
+
vq_loss = vq_loss + vq_losses['audio_vq']
|
| 138 |
+
else:
|
| 139 |
+
combined = relational
|
| 140 |
+
vq_loss = torch.zeros((), device=x.device)
|
| 141 |
+
|
| 142 |
+
active_mods = ['text']
|
| 143 |
+
if has_image:
|
| 144 |
+
active_mods.append('image')
|
| 145 |
+
if has_audio:
|
| 146 |
+
active_mods.append('audio')
|
| 147 |
+
active_count = len(active_mods)
|
| 148 |
+
|
| 149 |
+
# MemGram injection (after VQ, before Graph β D92)
|
| 150 |
+
memgram_decay_reg = torch.tensor(0.0, device=x.device)
|
| 151 |
+
|
| 152 |
+
if self.memgram_enabled and self.memgram is not None and self.vq_enabled:
|
| 153 |
+
vq_indices = indices_dict.get('text', torch.zeros(combined.shape[0], combined.shape[1], dtype=torch.long, device=x.device))
|
| 154 |
+
combined = self.memgram(
|
| 155 |
+
vq_indices=vq_indices,
|
| 156 |
+
hidden_state=combined,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
all_indices = None
|
| 160 |
+
composite_ids = None
|
| 161 |
+
composite_vq_loss = None
|
| 162 |
+
processed = combined
|
| 163 |
+
moegraph_ponder_loss = torch.tensor(0.0, device=x.device)
|
| 164 |
+
|
| 165 |
+
if self.graph_enabled and self.moegraph is not None and self.vq_enabled and vq_loss is not None:
|
| 166 |
+
self.moegraph._codebook_table = self.bridge.vq.table
|
| 167 |
+
self.moegraph._codebook_embed = None
|
| 168 |
+
|
| 169 |
+
all_indices = indices_dict.get('text', combined.new_zeros(combined.shape[0], combined.shape[1], dtype=torch.long))
|
| 170 |
+
if has_image and 'image' in indices_dict:
|
| 171 |
+
all_indices = torch.cat([all_indices, indices_dict['image']], dim=1)
|
| 172 |
+
if has_audio and 'audio' in indices_dict:
|
| 173 |
+
all_indices = torch.cat([all_indices, indices_dict['audio']], dim=1)
|
| 174 |
+
|
| 175 |
+
# MemGram retrieval for MoEGraph injection
|
| 176 |
+
memgram_cb = None
|
| 177 |
+
if self.memgram_enabled and self.memgram is not None and self.vq_enabled:
|
| 178 |
+
vq_idx = indices_dict.get('text', combined.new_zeros(combined.shape[0], combined.shape[1], dtype=torch.long))
|
| 179 |
+
memgram_cb = self.memgram.retrieve_cb(vq_idx)
|
| 180 |
+
|
| 181 |
+
# Attention output for KV conditioning
|
| 182 |
+
attn_out = None
|
| 183 |
+
if self.attention_enabled and self.attention is not None and self.kv_ledger is not None:
|
| 184 |
+
attn_out = self.attention(combined, self.kv_ledger, kq_cache=self.kq_cache)
|
| 185 |
+
|
| 186 |
+
# MoEGraph forward (unified ACT loop)
|
| 187 |
+
processed, moegraph_ponder_loss = self.moegraph(
|
| 188 |
+
combined, all_indices,
|
| 189 |
+
attention_output=attn_out,
|
| 190 |
+
memgram_cb_output=memgram_cb,
|
| 191 |
+
threshold=self.threshold,
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
# Composite motif generation (Phase 17)
|
| 195 |
+
if self.composite_head is not None:
|
| 196 |
+
composite_ids, composite_vq_loss, _ = self.composite_head(processed.mean(dim=1))
|
| 197 |
+
|
| 198 |
+
# Update bounded int-only KG co-occurrence state.
|
| 199 |
+
self.moegraph.update_kg_edges(all_indices)
|
| 200 |
+
|
| 201 |
+
# OutputRouter: route to appropriate head
|
| 202 |
+
if targets is not None or output_mode == "text":
|
| 203 |
+
logits = self.byte_head(processed)
|
| 204 |
+
elif output_mode == "video":
|
| 205 |
+
if self.video_head is None:
|
| 206 |
+
raise ValueError("output_mode='video' requested but video output is disabled")
|
| 207 |
+
logits = self.video_head(processed)
|
| 208 |
+
elif output_mode in {"audio", "talker"}:
|
| 209 |
+
if self.talker_head is None:
|
| 210 |
+
raise ValueError("audio/talker output requested but talker output is disabled")
|
| 211 |
+
logits = self.talker_head(processed)
|
| 212 |
+
elif self.training and self.output_router is not None:
|
| 213 |
+
route = self.output_router(processed, training=True)
|
| 214 |
+
route_weights, route_logits = route
|
| 215 |
+
logits = self.byte_head(processed)
|
| 216 |
+
elif self.output_router is not None:
|
| 217 |
+
route = self.output_router(processed, training=False)
|
| 218 |
+
if isinstance(route, torch.Tensor) and route.numel() > 0:
|
| 219 |
+
use_video = (route == 2).any() and self.video_head is not None
|
| 220 |
+
use_talk = (route == 3).any() and self.talker_head is not None
|
| 221 |
+
logits = self.video_head(processed) if use_video else \
|
| 222 |
+
self.talker_head(processed) if use_talk else \
|
| 223 |
+
self.byte_head(processed)
|
| 224 |
+
else:
|
| 225 |
+
logits = self.byte_head(processed)
|
| 226 |
+
else:
|
| 227 |
+
logits = self.byte_head(processed)
|
| 228 |
+
|
| 229 |
+
T_text = relational.shape[1]
|
| 230 |
+
if logits.dim() == 3 and logits.shape[-1] == VOCAB:
|
| 231 |
+
logits = logits[:, :T_text, :]
|
| 232 |
+
with torch.no_grad():
|
| 233 |
+
self._append_predictions_to_kv(logits.argmax(dim=-1), composite_ids=composite_ids)
|
| 234 |
+
losses = None
|
| 235 |
+
if targets is not None:
|
| 236 |
+
next_byte_logits = logits[:, :-1, :].contiguous()
|
| 237 |
+
lm_loss = F.cross_entropy(
|
| 238 |
+
next_byte_logits.view(-1, VOCAB),
|
| 239 |
+
targets.contiguous().view(-1),
|
| 240 |
+
ignore_index=SPECIAL_VOCAB["PAD"]
|
| 241 |
+
)
|
| 242 |
+
vq_component = commitment_warmup_weight * vq_loss if self.vq_enabled else None
|
| 243 |
+
losses = LossComponents(
|
| 244 |
+
lm=lm_loss,
|
| 245 |
+
vq_commitment=vq_component,
|
| 246 |
+
graph_l1=None,
|
| 247 |
+
moegraph_ponder=moegraph_ponder_loss,
|
| 248 |
+
memgram_decay_reg=memgram_decay_reg if self.memgram_enabled else None,
|
| 249 |
+
composite_vq=composite_vq_loss if self.composite_head is not None and composite_ids is not None else None,
|
| 250 |
+
weights=loss_weights if loss_weights is not None else LossWeights(),
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
return logits, losses, all_indices, None
|
| 254 |
+
|
| 255 |
+
@torch.no_grad()
|
| 256 |
+
def _append_predictions_to_kv(self, pred_ids, composite_ids=None):
|
| 257 |
+
if self.kv_ledger is None or self.kq_cache is None:
|
| 258 |
+
return
|
| 259 |
+
for b in range(pred_ids.shape[0]):
|
| 260 |
+
for t in range(pred_ids.shape[1]):
|
| 261 |
+
token_id = int(pred_ids[b, t])
|
| 262 |
+
self.kv_ledger.append(token_id)
|
| 263 |
+
self.kq_cache.append(token_id)
|
| 264 |
+
if composite_ids is None:
|
| 265 |
+
continue
|
| 266 |
+
composite_offset = self.bridge.total_codebook_size if self.vq_enabled and self.bridge is not None else 0
|
| 267 |
+
for k in range(composite_ids.shape[1]):
|
| 268 |
+
cid = int(composite_ids[b, k])
|
| 269 |
+
if cid >= 0:
|
| 270 |
+
self.kv_ledger.append(composite_offset + cid)
|
| 271 |
+
|
| 272 |
+
def _ternary_update_memory(self, accum_threshold=8, update_scales=True,
|
| 273 |
+
loss_components=None, loss_signal=None):
|
| 274 |
+
signal = loss_components.total if loss_components is not None else loss_signal
|
| 275 |
+
t_step = self._ternary_t_step(signal)
|
| 276 |
+
if signal is not None and not torch.isfinite(signal.detach()).all():
|
| 277 |
+
warnings.warn("Non-finite loss detected β skipping ternary state update",
|
| 278 |
+
RuntimeWarning, stacklevel=2)
|
| 279 |
+
self._clear_ternary_hooks()
|
| 280 |
+
self.zero_grad(set_to_none=True)
|
| 281 |
+
return
|
| 282 |
+
|
| 283 |
+
if loss_components is not None:
|
| 284 |
+
self._componentwise_ternary_backward(loss_components, t_step, update_scales, accum_threshold)
|
| 285 |
+
else:
|
| 286 |
+
self._apply_regular_ternary_hooks(accum_threshold, update_scales, t_step, loss_signal)
|
| 287 |
+
self._clear_ternary_hooks()
|
| 288 |
+
self._clear_backward_update_flags()
|
| 289 |
+
|
| 290 |
+
def prepare_ternary_backward(self, loss_signal=None, update_scales=True):
|
| 291 |
+
"""Configure streaming CUDA ternary updates before `loss.backward()`.
|
| 292 |
+
|
| 293 |
+
BigInt-scaled dense linear backward accumulates directly into int64
|
| 294 |
+
`corr_accum`, while legacy sparse tables still use int8 `T_accum`.
|
| 295 |
+
Calling this before backward lets the streaming path use the same
|
| 296 |
+
loss-scaled step that `_ternary_update_memory()` will finalize.
|
| 297 |
+
"""
|
| 298 |
+
t_step = self._ternary_t_step(loss_signal)
|
| 299 |
+
for module in self.modules():
|
| 300 |
+
if hasattr(module, "T_accum") or hasattr(module, "corr_accum"):
|
| 301 |
+
module._backward_t_accum_step = t_step
|
| 302 |
+
module._backward_update_scales = bool(update_scales)
|
| 303 |
+
module._stream_backward_updates = True
|
| 304 |
+
|
| 305 |
+
def _clear_backward_update_flags(self):
|
| 306 |
+
for module in self.modules():
|
| 307 |
+
for attr in (
|
| 308 |
+
"_backward_t_accum_step",
|
| 309 |
+
"_backward_update_scales",
|
| 310 |
+
"_stream_backward_updates",
|
| 311 |
+
"_streamed_ternary_backward",
|
| 312 |
+
"_streamed_bigint_backward",
|
| 313 |
+
):
|
| 314 |
+
if hasattr(module, attr):
|
| 315 |
+
delattr(module, attr)
|
| 316 |
+
|
| 317 |
+
@staticmethod
|
| 318 |
+
def _ternary_t_step(loss_signal):
|
| 319 |
+
return 1
|
| 320 |
+
|
| 321 |
+
def _clear_ternary_hooks(self):
|
| 322 |
+
base_names = [
|
| 323 |
+
"_hook_grad_T_sign", "_hook_grad_2d", "_hook_x_2d", "_hook_T",
|
| 324 |
+
"_hook_sparse_indices", "_hook_sparse_grad_sign", "_hook_sparse_T",
|
| 325 |
+
]
|
| 326 |
+
for module in self.modules():
|
| 327 |
+
if hasattr(module, "_T_accum_fp"):
|
| 328 |
+
delattr(module, "_T_accum_fp")
|
| 329 |
+
for hook_name in base_names:
|
| 330 |
+
if hasattr(module, hook_name):
|
| 331 |
+
delattr(module, hook_name)
|
| 332 |
+
for hook_name in list(vars(module).keys()):
|
| 333 |
+
if hook_name.startswith((
|
| 334 |
+
"_hook_grad_T_sign_", "_hook_grad_2d_", "_hook_x_2d_", "_hook_T_",
|
| 335 |
+
"_hook_sparse_indices_", "_hook_sparse_grad_sign_", "_hook_sparse_T_",
|
| 336 |
+
)):
|
| 337 |
+
delattr(module, hook_name)
|
| 338 |
+
|
| 339 |
+
def _componentwise_ternary_backward(self, loss_components, t_step, update_scales, accum_threshold):
|
| 340 |
+
from arbitor.kernel.ternary_scale import _COMPONENT_CONTEXT
|
| 341 |
+
|
| 342 |
+
self.prepare_ternary_backward(loss_components.total, update_scales=update_scales)
|
| 343 |
+
active = [(n, t, w) for n, t, w in loss_components.active_fields
|
| 344 |
+
if t is not None and t.dim() == 0 and t.requires_grad and float(w) != 0.0]
|
| 345 |
+
for idx, (name, comp_tensor, weight) in enumerate(active):
|
| 346 |
+
retain = idx < len(active) - 1
|
| 347 |
+
_COMPONENT_CONTEXT.set(name, weight)
|
| 348 |
+
try:
|
| 349 |
+
comp_tensor.backward(retain_graph=retain)
|
| 350 |
+
finally:
|
| 351 |
+
_COMPONENT_CONTEXT.clear()
|
| 352 |
+
self._consume_component_hooks(name, weight, t_step, update_scales, accum_threshold)
|
| 353 |
+
|
| 354 |
+
with torch.no_grad():
|
| 355 |
+
for module in self.modules():
|
| 356 |
+
if self._is_large_sparse_embedding(module):
|
| 357 |
+
continue
|
| 358 |
+
if update_scales:
|
| 359 |
+
self._step_E_from_accum(module)
|
| 360 |
+
self._apply_accumulated_flips(module, accum_threshold=accum_threshold)
|
| 361 |
+
|
| 362 |
+
def _consume_component_hooks(self, name, weight, t_step, update_scales, accum_threshold):
|
| 363 |
+
for module in self.modules():
|
| 364 |
+
sparse_idx_key = f"_hook_sparse_indices_{name}"
|
| 365 |
+
sparse_grad_key = f"_hook_sparse_grad_sign_{name}"
|
| 366 |
+
sparse_t_key = f"_hook_sparse_T_{name}"
|
| 367 |
+
if hasattr(module, sparse_idx_key) and hasattr(module, sparse_grad_key):
|
| 368 |
+
setattr(module, "_hook_sparse_indices", getattr(module, sparse_idx_key))
|
| 369 |
+
setattr(module, "_hook_sparse_grad_sign", getattr(module, sparse_grad_key))
|
| 370 |
+
if hasattr(module, sparse_t_key):
|
| 371 |
+
setattr(module, "_hook_sparse_T", getattr(module, sparse_t_key))
|
| 372 |
+
if update_scales and hasattr(module, "update_E"):
|
| 373 |
+
module._e_accum_threshold = 8
|
| 374 |
+
module.update_E()
|
| 375 |
+
if hasattr(module, "T_accum"):
|
| 376 |
+
module._t_accum_step = max(1, int(round(abs(float(weight)) * t_step)))
|
| 377 |
+
if hasattr(module, "ternary_step"):
|
| 378 |
+
module.ternary_step(accum_threshold=accum_threshold)
|
| 379 |
+
for key in (sparse_idx_key, sparse_grad_key, sparse_t_key):
|
| 380 |
+
if hasattr(module, key):
|
| 381 |
+
delattr(module, key)
|
| 382 |
+
continue
|
| 383 |
+
|
| 384 |
+
dense_key = f"_hook_grad_T_sign_{name}"
|
| 385 |
+
dense_t_key = f"_hook_T_{name}"
|
| 386 |
+
if hasattr(module, dense_key):
|
| 387 |
+
grad_sign = getattr(module, dense_key)
|
| 388 |
+
hook_t = getattr(module, dense_t_key, None)
|
| 389 |
+
self._accumulate_component_grad_continuous(
|
| 390 |
+
module, grad_sign, weight, t_step,
|
| 391 |
+
)
|
| 392 |
+
delattr(module, dense_key)
|
| 393 |
+
if hasattr(module, dense_t_key):
|
| 394 |
+
delattr(module, dense_t_key)
|
| 395 |
+
|
| 396 |
+
grad_key = f"_hook_grad_2d_{name}"
|
| 397 |
+
x_key = f"_hook_x_2d_{name}"
|
| 398 |
+
if not hasattr(module, grad_key) or not hasattr(module, x_key):
|
| 399 |
+
continue
|
| 400 |
+
comp_grad = getattr(module, grad_key)
|
| 401 |
+
comp_x = getattr(module, x_key)
|
| 402 |
+
if torch.isfinite(comp_grad).all() and torch.isfinite(comp_x).all():
|
| 403 |
+
raw_grad = torch.clamp(comp_grad.transpose(0, 1) @ comp_x, -10.0, 10.0)
|
| 404 |
+
self._accumulate_component_grad_continuous(
|
| 405 |
+
module, raw_grad, weight, t_step,
|
| 406 |
+
)
|
| 407 |
+
delattr(module, grad_key)
|
| 408 |
+
delattr(module, x_key)
|
| 409 |
+
|
| 410 |
+
def _accumulate_component_grad_continuous(self, module, raw_grad, weight, t_step):
|
| 411 |
+
"""Component loss accumulation without persistent float optimizer state."""
|
| 412 |
+
if not hasattr(module, "_T_shape"):
|
| 413 |
+
return
|
| 414 |
+
shape = tuple(int(x) for x in module._T_shape.tolist())
|
| 415 |
+
if tuple(raw_grad.shape) != shape:
|
| 416 |
+
return
|
| 417 |
+
with torch.no_grad():
|
| 418 |
+
step = max(1, int(round(abs(float(weight)) * t_step)))
|
| 419 |
+
if float(weight) < 0:
|
| 420 |
+
step = -step
|
| 421 |
+
if hasattr(module, "corr_accum") and hasattr(module, "_accumulate_corr_from_grad_sign"):
|
| 422 |
+
signed = raw_grad.sign().to(device=module.corr_accum.device, dtype=torch.int8)
|
| 423 |
+
module._accumulate_corr_from_grad_sign(signed, corr_step=step)
|
| 424 |
+
return
|
| 425 |
+
if not hasattr(module, "T_accum") or tuple(module.T_accum.shape) != shape:
|
| 426 |
+
return
|
| 427 |
+
if hasattr(module, "_T_accum_fp"):
|
| 428 |
+
delattr(module, "_T_accum_fp")
|
| 429 |
+
signed = raw_grad.sign().to(device=module.T_accum.device, dtype=torch.int8)
|
| 430 |
+
module.T_accum.copy_(
|
| 431 |
+
torch.clamp(
|
| 432 |
+
module.T_accum.to(torch.int16) - signed.to(torch.int16) * step,
|
| 433 |
+
-127,
|
| 434 |
+
127,
|
| 435 |
+
).to(torch.int8)
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
def _apply_regular_ternary_hooks(self, accum_threshold, update_scales, t_step, loss_signal):
|
| 439 |
+
for module in self.modules():
|
| 440 |
+
is_bigint = hasattr(module, "corr_accum") and hasattr(module, "_accumulate_corr_from_grad_sign")
|
| 441 |
+
is_legacy = hasattr(module, "T_accum") or hasattr(module, "E_accum")
|
| 442 |
+
if is_bigint or is_legacy:
|
| 443 |
+
self._prepare_per_group_threshold(module)
|
| 444 |
+
streamed = bool(getattr(module, "_streamed_ternary_backward", False))
|
| 445 |
+
has_hook = (
|
| 446 |
+
hasattr(module, "_hook_grad_T_sign")
|
| 447 |
+
or (hasattr(module, "_hook_grad_2d") and hasattr(module, "_hook_x_2d"))
|
| 448 |
+
or (hasattr(module, "_hook_sparse_indices") and hasattr(module, "_hook_sparse_grad_sign"))
|
| 449 |
+
)
|
| 450 |
+
bigint_streamed = bool(getattr(module, "_streamed_bigint_backward", False))
|
| 451 |
+
if (streamed or bigint_streamed) and not has_hook:
|
| 452 |
+
if streamed and update_scales:
|
| 453 |
+
self._step_E_from_accum(module)
|
| 454 |
+
if streamed:
|
| 455 |
+
had_flip = self._apply_accumulated_flips(module, accum_threshold=accum_threshold)
|
| 456 |
+
self._record_flip_health(module, had_flip)
|
| 457 |
+
if hasattr(module, "per_group_threshold"):
|
| 458 |
+
del module.per_group_threshold
|
| 459 |
+
continue
|
| 460 |
+
if has_hook:
|
| 461 |
+
if hasattr(module, "_hook_grad_T_sign") and hasattr(module, "_accumulate_corr_from_grad_sign"):
|
| 462 |
+
module._accumulate_corr_from_grad_sign(module._hook_grad_T_sign)
|
| 463 |
+
del module._hook_grad_T_sign
|
| 464 |
+
if hasattr(module, "ternary_step"):
|
| 465 |
+
module.ternary_step(accum_threshold=accum_threshold)
|
| 466 |
+
if hasattr(module, "per_group_threshold"):
|
| 467 |
+
del module.per_group_threshold
|
| 468 |
+
|
| 469 |
+
def _prepare_per_group_threshold(self, module):
|
| 470 |
+
if self._is_large_sparse_embedding(module):
|
| 471 |
+
module.per_group_threshold = None
|
| 472 |
+
return
|
| 473 |
+
if hasattr(module, "corr_accum") and not hasattr(module, "T_accum"):
|
| 474 |
+
module.per_group_threshold = None
|
| 475 |
+
return
|
| 476 |
+
if not hasattr(module, "E") or not hasattr(module, "_T_shape"):
|
| 477 |
+
module.per_group_threshold = None
|
| 478 |
+
return
|
| 479 |
+
shape = tuple(int(x) for x in module._T_shape.tolist())
|
| 480 |
+
out_dim, in_dim = shape
|
| 481 |
+
gpr = _ceil_div(in_dim, module.group_size)
|
| 482 |
+
E_view = module.E.view(out_dim, gpr).float()
|
| 483 |
+
threshold_g = 8.0 + 0.25 * torch.min(E_view.abs(), torch.tensor(32.0, device=E_view.device))
|
| 484 |
+
module.per_group_threshold = torch.clamp(threshold_g, max=16.0).to(torch.int8).reshape(-1)
|
| 485 |
+
|
| 486 |
+
@staticmethod
|
| 487 |
+
def _is_large_sparse_embedding(module):
|
| 488 |
+
return (
|
| 489 |
+
hasattr(module, "num_embeddings")
|
| 490 |
+
and hasattr(module, "sparse_threshold")
|
| 491 |
+
and module.num_embeddings >= module.sparse_threshold
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
@staticmethod
|
| 495 |
+
def _step_E_from_accum(module):
|
| 496 |
+
if hasattr(module, "corr_accum"):
|
| 497 |
+
return # BigInt modules don't use E_accum threshold flips
|
| 498 |
+
if not hasattr(module, "E") or not hasattr(module, "E_accum"):
|
| 499 |
+
return
|
| 500 |
+
threshold = int(getattr(module, "_e_accum_threshold", 8))
|
| 501 |
+
accum = module.E_accum.to(torch.int16)
|
| 502 |
+
step = torch.where(
|
| 503 |
+
accum >= threshold,
|
| 504 |
+
torch.ones_like(accum, dtype=torch.int16),
|
| 505 |
+
torch.where(accum <= -threshold, torch.full_like(accum, -1, dtype=torch.int16), torch.zeros_like(accum, dtype=torch.int16)),
|
| 506 |
+
)
|
| 507 |
+
if step.any():
|
| 508 |
+
module.E = torch.clamp(module.E.to(torch.int16) + step, -128, 127).to(torch.int8)
|
| 509 |
+
module.E_accum = (accum - step * threshold).to(torch.int8)
|
| 510 |
+
|
| 511 |
+
@staticmethod
|
| 512 |
+
def _apply_accumulated_flips(module, accum_threshold=3):
|
| 513 |
+
"""Packed-byte carry: when T_accum crosses Β±1, move trit by Β±1 via Β±3^pos."""
|
| 514 |
+
if not hasattr(module, "T_accum") or not hasattr(module, "T_packed") or not hasattr(module, "_T_shape"):
|
| 515 |
+
return False
|
| 516 |
+
shape = tuple(int(x) for x in module._T_shape.tolist())
|
| 517 |
+
if tuple(module.T_accum.shape) != shape:
|
| 518 |
+
return False
|
| 519 |
+
carry_up = module.T_accum > 1
|
| 520 |
+
carry_down = module.T_accum < -1
|
| 521 |
+
if not carry_up.any() and not carry_down.any():
|
| 522 |
+
return False
|
| 523 |
+
dev = module.T_packed.device
|
| 524 |
+
out_dim, in_dim = shape
|
| 525 |
+
pows = torch.tensor([1, 3, 9, 27, 81], device=dev, dtype=torch.int16)
|
| 526 |
+
pk = module.T_packed.to(torch.int16).clone()
|
| 527 |
+
for p in range(5):
|
| 528 |
+
if p >= in_dim:
|
| 529 |
+
continue
|
| 530 |
+
cols = torch.arange(p, in_dim, 5, device=dev)
|
| 531 |
+
if cols.numel() == 0:
|
| 532 |
+
continue
|
| 533 |
+
is_up = carry_up[:, cols]
|
| 534 |
+
is_dn = carry_down[:, cols]
|
| 535 |
+
if not is_up.any() and not is_dn.any():
|
| 536 |
+
continue
|
| 537 |
+
rows_2d = torch.arange(out_dim, device=dev)[:, None]
|
| 538 |
+
lin_idx = rows_2d * in_dim + cols[None, :]
|
| 539 |
+
byte_idx = lin_idx // 5
|
| 540 |
+
pv = pk[byte_idx]
|
| 541 |
+
p_up = (pv + pows[p]).clamp(0, 242)
|
| 542 |
+
p_dn = (pv - pows[p]).clamp(0, 242)
|
| 543 |
+
pk[byte_idx] = torch.where(is_up, p_up, torch.where(is_dn, p_dn, pv))
|
| 544 |
+
module.T_packed = pk.to(torch.uint8)
|
| 545 |
+
# Reset T_accum to 0 on carry so W = T_accum Γ T doesn't jump
|
| 546 |
+
mask = carry_up | carry_down
|
| 547 |
+
module.T_accum[mask] = torch.zeros_like(module.T_accum[mask])
|
| 548 |
+
return True
|
| 549 |
+
|
| 550 |
+
@staticmethod
|
| 551 |
+
def _record_flip_health(module, had_flip):
|
| 552 |
+
if not hasattr(module, "T_accum"):
|
| 553 |
+
return
|
| 554 |
+
steps_since = getattr(module, "_steps_since_flip", 0)
|
| 555 |
+
module._steps_since_flip = 0 if had_flip else steps_since + 1
|
| 556 |
+
module._had_flip = False
|
| 557 |
+
|
| 558 |
+
def generate(self, idx, max_new_token, temperature=1.0, images=None, audio=None,
|
| 559 |
+
conversation_id=None, top_k=None, min_new_tokens=0, return_metadata=False):
|
| 560 |
+
if self.kv_ledger is not None and self.kv_ledger.size == 0:
|
| 561 |
+
with torch.no_grad():
|
| 562 |
+
for token_id in idx.reshape(-1).tolist():
|
| 563 |
+
self.kv_ledger.append(int(token_id))
|
| 564 |
+
self.kq_cache.append(int(token_id))
|
| 565 |
+
for i in range(max_new_token):
|
| 566 |
+
idx_cond = idx[:, -CTX:]
|
| 567 |
+
logits, _, _, _ = self(idx_cond, images=images, audio=audio, timestep=i, output_mode="text")
|
| 568 |
+
last_logits = logits[:, -1, :] / temperature
|
| 569 |
+
# top-k filtering
|
| 570 |
+
if top_k is not None and top_k > 0:
|
| 571 |
+
v, _ = torch.topk(last_logits, min(top_k, last_logits.size(-1)))
|
| 572 |
+
kth = v[:, -1].unsqueeze(-1).expand_as(last_logits)
|
| 573 |
+
last_logits = last_logits.where(last_logits >= kth, float('-inf'))
|
| 574 |
+
probs = F.softmax(last_logits, dim=-1)
|
| 575 |
+
idx_next = torch.multinomial(probs, num_samples=1)
|
| 576 |
+
idx = torch.cat([idx, idx_next], dim=1)
|
| 577 |
+
# Enforce min_new_tokens (only relevant if caller truncates after generation)
|
| 578 |
+
generated = idx.shape[1] - (min_new_tokens if return_metadata else 0)
|
| 579 |
+
if return_metadata:
|
| 580 |
+
return {
|
| 581 |
+
"tokens": idx,
|
| 582 |
+
"n_generated": generated,
|
| 583 |
+
"temperature": temperature,
|
| 584 |
+
}
|
| 585 |
+
return idx
|
arbitor/optim/__init__.py
ADDED
|
File without changes
|
arbitor/optim/sign_sgd.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.optim import Optimizer
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class SignSGD(Optimizer):
|
| 6 |
+
def __init__(self, params, lr=1e-2, weight_decay=0.0):
|
| 7 |
+
defaults = dict(lr=lr, weight_decay=weight_decay)
|
| 8 |
+
super().__init__(params, defaults)
|
| 9 |
+
|
| 10 |
+
@torch.no_grad()
|
| 11 |
+
def step(self, closure=None):
|
| 12 |
+
loss = None
|
| 13 |
+
if closure is not None:
|
| 14 |
+
with torch.enable_grad():
|
| 15 |
+
loss = closure()
|
| 16 |
+
|
| 17 |
+
for group in self.param_groups:
|
| 18 |
+
lr = group["lr"]
|
| 19 |
+
wd = group["weight_decay"]
|
| 20 |
+
|
| 21 |
+
for p in group["params"]:
|
| 22 |
+
if p.grad is None:
|
| 23 |
+
continue
|
| 24 |
+
|
| 25 |
+
grad = p.grad
|
| 26 |
+
if grad.is_sparse:
|
| 27 |
+
grad = grad.to_dense()
|
| 28 |
+
|
| 29 |
+
update = grad.sign()
|
| 30 |
+
|
| 31 |
+
if wd > 0:
|
| 32 |
+
update = update + wd * p.sign()
|
| 33 |
+
|
| 34 |
+
p.add_(-lr * update)
|
| 35 |
+
|
| 36 |
+
return loss
|
| 37 |
+
|
| 38 |
+
@torch.no_grad()
|
| 39 |
+
def get_memory_mb(self, params=None) -> float:
|
| 40 |
+
if params is None:
|
| 41 |
+
params = []
|
| 42 |
+
for group in self.param_groups:
|
| 43 |
+
params.extend(group["params"])
|
| 44 |
+
total_bytes = sum(p.numel() * p.element_size() for p in params)
|
| 45 |
+
return total_bytes / (1024 * 1024)
|
arbitor/profiling.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Profiling utilities: torch.profiler wrapper and analysis tools.
|
| 3 |
+
|
| 4 |
+
Following D-103: profile first, optimize only hot paths.
|
| 5 |
+
Uses torch.profiler to identify training loop bottlenecks.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import sys
|
| 9 |
+
import os
|
| 10 |
+
import json
|
| 11 |
+
import math
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
sys.path.insert(0, os.path.dirname(__file__))
|
| 15 |
+
|
| 16 |
+
from .main import ARBModel
|
| 17 |
+
from .config import VOCAB, CTX
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def profile_training(model, train_data, device, n_steps=20, warmup_steps=5,
|
| 21 |
+
top_k=10, batch_size=64, ctx=CTX):
|
| 22 |
+
"""
|
| 23 |
+
Profile N training steps using torch.profiler.
|
| 24 |
+
|
| 25 |
+
Runs profiling with CUDA + CPU activity tracing, warmup steps (no profiling),
|
| 26 |
+
then profiled steps. Returns list of top-K hot path tuples and saves JSON.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
model: ARBModel instance
|
| 30 |
+
train_data: 1D byte tensor of training data
|
| 31 |
+
device: 'cuda' or 'cpu'
|
| 32 |
+
n_steps: Number of profiled training steps
|
| 33 |
+
warmup_steps: Steps before profiling begins (no tracing)
|
| 34 |
+
top_k: Number of top operations to return
|
| 35 |
+
batch_size: Batch size for each training step
|
| 36 |
+
ctx: Context window length
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
List of dicts with keys: op_name, cuda_time_us, cpu_time_us, calls
|
| 40 |
+
"""
|
| 41 |
+
model.train()
|
| 42 |
+
prof = None
|
| 43 |
+
|
| 44 |
+
if device == "cuda":
|
| 45 |
+
prof = torch.profiler.profile(
|
| 46 |
+
activities=[
|
| 47 |
+
torch.profiler.ProfilerActivity.CPU,
|
| 48 |
+
torch.profiler.ProfilerActivity.CUDA,
|
| 49 |
+
],
|
| 50 |
+
record_shapes=True,
|
| 51 |
+
with_stack=True,
|
| 52 |
+
with_flops=True,
|
| 53 |
+
)
|
| 54 |
+
else:
|
| 55 |
+
prof = torch.profiler.profile(
|
| 56 |
+
activities=[torch.profiler.ProfilerActivity.CPU],
|
| 57 |
+
record_shapes=True,
|
| 58 |
+
with_stack=False,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# Warmup steps (no profiling)
|
| 62 |
+
for _ in range(warmup_steps):
|
| 63 |
+
ix = torch.randint(0, len(train_data) - ctx - 1, (batch_size,))
|
| 64 |
+
x = torch.stack([train_data[j: j + ctx] for j in ix])
|
| 65 |
+
targets = x[:, 3:]
|
| 66 |
+
x = x.to(device)
|
| 67 |
+
targets = targets.to(device)
|
| 68 |
+
with torch.no_grad():
|
| 69 |
+
model(x, targets=targets)
|
| 70 |
+
|
| 71 |
+
# Profiled steps
|
| 72 |
+
prof.start()
|
| 73 |
+
for _ in range(n_steps):
|
| 74 |
+
ix = torch.randint(0, len(train_data) - ctx - 1, (batch_size,))
|
| 75 |
+
x = torch.stack([train_data[j: j + ctx] for j in ix])
|
| 76 |
+
targets = x[:, 3:]
|
| 77 |
+
x = x.to(device)
|
| 78 |
+
targets = targets.to(device)
|
| 79 |
+
with torch.no_grad():
|
| 80 |
+
model(x, targets=targets)
|
| 81 |
+
if device == "cuda":
|
| 82 |
+
torch.cuda.synchronize()
|
| 83 |
+
prof.stop()
|
| 84 |
+
|
| 85 |
+
# Process profiler output
|
| 86 |
+
if device == "cuda":
|
| 87 |
+
key_avg = prof.key_averages()
|
| 88 |
+
table = key_avg.table(sort_by="cuda_time_total", row_limit=top_k)
|
| 89 |
+
else:
|
| 90 |
+
key_avg = prof.key_averages()
|
| 91 |
+
table = key_avg.table(sort_by="cpu_time_total", row_limit=top_k)
|
| 92 |
+
|
| 93 |
+
# Extract top-K entries
|
| 94 |
+
events = key_avg.events() if hasattr(key_avg, 'events') else key_avg[:top_k]
|
| 95 |
+
top_results = []
|
| 96 |
+
for evt in events[:top_k]:
|
| 97 |
+
# device_time replaces deprecated cuda_time in recent PyTorch
|
| 98 |
+
cuda_t = (evt.device_time if hasattr(evt, 'device_time') and evt.device_time is not None
|
| 99 |
+
else evt.cuda_time if hasattr(evt, 'cuda_time') else 0)
|
| 100 |
+
entry = {
|
| 101 |
+
"op_name": evt.key if hasattr(evt, 'key') else str(evt),
|
| 102 |
+
"cuda_time_us": cuda_t,
|
| 103 |
+
"cpu_time_us": evt.cpu_time if hasattr(evt, 'cpu_time') else 0,
|
| 104 |
+
"calls": evt.count if hasattr(evt, 'count') else 1,
|
| 105 |
+
}
|
| 106 |
+
top_results.append(entry)
|
| 107 |
+
|
| 108 |
+
# Print summary
|
| 109 |
+
print("\n=== Profiling Results (Top-{} Hot Paths) ===".format(top_k))
|
| 110 |
+
print(table)
|
| 111 |
+
print("============================================\n")
|
| 112 |
+
|
| 113 |
+
# Save profiler output as JSON
|
| 114 |
+
prof.export_chrome_trace("/tmp/profiler_trace.json")
|
| 115 |
+
|
| 116 |
+
return top_results
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def analyze_profiler_output(prof_path):
|
| 120 |
+
"""
|
| 121 |
+
Load saved profiler JSON output and extract key insights.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
prof_path: Path to saved profiler JSON file
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
List of dicts with op_name, cuda_time_us, cpu_time_us, calls
|
| 128 |
+
"""
|
| 129 |
+
with open(prof_path, "r") as f:
|
| 130 |
+
data = json.load(f)
|
| 131 |
+
|
| 132 |
+
# Profiler JSON can be a dict with 'traceEvents' or a flat list
|
| 133 |
+
if isinstance(data, dict) and "traceEvents" in data:
|
| 134 |
+
events = data["traceEvents"]
|
| 135 |
+
elif isinstance(data, list):
|
| 136 |
+
events = data
|
| 137 |
+
else:
|
| 138 |
+
events = []
|
| 139 |
+
|
| 140 |
+
# Aggregate events by name
|
| 141 |
+
op_stats = {}
|
| 142 |
+
for evt in events:
|
| 143 |
+
if isinstance(evt, dict):
|
| 144 |
+
name = evt.get("name", "unknown")
|
| 145 |
+
dur = evt.get("dur", 0) # microseconds
|
| 146 |
+
cat = evt.get("cat", "")
|
| 147 |
+
if name not in op_stats:
|
| 148 |
+
op_stats[name] = {"cuda_time_us": 0, "cpu_time_us": 0, "calls": 0}
|
| 149 |
+
if "gpu" in cat.lower():
|
| 150 |
+
op_stats[name]["cuda_time_us"] += dur
|
| 151 |
+
elif "cpu" in cat.lower() or cat == "":
|
| 152 |
+
op_stats[name]["cpu_time_us"] += dur
|
| 153 |
+
op_stats[name]["calls"] += 1
|
| 154 |
+
|
| 155 |
+
# Sort by CUDA time descending
|
| 156 |
+
sorted_ops = sorted(
|
| 157 |
+
op_stats.items(),
|
| 158 |
+
key=lambda x: x[1]["cuda_time_us"],
|
| 159 |
+
reverse=True,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
results = []
|
| 163 |
+
for name, stats in sorted_ops:
|
| 164 |
+
results.append({
|
| 165 |
+
"op_name": name,
|
| 166 |
+
"cuda_time_us": stats["cuda_time_us"],
|
| 167 |
+
"cpu_time_us": stats["cpu_time_us"],
|
| 168 |
+
"calls": stats["calls"],
|
| 169 |
+
})
|
| 170 |
+
|
| 171 |
+
# Print formatted summary
|
| 172 |
+
print("\n=== Profiler Analysis ===")
|
| 173 |
+
print(f"{'Operation':<40} {'CUDA Time (us)':>15} {'CPU Time (us)':>15} {'Calls':>8}")
|
| 174 |
+
print("-" * 80)
|
| 175 |
+
for r in results[:20]:
|
| 176 |
+
print(f"{r['op_name']:<40} {r['cuda_time_us']:>15.0f} {r['cpu_time_us']:>15.0f} {r['calls']:>8}")
|
| 177 |
+
|
| 178 |
+
# Identify dominating patterns
|
| 179 |
+
total_cuda = sum(r["cuda_time_us"] for r in results)
|
| 180 |
+
if total_cuda > 0:
|
| 181 |
+
print("\n=== Hot Path Analysis ===")
|
| 182 |
+
for r in results[:5]:
|
| 183 |
+
pct = (r["cuda_time_us"] / total_cuda) * 100 if total_cuda > 0 else 0
|
| 184 |
+
label = ""
|
| 185 |
+
if "vq" in r["op_name"].lower() or "flash_vq" in r["op_name"].lower():
|
| 186 |
+
label = " β VQ candidate for Triton kernel"
|
| 187 |
+
elif "moe" in r["op_name"].lower() or "scatter" in r["op_name"].lower():
|
| 188 |
+
label = " β MoE dispatch candidate"
|
| 189 |
+
elif "embed" in r["op_name"].lower() or "gather" in r["op_name"].lower():
|
| 190 |
+
label = " β Embedding gather (existing Triton kernel)"
|
| 191 |
+
elif "mm" in r["op_name"].lower() or "linear" in r["op_name"].lower():
|
| 192 |
+
label = " β General matmul (torch.compile candidate)"
|
| 193 |
+
print(f" {r['op_name']:<40} {pct:>5.1f}%{label}")
|
| 194 |
+
|
| 195 |
+
print("============================================\n")
|
| 196 |
+
return results
|
arbitor/sequencers.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Sequencer modules β input processing for all modalities."""
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
from .kernel.ternary_scale import TernaryScaleTensor, TScaleType, TernaryRMSNorm, GROUP_SIZES, _HAS_TRITON, _HAS_TILELANG
|
| 7 |
+
if _HAS_TRITON:
|
| 8 |
+
import triton
|
| 9 |
+
import triton.language as tl
|
| 10 |
+
else:
|
| 11 |
+
triton = None
|
| 12 |
+
tl = None
|
| 13 |
+
try:
|
| 14 |
+
from .kernel.ternary_scale import _TritonTernaryEmbedFn
|
| 15 |
+
except ImportError:
|
| 16 |
+
_TritonTernaryEmbedFn = None
|
| 17 |
+
from .converters.convert_to_ternary8 import pack_ternary, unpack_ternary
|
| 18 |
+
from math import ceil as _ceil
|
| 19 |
+
|
| 20 |
+
_ceil_div = lambda a, b: _ceil(a / b) if b > 0 else 0
|
| 21 |
+
from .config import VOCAB, EMBEDDING_DIM, HIDDEN_DIM, AUDIO_SR, AUDIO_FRAME_RATE
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class ByteEmbedding(nn.Module):
|
| 25 |
+
"""Byte-level embedding via packed ternary + BigInt correlation.
|
| 26 |
+
|
| 27 |
+
All training state is integer. T_accum/E_accum replaced by
|
| 28 |
+
corr_accum (int64 per group, never clips or resets).
|
| 29 |
+
|
| 30 |
+
S = 2^(E + K Γ mean_corr) where mean_corr = corr_accum / (step Γ gs)
|
| 31 |
+
"""
|
| 32 |
+
def __init__(self, tscale_type=TScaleType.T32):
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.tscale_type = tscale_type
|
| 35 |
+
self.threshold = 0.05
|
| 36 |
+
self.group_size = GROUP_SIZES.get(tscale_type, GROUP_SIZES[TScaleType.T64])
|
| 37 |
+
shape = (VOCAB, EMBEDDING_DIM)
|
| 38 |
+
|
| 39 |
+
init_std = 0.02
|
| 40 |
+
init_threshold = min(self.threshold, 0.5 * init_std)
|
| 41 |
+
self.threshold = init_threshold
|
| 42 |
+
w_init = torch.randn(VOCAB, EMBEDDING_DIM) * init_std
|
| 43 |
+
T_init = w_init.sign() * (w_init.abs() > init_threshold).to(w_init.dtype)
|
| 44 |
+
packed_T, T_shape, T_pad = pack_ternary(T_init)
|
| 45 |
+
|
| 46 |
+
self.register_buffer("T_packed", packed_T)
|
| 47 |
+
self.register_buffer("_T_shape", torch.tensor([VOCAB, EMBEDDING_DIM], dtype=torch.long))
|
| 48 |
+
self.register_buffer("_T_pad", torch.tensor(T_pad, dtype=torch.long))
|
| 49 |
+
|
| 50 |
+
out_dim, in_dim = shape
|
| 51 |
+
gpr = _ceil_div(in_dim, self.group_size)
|
| 52 |
+
total_in = gpr * self.group_size
|
| 53 |
+
padded = torch.zeros(out_dim, total_in)
|
| 54 |
+
abs_w = w_init.abs()
|
| 55 |
+
padded[:, :in_dim] = abs_w
|
| 56 |
+
grouped = padded.view(out_dim, gpr, self.group_size)
|
| 57 |
+
grp_means = grouped.mean(dim=2)
|
| 58 |
+
E_vals = torch.where(grp_means > 0, grp_means, torch.ones_like(grp_means))
|
| 59 |
+
self.register_buffer("E", E_vals.flatten().log2().clamp(-128, 127).to(torch.int8))
|
| 60 |
+
|
| 61 |
+
# BigInt correlation accumulator (replaces T_accum + E_accum)
|
| 62 |
+
n_grp = out_dim * gpr
|
| 63 |
+
self.register_buffer("corr_accum", torch.zeros(n_grp, dtype=torch.int64))
|
| 64 |
+
self.register_buffer("step_counter", torch.zeros(1, dtype=torch.int64))
|
| 65 |
+
|
| 66 |
+
self.norm = TernaryRMSNorm(EMBEDDING_DIM, tscale_type=tscale_type)
|
| 67 |
+
|
| 68 |
+
def _get_T(self):
|
| 69 |
+
return unpack_ternary(self.T_packed, tuple(self._T_shape.tolist()), int(self._T_pad.item()))
|
| 70 |
+
|
| 71 |
+
def _get_S(self):
|
| 72 |
+
gpr = _ceil_div(EMBEDDING_DIM, self.group_size)
|
| 73 |
+
e_adj = self.E.float()
|
| 74 |
+
step = int(self.step_counter.item())
|
| 75 |
+
if step > 0:
|
| 76 |
+
from .kernel.ternary_scale import _bigint_corr_strength
|
| 77 |
+
denom = max(step * self.group_size, 1)
|
| 78 |
+
e_adj = e_adj + (self.corr_accum.float() / denom) * _bigint_corr_strength()
|
| 79 |
+
E_exp = e_adj.view(VOCAB, gpr).repeat_interleave(self.group_size, dim=1)
|
| 80 |
+
if E_exp.shape[1] > EMBEDDING_DIM:
|
| 81 |
+
E_exp = E_exp[:, :EMBEDDING_DIM]
|
| 82 |
+
return torch.exp2(E_exp)
|
| 83 |
+
|
| 84 |
+
@torch.no_grad()
|
| 85 |
+
def _accumulate_corr_from_grad_sign(self, grad_sign, corr_step=1):
|
| 86 |
+
if grad_sign is None:
|
| 87 |
+
return
|
| 88 |
+
shape = tuple(self._T_shape.tolist())
|
| 89 |
+
out_dim, in_dim = shape
|
| 90 |
+
if tuple(grad_sign.shape) != shape:
|
| 91 |
+
return
|
| 92 |
+
gs = self.group_size
|
| 93 |
+
T = self._get_T().to(device=grad_sign.device, dtype=torch.int16)
|
| 94 |
+
signed = grad_sign.to(torch.int16) * T
|
| 95 |
+
gpr = _ceil_div(in_dim, gs)
|
| 96 |
+
total_in = gpr * gs
|
| 97 |
+
if total_in > in_dim:
|
| 98 |
+
signed = F.pad(signed, (0, total_in - in_dim))
|
| 99 |
+
score = signed.view(out_dim, gpr, gs).sum(dim=2, dtype=torch.int16)
|
| 100 |
+
self.corr_accum -= score.flatten().to(dtype=torch.int64) * int(corr_step)
|
| 101 |
+
self.step_counter += abs(int(corr_step))
|
| 102 |
+
|
| 103 |
+
def forward(self, x):
|
| 104 |
+
if x.is_cuda and _HAS_TRITON and _TritonTernaryEmbedFn is not None:
|
| 105 |
+
_dummy = torch.zeros(1, device=x.device, requires_grad=True)
|
| 106 |
+
emb = _TritonTernaryEmbedFn.apply(x, _dummy, self)
|
| 107 |
+
return self.norm(emb)
|
| 108 |
+
T = self._get_T()
|
| 109 |
+
S = self._get_S()
|
| 110 |
+
w_eff = S * T.float()
|
| 111 |
+
w_eff_grad = w_eff.detach().requires_grad_(True)
|
| 112 |
+
|
| 113 |
+
def capture_w_grad(grad_w):
|
| 114 |
+
self._hook_grad_T_sign = grad_w.sign().to(torch.int8)
|
| 115 |
+
|
| 116 |
+
w_eff_grad.register_hook(capture_w_grad)
|
| 117 |
+
out = self.norm(F.embedding(x, w_eff_grad))
|
| 118 |
+
return out
|
| 119 |
+
|
| 120 |
+
def ternary_step(self, accum_threshold=3):
|
| 121 |
+
if hasattr(self, "_hook_grad_T_sign"):
|
| 122 |
+
if hasattr(self, "_accumulate_corr_from_grad_sign"):
|
| 123 |
+
self._accumulate_corr_from_grad_sign(self._hook_grad_T_sign)
|
| 124 |
+
del self._hook_grad_T_sign
|
| 125 |
+
|
| 126 |
+
def update_E(self, loss_signal=None):
|
| 127 |
+
pass # E is fixed; S adjusted via corr_accum
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class Sequencer(nn.Module):
|
| 131 |
+
def __init__(self, modality, window_size, tscale_type=TScaleType.T32):
|
| 132 |
+
super().__init__()
|
| 133 |
+
self.modality = modality
|
| 134 |
+
self.window_size = window_size
|
| 135 |
+
self.tscale_type = tscale_type
|
| 136 |
+
|
| 137 |
+
def forward(self, x):
|
| 138 |
+
raise NotImplementedError
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class TextSequencer(Sequencer):
|
| 142 |
+
def __init__(self, tscale_type=TScaleType.T32):
|
| 143 |
+
super().__init__(modality='text', window_size=3, tscale_type=tscale_type)
|
| 144 |
+
self.projection = TernaryScaleTensor(EMBEDDING_DIM * self.window_size, HIDDEN_DIM, tscale_type=tscale_type)
|
| 145 |
+
self.norm = TernaryRMSNorm(HIDDEN_DIM, tscale_type=tscale_type)
|
| 146 |
+
|
| 147 |
+
def forward(self, x):
|
| 148 |
+
trigrams = x.unfold(dimension=1, size=self.window_size, step=1)
|
| 149 |
+
trigrams = rearrange(trigrams, 'b t d w -> b t (d w)')
|
| 150 |
+
relational = self.projection(trigrams)
|
| 151 |
+
return self.norm(relational)
|
| 152 |
+
class VAE2DSequencer(Sequencer):
|
| 153 |
+
def __init__(self, tscale_type=TScaleType.T32, quantize=None, device="cpu"):
|
| 154 |
+
super().__init__(modality='image', window_size=1, tscale_type=tscale_type)
|
| 155 |
+
from .encoders.vae2d import load_vae2d as _load_vae2d
|
| 156 |
+
self.vae = _load_vae2d(device=device, quantize=quantize)
|
| 157 |
+
self.vae_device = torch.device(device)
|
| 158 |
+
self.project = TernaryScaleTensor(4, HIDDEN_DIM, tscale_type=tscale_type)
|
| 159 |
+
self.norm = TernaryRMSNorm(HIDDEN_DIM, tscale_type=tscale_type)
|
| 160 |
+
|
| 161 |
+
def forward(self, x):
|
| 162 |
+
if x.device != self.vae_device:
|
| 163 |
+
x = x.to(self.vae_device)
|
| 164 |
+
latent = self.vae(x)
|
| 165 |
+
tokens = rearrange(latent, 'b c h w -> b (h w) c')
|
| 166 |
+
out = self.project(tokens)
|
| 167 |
+
return self.norm(out)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class VAEAudioSequencer(Sequencer):
|
| 171 |
+
def __init__(self, tscale_type=TScaleType.T32, quantize=None, device="cpu"):
|
| 172 |
+
super().__init__(modality='audio', window_size=1, tscale_type=tscale_type)
|
| 173 |
+
from .encoders.vae2d import load_vae2d as _load_vae2d
|
| 174 |
+
from .encoders.mel_frontend import MelSpectrogram3Band as _Mel3Band
|
| 175 |
+
self.vae = _load_vae2d(device=device, quantize=quantize)
|
| 176 |
+
self.vae_device = torch.device(device)
|
| 177 |
+
self.mel = _Mel3Band(sample_rate=AUDIO_SR)
|
| 178 |
+
self.project = TernaryScaleTensor(4, HIDDEN_DIM, tscale_type=tscale_type)
|
| 179 |
+
self.norm = TernaryRMSNorm(HIDDEN_DIM, tscale_type=tscale_type)
|
| 180 |
+
|
| 181 |
+
def forward(self, waveform):
|
| 182 |
+
if waveform.dim() == 1:
|
| 183 |
+
waveform = waveform.unsqueeze(0)
|
| 184 |
+
elif waveform.dim() == 3:
|
| 185 |
+
if waveform.shape[1] == 1:
|
| 186 |
+
waveform = waveform.squeeze(1)
|
| 187 |
+
else:
|
| 188 |
+
waveform = waveform.mean(dim=1)
|
| 189 |
+
spec = self.mel(waveform)
|
| 190 |
+
if spec.device != self.vae_device:
|
| 191 |
+
spec = spec.to(self.vae_device)
|
| 192 |
+
latent = self.vae(spec)
|
| 193 |
+
tokens = rearrange(latent, 'b c h w -> b (h w) c')
|
| 194 |
+
out = self.project(tokens)
|
| 195 |
+
return self.norm(out)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
class MultimodalSequencer(nn.Module):
|
| 199 |
+
def __init__(self, tscale_type=TScaleType.T32, enable_text=True, enable_image=True, enable_audio=True):
|
| 200 |
+
super().__init__()
|
| 201 |
+
self.text = TextSequencer(tscale_type=tscale_type) if enable_text else None
|
| 202 |
+
self.image = VAE2DSequencer(tscale_type=tscale_type) if enable_image else None
|
| 203 |
+
self.audio = VAEAudioSequencer(tscale_type=tscale_type) if enable_audio else None
|
| 204 |
+
self.enabled_modalities = []
|
| 205 |
+
if enable_text:
|
| 206 |
+
self.enabled_modalities.append('text')
|
| 207 |
+
if enable_image:
|
| 208 |
+
self.enabled_modalities.append('image')
|
| 209 |
+
if enable_audio:
|
| 210 |
+
self.enabled_modalities.append('audio')
|
| 211 |
+
|
| 212 |
+
def forward(self, modality_inputs):
|
| 213 |
+
outputs = {}
|
| 214 |
+
for mod in self.enabled_modalities:
|
| 215 |
+
seq = getattr(self, mod)
|
| 216 |
+
if mod in modality_inputs and modality_inputs[mod] is not None and seq is not None:
|
| 217 |
+
outputs[mod] = seq(modality_inputs[mod])
|
| 218 |
+
return outputs
|
arbitor/vq.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""VQ modules β vector quantization adapters."""
|
| 2 |
+
import math
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from .kernel.ternary_scale import TernaryScaleTensor, TScaleType, TernaryRMSNorm
|
| 7 |
+
from .components import TernaryVQCodebook
|
| 8 |
+
from .config import EMBEDDING_DIM, HIDDEN_DIM, CODEBOOK_DIM, SHARED_VQ_SIZE, TIMESTAMP_MAX_PERIOD
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class SharedVQ(nn.Module):
|
| 12 |
+
"""Single shared VQ codebook for all modalities (10M entries).
|
| 13 |
+
|
| 14 |
+
Each modality projects to the shared CODEBOOK_DIM=64 space, then
|
| 15 |
+
quantizes independently through the shared codebook. Text uses
|
| 16 |
+
CODEBOOK_DIM directly.
|
| 17 |
+
|
| 18 |
+
IDs are globally unique: all modalities share the same range [0, 10M).
|
| 19 |
+
"""
|
| 20 |
+
def __init__(self, codebook_size=SHARED_VQ_SIZE, codebook_dim=CODEBOOK_DIM,
|
| 21 |
+
tscale_type=TScaleType.T32, enable_image=True, enable_audio=True):
|
| 22 |
+
super().__init__()
|
| 23 |
+
codebook_size = SHARED_VQ_SIZE if codebook_size is None else codebook_size
|
| 24 |
+
self.codebook_size = codebook_size
|
| 25 |
+
self.codebook_dim = codebook_dim
|
| 26 |
+
|
| 27 |
+
# Per-modality input projections (their_dim β CODEBOOK_DIM)
|
| 28 |
+
self.text_proj = TernaryScaleTensor(HIDDEN_DIM, codebook_dim, tscale_type=tscale_type)
|
| 29 |
+
if enable_image:
|
| 30 |
+
self.image_proj = TernaryScaleTensor(HIDDEN_DIM, codebook_dim, tscale_type=tscale_type)
|
| 31 |
+
if enable_audio:
|
| 32 |
+
self.audio_proj = TernaryScaleTensor(HIDDEN_DIM, codebook_dim, tscale_type=tscale_type)
|
| 33 |
+
|
| 34 |
+
# Shared VQ codebook
|
| 35 |
+
self.vq = TernaryVQCodebook(
|
| 36 |
+
codebook_size=codebook_size,
|
| 37 |
+
codebook_dim=codebook_dim,
|
| 38 |
+
commitment_weight=1.0,
|
| 39 |
+
tscale_type=tscale_type,
|
| 40 |
+
)
|
| 41 |
+
self.modalities = ['text']
|
| 42 |
+
if enable_image:
|
| 43 |
+
self.modalities.append('image')
|
| 44 |
+
if enable_audio:
|
| 45 |
+
self.modalities.append('audio')
|
| 46 |
+
|
| 47 |
+
@staticmethod
|
| 48 |
+
def _sinusoidal_timestamp(seq_len, dim, max_period=TIMESTAMP_MAX_PERIOD, device=None):
|
| 49 |
+
freqs = torch.exp(-torch.arange(0, dim, 2, device=device).float() * (math.log(max_period) / dim))
|
| 50 |
+
t = torch.arange(seq_len, device=device).float().unsqueeze(1)
|
| 51 |
+
pe = torch.zeros(seq_len, dim, device=device)
|
| 52 |
+
pe[:, 0::2] = torch.sin(t * freqs)
|
| 53 |
+
pe[:, 1::2] = torch.cos(t * freqs)
|
| 54 |
+
return pe
|
| 55 |
+
|
| 56 |
+
def forward(self, modality_inputs, timestep=0):
|
| 57 |
+
outputs = []
|
| 58 |
+
vq_losses = {}
|
| 59 |
+
indices_dict = {}
|
| 60 |
+
for mod in self.modalities:
|
| 61 |
+
if mod not in modality_inputs or modality_inputs[mod] is None:
|
| 62 |
+
continue
|
| 63 |
+
x = modality_inputs[mod]
|
| 64 |
+
proj = getattr(self, f'{mod}_proj')
|
| 65 |
+
x_proj = proj(x)
|
| 66 |
+
quantized, idx, loss = self.vq(x_proj)
|
| 67 |
+
outputs.append(quantized)
|
| 68 |
+
vq_losses[f'{mod}_vq'] = loss
|
| 69 |
+
indices_dict[mod] = idx
|
| 70 |
+
|
| 71 |
+
combined = torch.cat(outputs, dim=1) if outputs else modality_inputs.get('text', None)
|
| 72 |
+
if combined is not None and timestep > 0:
|
| 73 |
+
ts_enc = self._sinusoidal_timestamp(combined.shape[1], combined.shape[2], device=combined.device)
|
| 74 |
+
combined = combined + ts_enc.unsqueeze(0)
|
| 75 |
+
return combined, vq_losses, indices_dict
|
| 76 |
+
|
| 77 |
+
@property
|
| 78 |
+
def total_codebook_size(self):
|
| 79 |
+
return self.codebook_size
|
| 80 |
+
|
| 81 |
+
@torch.no_grad()
|
| 82 |
+
def get_codebook_utilization(self):
|
| 83 |
+
cluster_size = self.vq.cluster_size
|
| 84 |
+
return (cluster_size > 0).float().mean().item()
|
| 85 |
+
|
| 86 |
+
@torch.no_grad()
|
| 87 |
+
def get_dead_code_count(self):
|
| 88 |
+
cluster_size = self.vq.cluster_size
|
| 89 |
+
return (cluster_size < self.vq.threshold_ema_dead_code).sum().item()
|
docs/ARB-RENAME-NOTE.md
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: ARB System Rename
|
| 3 |
+
date: 2026-05-18
|
| 4 |
+
context: System renamed from MORPH to ARB (Any Relational Bit)
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
# ARB System β Rename Note
|
| 8 |
+
|
| 9 |
+
## New Name
|
| 10 |
+
|
| 11 |
+
The system has been renamed from **MORPH** to **ARB** (Any Relational Bit).
|
| 12 |
+
|
| 13 |
+
- **ARB** = Any Relational Bit β the core ternary architecture
|
| 14 |
+
- **ARBS** = ARB System β the full software system
|
| 15 |
+
- **ARBitor** = The Python package name (`arbitor/`)
|
| 16 |
+
|
| 17 |
+
## Package Structure
|
| 18 |
+
|
| 19 |
+
All core system files now live under `arbitor/`:
|
| 20 |
+
|
| 21 |
+
```
|
| 22 |
+
models/Trigram/
|
| 23 |
+
βββ arbitor/ # Core ARB system package
|
| 24 |
+
β βββ __init__.py # Public API exports
|
| 25 |
+
β βββ trigram.py # Core model (ARBModel replaces MORPHTernaryModel)
|
| 26 |
+
β βββ tscale.py # Ternary scale tensors
|
| 27 |
+
β βββ convert_to_ternary.py # 5-trit packing
|
| 28 |
+
β βββ convert_to_ternary*.py # Legacy converters
|
| 29 |
+
β βββ flash_vq.py # FlashVQ codebook
|
| 30 |
+
β βββ ternary_audit.py # Model state auditor
|
| 31 |
+
β βββ profiling.py # Profiling utilities
|
| 32 |
+
β βββ train.py # Training pipeline
|
| 33 |
+
β βββ optim/
|
| 34 |
+
β β βββ sign_sgd.py # SignSGD optimizer
|
| 35 |
+
β βββ encoders/ # Float sidecar encoders
|
| 36 |
+
β βββ __init__.py
|
| 37 |
+
β βββ audio_codec.py
|
| 38 |
+
β βββ audio_vq_encoder.py
|
| 39 |
+
β βββ video_vae.py
|
| 40 |
+
βββ testing/ # Tests (import from arbitor)
|
| 41 |
+
βββ .planning/ # Planning docs (P0-P10 complete)
|
| 42 |
+
βββ TRUE-TERNARY-REFACTOR*.md # Architecture refactor notes
|
| 43 |
+
βββ BENCHMARK.md # Benchmark docs
|
| 44 |
+
βββ benchmark_true_ternary.py # Benchmark scripts
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
## Import Changes
|
| 48 |
+
|
| 49 |
+
| Before | After |
|
| 50 |
+
|--------|-------|
|
| 51 |
+
| `from trigram import ARBModel` | `from arbitor.trigram import ARBModel` |
|
| 52 |
+
| `from tscale import TernaryScaleTensor` | `from arbitor.tscale import TernaryScaleTensor` |
|
| 53 |
+
| `from optim.sign_sgd import SignSGD` | `from arbitor.optim.sign_sgd import SignSGD` |
|
| 54 |
+
| `from encoders.video_vae import load_vae` | `from arbitor.encoders.video_vae import load_vae` |
|
| 55 |
+
| `from arbitor import ARBModel` | Shorthand via `arbitor/__init__.py` |
|
| 56 |
+
| `import trigram` | `from arbitor import trigram` |
|
| 57 |
+
|
| 58 |
+
## Class Rename
|
| 59 |
+
|
| 60 |
+
| Before | After |
|
| 61 |
+
|--------|-------|
|
| 62 |
+
| `MORPHTernaryModel` | `ARBModel` |
|
docs/arbs-tts/README.md
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ARBS Ternary Training System (TTS)
|
| 2 |
+
|
| 3 |
+
## E1TM Format β Exponent-1 Ternary Mantissa
|
| 4 |
+
|
| 5 |
+
E1TM encodes each weight group as **one int8 exponent shared across N ternary mantissas**.
|
| 6 |
+
|
| 7 |
+
```
|
| 8 |
+
W_eff[i] = S Γ T[i] where T[i] β {-1, 0, +1}, S = 2^{E + Ξ}
|
| 9 |
+
|
| 10 |
+
E = int8 logβ scale (persistent, per group)
|
| 11 |
+
Ξ = 4 Γ corr_accum / (step Γ gs) (from BigInt accumulator)
|
| 12 |
+
S = 2^{E+Ξ} (float32, ephemeral β created per forward, discarded)
|
| 13 |
+
```
|
| 14 |
+
|
| 15 |
+
### Format variants
|
| 16 |
+
|
| 17 |
+
| Name | TScaleType | T per E | gs | E bpw | T bpw | Total bpw (inf) | Precision |
|
| 18 |
+
|---|---|---|---|---|---|---|---|
|
| 19 |
+
| E1TM4 | T4 | 4 | 4 | 2.000 | 1.58 | 3.58 | Highest |
|
| 20 |
+
| E1TM6 | T6 | 6 | 6 | 1.333 | 1.58 | 2.91 | |
|
| 21 |
+
| E1TM8 | T8 | 8 | 8 | 1.000 | 1.58 | 2.58 | |
|
| 22 |
+
| E1TM16 | T16 | 16 | 16 | 0.500 | 1.58 | 2.08 | |
|
| 23 |
+
| **E1TM32** | **T32** | **32** | **32** | **0.250** | **1.58** | **1.85** | **Default** |
|
| 24 |
+
| E1TM64 | T64 | 64 | 64 | 0.125 | 1.58 | 1.71 | |
|
| 25 |
+
| E1TM96 | T96 | 96 | 96 | 0.083 | 1.58 | 1.67 | Most packed |
|
| 26 |
+
|
| 27 |
+
Higher T number = more T per E = less storage = coarser per-weight magnitude.
|
| 28 |
+
|
| 29 |
+
### Group sizes
|
| 30 |
+
|
| 31 |
+
The TScaleType name is the group size:
|
| 32 |
+
|
| 33 |
+
```python
|
| 34 |
+
TScaleType.T4 β gs = 4 β E shared across 4 ternary mantissas
|
| 35 |
+
TScaleType.T32 β gs = 32 β E shared across 32 ternary mantissas
|
| 36 |
+
TScaleType.T96 β gs = 96 β E shared across 96 ternary mantissas
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
### Persistent training state (all integer)
|
| 40 |
+
|
| 41 |
+
| Buffer | Type | Size/weight | Role |
|
| 42 |
+
|---|---|---|---|
|
| 43 |
+
| T_packed | uint8 | 1.58 bpw | Base-3 packed ternary {-1,0,+1}, 5 trits/byte |
|
| 44 |
+
| E | int8 | 8/N bpw | Logβ scale, one per N-weight group |
|
| 45 |
+
| corr_accum | int64 | 64/N bpw | BigInt accumulator for gradient sign votes |
|
| 46 |
+
| step_counter | int64 | 0 bpw | Total steps processed |
|
| 47 |
+
|
| 48 |
+
**No float32/16 anywhere in persistent state.** Float32 ephemeral `W_eff` is created per-forward and discarded after backward.
|
| 49 |
+
|
| 50 |
+
### Why ternary over binary or int4
|
| 51 |
+
|
| 52 |
+
| Format | Values/weight | Packing efficiency | Null state |
|
| 53 |
+
|---|---|---|---|
|
| 54 |
+
| Binary | 2 | 1 bit/bw (100%) | No |
|
| 55 |
+
| Ternary | 3 | 1.58 bpw (logβ3 β 95%) | **Yes** (T=0 = null) |
|
| 56 |
+
| Int4 | 16 | 4 bpw (100%) | No |
|
| 57 |
+
|
| 58 |
+
Ternary's null state (T=0) provides structural sparsity β β38% of weights are zero, skipping matmul tiles. No other low-bit format has this property at equivalent bpw.
|
| 59 |
+
|
| 60 |
+
### The BigInt difference
|
| 61 |
+
|
| 62 |
+
Unlike conventional quantization where E is static after conversion, ARBS TTS trains **through** E via a BigInt correlation accumulator:
|
| 63 |
+
|
| 64 |
+
```
|
| 65 |
+
corr_accum[g] -= Ξ£ (grad_sign Γ T) # int64, never clips or resets
|
| 66 |
+
Ξ = 4 Γ corr_accum / (step Γ gs) # continuous adjustment from integer division
|
| 67 |
+
S = 2^{E + Ξ} # effective scale (ephemeral float32)
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
The division `corr_accum / (step Γ gs)` is the **Big Number Calculator** operation β it converts the accumulated integer evidence into a continuous ratio with arbitrary precision. No threshold flips, no discrete steps, no information loss.
|
| 71 |
+
|
| 72 |
+
### Training vs inference
|
| 73 |
+
|
| 74 |
+
| Phase | T_packed | E | corr_accum | step | S |
|
| 75 |
+
|---|---|---|---|---|---|
|
| 76 |
+
| Training | Read-only | Read-only | **Accumulates** | **Increments** | Computed from corr/step |
|
| 77 |
+
| Inference (Option A) | Frozen | Frozen | Frozen | Frozen | Burned into checkpoint |
|
| 78 |
+
| Inference (Option B) | Frozen | **Fused** | Discarded | Discarded | Static 2^{E_fused} |
|
| 79 |
+
|
| 80 |
+
**Option A** (export): keep corr_accum + step for continuous S.
|
| 81 |
+
**Option B** (fuse): `E_fused = round(E + 4 Γ corr_accum / (step Γ gs))` β discards corr_accum, drops to 2.6 bpw.
|
| 82 |
+
|
| 83 |
+
### Relationship to IEEE float
|
| 84 |
+
|
| 85 |
+
```
|
| 86 |
+
IEEE FP32: 1 sign + 8 exponent + 23 mantissa β per value
|
| 87 |
+
E1TM32: 1 exponent (int8) + 32 ternary signs β per group of 32
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
In IEEE, the exponent and mantissa belong to the same value. In E1TM, the exponent is **shared** β the mantissa is split into N independent ternary signs. The corr_accum provides sub-exponent precision beyond the int8 E, making the effective scale continuous rather than constrained to the 256 discrete `2^E` values.
|
docs/benchmarks/BENCHMARK.md
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TrueTernary Benchmark
|
| 2 |
+
|
| 3 |
+
Results from `benchmark_true_ternary.py` β comparing pure ternary training against standard methods on MORPHTernaryModel.
|
| 4 |
+
|
| 5 |
+
## Quick Start
|
| 6 |
+
|
| 7 |
+
```bash
|
| 8 |
+
cd models/Trigram
|
| 9 |
+
|
| 10 |
+
# TrueTernary (strict, 0 float params, 14M ternary weights)
|
| 11 |
+
python benchmark_true_ternary.py --configs TrueTernary --steps 200 --batch 4 --ctx 33
|
| 12 |
+
|
| 13 |
+
# Adam baseline (full model, 102M float params)
|
| 14 |
+
python benchmark_true_ternary.py --configs Adam_FP32 --steps 200 --batch 4 --ctx 33
|
| 15 |
+
|
| 16 |
+
# Compare both
|
| 17 |
+
python benchmark_true_ternary.py --configs Adam_FP32,TrueTernary --steps 200 --batch 4 --ctx 33 --reuse-base
|
| 18 |
+
|
| 19 |
+
# Training script (strict ternary)
|
| 20 |
+
python train.py --max_steps 1000 --batch_size 4 --ctx 33 --strict_ternary
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
## Head-to-Head: TrueTernary vs Adam (200 steps, B=4, C=33)
|
| 24 |
+
|
| 25 |
+
| Metric | Adam_FP32 | TrueTernary |
|
| 26 |
+
|--------|-----------|-------------|
|
| 27 |
+
| **Trainable params** | 102,629,376 (float32) | **0** (pure ternary) |
|
| 28 |
+
| **Model weights** | 473.6 MB | **0.0 MB** |
|
| 29 |
+
| **Optimizer state** | 391.5 MB | **0.0 MB** |
|
| 30 |
+
| **Training state** | 473.6 + 391.5 = **865 MB** | **18.3 MB** (buffers only) |
|
| 31 |
+
| **Peak VRAM** | ~2,548 MB | **~232 MB** (includes CUDA context) |
|
| 32 |
+
| **Step time** | ~200 ms | **~131 ms** |
|
| 33 |
+
| **Final loss** | ~12.3 | **5.75** β |
|
| 34 |
+
| **Min loss** | β | **4.49** |
|
| 35 |
+
| **Converges?** | Yes (to high loss) | **Yes (near optimal: ln(288)β5.66)** |
|
| 36 |
+
|
| 37 |
+
### Key Takeaways
|
| 38 |
+
|
| 39 |
+
- **VRAM**: TrueTernary uses **~40Γ less** persistent state (18 MB vs 865 MB)
|
| 40 |
+
- **Speed**: 1.5Γ faster per step (131 ms vs 200 ms) β pure add/sub/skip, no float GEMM
|
| 41 |
+
- **Convergence**: TrueTernary reaches **5.75** (near theoretical minimum ln(288) β 5.66) β Adam stalls at **12.3**
|
| 42 |
+
- **No float params**: TrueTernary has 0 trainable float params, 0 float buffers
|
| 43 |
+
|
| 44 |
+
## TrueTernary Training Dynamics (200 steps)
|
| 45 |
+
|
| 46 |
+
The loss curve follows a characteristic 3-phase pattern:
|
| 47 |
+
|
| 48 |
+
```
|
| 49 |
+
Phase 1 (steps 0-15): Mass T flips from random init, loss spikes to ~90
|
| 50 |
+
Phase 2 (steps 15-80): Recovery and convergence, loss drops from ~15 to ~6
|
| 51 |
+
Phase 3 (steps 80-200): Stable convergence, loss hovers at 5.7-6.0
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
**Convergence evidence:**
|
| 55 |
+
|
| 56 |
+
| Segment | Mean Loss | Min Loss | Trend |
|
| 57 |
+
|---------|-----------|----------|-------|
|
| 58 |
+
| Steps 0-50 | 13.4 | 4.49 | High variance (T flips) |
|
| 59 |
+
| Steps 50-100 | 8.7 | 6.03 | Monotonic decline |
|
| 60 |
+
| Steps 100-150 | 6.4 | 5.69 | Approaching optimum |
|
| 61 |
+
| Steps 150-200 | **5.82** | **5.64** | **Converged** |
|
| 62 |
+
|
| 63 |
+
The minimum loss of **4.49** is well below the uniform-distribution baseline (ln(288) β 5.66), indicating the model captures meaningful byte-level patterns.
|
| 64 |
+
|
| 65 |
+
## Training State Breakdown (14M ternary weights)
|
| 66 |
+
|
| 67 |
+
| Component | Storage | Size | Role |
|
| 68 |
+
|-----------|---------|------|------|
|
| 69 |
+
| T_packed | 5-trit/byte uint8 | 2.67 MB | Packed {-1, 0, +1} weights |
|
| 70 |
+
| E | int8 per group | 1.12 MB | Logβ scale exponent |
|
| 71 |
+
| E_accum | int8 per group | 1.12 MB | Residual E accumulator |
|
| 72 |
+
| T_accum | int8 per weight | 13.36 MB | Gradient sign accumulator |
|
| 73 |
+
| **Total** | | **18.27 MB** | |
|
| 74 |
+
|
| 75 |
+
All int8 or packed ternary β no IEEE float anywhere in weight state.
|
| 76 |
+
|
| 77 |
+
## Scale Projection to 3B Parameters
|
| 78 |
+
|
| 79 |
+
| Component | 14M | 3B (projected) |
|
| 80 |
+
|-----------|-----|----------------|
|
| 81 |
+
| T_packed | 2.67 MB | **~572 MB** |
|
| 82 |
+
| E | 1.12 MB | **~240 MB** |
|
| 83 |
+
| E_accum | 1.12 MB | **~240 MB** |
|
| 84 |
+
| T_accum | 13.36 MB | **~2.86 GB** |
|
| 85 |
+
| **Total training** | **18.27 MB** | **~3.9 GB** |
|
| 86 |
+
| Inference (T+E only) | ~3.8 MB | **~812 MB** |
|
| 87 |
+
|
| 88 |
+
At 3B: **~3.9 GB** training VRAM fits on a single RTX 4060 (8 GB). Compare to BF16 Adam: **~18 GB** (requires server GPU).
|
| 89 |
+
|
| 90 |
+
## Architecture Components
|
| 91 |
+
|
| 92 |
+
All internal trainable components are now ternary or integer buffers (REFACTOR6+):
|
| 93 |
+
|
| 94 |
+
- `TernaryScaleTensor` β packed ternary linear layers
|
| 95 |
+
- `TernaryEmbeddingTable` β packed ternary embedding lookup
|
| 96 |
+
- `TernaryLSTMCell` β LSTM with ternary projections
|
| 97 |
+
- `TernaryVQCodebook` β VQ with ternary embedding table
|
| 98 |
+
- Graph: Triton-backed edge aggregation + gather-add kernels (REFACTOR8)
|
| 99 |
+
- MoE: Triton-backed dense combine kernel (REFACTOR8)
|
| 100 |
+
|
| 101 |
+
The only remaining float parameters are imported frozen encoders (ViT, Whisper).
|
| 102 |
+
|
| 103 |
+
## Running the Benchmark
|
| 104 |
+
|
| 105 |
+
```bash
|
| 106 |
+
# Default: 200 steps, batch=4, ctx=33
|
| 107 |
+
python benchmark_true_ternary.py
|
| 108 |
+
|
| 109 |
+
# Custom config
|
| 110 |
+
python benchmark_true_ternary.py \
|
| 111 |
+
--configs TrueTernary \
|
| 112 |
+
--steps 500 \
|
| 113 |
+
--batch 8 \
|
| 114 |
+
--ctx 66 \
|
| 115 |
+
--update-backend gpu \
|
| 116 |
+
--scale-update-interval 1
|
| 117 |
+
|
| 118 |
+
# Compare with Adam
|
| 119 |
+
python benchmark_true_ternary.py \
|
| 120 |
+
--configs Adam_FP32,TrueTernary \
|
| 121 |
+
--steps 200 \
|
| 122 |
+
--batch 4 \
|
| 123 |
+
--ctx 33 \
|
| 124 |
+
--reuse-base
|
| 125 |
+
|
| 126 |
+
# Change T_accum threshold (higher = less frequent flips)
|
| 127 |
+
python benchmark_true_ternary.py \
|
| 128 |
+
--accum-threshold 5
|
| 129 |
+
|
| 130 |
+
# Full training pipeline
|
| 131 |
+
python train.py \
|
| 132 |
+
--max_steps 5000 \
|
| 133 |
+
--batch_size 8 \
|
| 134 |
+
--ctx 66 \
|
| 135 |
+
--strict_ternary \
|
| 136 |
+
--scale_update_interval 1 \
|
| 137 |
+
--run_name my_ternary_run
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
## Benchmark CLI Arguments
|
| 141 |
+
|
| 142 |
+
| Argument | Default | Description |
|
| 143 |
+
|----------|---------|-------------|
|
| 144 |
+
| `--configs` | `TrueTernary` | Comma-separated: `Adam_FP32`, `SignSGD_Old`, `TrueTernary` |
|
| 145 |
+
| `--steps` | 200 | Training steps |
|
| 146 |
+
| `--batch` | 4 | Batch size |
|
| 147 |
+
| `--ctx` | 33 | Context length |
|
| 148 |
+
| `--update-backend` | `gpu` | `gpu`, `gpu-signcache`, `dense-fallback`, `none` |
|
| 149 |
+
| `--scale-update-interval` | 1 | E update frequency (0 = disable) |
|
| 150 |
+
| `--accum-threshold` | 3 | T_accum flip threshold |
|
| 151 |
+
| `--print-every` | 50 | Logging frequency |
|