File size: 4,238 Bytes
4a5537b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#!/usr/bin/env python3
"""Run a nonuniform-pruned Qwen3-Coder-Next variant with vLLM.

Nonuniform variants have different expert counts per layer. Stock vLLM
assumes a single `config.num_experts` for all layers. This script applies
a monkey-patch before loading the model so each layer gets its own count
from `config.per_layer_num_experts`.

Usage:
    # Interactive generation
    python run_vllm_nonuniform.py --model ./coding-25-nonuniform --tp 4

    # With a custom prompt
    python run_vllm_nonuniform.py --model ./coding-50-nonuniform --tp 2 \
        --prompt "Write a Python function to merge two sorted lists."

    # As a library (import the patch before importing vLLM)
    import run_vllm_nonuniform  # applies patch on import
    from vllm import LLM, SamplingParams
    llm = LLM(model="./coding-25-nonuniform", ...)

How it works:
    vLLM spawns worker subprocesses that re-import all modules. A plain
    monkey-patch in the parent process would be lost. This script places
    the variant folder on PYTHONPATH so that `sitecustomize.py` (bundled
    in each nonuniform variant) auto-applies the patch in every Python
    process, including workers.
"""

import argparse
import json
import os
import sys


def setup(model_path: str):
    """Apply the nonuniform-expert patch for vLLM worker processes."""
    model_path = os.path.abspath(os.path.expanduser(model_path))

    # Verify this is a nonuniform variant
    meta_path = os.path.join(model_path, "pruned_metadata.json")
    if os.path.exists(meta_path):
        meta = json.load(open(meta_path))
        if meta.get("allocation") != "nonuniform":
            print(f"Note: {model_path} is a uniform variant; patch not needed.")
            return
        per = meta["per_layer_num_experts"]
        print(f"Nonuniform variant: {meta['total_pruned']} experts pruned "
              f"({meta['sparsity']:.0%}), kept range [{min(per)}, {max(per)}]")

    # Add model dir to PYTHONPATH so workers pick up sitecustomize.py
    cur = os.environ.get("PYTHONPATH", "")
    if model_path not in cur:
        os.environ["PYTHONPATH"] = model_path + (os.pathsep + cur if cur else "")

    # Also add to sys.path for the current process
    if model_path not in sys.path:
        sys.path.insert(0, model_path)

    # Apply patch in this process
    import vllm_pruned_patch
    vllm_pruned_patch.apply()


# Auto-apply on import if model path is detectable from argv
if __name__ != "__main__":
    # When imported as a library, caller should call setup() explicitly.
    pass


def main():
    parser = argparse.ArgumentParser(
        description="Run a nonuniform-pruned model with vLLM")
    parser.add_argument("--model", required=True,
                        help="Path to a nonuniform variant folder")
    parser.add_argument("--tp", type=int, default=4,
                        help="Tensor parallel size")
    parser.add_argument("--prompt", type=str,
                        default="def fibonacci(n):\n"
                                '    """Return the n-th Fibonacci number."""\n',
                        help="Prompt for generation")
    parser.add_argument("--max-tokens", type=int, default=512)
    parser.add_argument("--temperature", type=float, default=0.0)
    parser.add_argument("--gpu-mem-util", type=float, default=0.85)
    parser.add_argument("--max-model-len", type=int, default=4096)
    args = parser.parse_args()

    # Apply patch BEFORE importing vLLM
    setup(args.model)

    from vllm import LLM, SamplingParams

    llm = LLM(
        model=args.model,
        tensor_parallel_size=args.tp,
        dtype="bfloat16",
        gpu_memory_utilization=args.gpu_mem_util,
        max_model_len=args.max_model_len,
        trust_remote_code=True,
        enforce_eager=True,
    )

    sp = SamplingParams(
        temperature=args.temperature,
        max_tokens=args.max_tokens,
    )

    outputs = llm.generate([args.prompt], sp)
    text = outputs[0].outputs[0].text
    n_tok = len(outputs[0].outputs[0].token_ids)

    print("=" * 60)
    print("PROMPT:")
    print(args.prompt)
    print("=" * 60)
    print("COMPLETION:")
    print(text)
    print("=" * 60)
    print(f"({n_tok} tokens)")


if __name__ == "__main__":
    main()