OpenOneRec commited on
Commit
4f781c9
Β·
verified Β·
1 Parent(s): b91d707

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +286 -19
README.md CHANGED
@@ -1,28 +1,295 @@
1
- # HF Template
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- Populate this folder on the training machine with a working HF model snapshot
4
- (Qwen3 + Summary Attention variant) **before** running
5
- `examples/pretrain/convert/convert_muse_to_hf.sh`.
6
 
7
- ## Expected contents
8
 
9
- | File | Purpose |
10
- |---|---|
11
- | `config.json` | HF config with `summary_*` fields matching your trained model |
12
- | `generation_config.json` | Default generation settings |
13
- | `tokenizer.json` / `tokenizer_config.json` / `special_tokens_map.json` | Tokenizer |
14
- | `vocab.json` / `merges.txt` | Tokenizer vocab (if applicable) |
15
- | `modeling_qwen3*.py` | HF-compatible modeling code with SA support |
16
- | `summary_context.py` | Helper module imported by the modeling code |
17
 
18
- Only the **weights** come from the Muse DCP β€” everything else above is copied
19
- verbatim into `<OUTPUT_DIR>/<STEP>/hf/` by the convert script.
 
 
 
20
 
21
- ## Usage
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  ```bash
24
  bash examples/pretrain/convert/convert_muse_to_hf.sh \
25
- /path/to/muse_outputs/1b6_sa_hybrid_8k \
26
- global_step5000 \
27
- examples/pretrain/hf_template
28
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ - zh
6
+ base_model:
7
+ - Qwen/Qwen3-4B-Base
8
+ pipeline_tag: text-generation
9
+ ---
10
+ <div align="center">
11
+ <h1>Kwai Summary Attention (KSA)</h1>
12
+ <p align="center">
13
+ <strong>Efficient long-context modeling via learnable summary tokens</strong>
14
+ </p>
15
+ <p align="center">
16
+ <a href="https://arxiv.org/abs/2604.24432">
17
+ <img alt="Paper" src="https://img.shields.io/badge/Paper-arXiv%3A2604.24432-b31b1b?logo=arxiv" />
18
+ </a>
19
+ <a href="https://github.com/Kuaishou-OneRec/KSA">
20
+ <img alt="GitHub" src="https://img.shields.io/badge/GitHub-Kuaishou--OneRec-black?logo=github" />
21
+ </a>
22
+ <a href="#-license">
23
+ <img alt="License" src="https://img.shields.io/badge/License-Apache%202.0-green" />
24
+ </a>
25
+ </p>
26
+ <p align="center">
27
+ <a href="README.md">English</a> | <a href="README_zh.md">δΈ­ζ–‡</a>
28
+ </p>
29
+ </div>
30
+ <br>
31
 
32
+ ## πŸ“– Introduction
 
 
33
 
34
+ **Kwai Summary Attention (KSA)** is an efficient attention mechanism that compresses historical context into a small set of *learnable summary tokens* inserted at regular chunk boundaries. Unlike GQA/MLA, which keep one cache entry per token, and unlike sliding-window or linear attention, which discard or lossily compress distant history, KSA takes an **intermediate path**: the KV cache scales as **O(N/R)** with a semantic-level compression ratio R, trading a small amount of memory for *complete, referential, and interpretable* retention of long-range dependencies.
35
 
36
+ This repository contains:
 
 
 
 
 
 
 
37
 
38
+ - **Muse** training framework with the Qwen3 + Summary Attention model.
39
+ - A block-sparse training / prefill **kernel** for Summary Attention.
40
+ - A ring-buffer **KV cache** implementation for decoding, packaged as a HuggingFace `trust_remote_code` template.
41
+ - A full end-to-end **pretraining recipe** that progressively extends a Qwen3-1.9B base from 8k to 128k context.
42
+ - Weight conversion utilities (DCP β†’ HuggingFace safetensors) and an inference sanity-check script.
43
 
44
+ <p align="center"><img src="./assets/figures/mainmodel.png" width="80%" alt="KSA hybrid architecture: summary tokens interleaved with text tokens, with Summary Attention layers and Full Attention layers in a 3:1 hybrid ratio." /></p>
45
+ <p align="center"><em>Figure: KSA hybrid architecture. Summary tokens interleave with text tokens; Summary Attention and Full Attention layers are stacked in a 3:1 ratio.</em></p>
46
+
47
+ ## πŸ”₯ News
48
+
49
+ - **2026-04-28** β€” KSA technical report is released on arXiv: [arXiv:2604.24432](https://arxiv.org/abs/2604.24432).
50
+ - **2026-04-28** β€” Code, training recipes, block-sparse kernel, and HuggingFace `trust_remote_code` template are open-sourced under this repository.
51
+
52
+ ## ✨ Highlights
53
+
54
+ - **Sequence-level KV compression.** Summary tokens partition the sequence into chunks of size $N$; the summary of each chunk acts as a compressed prior of distant history. KV cache grows as $O(N/R)$ instead of $O(N)$, and is **orthogonal** to GQA / MLA β€” the compression ratios multiply.
55
+ - **Sliding *chunk* attention, not sliding *window*.** Window boundaries are aligned with chunk boundaries so every past chunk is either fully visible (text) or summarized (summary token), never partially both. This avoids the information gap that naive SWA introduces at window edges.
56
+ - **Hybrid by default.** The released recipe uses a `3:1` *Summary : Full* layer interleaving. A small dose of full attention serves as a cross-chunk integrator and stabilizes long-context retrieval.
57
+ - **Summary KV cache for decoding.** KV states are laid out as a single contiguous buffer `[scratch | current chunk | sliding chunks (ring) | summary buffer]`. Every decode step reads one contiguous slice β€” no `cat`, no `gather`, no dense mask materialization. See [`examples/pretrain/hf_template/modeling_qwen3sa.py`](examples/pretrain/hf_template/modeling_qwen3sa.py).
58
+ - **Block-sparse training / prefill kernel.** Only non-empty block pairs are loaded from HBM to SRAM, avoiding the $O(L^2)$ mask materialization that would otherwise be infeasible at 128k. Distributed as a prebuilt wheel under [`summary_attention_kernel/`](summary_attention_kernel/).
59
+ - **Three-stage training recipe.** Attention distillation β†’ parameter annealing β†’ sequence-length extension, all reproducible via the `run_pretrain_{8,32,64,128}k.sh` launchers.
60
+
61
+ ## πŸ€– Model Zoo
62
+
63
+ *Coming soon.* Pretrained checkpoints will be published on Hugging Face once the technical report is released.
64
+
65
+ | Model | Backbone | Parameters | Context | Training | Link |
66
+ | :------------ | :---------- | :--------- | :------ | :-------------------- | :---- |
67
+ | KSA-4B (CPT) | Qwen3-4B | 4B | 128k | Continual pretraining | *TBD* |
68
+
69
+ The 1.9B *from-scratch* configuration is provided as a reproducible recipe only; no 1.9B weights will be released.
70
+
71
+ ## πŸ—οΈ Method & Architecture
72
+
73
+ KSA compresses long context at the *semantic* level by inserting a small number of **learnable summary tokens** at fixed chunk boundaries, then treating the past as a sequence of chunks β€” each exposed either as full text or as its summary state.
74
+
75
+ ### 1. Sliding Chunk Attention
76
+
77
+ <p align="center"><img src="./assets/figures/sca_vs_swa.png" width="75%" alt="Sliding-window attention may cut through a chunk and lose boundary information; sliding-chunk attention aligns with chunk boundaries and guarantees clean information routing." /></p>
78
+ <p align="center"><em>Figure: Sliding-chunk attention aligns windows to chunk boundaries. Naive sliding windows cut through chunks and drop boundary information.</em></p>
79
+
80
+ If the window boundary cuts through a chunk, that chunk is neither fully covered by text tokens nor wholly summarized β€” its information falls through the cracks. KSA aligns windows to chunks so every past chunk is *exclusively* accessed either as full text (inside the window) or via its summary token (outside), with no double-counting and no gaps.
81
+
82
+ ### 2. Ring-buffer KV Cache
83
+
84
+ <p align="center"><img src="./assets/figures/buffer_layout.png" width="82%" alt="Contiguous KV cache layout for KSA decoding: scratch slot, current chunk, sliding-chunk ring, and summary token buffer all share a single physical tensor." /></p>
85
+ <p align="center"><em>Figure: Decoding KV cache layout. Every logical region is a contiguous slice of a single physical tensor.</em></p>
86
+
87
+ Every logical region β€” scratch, current chunk, sliding ring, summary buffer β€” is a contiguous slice of a single tensor. Text attention and summary attention each read one span. RoPE is applied *before* caching, so physical position in the ring is independent of logical position. Chunk eviction is an in-place copy into the oldest ring slot; no reallocation, no concatenation, no dense mask.
88
+
89
+ ### 3. Sub-linear KV Scaling
90
+
91
+ <p align="center"><img src="./assets/figures/kv_cache_comparison.png" width="65%" alt="KV cache growth vs. sequence length: Full attention grows linearly, SWA is flat but loses distant history, KSA grows sub-linearly while preserving a compressed trace of all history." /></p>
92
+ <p align="center"><em>Figure: KV cache growth vs. sequence length.</em></p>
93
+
94
+ ### 4. Training Recipe
95
+
96
+ Three stages, repeated at each target sequence length (8k β†’ 32k β†’ 64k β†’ 128k):
97
+
98
+ 1. **Attention distillation** β€” warm up the summary-attention parameters against a Full-Attention teacher.
99
+ 2. **Parameter annealing** β€” unfreeze the full model and jointly optimize.
100
+ 3. **Sequence-length extension** β€” scale `max_position_embeddings` and resume with adjusted RoPE base.
101
+
102
+ See [`examples/pretrain/README.md`](examples/pretrain/README.md) for per-stage hyperparameters.
103
+
104
+ ### Released model configuration
105
+
106
+ The release ships two recipes: a 1.9B hybrid model trained from scratch (recipe only β€” no weights released) and a 4B continual-pretraining variant.
107
+
108
+ | Configuration | From Scratch (1.9B) | Continual Pretraining (4B) |
109
+ | :---------------------------- | :------------------ | :------------------------- |
110
+ | Number of layers | 24 | 36 |
111
+ | Hidden size | 2048 | 2560 |
112
+ | Intermediate size | 6144 | 9728 |
113
+ | Attention heads (Q / KV) | 16 / 16 | 32 / 8 |
114
+ | Head dimension | 128 | 128 |
115
+ | Hybrid ratio (Summary : Full) | 3 : 1 | 3 : 1 |
116
+ | Summary chunk size | 8 | 8 |
117
+ | Sliding chunk number | 128 | 128 |
118
+ | Tied embeddings | False | True |
119
+
120
+ The config lives at [`examples/pretrain/model_config/model_config_1b9_hybrid.json`](examples/pretrain/model_config/model_config_1b9_hybrid.json) and is loaded via the `Qwen3SummaryAttentionConfig` / `Qwen3SummaryModel` registered in `muse/models/`.
121
+
122
+ ## πŸ“ˆ Performance
123
+
124
+ We evaluate KSA under two settings β€” **Continual Pretraining (CPT)** from a Qwen3-4B-base checkpoint (85B tokens), and **Train-from-Scratch** at 1.9B (400B tokens). Full results are in the [technical report](https://arxiv.org/abs/2604.24432); the highlights below are taken directly from its tables.
125
+
126
+ ### Long-context retrieval β€” RULER (CPT, 4B)
127
+
128
+ | Benchmark | Full | Hybrid-SWA | Hybrid-SCA | Hybrid-Linear | KSA | **Hybrid-KSA** |
129
+ | :---------- | :-------- | :--------- | :--------- | :------------ | :---- | :------------- |
130
+ | RULER-4K | 92.88 | 91.30 | 86.02 | 86.39 | 91.55 | **92.97** |
131
+ | RULER-8K | **91.38** | 88.03 | 84.28 | 83.86 | 86.78 | 90.53 |
132
+ | RULER-16K | **89.12** | 82.87 | 80.67 | 78.06 | 84.78 | 88.86 |
133
+ | RULER-32K | 84.74 | 78.94 | 76.89 | 76.48 | 80.30 | **86.65** |
134
+ | RULER-64K | **78.16** | 73.88 | 68.88 | 73.50 | 76.09 | 76.04 |
135
+ | RULER-128K | 65.86 | 66.27 | 60.94 | 67.98 | 66.81 | **71.67** |
136
+
137
+ Hybrid-KSA leads at 4K, 32K, and 128K, and at **128K it surpasses Full attention by +5.81 points** while operating with a substantially smaller KV cache. Across all RULER lengths it is the strongest sub-quadratic alternative to Full attention.
138
+
139
+ ### General benchmarks (CPT, 4B)
140
+
141
+ | Benchmark | Full | Hybrid-SWA | Hybrid-SCA | Hybrid-Linear | KSA | **Hybrid-KSA** |
142
+ | :-------- | :-------- | :--------- | :--------- | :------------ | :---- | :------------- |
143
+ | MMLU | **71.83** | 70.57 | 69.83 | 64.33 | 70.73 | 70.50 |
144
+ | CMMLU | **75.00** | 73.69 | 72.59 | 68.41 | 73.29 | 72.63 |
145
+ | C-Eval | **73.66** | 72.36 | 71.66 | 67.42 | 72.14 | 72.66 |
146
+ | MMLU-Pro | **46.36** | 45.23 | 45.11 | 38.83 | 45.70 | 45.39 |
147
+ | CMath | 83.41 | **84.84** | 83.16 | 79.09 | 84.58 | 84.25 |
148
+ | GSM8K | **82.75** | 81.92 | 80.10 | 72.44 | 81.09 | 79.50 |
149
+ | MATH | 47.48 | **48.24** | 47.45 | 42.57 | 48.15 | 47.56 |
150
+ | MBPP | 61.30 | 61.70 | 59.60 | 55.30 | 61.50 | **62.20** |
151
+ | HumanEval | 58.54 | 61.89 | 61.89 | 54.58 | 60.97 | **62.50** |
152
+ | **Avg.** | 73.50 | 72.12 | 69.94 | 67.28 | 72.30 | **73.59** |
153
+
154
+ KSA preserves full general capability under CPT β€” Hybrid-KSA's average **(73.59) edges out Full attention (73.50)**, with the smallest gap-to-Full of any sub-quadratic alternative.
155
+
156
+ ### Train-from-scratch headlines (1.9B, 400B tokens)
157
+
158
+ - **RULER-128K**: Hybrid-KSA **65.35** vs. Full attention **48.75** ( **+16.60** ). Hybrid-KSA stays robust as length grows (80.65 β†’ 65.35 from 4K to 128K), while Full attention collapses (76.08 β†’ 48.75).
159
+ - **GSM8K**: Hybrid-KSA **59.14** vs. Full **48.29** ( **+10.85** ). **MATH**: **36.92** vs. **23.38** ( **+13.54** ).
160
+ - **MBPP / HumanEval**: best of all configurations at **36.40 / 31.71**.
161
+ - **Training loss**: Hybrid-KSA reaches the lowest final loss (**1.524**), below Hybrid-GDN (1.534), Hybrid-SWA (1.550), and Full (1.572).
162
+
163
+ ### Needle-in-a-Haystack & RULER-128K subtasks (CPT)
164
+
165
+ Hybrid-KSA achieves **near-perfect single-needle retrieval across 4K–128K** at all needle depths, with only a minor dip at 128K. On RULER-128K subtasks it leads on **NIAH-Multivalue (98.75, +10.63 over Full)**, **VT (90.50, +30.0 over Full)**, **FWE (65.84)**, and **SQuAD (42.50)**.
166
+
167
+ ### Inference efficiency (4B, 128K context)
168
+
169
+ - **KV cache**: 7.5 GB vs. 18.6 GB for Full attention β€” a **2.5Γ— reduction**.
170
+ - **Decode throughput** at 16K prefill: **1.06Γ— of Full attention**, vs. 0.73Γ— for Hybrid-SWA and 0.81Γ— for Hybrid-Ring-Linear.
171
+
172
+ ## πŸš€ Quick Start
173
+
174
+ ### 1. Build the reference image
175
+
176
+ Ubuntu 24.04 + CUDA 12.6 + Python 3.12 + PyTorch 2.6.0 + FlashAttention 2.7.4.post1, with the block-sparse kernel preinstalled:
177
+
178
+ ```bash
179
+ docker build -t ksa-train -f dockerfile/Dockerfile .
180
+ ```
181
+
182
+ Versions are pinned from an actual training-host snapshot; see [`dockerfile/requirements.txt`](dockerfile/requirements.txt) for the full list. If you prefer bare-metal installation, mirror the same pins.
183
+
184
+ ### 2. Configure environment variables
185
+
186
+ ```bash
187
+ cp .env.example .env # then edit paths
188
+ bash set_env.sh
189
+ ```
190
+
191
+ The run scripts auto-export `PYTHONPATH=$PWD:$PYTHONPATH`, so keeping the repo root on `PYTHONPATH` is sufficient.
192
+
193
+ ### 3. Pretrain (progressive length extension)
194
+
195
+ Four stages, each resuming weights from the previous:
196
+
197
+ ```bash
198
+ bash examples/pretrain/run_pretrain_8k.sh # 1. from scratch at 8k
199
+ bash examples/pretrain/run_pretrain_32k.sh # 2. extend to 32k
200
+ bash examples/pretrain/run_pretrain_64k.sh # 3. extend to 64k
201
+ bash examples/pretrain/run_pretrain_128k.sh # 4. extend to 128k
202
+ ```
203
+
204
+ Edit `CHECKPOINT_DIR` / `OUTPUT_DIR` at the top of each script to match your storage layout. Each stage launches via `mpirun` and writes DCP checkpoints + dataloader state to `$OUTPUT_DIR/global_stepN/`. See [`examples/pretrain/README.md`](examples/pretrain/README.md) for mid-run resume, chunked-CE toggles, and per-stage hyperparameters.
205
+
206
+ ### 4. Convert a trained checkpoint to HuggingFace
207
 
208
  ```bash
