| |
| """ |
| GLM-4.6 AWQ Quantization Script |
| |
| Quantizes GLM-4.6 (357B MoE) to 4-bit AWQ for efficient inference with vLLM. |
| |
| Requirements: |
| - 1× GPU with 48GB+ VRAM (single GPU is optimal) |
| - 768GB+ system RAM (DDR4/DDR5) |
| - 300GB+ swap space (will be actively used) |
| - PyTorch with CUDA support |
| - llm-compressor |
| - transformers |
| - datasets |
| |
| Hardware Notes: |
| - Multi-GPU provides NO quantization speedup (process is RAM-bound, not GPU-bound) |
| - The full BF16 model (~714GB) will be offloaded to system RAM/swap |
| - Quantized using: 1× RTX PRO 6000 Blackwell Max-Q (96GB) + 768GB RAM |
| - Quantization time: ~5 hours (includes calibration, smoothing, compression, and saving) |
| |
| Usage: |
| python quantize_glm46_awq.py --model zai-org/GLM-4.6 --output ./GLM-4.6-AWQ |
| |
| Advanced options: |
| python quantize_glm46_awq.py \ |
| --model zai-org/GLM-4.6 \ |
| --output ./GLM-4.6-AWQ \ |
| --device-map sequential \ |
| --max-cpu-memory 750GiB \ |
| --cal-samples 512 |
| """ |
|
|
| import os |
| import argparse |
| import json |
| import shutil |
| import pathlib |
| from typing import List |
|
|
| import torch |
| from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM |
| from datasets import load_dataset |
| from llmcompressor import oneshot |
| from llmcompressor.modifiers.awq import AWQModifier |
|
|
|
|
| def add_no_split(cfg: AutoConfig, classes: List[str]) -> AutoConfig: |
| """Prevent splitting specific module classes across devices.""" |
| ns = set(getattr(cfg, "no_split_module_classes", []) or []) |
| ns.update(classes) |
| cfg.no_split_module_classes = list(ns) |
| return cfg |
|
|
|
|
| def compute_batch_size(seq_len: int, target_tokens: int) -> int: |
| """Calculate batch size to achieve target tokens per calibration step.""" |
| return max(1, target_tokens // seq_len) |
|
|
|
|
| def clone_and_fix_index(src_dir: str) -> str: |
| """ |
| Clone model directory and fix empty-string key in weight_map if present. |
| This prevents device_map='auto' errors with some sharded checkpoints. |
| """ |
| src = pathlib.Path(src_dir) |
| dst = src.parent / (src.name + "_fixed_index") |
| if dst.exists(): |
| shutil.rmtree(dst) |
| shutil.copytree(src, dst) |
|
|
| candidates = ["model.safetensors.index.json", "pytorch_model.bin.index.json"] |
| found = None |
| for c in candidates: |
| p = dst / c |
| if p.exists(): |
| found = p |
| break |
| if not found: |
| return str(dst) |
|
|
| with open(found, "r") as f: |
| idx = json.load(f) |
| wm = idx.get("weight_map", {}) |
| if "" in wm: |
| del wm[""] |
| idx["weight_map"] = wm |
| with open(found, "w") as f: |
| json.dump(idx, f) |
| return str(dst) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Quantize GLM-4.6 to 4-bit AWQ") |
| parser.add_argument("--model", required=True, help="Path or HF ID of GLM-4.6 model (e.g., zai-org/GLM-4.6)") |
| parser.add_argument("--output", required=True, help="Output directory for quantized model") |
| parser.add_argument("--cal-samples", type=int, default=512, help="Number of calibration samples (default: 512)") |
| parser.add_argument("--cal-seq-len", type=int, default=2048, help="Calibration sequence length (default: 2048)") |
| parser.add_argument("--batch-tokens", type=int, default=131072, help="Tokens per calibration step (default: 131072)") |
| parser.add_argument("--dataset", default="neuralmagic/LLM_compression_calibration", help="Calibration dataset") |
| parser.add_argument("--dataset-split", default="train", help="Dataset split to use") |
| parser.add_argument("--device-map", choices=["auto", "sequential"], default="auto", |
| help="Device placement strategy: 'auto' (recommended) or 'sequential' (robust)") |
| parser.add_argument("--max-memory-per-gpu", type=str, default="92GiB", |
| help="Max memory per GPU (default: 92GiB for 96GB GPUs)") |
| parser.add_argument("--max-cpu-memory", type=str, default="500GiB", |
| help="Max CPU memory for offloading (default: 500GiB)") |
| args = parser.parse_args() |
|
|
| |
| os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") |
| os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True,max_split_size_mb:512") |
|
|
| |
| os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0") |
|
|
| |
| try: |
| torch.backends.cuda.matmul.fp32_precision = "tf32" |
| torch.backends.cudnn.conv.fp32_precision = "tf32" |
| except Exception: |
| pass |
|
|
| torch.set_num_threads(8) |
|
|
| |
| if not torch.cuda.is_available(): |
| raise RuntimeError("CUDA is not available. This script requires GPU(s).") |
|
|
| num_gpus = torch.cuda.device_count() |
| print(f"✓ Found {num_gpus} CUDA device(s)") |
| print(f"✓ Using GPU 0 for quantization (CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES', 'all')})") |
| print(f"\nNote: Multi-GPU provides NO speedup for quantization - the process is RAM-bound.") |
| print(f" The full BF16 model (~714GB) will be offloaded to system RAM/swap.") |
|
|
| |
| print(f"Loading config from: {args.model}") |
| cfg = AutoConfig.from_pretrained(args.model, trust_remote_code=True) |
|
|
| |
| cfg = add_no_split(cfg, ["MergedColumnParallelLinear"]) |
|
|
| |
| print("Loading tokenizer...") |
| tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True, use_fast=True) |
|
|
| |
| print(f"Loading model weights from: {args.model}") |
| load_dir = args.model |
|
|
| if args.device_map == "auto": |
| try: |
| load_dir = clone_and_fix_index(args.model) |
| except Exception as e: |
| print(f"Index sanitization skipped: {e}") |
|
|
| |
| max_mem = {i: args.max_memory_per_gpu for i in range(num_gpus)} |
| max_mem["cpu"] = args.max_cpu_memory |
|
|
| try: |
| model = AutoModelForCausalLM.from_pretrained( |
| load_dir, |
| torch_dtype=torch.bfloat16, |
| low_cpu_mem_usage=True, |
| trust_remote_code=True, |
| device_map=args.device_map, |
| config=cfg, |
| max_memory=max_mem, |
| offload_folder=None, |
| offload_state_dict=False, |
| ) |
| except KeyError as e: |
| if args.device_map == "auto": |
| print(f"Auto device_map failed with {e}; falling back to sequential...") |
| model = AutoModelForCausalLM.from_pretrained( |
| load_dir, |
| torch_dtype=torch.bfloat16, |
| low_cpu_mem_usage=True, |
| trust_remote_code=True, |
| device_map="sequential", |
| config=cfg, |
| max_memory=max_mem, |
| ) |
| else: |
| raise |
|
|
| print("✓ Model loaded successfully") |
|
|
| |
| print("\nGPU Memory Usage:") |
| for i in range(num_gpus): |
| allocated = torch.cuda.memory_allocated(i) / 1e9 |
| peak = torch.cuda.max_memory_allocated(i) / 1e9 |
| print(f" GPU {i}: {allocated:.2f} GB allocated / {peak:.2f} GB peak") |
|
|
| |
| print(f"\nLoading calibration dataset: {args.dataset}") |
| ds = load_dataset(args.dataset, split=args.dataset_split) |
| ds = ds.shuffle(seed=42).select(range(args.cal_samples)) |
| print(f"✓ Selected {len(ds)} calibration samples") |
|
|
| seq_len = args.cal_seq_len |
| batch_size = compute_batch_size(seq_len, args.batch_tokens) |
| print(f"Calibration config: seq_len={seq_len}, batch_size={batch_size}") |
|
|
| |
| |
| ignore_patterns = [ |
| "lm_head", |
| "model.embed_tokens", |
| "re:.*input_layernorm$", |
| "re:.*post_attention_layernorm$", |
| "model.norm", |
| "re:.*q_norm$", |
| "re:.*k_norm$", |
| "re:.*shared_experts.*", |
| "re:.*mlp\\.gate\\.weight$", |
| "re:.*mlp\\.gate\\..*bias$", |
| "re:model.layers.[0-2]\\.", |
| ] |
|
|
| |
| targets = [ |
| "re:.*gate_proj.*", |
| "re:.*up_proj.*", |
| "re:.*down_proj.*", |
| "re:.*k_proj.*", |
| "re:.*q_proj.*", |
| "re:.*v_proj.*", |
| "re:.*o_proj.*", |
| ] |
|
|
| recipe = [ |
| AWQModifier( |
| ignore=ignore_patterns, |
| config_groups={ |
| "group_0": { |
| "targets": targets, |
| "weights": { |
| "num_bits": 4, |
| "type": "int", |
| "symmetric": True, |
| "group_size": 128, |
| "strategy": "group", |
| "dynamic": False, |
| }, |
| "input_activations": None, |
| "output_activations": None, |
| "format": None, |
| } |
| }, |
| ) |
| ] |
|
|
| |
| print("\n" + "="*80) |
| print("Starting AWQ quantization...") |
| print("="*80) |
|
|
| with torch.inference_mode(): |
| oneshot_args = { |
| "model": model, |
| "dataset": ds, |
| "recipe": recipe, |
| "max_seq_length": seq_len, |
| "num_calibration_samples": len(ds), |
| } |
|
|
| |
| try: |
| from inspect import signature |
| if "batch_size" in signature(oneshot).parameters: |
| oneshot_args["batch_size"] = batch_size |
| except Exception: |
| pass |
|
|
| oneshot(**oneshot_args) |
|
|
| print("\n✓ AWQ quantization completed successfully") |
|
|
| |
| print(f"\nSaving quantized model to: {args.output}") |
| os.makedirs(args.output, exist_ok=True) |
|
|
| model.save_pretrained(args.output, save_compressed=True) |
| tokenizer.save_pretrained(args.output) |
|
|
| print("\n" + "="*80) |
| print("QUANTIZATION COMPLETE") |
| print("="*80) |
| print(f"Quantized model saved to: {args.output}") |
| print(f"\nModel size on disk: ~176 GB (39 safetensors files)") |
| print(f"\nTo use with vLLM:") |
| print(f" vllm serve {args.output} \\") |
| print(f" --tensor-parallel-size 4 \\") |
| print(f" --enable-expert-parallel \\") |
| print(f" --trust-remote-code") |
| print("="*80) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|