Qwen3-Coder-Next-RCO-pruned / run_vllm_nonuniform.py
helcig's picture
Add top-level README and run_vllm_nonuniform.py
4a5537b verified
#!/usr/bin/env python3
"""Run a nonuniform-pruned Qwen3-Coder-Next variant with vLLM.
Nonuniform variants have different expert counts per layer. Stock vLLM
assumes a single `config.num_experts` for all layers. This script applies
a monkey-patch before loading the model so each layer gets its own count
from `config.per_layer_num_experts`.
Usage:
# Interactive generation
python run_vllm_nonuniform.py --model ./coding-25-nonuniform --tp 4
# With a custom prompt
python run_vllm_nonuniform.py --model ./coding-50-nonuniform --tp 2 \
--prompt "Write a Python function to merge two sorted lists."
# As a library (import the patch before importing vLLM)
import run_vllm_nonuniform # applies patch on import
from vllm import LLM, SamplingParams
llm = LLM(model="./coding-25-nonuniform", ...)
How it works:
vLLM spawns worker subprocesses that re-import all modules. A plain
monkey-patch in the parent process would be lost. This script places
the variant folder on PYTHONPATH so that `sitecustomize.py` (bundled
in each nonuniform variant) auto-applies the patch in every Python
process, including workers.
"""
import argparse
import json
import os
import sys
def setup(model_path: str):
"""Apply the nonuniform-expert patch for vLLM worker processes."""
model_path = os.path.abspath(os.path.expanduser(model_path))
# Verify this is a nonuniform variant
meta_path = os.path.join(model_path, "pruned_metadata.json")
if os.path.exists(meta_path):
meta = json.load(open(meta_path))
if meta.get("allocation") != "nonuniform":
print(f"Note: {model_path} is a uniform variant; patch not needed.")
return
per = meta["per_layer_num_experts"]
print(f"Nonuniform variant: {meta['total_pruned']} experts pruned "
f"({meta['sparsity']:.0%}), kept range [{min(per)}, {max(per)}]")
# Add model dir to PYTHONPATH so workers pick up sitecustomize.py
cur = os.environ.get("PYTHONPATH", "")
if model_path not in cur:
os.environ["PYTHONPATH"] = model_path + (os.pathsep + cur if cur else "")
# Also add to sys.path for the current process
if model_path not in sys.path:
sys.path.insert(0, model_path)
# Apply patch in this process
import vllm_pruned_patch
vllm_pruned_patch.apply()
# Auto-apply on import if model path is detectable from argv
if __name__ != "__main__":
# When imported as a library, caller should call setup() explicitly.
pass
def main():
parser = argparse.ArgumentParser(
description="Run a nonuniform-pruned model with vLLM")
parser.add_argument("--model", required=True,
help="Path to a nonuniform variant folder")
parser.add_argument("--tp", type=int, default=4,
help="Tensor parallel size")
parser.add_argument("--prompt", type=str,
default="def fibonacci(n):\n"
' """Return the n-th Fibonacci number."""\n',
help="Prompt for generation")
parser.add_argument("--max-tokens", type=int, default=512)
parser.add_argument("--temperature", type=float, default=0.0)
parser.add_argument("--gpu-mem-util", type=float, default=0.85)
parser.add_argument("--max-model-len", type=int, default=4096)
args = parser.parse_args()
# Apply patch BEFORE importing vLLM
setup(args.model)
from vllm import LLM, SamplingParams
llm = LLM(
model=args.model,
tensor_parallel_size=args.tp,
dtype="bfloat16",
gpu_memory_utilization=args.gpu_mem_util,
max_model_len=args.max_model_len,
trust_remote_code=True,
enforce_eager=True,
)
sp = SamplingParams(
temperature=args.temperature,
max_tokens=args.max_tokens,
)
outputs = llm.generate([args.prompt], sp)
text = outputs[0].outputs[0].text
n_tok = len(outputs[0].outputs[0].token_ids)
print("=" * 60)
print("PROMPT:")
print(args.prompt)
print("=" * 60)
print("COMPLETION:")
print(text)
print("=" * 60)
print(f"({n_tok} tokens)")
if __name__ == "__main__":
main()