#!/usr/bin/env python3 # # Copyright (C) 2023 - 2026 Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: MIT # # Quantization script for Step-3.5-Flash with MoE layer replacement from __future__ import annotations import argparse import os import re import shutil from pathlib import Path from types import MethodType from typing import Any import torch import torch.nn as nn from quark.torch import LLMTemplate, ModelQuantizer, export_safetensors from quark.torch.utils.llm import ( get_calib_dataloader, get_model, get_tokenizer, ) from quark.common.utils.log import ScreenLogger try: # Needed only when the model is loaded with accelerate offload (meta tensors). from accelerate.hooks import AlignDevicesHook, add_hook_to_module # type: ignore from accelerate.utils import PrefixedDataset # type: ignore _ACCELERATE_AVAILABLE = True except Exception: AlignDevicesHook = None # type: ignore[assignment] add_hook_to_module = None # type: ignore[assignment] PrefixedDataset = None # type: ignore[assignment] _ACCELERATE_AVAILABLE = False DEFAULT_INPUT_MODEL_PATH = "stepfun-ai/Step-3.5-Flash" DEFAULT_OUTPUT_MODEL_PATH = "quantized_models/Step-3.5-Flash-MXFP4" logger = ScreenLogger(__name__) def _step35_template_exclude_layers() -> list[str]: return [ # embeddings / lm head / norms "model.embed_tokens*", "*embed_tokens*", "*lm_head*", "*layernorm*", "*norm*", # Router gate "*moe.gate", "*moe.router_bias*", # The first three blocks use dense FFNs "model.layers.0.mlp.*", "model.layers.1.mlp.*", "model.layers.2.mlp.*", # Shared Experts "*share_expert*", "*self_attn*", ] PRESETS: dict[str, dict[str, Any]] = { "mxfp4_moe_only_no_kvcache": { "quant_scheme": "mxfp4", "exclude_layers": _step35_template_exclude_layers(), }, } def _copy_non_weight_files(src_dir: str, dst_dir: str) -> None: """ Copy non-weight files from an HF model directory (json/jinja/tokenizer, etc.), while skipping *.safetensors and model.safetensors.index.json. Note: `export_safetensors` exports the essential HF weights and config, but the original model directory may contain extra assets (e.g. chat_template.jinja). We do a conservative copy here so offline inference keeps those auxiliary files. """ src = Path(src_dir) dst = Path(dst_dir) dst.mkdir(parents=True, exist_ok=True) for p in src.iterdir(): if p.is_dir(): continue name = p.name if name.endswith(".safetensors"): continue if name == "model.safetensors.index.json": continue # Export will (re-)write config / generation_config; copying them here is harmless # (later writes will overwrite). shutil.copy2(p, dst / name) def _register_step35_flash_template() -> None: """ Register a Quark LLMTemplate for Step-3.5-Flash (config.model_type = step3p5). """ model_type = "step3p5" if model_type in LLMTemplate.list_available(): return step35_flash_template = LLMTemplate( model_type=model_type, kv_layers_name=["*k_proj", "*v_proj"], q_layer_name="*q_proj", exclude_layers_name=_step35_template_exclude_layers(), ) LLMTemplate.register_template(step35_flash_template) logger.info("Registered LLMTemplate: %s", model_type) @torch.no_grad() def replace_step35_moelinear_with_linear(moe_module: Any) -> None: """ Convert Step3p5MoEMLP's MoELinear modules into separate Linear layers per expert. """ if getattr(moe_module, "_step35_replaced", False): return logger.debug("Converting Step3p5MoEMLP experts to separate gate/up/down Linear layers...") # Get dimensions from the module num_experts: int = int(getattr(moe_module, "moe_num_experts", 288)) hidden_size: int = int(getattr(moe_module, "hidden_size", 4096)) moe_intermediate_size: int = int(getattr(moe_module, "moe_intermediate_size", 1280)) # Store original device and dtype from one of the MoELinear modules original_device = moe_module.gate_proj.weight.device original_dtype = moe_module.gate_proj.weight.dtype # [num_experts, in, out] # Expose common attribute names for the forward helper moe_module.hidden_size = hidden_size moe_module.expert_dim = moe_intermediate_size moe_module.num_experts = num_experts is_meta: bool = original_device == torch.device("meta") target_device_for_new = original_device if not is_meta else torch.device("meta") # Create individual expert modules, each containing gate_proj, up_proj, down_proj for expert_index in range(num_experts): expert_module = nn.Module() expert_module.gate_proj = nn.Linear( hidden_size, moe_intermediate_size, bias=False, device=target_device_for_new, dtype=original_dtype ) expert_module.up_proj = nn.Linear( hidden_size, moe_intermediate_size, bias=False, device=target_device_for_new, dtype=original_dtype ) expert_module.down_proj = nn.Linear( moe_intermediate_size, hidden_size, bias=False, device=target_device_for_new, dtype=original_dtype ) setattr(moe_module, str(expert_index), expert_module) # Sync weights from MoELinear to individual Linear modules weights_synced = _step35_sync_weights_to_linear(moe_module) # Replace forward method moe_module.forward = MethodType(_step35_moe_forward, moe_module) if weights_synced: _step35_cleanup_fused(moe_module) moe_module._step35_replaced = True @torch.no_grad() def _step35_sync_weights_to_linear(module: Any) -> bool: """ Split MoELinear weights and copy into per-expert Linear layers. Returns True if synced; returns False if fused weights are still on 'meta' (not materialized). MoELinear tensors in Step3p5MoEMLP are expected to be: - gate_proj.weight: [num_experts, moe_intermediate_size, hidden_size] - up_proj.weight: [num_experts, moe_intermediate_size, hidden_size] - down_proj.weight: [num_experts, hidden_size, moe_intermediate_size] """ if getattr(module, "_weights_synced", False): return True W_gate = getattr(module, "gate_proj", None) W_up = getattr(module, "up_proj", None) W_down = getattr(module, "down_proj", None) if W_gate is None or W_up is None or W_down is None: return False is_offload = getattr(W_gate.weight, "is_meta", False) or W_gate.weight.device == torch.device("meta") if is_offload: # Loaded with accelerate offload: tensors live in module._hf_hook.weights_map on CPU. if not _ACCELERATE_AVAILABLE: raise RuntimeError( "Model appears to be loaded with accelerate offload (meta tensors), but accelerate is not available." ) if not hasattr(module, "_hf_hook"): return False W_gate = module._hf_hook.weights_map["gate_proj.weight"] W_up = module._hf_hook.weights_map["up_proj.weight"] W_down = module._hf_hook.weights_map["down_proj.weight"] try: for expert_index in range(int(module.num_experts)): expert_module = getattr(module, str(expert_index)) W_gate_current = W_gate.weight[expert_index] # [moe_intermediate_size, hidden_size] W_up_current = W_up.weight[expert_index] # [moe_intermediate_size, hidden_size] W_down_current = W_down.weight[expert_index] # [hidden_size, moe_intermediate_size] if is_offload: hook = module._hf_hook dataset = hook.weights_map.dataset layer_value = [W_gate_current, W_up_current, W_down_current] for idx, layer_name in enumerate(["gate_proj", "up_proj", "down_proj"]): prefix = f"{hook.weights_map.prefix}{expert_index}.{layer_name}." prefixed_weights_map = PrefixedDataset(dataset, prefix) full_name = f"{prefix}weight" dataset.all_keys.append(full_name) dataset.state_dict[full_name] = layer_value[idx] quark_hook = AlignDevicesHook( execution_device=hook.execution_device, offload=hook.offload, io_same_device=hook.io_same_device, weights_map=prefixed_weights_map, offload_buffers=hook.offload_buffers, place_submodules=hook.place_submodules, skip_keys=hook.skip_keys, tied_params_map=hook.tied_params_map, ) linear_module = getattr(expert_module, layer_name) add_hook_to_module(linear_module, quark_hook) else: # No transpose needed: nn.Linear expects [out_features, in_features], which matches MoELinear tensors. expert_module.gate_proj.weight.data.copy_(W_gate_current.to(W_gate.weight.device)) expert_module.up_proj.weight.data.copy_(W_up_current.to(W_up.weight.device)) expert_module.down_proj.weight.data.copy_(W_down_current.to(W_down.weight.device)) if is_offload: prefix = module._hf_hook.weights_map.prefix del module._hf_hook.weights_map.dataset.state_dict[f"{prefix}gate_proj.weight"] del module._hf_hook.weights_map.dataset.state_dict[f"{prefix}up_proj.weight"] del module._hf_hook.weights_map.dataset.state_dict[f"{prefix}down_proj.weight"] module._hf_hook.weights_map.dataset.all_keys.remove(f"{prefix}gate_proj.weight") module._hf_hook.weights_map.dataset.all_keys.remove(f"{prefix}up_proj.weight") module._hf_hook.weights_map.dataset.all_keys.remove(f"{prefix}down_proj.weight") module._weights_synced = True return True except Exception as e: logger.warning("Failed to sync Step3.5 MoE weights: %s", e) return False @torch.no_grad() def _step35_cleanup_fused(module: Any) -> None: """Optionally remove fused MoELinear modules after replacement.""" # The original MoELinear modules should be garbage collected # when they're replaced, but we can explicitly clear references for proj_name in ["gate_proj", "up_proj", "down_proj"]: # Clear any remaining references to original MoELinear if hasattr(module, proj_name): delattr(module, proj_name) torch.cuda.empty_cache() logger.debug(f"Cleaned up original MoELinear modules") def _step35_moe_forward(self: Any, hidden_states: torch.Tensor) -> torch.Tensor: """ Forward using per-expert gate_proj, up_proj, down_proj (nn.Linear), matching the original Step3p5MoEMLP.forward semantics but without MoELinear. """ synced = _step35_sync_weights_to_linear(self) if not synced: raise RuntimeError( "Step3p5MoEMLP weights are on 'meta' (not materialized). " "Move fused parameters to a real device first, then call forward." ) batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) # Router/gating if self.need_fp32_gate: router_logits = torch.matmul(hidden_states.to(torch.float32), self.gate.weight.t().to(torch.float32)) else: # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states) # Custom routing or standard softmax + top-k if hasattr(self, 'custom_routing_function') and self.custom_routing_function: routing_weights, selected_experts = self.custom_routing_function( router_logits, self.top_k, renormalize=True) else: routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) # Apply scaling factor routing_weights = routing_weights * self.routed_scaling_factor # Initialize output final_hidden_states = torch.zeros( (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device) # One hot encode the selected experts to create an expert mask # this will be used to easily index which expert is going to be solicited expert_mask = torch.nn.functional.one_hot( selected_experts, num_classes=self.num_experts).permute(2, 1, 0) limit = getattr(self, 'limit', None) # Loop over all available experts in the model and perform the computation on each expert for expert_idx in range(self.num_experts): idx, top_x = torch.where(expert_mask[expert_idx]) # Index the correct hidden states and compute the expert hidden state for # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) expert_module = getattr(self, str(expert_idx)) up = expert_module.up_proj(current_state) gate = self.act_fn(expert_module.gate_proj(current_state)) if limit is not None: gate = gate.clamp(min=None, max=limit) up = up.clamp(min=-limit, max=limit) current_hidden_states = expert_module.down_proj(gate * up) * routing_weights[top_x, idx, None] # However `index_add_` only support torch tensors for indexing so we'll use # the `top_x` tensor here. final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states @torch.no_grad() def patch_step35_moe(model: nn.Module) -> int: """ Apply Step-3.5-Flash MoE replacement to all Step3p5MoEMLP modules in the model. """ patched = 0 for name, module in model.named_modules(remove_duplicate=False): if module.__class__.__name__ == "Step3p5MoEMLP": replace_step35_moelinear_with_linear(module) patched += 1 logger.debug(f"Patched MoE module: {name}") if patched > 0: logger.info("Patched %d Step3p5MoEMLP module(s) for quantization.", patched) return patched def _resolve_calib_device(device: str, model: nn.Module) -> str: """ Resolve a torch-compatible device string for calibration inputs. """ if device != "auto": return str(device) hf_map = getattr(model, "hf_device_map", None) if isinstance(hf_map, dict): cuda_ids: list[int] = [] for v in hf_map.values(): m = re.match(r"^cuda:(\d+)$", str(v)) if m: cuda_ids.append(int(m.group(1))) if cuda_ids: return f"cuda:{min(cuda_ids)}" if torch.cuda.is_available(): return "cuda:0" return "cpu" def main(args: argparse.Namespace) -> None: os.makedirs(args.output_quantized_hf_path, exist_ok=True) _register_step35_flash_template() if getattr(args, "preset", None): preset_cfg = PRESETS[args.preset] args.quant_scheme = preset_cfg["quant_scheme"] if getattr(args, "quant_algo", None) is None and "quant_algo" in preset_cfg: args.quant_algo = preset_cfg["quant_algo"] logger.info("Using preset: %s", args.preset) logger.info("Input model: %s", args.model_dir) logger.info("Output dir: %s", args.output_quantized_hf_path) logger.info("Step 1/4: Loading model and tokenizer ...") model, _ = get_model( args.model_dir, data_type=args.data_type, device=args.device, multi_gpu=args.multi_gpu, multi_device=args.multi_device, attn_implementation=args.model_attn_implementation, trust_remote_code=args.trust_remote_code, ) patch_step35_moe(model) model_type = model.config.model_type if hasattr(model.config, "model_type") else model.config.architectures[0] tokenizer = get_tokenizer( args.model_dir, max_seq_len=args.seq_len, model_type=model_type, trust_remote_code=args.trust_remote_code ) logger.info("Step 2/4: Building calibration dataloader ...") base_device = str(model.device) if (args.multi_gpu or args.multi_device) else str(args.device) main_device = _resolve_calib_device(base_device, model) logger.info("Calibration dataset: %s", args.dataset) calib_dataloader = get_calib_dataloader( dataset_name=args.dataset, tokenizer=tokenizer, batch_size=args.batch_size, num_calib_data=args.num_calib_data, seqlen=args.seq_len, device=main_device, ) logger.info("Step 3/4: Quantizing ...") template = LLMTemplate.get(model_type) if args.exclude_layers is not None: logger.warning( "Ignoring --exclude_layers (%s). This script always uses " "_register_step35_flash_template excludes for Step-3.5-Flash.", args.exclude_layers, ) exclude_layers = _step35_template_exclude_layers() logger.info("Exclude layers (template): %s", exclude_layers) if getattr(args, "quant_algo", None): logger.info("Quantization algorithm(s): %s", args.quant_algo) quant_config = template.get_config( scheme=args.quant_scheme, algorithm=args.quant_algo, kv_cache_scheme=None, min_kv_scale=0.0, layer_config={}, attention_scheme=None, exclude_layers=exclude_layers, algo_configs=None, ) quantizer = ModelQuantizer(quant_config, args.multi_device) model = quantizer.quantize_model(model, calib_dataloader) model = quantizer.freeze(model) logger.info("Step 4/4: Exporting HF safetensors ...") _copy_non_weight_files(args.model_dir, args.output_quantized_hf_path) with torch.no_grad(): export_safetensors( model=model, output_dir=args.output_quantized_hf_path, custom_mode="quark", weight_format=args.export_weight_format, pack_method=args.pack_method, ) tokenizer.save_pretrained(args.output_quantized_hf_path) logger.info("Export completed.") logger.info("========== Quantization Completed Successfully ==========") if __name__ == "__main__": parser = argparse.ArgumentParser( description="Offline quantization for Step-3.5-Flash with MoE layer replacement" ) parser.add_argument("--model_dir", dest="model_dir", type=str, default=DEFAULT_INPUT_MODEL_PATH) parser.add_argument("--output_dir", dest="output_quantized_hf_path", type=str, default=DEFAULT_OUTPUT_MODEL_PATH) # Model loading parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "cpu"]) parser.add_argument("--multi_gpu", dest="multi_gpu", action="store_true") parser.add_argument("--multi_device", dest="multi_device", action="store_true") parser.add_argument( "--model_attn_implementation", dest="model_attn_implementation", type=str, default="eager", choices=["eager", "sdpa", "flash_attention_2"], ) parser.add_argument( "--data_type", dest="data_type", type=str, default="auto", choices=["auto", "float16", "bfloat16", "float32"], ) # Calibration parser.add_argument( "--dataset", dest="dataset", type=str, default="pileval", help="Calibration dataset name. Default is 'pileval'.", ) parser.add_argument("--seq_len", dest="seq_len", type=int, default=512) parser.add_argument("--batch_size", dest="batch_size", type=int, default=1) parser.add_argument("--num_calib_data", dest="num_calib_data", type=int, default=128) # Quantization parser.add_argument( "--preset", dest="preset", type=str, choices=sorted(PRESETS.keys()), default="mxfp4_moe_only_no_kvcache", help="Convenience preset for quantization settings.", ) parser.add_argument( "--quant_algo", dest="quant_algo", type=str, default=None, help="Optional quantization algorithm(s) to apply.", ) parser.add_argument( "--exclude_layers", type=str, nargs="*", default=None, help="Layer wildcard patterns to exclude from quantization.", ) # Export parser.add_argument("--pack_method", dest="pack_method", type=str, default="reorder", choices=["order", "reorder"]) parser.add_argument( "--export_weight_format", dest="export_weight_format", type=str, default="real_quantized", choices=["fake_quantized", "real_quantized"], ) group = parser.add_mutually_exclusive_group() group.add_argument( "--trust_remote_code", action="store_true", dest="trust_remote_code", help="Enable execution of custom model code from the Hub (use only with repositories you fully trust).", ) group.add_argument( "--no_trust_remote_code", action="store_false", dest="trust_remote_code", help="Disable execution of custom model code from the Hub (safer, recommended if unsure).", ) parser.set_defaults(trust_remote_code=True) main(parser.parse_args())