File size: 3,707 Bytes
37afba3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import re
from collections import Counter
from typing import Optional

import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from auto_round import AutoRound

MODEL_PATH = "."
OUTPUT_DIR = "./Qwen3.6-27B-mixed-autoround"
DATASET_NAME = "NeelNanda/pile-10k"

FP16_PATTERNS = ("lm_head", "linear_attn", "visual", "mtp", "embed_tokens")
PROTECT_FIRST = 3
PROTECT_LAST = 3

MAX_SAMPLES = 512
SEQ_LEN = 2048


def get_layer_idx(module_name: str) -> Optional[int]:
    match = re.search(r"model\.layers\.(\d+)\.", module_name)
    return int(match.group(1)) if match else None


def build_layer_config(model: torch.nn.Module) -> dict:
    indices = {get_layer_idx(name) for name, _ in model.named_modules()}
    indices.discard(None)

    num_layers = max(indices) + 1 if indices else 0
    print(f"language_model layer count: {num_layers}")

    boundary_layers = set(range(PROTECT_FIRST)) | set(range(num_layers - PROTECT_LAST, num_layers))
    print(f"Boundary layers (MLP -> INT8): {sorted(boundary_layers)}")

    layer_config = {}
    for name, module in model.named_modules():
        if not isinstance(module, torch.nn.Linear):
            continue

        if any(p in name for p in FP16_PATTERNS):
            layer_config[name] = {"bits": 16}

        elif "self_attn" in name:
            layer_config[name] = {"bits": 8, "group_size": 128, "sym": True}

        elif "mlp" in name:
            idx = get_layer_idx(name)
            bits = 8 if (idx is not None and idx in boundary_layers) else 4
            layer_config[name] = {"bits": bits, "group_size": 128, "sym": True}

        else:
            layer_config[name] = {"bits": 8, "group_size": 128, "sym": True}
            print(f"[fallback to int8] {name}")

    return layer_config


def collect_calibration_samples(tokenizer) -> list:
    dataset = load_dataset(DATASET_NAME, split="train")
    samples =[]

    for item in dataset:
        tokenized = tokenizer(
            item["text"],
            truncation=True,
            max_length=SEQ_LEN,
            return_tensors="pt",
        )

        if tokenized["input_ids"].shape[-1] >= SEQ_LEN:
            samples.append(tokenized.data)

        if len(samples) >= MAX_SAMPLES:
            break

    return samples


def main():
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_PATH,
        torch_dtype="auto",
        device_map="auto",
        trust_remote_code=True,
    )

    layer_config = build_layer_config(model)

    bits_counter = Counter(cfg["bits"] for cfg in layer_config.values())
    print(f"Layer count by bits: {dict(bits_counter)}")
    print(layer_config)

    del model
    torch.cuda.empty_cache()

    tokens_list = collect_calibration_samples(tokenizer)

    print(f"collected calibration samples: {len(tokens_list)}")
    assert len(tokens_list) == MAX_SAMPLES, "error"

    print(f"len(tokens_list) = {len(tokens_list)}")
    print(f"first input_ids shape  = {tokens_list[0]['input_ids'].shape}")
    print(f"last  input_ids shape  = {tokens_list[-1]['input_ids'].shape}")
    print(f"first dtype = {tokens_list[0]['input_ids'].dtype}")

    ar = AutoRound(
        model=MODEL_PATH,
        tokenizer=tokenizer,
        scheme="W4A16",
        enable_torch_compile=True,
        group_size=128,
        sym=True,
        layer_config=layer_config,
        dataset=tokens_list,
        device_map="0,1",
        batch_size=8,
        seqlen=SEQ_LEN,
        iters=1000,
        nsamples=MAX_SAMPLES,
    )
    ar.quantize_and_save(OUTPUT_DIR, format="auto_round")


if __name__ == "__main__":
    main()