helcig commited on
Commit
4a5537b
·
verified ·
1 Parent(s): d5386d2

Add top-level README and run_vllm_nonuniform.py

Browse files
Files changed (2) hide show
  1. README.md +204 -0
  2. run_vllm_nonuniform.py +123 -0
README.md ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ base_model: Qwen/Qwen3-Coder-Next
4
+ tags:
5
+ - moe
6
+ - expert-pruning
7
+ - rco
8
+ - qwen3-next
9
+ ---
10
+
11
+ # Qwen3-Coder-Next: RCO pruned variants
12
+
13
+ Expert-pruned checkpoints of [Qwen/Qwen3-Coder-Next](https://huggingface.co/Qwen/Qwen3-Coder-Next) produced by **Riemannian Constrained Optimization (RCO)**.
14
+
15
+ - Paper: [*Model Compression with Exact Budget Constraints via Riemannian Manifolds*](https://arxiv.org/abs/2605.00649) (Helcig & Alistarh, 2026)
16
+ - Code: [github.com/IST-DASLab/RCO](https://github.com/IST-DASLab/RCO)
17
+
18
+ Eight variants, one per (sparsity × calibration × allocation):
19
+
20
+ | Sparsity | Calibration | Allocation | Folder |
21
+ |---|---|---|---|
22
+ | 25% | coding | uniform | [`coding-25-uniform/`](./coding-25-uniform) |
23
+ | 25% | coding | nonuniform | [`coding-25-nonuniform/`](./coding-25-nonuniform) |
24
+ | 50% | coding | uniform | [`coding-50-uniform/`](./coding-50-uniform) |
25
+ | 50% | coding | nonuniform | [`coding-50-nonuniform/`](./coding-50-nonuniform) |
26
+ | 25% | general | uniform | [`general-25-uniform/`](./general-25-uniform) |
27
+ | 25% | general | nonuniform | [`general-25-nonuniform/`](./general-25-nonuniform) |
28
+ | 50% | general | uniform | [`general-50-uniform/`](./general-50-uniform) |
29
+ | 50% | general | nonuniform | [`general-50-nonuniform/`](./general-50-nonuniform) |
30
+
31
+
32
+ ## What is expert pruning?
33
+
34
+ Qwen3-Coder-Next is a Mixture-of-Experts (MoE) model with 512 routed experts per layer (top-10 active per token). Most experts are rarely used. Expert pruning permanently removes low-impact experts, shrinking the checkpoint and reducing memory at inference time.
35
+
36
+ ## Uniform vs nonuniform allocation
37
+
38
+ Each variant prunes a fixed fraction (25% or 50%) of the total experts across all 48 layers. The key design choice is **how to distribute that budget across layers**:
39
+
40
+ - **Uniform:** every layer keeps the same number of experts (e.g., 384 per layer at 25%, 256 at 50%). This is simple and compatible with stock inference frameworks: `config.num_experts` is a single integer, so vLLM, HuggingFace, and SGLang load the checkpoint without any code changes. However, forcing the same budget on every layer is suboptimal because some layers are more sensitive to pruning than others.
41
+
42
+ - **Nonuniform:** the optimizer distributes the pruning budget across layers based on calibration loss. Critical layers keep more experts; redundant layers are pruned more aggressively. At the same total sparsity, this recovers more of the base model's quality (e.g., 97% HumanEval recovery vs 55% at 50% sparsity). The trade-off: each layer has a different expert count, which stock frameworks don't support out of the box. Nonuniform variants include a bundled `vllm_pruned_patch.py` that monkey-patches vLLM to handle per-layer expert counts (setup in the section below, and a one-page reference in each variant's `LOAD_VLLM.md`).
43
+
44
+ The gap grows with sparsity. At 25%, uniform is about 8 points behind nonuniform on HumanEval. At 50%, the gap is 42 points (0.409 vs 0.720). For general benchmarks, uniform and nonuniform perform comparably at both sparsity levels (within 1 point on MC-8).
45
+
46
+ ## Coding vs general calibration
47
+
48
+ RCO optimizes which experts to prune by minimizing KL divergence on a calibration dataset. The choice of calibration data determines what the pruned model preserves:
49
+
50
+ - **Coding** (evol-codealpaca): preserves code generation ability (HumanEval, MBPP, LiveCodeBench) at the cost of general knowledge (MC-8).
51
+ - **General** (FineWeb-Edu): preserves general reasoning and knowledge benchmarks (ARC, HellaSwag, MMLU, etc.) but loses coding ability almost entirely.
52
+
53
+ This is not a limitation of the method; it reflects how specialized the base model's experts are. Pick the calibration that matches your deployment use case.
54
+
55
+ ## Which variant should I pick?
56
+
57
+ | Use case | Recommended variant |
58
+ |---|---|
59
+ | Coding, easy deployment | `coding-25-uniform` (92% HE, stock vLLM) |
60
+ | Coding, best quality | `coding-25-nonuniform` (100% HE, needs patch) |
61
+ | Coding, max compression | `coding-50-nonuniform` (97% HE, needs patch) |
62
+ | General, easy deployment | `general-25-uniform` (99% MC-8, stock vLLM) |
63
+ | General, best quality | `general-25-nonuniform` (100% MC-8, needs patch) |
64
+ | General, max compression | `general-50-uniform` (92% MC-8, stock vLLM) |
65
+
66
+ ## How to run
67
+
68
+ ### Uniform variants (stock vLLM / HuggingFace)
69
+
70
+ Uniform variants have a single `config.num_experts` value. They load with zero code changes:
71
+
72
+ ```python
73
+ from vllm import LLM, SamplingParams
74
+
75
+ llm = LLM(model="./coding-25-uniform", tensor_parallel_size=4,
76
+ dtype="bfloat16", trust_remote_code=True)
77
+ out = llm.generate(["def fib(n):"], SamplingParams(max_tokens=256))
78
+ print(out[0].outputs[0].text)
79
+ ```
80
+
81
+ Or with HuggingFace transformers:
82
+
83
+ ```python
84
+ from transformers import AutoModelForCausalLM, AutoTokenizer
85
+
86
+ model = AutoModelForCausalLM.from_pretrained(
87
+ "./coding-25-uniform", torch_dtype="bfloat16",
88
+ device_map="auto", trust_remote_code=True)
89
+ tokenizer = AutoTokenizer.from_pretrained("./coding-25-uniform")
90
+ ```
91
+
92
+ ### Nonuniform variants (needs monkey-patch)
93
+
94
+ Nonuniform variants have different expert counts per layer. Stock vLLM builds every layer with `config.num_experts` (a single integer), which causes a shape mismatch on load. The repo provides three files to handle this. Two live inside each nonuniform variant folder:
95
+
96
+ - `vllm_pruned_patch.py`: overrides `Qwen3NextSparseMoeBlock.__init__` to read per-layer counts from `config.per_layer_num_experts`, and `Qwen3NextForCausalLM.get_expert_mapping` to use the max kept count
97
+ - `sitecustomize.py`: auto-applies the patch in every Python process, including vLLM worker subprocesses spawned via `multiprocessing.spawn`
98
+
99
+ And one at the repo root:
100
+
101
+ - `run_vllm_nonuniform.py`: wrapper script that sets up PYTHONPATH and applies the patch
102
+
103
+ Note: `enforce_eager=True` is required when loading nonuniform variants. CUDA-graph capture currently does not support the heterogeneous expert layout.
104
+
105
+ **Option 1: Use the bundled script**
106
+
107
+ ```bash
108
+ # From the repo root
109
+ python run_vllm_nonuniform.py --model ./coding-25-nonuniform --tp 4
110
+
111
+ # Custom prompt
112
+ python run_vllm_nonuniform.py --model ./coding-50-nonuniform --tp 4 \
113
+ --prompt "Write a Python function to merge two sorted lists."
114
+ ```
115
+
116
+ **Option 2: Set PYTHONPATH manually**
117
+
118
+ The key requirement is that the variant folder (which contains `sitecustomize.py`) is on `PYTHONPATH` so that vLLM worker subprocesses pick up the patch automatically:
119
+
120
+ ```bash
121
+ export PYTHONPATH=/path/to/coding-25-nonuniform:${PYTHONPATH:-}
122
+
123
+ python -c "
124
+ from vllm import LLM, SamplingParams
125
+ llm = LLM(model='/path/to/coding-25-nonuniform', tensor_parallel_size=4,
126
+ dtype='bfloat16', trust_remote_code=True, enforce_eager=True)
127
+ out = llm.generate(['def fib(n):'], SamplingParams(max_tokens=256))
128
+ print(out[0].outputs[0].text)
129
+ "
130
+ ```
131
+
132
+ **Option 3: From Python (library use)**
133
+
134
+ ```python
135
+ import sys, os
136
+ # Add variant folder to path BEFORE importing vllm
137
+ sys.path.insert(0, "/path/to/coding-25-nonuniform")
138
+ os.environ["PYTHONPATH"] = "/path/to/coding-25-nonuniform:" + os.environ.get("PYTHONPATH", "")
139
+
140
+ import vllm_pruned_patch
141
+ vllm_pruned_patch.apply()
142
+
143
+ from vllm import LLM, SamplingParams
144
+ llm = LLM(model="/path/to/coding-25-nonuniform",
145
+ tensor_parallel_size=4, dtype="bfloat16",
146
+ trust_remote_code=True, enforce_eager=True)
147
+ ```
148
+
149
+ **Why PYTHONPATH?** vLLM uses `multiprocessing.spawn` for worker processes (required with CUDA). Spawned workers re-import all modules from scratch, so a monkey-patch applied only in the parent process is lost. Python's `sitecustomize.py` mechanism runs automatically in every interpreter that has the relevant directory on `sys.path`. Putting the variant folder on `PYTHONPATH` is the simplest way to ensure all workers get the patch.
150
+
151
+ **Note on tensor parallelism:** TP works fine with nonuniform variants (TP shards hidden dimensions inside each expert, not across experts). Expert parallelism (EP) does NOT work with heterogeneous counts; keep `--enable-eplb` off (the default).
152
+
153
+ ## Evaluation results
154
+
155
+ All evaluations run with vLLM (bf16, greedy decoding). Coding benchmarks: HumanEval (pass@1), MBPP (pass@1). General benchmarks: ARC-Challenge, ARC-Easy, BoolQ, HellaSwag, MMLU, OpenBookQA, RTE, WinoGrande (accuracy / acc_norm). MC-8 is the unweighted average of the eight general benchmarks. Recovery is relative to the full (unpruned) model.
156
+
157
+ ### Coding benchmarks
158
+
159
+ | Variant | Size | HumanEval | rec. | MBPP | rec. |
160
+ |---|---|---|---|---|---|
161
+ | **Full model** | **159 GB** | **0.744** | n/a | **0.764** | n/a |
162
+ | [coding-25-uniform](./coding-25-uniform) | 121 GB | 0.683 | 92% | **0.688** | **90%** |
163
+ | [coding-25-nonuniform](./coding-25-nonuniform) | 121 GB | **0.744** | **100%** | 0.678 | 89% |
164
+ | [coding-50-uniform](./coding-50-uniform) | 82 GB | 0.409 | 55% | 0.534 | 70% |
165
+ | [coding-50-nonuniform](./coding-50-nonuniform) | 82 GB | **0.720** | **97%** | **0.690** | **90%** |
166
+ | [general-25-uniform](./general-25-uniform) | 121 GB | 0.043 | 6% | 0.046 | 6% |
167
+ | [general-25-nonuniform](./general-25-nonuniform) | 121 GB | 0.061 | 8% | 0.058 | 8% |
168
+ | [general-50-uniform](./general-50-uniform) | 82 GB | 0.000 | 0% | 0.018 | 2% |
169
+ | [general-50-nonuniform](./general-50-nonuniform) | 82 GB | 0.012 | 2% | 0.010 | 1% |
170
+
171
+ ### General benchmarks (MC-8)
172
+
173
+ | Variant | MC-8 avg | rec. | ARC-C | ARC-E | BoolQ | HSwag | MMLU | OBQA | RTE | WinoG |
174
+ |---|---|---|---|---|---|---|---|---|---|---|
175
+ | **Full model** | **0.714** | n/a | 0.606 | 0.821 | 0.885 | 0.775 | 0.767 | 0.430 | 0.765 | 0.666 |
176
+ | [coding-25-uniform](./coding-25-uniform) | 0.656 | 92% | 0.501 | 0.722 | 0.864 | 0.690 | 0.710 | 0.380 | 0.729 | 0.655 |
177
+ | [coding-25-nonuniform](./coding-25-nonuniform) | 0.638 | 89% | 0.462 | 0.662 | 0.851 | 0.665 | 0.680 | 0.362 | **0.776** | 0.642 |
178
+ | [coding-50-uniform](./coding-50-uniform) | 0.577 | 81% | 0.403 | 0.641 | 0.789 | 0.578 | 0.564 | 0.350 | 0.671 | 0.616 |
179
+ | [coding-50-nonuniform](./coding-50-nonuniform) | 0.546 | 76% | 0.356 | 0.555 | 0.776 | 0.548 | 0.543 | 0.340 | 0.646 | 0.603 |
180
+ | [general-25-uniform](./general-25-uniform) | 0.707 | 99% | 0.600 | 0.807 | 0.876 | **0.785** | 0.704 | **0.452** | 0.751 | 0.677 |
181
+ | [general-25-nonuniform](./general-25-nonuniform) | **0.714** | **100%** | **0.618** | **0.822** | **0.882** | 0.776 | **0.712** | 0.442 | 0.762 | **0.699** |
182
+ | [general-50-uniform](./general-50-uniform) | **0.654** | **92%** | **0.541** | **0.771** | 0.839 | **0.709** | **0.610** | **0.428** | **0.675** | **0.658** |
183
+ | [general-50-nonuniform](./general-50-nonuniform) | 0.644 | 90% | 0.526 | 0.762 | **0.842** | 0.708 | 0.595 | 0.414 | **0.675** | 0.635 |
184
+
185
+ ### Key takeaways
186
+
187
+ - **Calibration domain determines the trade-off.** Coding-calibrated variants preserve code generation (up to 100% HumanEval recovery) but lose general knowledge. General-calibrated variants preserve MC-8 (up to 100% recovery) but lose coding ability entirely.
188
+ - **Nonuniform allocation matters most at high sparsity.** At 50% sparsity, nonuniform recovers 97% of HumanEval vs 55% for uniform, a 42-point gap. At 25%, the gap is smaller (100% vs 92%).
189
+ - **25% sparsity is nearly lossless for the target domain.** Both coding-25-nonuniform (100% HE) and general-25-nonuniform (100% MC-8) match the full model within noise.
190
+ - **Uniform variants load in stock vLLM/HF with no patches.** Nonuniform variants require the bundled `vllm_pruned_patch.py` (see "How to run" above).
191
+
192
+ ## Citation
193
+
194
+ ```bibtex
195
+ @misc{helcig2026rco,
196
+ title={Model Compression with Exact Budget Constraints via Riemannian Manifolds},
197
+ author={Helcig, Michael and Alistarh, Dan},
198
+ year={2026},
199
+ eprint={2605.00649},
200
+ archivePrefix={arXiv},
201
+ primaryClass={cs.LG},
202
+ url={https://arxiv.org/abs/2605.00649},
203
+ }
204
+ ```
run_vllm_nonuniform.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Run a nonuniform-pruned Qwen3-Coder-Next variant with vLLM.
3
+
4
+ Nonuniform variants have different expert counts per layer. Stock vLLM
5
+ assumes a single `config.num_experts` for all layers. This script applies
6
+ a monkey-patch before loading the model so each layer gets its own count
7
+ from `config.per_layer_num_experts`.
8
+
9
+ Usage:
10
+ # Interactive generation
11
+ python run_vllm_nonuniform.py --model ./coding-25-nonuniform --tp 4
12
+
13
+ # With a custom prompt
14
+ python run_vllm_nonuniform.py --model ./coding-50-nonuniform --tp 2 \
15
+ --prompt "Write a Python function to merge two sorted lists."
16
+
17
+ # As a library (import the patch before importing vLLM)
18
+ import run_vllm_nonuniform # applies patch on import
19
+ from vllm import LLM, SamplingParams
20
+ llm = LLM(model="./coding-25-nonuniform", ...)
21
+
22
+ How it works:
23
+ vLLM spawns worker subprocesses that re-import all modules. A plain
24
+ monkey-patch in the parent process would be lost. This script places
25
+ the variant folder on PYTHONPATH so that `sitecustomize.py` (bundled
26
+ in each nonuniform variant) auto-applies the patch in every Python
27
+ process, including workers.
28
+ """
29
+
30
+ import argparse
31
+ import json
32
+ import os
33
+ import sys
34
+
35
+
36
+ def setup(model_path: str):
37
+ """Apply the nonuniform-expert patch for vLLM worker processes."""
38
+ model_path = os.path.abspath(os.path.expanduser(model_path))
39
+
40
+ # Verify this is a nonuniform variant
41
+ meta_path = os.path.join(model_path, "pruned_metadata.json")
42
+ if os.path.exists(meta_path):
43
+ meta = json.load(open(meta_path))
44
+ if meta.get("allocation") != "nonuniform":
45
+ print(f"Note: {model_path} is a uniform variant; patch not needed.")
46
+ return
47
+ per = meta["per_layer_num_experts"]
48
+ print(f"Nonuniform variant: {meta['total_pruned']} experts pruned "
49
+ f"({meta['sparsity']:.0%}), kept range [{min(per)}, {max(per)}]")
50
+
51
+ # Add model dir to PYTHONPATH so workers pick up sitecustomize.py
52
+ cur = os.environ.get("PYTHONPATH", "")
53
+ if model_path not in cur:
54
+ os.environ["PYTHONPATH"] = model_path + (os.pathsep + cur if cur else "")
55
+
56
+ # Also add to sys.path for the current process
57
+ if model_path not in sys.path:
58
+ sys.path.insert(0, model_path)
59
+
60
+ # Apply patch in this process
61
+ import vllm_pruned_patch
62
+ vllm_pruned_patch.apply()
63
+
64
+
65
+ # Auto-apply on import if model path is detectable from argv
66
+ if __name__ != "__main__":
67
+ # When imported as a library, caller should call setup() explicitly.
68
+ pass
69
+
70
+
71
+ def main():
72
+ parser = argparse.ArgumentParser(
73
+ description="Run a nonuniform-pruned model with vLLM")
74
+ parser.add_argument("--model", required=True,
75
+ help="Path to a nonuniform variant folder")
76
+ parser.add_argument("--tp", type=int, default=4,
77
+ help="Tensor parallel size")
78
+ parser.add_argument("--prompt", type=str,
79
+ default="def fibonacci(n):\n"
80
+ ' """Return the n-th Fibonacci number."""\n',
81
+ help="Prompt for generation")
82
+ parser.add_argument("--max-tokens", type=int, default=512)
83
+ parser.add_argument("--temperature", type=float, default=0.0)
84
+ parser.add_argument("--gpu-mem-util", type=float, default=0.85)
85
+ parser.add_argument("--max-model-len", type=int, default=4096)
86
+ args = parser.parse_args()
87
+
88
+ # Apply patch BEFORE importing vLLM
89
+ setup(args.model)
90
+
91
+ from vllm import LLM, SamplingParams
92
+
93
+ llm = LLM(
94
+ model=args.model,
95
+ tensor_parallel_size=args.tp,
96
+ dtype="bfloat16",
97
+ gpu_memory_utilization=args.gpu_mem_util,
98
+ max_model_len=args.max_model_len,
99
+ trust_remote_code=True,
100
+ enforce_eager=True,
101
+ )
102
+
103
+ sp = SamplingParams(
104
+ temperature=args.temperature,
105
+ max_tokens=args.max_tokens,
106
+ )
107
+
108
+ outputs = llm.generate([args.prompt], sp)
109
+ text = outputs[0].outputs[0].text
110
+ n_tok = len(outputs[0].outputs[0].token_ids)
111
+
112
+ print("=" * 60)
113
+ print("PROMPT:")
114
+ print(args.prompt)
115
+ print("=" * 60)
116
+ print("COMPLETION:")
117
+ print(text)
118
+ print("=" * 60)
119
+ print(f"({n_tok} tokens)")
120
+
121
+
122
+ if __name__ == "__main__":
123
+ main()