209
  bash examples/pretrain/convert/convert_muse_to_hf.sh \
210
+ /path/to/muse_outputs/1b9_sa_hybrid_128k \
211
+ global_step5000 \
212
+ examples/pretrain/hf_template
213
  ```
214
+
215
+ The converted HF directory lands at `<OUTPUT_DIR>/<STEP>/hf/` and contains the remapped safetensors plus the `modeling_qwen3sa.py` / `summary_context.py` / tokenizer files from `hf_template/`. See [`examples/pretrain/hf_template/README.md`](examples/pretrain/hf_template/README.md) for the expected template contents.
216
+
217
+ ### 5. Inference β€” sanity-check a converted model
218
+
219
+ ```bash
220
+ python examples/inference/inference.py \
221
+ --model_path /path/to/global_step5000/hf \
222
+ --prompt "介绍一下你θ‡ͺε·±" \
223
+ --device cuda:0
224
+ ```
225
+
226
+ The inference path uses HuggingFace's `AutoModelForCausalLM` with `trust_remote_code=True` and goes through the ring-buffer KV cache defined in `hf_template/modeling_qwen3sa.py` β€” no framework-specific glue required.
227
+
228
+ ## πŸ“ Repository Layout
229
+
230
+ ```
231
+ .
232
+ β”œβ”€β”€ muse/ # Training framework (models, layers, training loop)
233
+ β”‚ β”œβ”€β”€ models/qwen3_sa/ # Qwen3 + Summary Attention model
234
+ β”‚ β”œβ”€β”€ layers/summary_context.py # SummaryBatchContext + mask helpers
235
+ β”‚ └── ...
236
+ β”œβ”€β”€ recipes/
237
+ β”‚ └── pretrain_kai_summary_unified.py # Main pretrain entry
238
+ β”œβ”€β”€ summary_attention_kernel/
239
+ β”‚ β”œβ”€β”€ summary_attn-*.whl # Block-sparse SA kernel (training + prefill)
240
+ β”‚ └── flash_attn_cute-*.whl # CuTe-based FlashAttention build used by the kernel
241
+ β”œβ”€β”€ examples/
242
+ β”‚ β”œβ”€β”€ pretrain/ # Progressive 8kβ†’128k recipe
243
+ β”‚ β”‚ β”œβ”€β”€ model_config/ # model_config_1b9_hybrid.json
244
+ β”‚ β”‚ β”œβ”€β”€ dataset_config/ # per-seq-length mmap dataset specs
245
+ β”‚ β”‚ β”œβ”€β”€ run_pretrain_{8,32,64,128}k.sh
246
+ β”‚ β”‚ β”œβ”€β”€ convert/ # DCP β†’ HF safetensors
247
+ β”‚ β”‚ └── hf_template/ # HF-compatible modeling + config template
248
+ β”‚ └── inference/
249
+ β”‚ └── inference.py # Quick chat-style sanity check
250
+ β”œβ”€β”€ data/ # (User-populated) mmap corpora
251
+ β”œβ”€β”€ dockerfile/ # Reference Dockerfile + requirements.txt
252
+ └── README.md / README_zh.md
253
+ ```
254
+
255
+ ## πŸ›£οΈ Roadmap
256
+
257
+ We are actively working on:
258
+
259
+ - [x] Technical report on arXiv ([arXiv:2604.24432](https://arxiv.org/abs/2604.24432)).
260
+ - [ ] Publish pretrained 1.9B checkpoints on Hugging Face.
261
+ - [ ] Release the 4B continual-pretraining recipe and checkpoint.
262
+ - [ ] Expanded evaluation scripts for RULER / NIAH / LongBench v2 reproduction.
263
+ - [ ] A reference serving stack with the ring-buffer KV cache.
264
+ - [ ] Additional ablations and tutorials.
265
+
266
+ Contributions are welcome β€” feel free to open an issue or PR.
267
+
268
+ ## πŸ“œ Citation
269
+
270
+ If you find KSA useful, please cite our technical report:
271
+
272
+ ```bibtex
273
+ @techreport{kwai2026ksa,
274
+ title = {Kwai Summary Attention Technical Report},
275
+ author = {OneRec Team},
276
+ year = {2026},
277
+ institution = {Kuaishou Technology},
278
+ url = {https://arxiv.org/abs/2604.24432}
279
+ }
280
+ ```
281
+
282
+ ## πŸ›‘οΈ License
283
+
284
+ The code in this repository is licensed under the **Apache 2.0 License** (see [`LICENSE`](LICENSE)). Model weights, when released, will be subject to their own license agreements.
285
+
286
+ ## πŸ™ Acknowledgements
287
+
288
+ KSA is built upon and inspired by the open-source ecosystem. We would like to thank:
289
+
290
+ - **Qwen3** β€” for the base architecture and tokenizer that KSA extends.
291
+ - **FlashAttention** β€” for the dense-attention primitives our block-sparse kernel composes with.
292
+ - **HuggingFace Transformers** β€” for the model / tokenizer / generation abstractions that make `trust_remote_code` deployment painless.
293
+ - **PyTorch distributed training** β€” for FSDP, DCP, and the communication primitives that make large-scale pretraining tractable.
294
+
295
+ We sincerely thank these projects for their outstanding work.