Step-3.5-Flash-MXFP4 / step3p5_quantize_quark.py
Colin Zeng
Model Upload
bd26087
#!/usr/bin/env python3
#
# Copyright (C) 2023 - 2026 Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT
#
# Quantization script for Step-3.5-Flash with MoE layer replacement
from __future__ import annotations
import argparse
import os
import re
import shutil
from pathlib import Path
from types import MethodType
from typing import Any
import torch
import torch.nn as nn
from quark.torch import LLMTemplate, ModelQuantizer, export_safetensors
from quark.torch.utils.llm import (
get_calib_dataloader,
get_model,
get_tokenizer,
)
from quark.common.utils.log import ScreenLogger
try:
# Needed only when the model is loaded with accelerate offload (meta tensors).
from accelerate.hooks import AlignDevicesHook, add_hook_to_module # type: ignore
from accelerate.utils import PrefixedDataset # type: ignore
_ACCELERATE_AVAILABLE = True
except Exception:
AlignDevicesHook = None # type: ignore[assignment]
add_hook_to_module = None # type: ignore[assignment]
PrefixedDataset = None # type: ignore[assignment]
_ACCELERATE_AVAILABLE = False
DEFAULT_INPUT_MODEL_PATH = "stepfun-ai/Step-3.5-Flash"
DEFAULT_OUTPUT_MODEL_PATH = "quantized_models/Step-3.5-Flash-MXFP4"
logger = ScreenLogger(__name__)
def _step35_template_exclude_layers() -> list[str]:
return [
# embeddings / lm head / norms
"model.embed_tokens*",
"*embed_tokens*",
"*lm_head*",
"*layernorm*",
"*norm*",
# Router gate
"*moe.gate",
"*moe.router_bias*",
# The first three blocks use dense FFNs
"model.layers.0.mlp.*",
"model.layers.1.mlp.*",
"model.layers.2.mlp.*",
# Shared Experts
"*share_expert*",
"*self_attn*",
]
PRESETS: dict[str, dict[str, Any]] = {
"mxfp4_moe_only_no_kvcache": {
"quant_scheme": "mxfp4",
"exclude_layers": _step35_template_exclude_layers(),
},
}
def _copy_non_weight_files(src_dir: str, dst_dir: str) -> None:
"""
Copy non-weight files from an HF model directory (json/jinja/tokenizer, etc.),
while skipping *.safetensors and model.safetensors.index.json.
Note: `export_safetensors` exports the essential HF weights and config, but the
original model directory may contain extra assets (e.g. chat_template.jinja).
We do a conservative copy here so offline inference keeps those auxiliary files.
"""
src = Path(src_dir)
dst = Path(dst_dir)
dst.mkdir(parents=True, exist_ok=True)
for p in src.iterdir():
if p.is_dir():
continue
name = p.name
if name.endswith(".safetensors"):
continue
if name == "model.safetensors.index.json":
continue
# Export will (re-)write config / generation_config; copying them here is harmless
# (later writes will overwrite).
shutil.copy2(p, dst / name)
def _register_step35_flash_template() -> None:
"""
Register a Quark LLMTemplate for Step-3.5-Flash (config.model_type = step3p5).
"""
model_type = "step3p5"
if model_type in LLMTemplate.list_available():
return
step35_flash_template = LLMTemplate(
model_type=model_type,
kv_layers_name=["*k_proj", "*v_proj"],
q_layer_name="*q_proj",
exclude_layers_name=_step35_template_exclude_layers(),
)
LLMTemplate.register_template(step35_flash_template)
logger.info("Registered LLMTemplate: %s", model_type)
@torch.no_grad()
def replace_step35_moelinear_with_linear(moe_module: Any) -> None:
"""
Convert Step3p5MoEMLP's MoELinear modules into separate Linear layers per expert.
"""
if getattr(moe_module, "_step35_replaced", False):
return
logger.debug("Converting Step3p5MoEMLP experts to separate gate/up/down Linear layers...")
# Get dimensions from the module
num_experts: int = int(getattr(moe_module, "moe_num_experts", 288))
hidden_size: int = int(getattr(moe_module, "hidden_size", 4096))
moe_intermediate_size: int = int(getattr(moe_module, "moe_intermediate_size", 1280))
# Store original device and dtype from one of the MoELinear modules
original_device = moe_module.gate_proj.weight.device
original_dtype = moe_module.gate_proj.weight.dtype
# [num_experts, in, out]
# Expose common attribute names for the forward helper
moe_module.hidden_size = hidden_size
moe_module.expert_dim = moe_intermediate_size
moe_module.num_experts = num_experts
is_meta: bool = original_device == torch.device("meta")
target_device_for_new = original_device if not is_meta else torch.device("meta")
# Create individual expert modules, each containing gate_proj, up_proj, down_proj
for expert_index in range(num_experts):
expert_module = nn.Module()
expert_module.gate_proj = nn.Linear(
hidden_size, moe_intermediate_size, bias=False, device=target_device_for_new, dtype=original_dtype
)
expert_module.up_proj = nn.Linear(
hidden_size, moe_intermediate_size, bias=False, device=target_device_for_new, dtype=original_dtype
)
expert_module.down_proj = nn.Linear(
moe_intermediate_size, hidden_size, bias=False, device=target_device_for_new, dtype=original_dtype
)
setattr(moe_module, str(expert_index), expert_module)
# Sync weights from MoELinear to individual Linear modules
weights_synced = _step35_sync_weights_to_linear(moe_module)
# Replace forward method
moe_module.forward = MethodType(_step35_moe_forward, moe_module)
if weights_synced:
_step35_cleanup_fused(moe_module)
moe_module._step35_replaced = True
@torch.no_grad()
def _step35_sync_weights_to_linear(module: Any) -> bool:
"""
Split MoELinear weights and copy into per-expert Linear layers.
Returns True if synced; returns False if fused weights are still on 'meta' (not materialized).
MoELinear tensors in Step3p5MoEMLP are expected to be:
- gate_proj.weight: [num_experts, moe_intermediate_size, hidden_size]
- up_proj.weight: [num_experts, moe_intermediate_size, hidden_size]
- down_proj.weight: [num_experts, hidden_size, moe_intermediate_size]
"""
if getattr(module, "_weights_synced", False):
return True
W_gate = getattr(module, "gate_proj", None)
W_up = getattr(module, "up_proj", None)
W_down = getattr(module, "down_proj", None)
if W_gate is None or W_up is None or W_down is None:
return False
is_offload = getattr(W_gate.weight, "is_meta", False) or W_gate.weight.device == torch.device("meta")
if is_offload:
# Loaded with accelerate offload: tensors live in module._hf_hook.weights_map on CPU.
if not _ACCELERATE_AVAILABLE:
raise RuntimeError(
"Model appears to be loaded with accelerate offload (meta tensors), but accelerate is not available."
)
if not hasattr(module, "_hf_hook"):
return False
W_gate = module._hf_hook.weights_map["gate_proj.weight"]
W_up = module._hf_hook.weights_map["up_proj.weight"]
W_down = module._hf_hook.weights_map["down_proj.weight"]
try:
for expert_index in range(int(module.num_experts)):
expert_module = getattr(module, str(expert_index))
W_gate_current = W_gate.weight[expert_index] # [moe_intermediate_size, hidden_size]
W_up_current = W_up.weight[expert_index] # [moe_intermediate_size, hidden_size]
W_down_current = W_down.weight[expert_index] # [hidden_size, moe_intermediate_size]
if is_offload:
hook = module._hf_hook
dataset = hook.weights_map.dataset
layer_value = [W_gate_current, W_up_current, W_down_current]
for idx, layer_name in enumerate(["gate_proj", "up_proj", "down_proj"]):
prefix = f"{hook.weights_map.prefix}{expert_index}.{layer_name}."
prefixed_weights_map = PrefixedDataset(dataset, prefix)
full_name = f"{prefix}weight"
dataset.all_keys.append(full_name)
dataset.state_dict[full_name] = layer_value[idx]
quark_hook = AlignDevicesHook(
execution_device=hook.execution_device,
offload=hook.offload,
io_same_device=hook.io_same_device,
weights_map=prefixed_weights_map,
offload_buffers=hook.offload_buffers,
place_submodules=hook.place_submodules,
skip_keys=hook.skip_keys,
tied_params_map=hook.tied_params_map,
)
linear_module = getattr(expert_module, layer_name)
add_hook_to_module(linear_module, quark_hook)
else:
# No transpose needed: nn.Linear expects [out_features, in_features], which matches MoELinear tensors.
expert_module.gate_proj.weight.data.copy_(W_gate_current.to(W_gate.weight.device))
expert_module.up_proj.weight.data.copy_(W_up_current.to(W_up.weight.device))
expert_module.down_proj.weight.data.copy_(W_down_current.to(W_down.weight.device))
if is_offload:
prefix = module._hf_hook.weights_map.prefix
del module._hf_hook.weights_map.dataset.state_dict[f"{prefix}gate_proj.weight"]
del module._hf_hook.weights_map.dataset.state_dict[f"{prefix}up_proj.weight"]
del module._hf_hook.weights_map.dataset.state_dict[f"{prefix}down_proj.weight"]
module._hf_hook.weights_map.dataset.all_keys.remove(f"{prefix}gate_proj.weight")
module._hf_hook.weights_map.dataset.all_keys.remove(f"{prefix}up_proj.weight")
module._hf_hook.weights_map.dataset.all_keys.remove(f"{prefix}down_proj.weight")
module._weights_synced = True
return True
except Exception as e:
logger.warning("Failed to sync Step3.5 MoE weights: %s", e)
return False
@torch.no_grad()
def _step35_cleanup_fused(module: Any) -> None:
"""Optionally remove fused MoELinear modules after replacement."""
# The original MoELinear modules should be garbage collected
# when they're replaced, but we can explicitly clear references
for proj_name in ["gate_proj", "up_proj", "down_proj"]:
# Clear any remaining references to original MoELinear
if hasattr(module, proj_name):
delattr(module, proj_name)
torch.cuda.empty_cache()
logger.debug(f"Cleaned up original MoELinear modules")
def _step35_moe_forward(self: Any, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Forward using per-expert gate_proj, up_proj, down_proj (nn.Linear),
matching the original Step3p5MoEMLP.forward semantics but without MoELinear.
"""
synced = _step35_sync_weights_to_linear(self)
if not synced:
raise RuntimeError(
"Step3p5MoEMLP weights are on 'meta' (not materialized). "
"Move fused parameters to a real device first, then call forward."
)
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# Router/gating
if self.need_fp32_gate:
router_logits = torch.matmul(hidden_states.to(torch.float32), self.gate.weight.t().to(torch.float32))
else:
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
# Custom routing or standard softmax + top-k
if hasattr(self, 'custom_routing_function') and self.custom_routing_function:
routing_weights, selected_experts = self.custom_routing_function(
router_logits, self.top_k, renormalize=True)
else:
routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
# Apply scaling factor
routing_weights = routing_weights * self.routed_scaling_factor
# Initialize output
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim),
dtype=hidden_states.dtype,
device=hidden_states.device)
# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be solicited
expert_mask = torch.nn.functional.one_hot(
selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
limit = getattr(self, 'limit', None)
# Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(self.num_experts):
idx, top_x = torch.where(expert_mask[expert_idx])
# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
expert_module = getattr(self, str(expert_idx))
up = expert_module.up_proj(current_state)
gate = self.act_fn(expert_module.gate_proj(current_state))
if limit is not None:
gate = gate.clamp(min=None, max=limit)
up = up.clamp(min=-limit, max=limit)
current_hidden_states = expert_module.down_proj(gate * up) * routing_weights[top_x, idx, None]
# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states
@torch.no_grad()
def patch_step35_moe(model: nn.Module) -> int:
"""
Apply Step-3.5-Flash MoE replacement to all Step3p5MoEMLP modules in the model.
"""
patched = 0
for name, module in model.named_modules(remove_duplicate=False):
if module.__class__.__name__ == "Step3p5MoEMLP":
replace_step35_moelinear_with_linear(module)
patched += 1
logger.debug(f"Patched MoE module: {name}")
if patched > 0:
logger.info("Patched %d Step3p5MoEMLP module(s) for quantization.", patched)
return patched
def _resolve_calib_device(device: str, model: nn.Module) -> str:
"""
Resolve a torch-compatible device string for calibration inputs.
"""
if device != "auto":
return str(device)
hf_map = getattr(model, "hf_device_map", None)
if isinstance(hf_map, dict):
cuda_ids: list[int] = []
for v in hf_map.values():
m = re.match(r"^cuda:(\d+)$", str(v))
if m:
cuda_ids.append(int(m.group(1)))
if cuda_ids:
return f"cuda:{min(cuda_ids)}"
if torch.cuda.is_available():
return "cuda:0"
return "cpu"
def main(args: argparse.Namespace) -> None:
os.makedirs(args.output_quantized_hf_path, exist_ok=True)
_register_step35_flash_template()
if getattr(args, "preset", None):
preset_cfg = PRESETS[args.preset]
args.quant_scheme = preset_cfg["quant_scheme"]
if getattr(args, "quant_algo", None) is None and "quant_algo" in preset_cfg:
args.quant_algo = preset_cfg["quant_algo"]
logger.info("Using preset: %s", args.preset)
logger.info("Input model: %s", args.model_dir)
logger.info("Output dir: %s", args.output_quantized_hf_path)
logger.info("Step 1/4: Loading model and tokenizer ...")
model, _ = get_model(
args.model_dir,
data_type=args.data_type,
device=args.device,
multi_gpu=args.multi_gpu,
multi_device=args.multi_device,
attn_implementation=args.model_attn_implementation,
trust_remote_code=args.trust_remote_code,
)
patch_step35_moe(model)
model_type = model.config.model_type if hasattr(model.config, "model_type") else model.config.architectures[0]
tokenizer = get_tokenizer(
args.model_dir, max_seq_len=args.seq_len, model_type=model_type, trust_remote_code=args.trust_remote_code
)
logger.info("Step 2/4: Building calibration dataloader ...")
base_device = str(model.device) if (args.multi_gpu or args.multi_device) else str(args.device)
main_device = _resolve_calib_device(base_device, model)
logger.info("Calibration dataset: %s", args.dataset)
calib_dataloader = get_calib_dataloader(
dataset_name=args.dataset,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_calib_data=args.num_calib_data,
seqlen=args.seq_len,
device=main_device,
)
logger.info("Step 3/4: Quantizing ...")
template = LLMTemplate.get(model_type)
if args.exclude_layers is not None:
logger.warning(
"Ignoring --exclude_layers (%s). This script always uses "
"_register_step35_flash_template excludes for Step-3.5-Flash.",
args.exclude_layers,
)
exclude_layers = _step35_template_exclude_layers()
logger.info("Exclude layers (template): %s", exclude_layers)
if getattr(args, "quant_algo", None):
logger.info("Quantization algorithm(s): %s", args.quant_algo)
quant_config = template.get_config(
scheme=args.quant_scheme,
algorithm=args.quant_algo,
kv_cache_scheme=None,
min_kv_scale=0.0,
layer_config={},
attention_scheme=None,
exclude_layers=exclude_layers,
algo_configs=None,
)
quantizer = ModelQuantizer(quant_config, args.multi_device)
model = quantizer.quantize_model(model, calib_dataloader)
model = quantizer.freeze(model)
logger.info("Step 4/4: Exporting HF safetensors ...")
_copy_non_weight_files(args.model_dir, args.output_quantized_hf_path)
with torch.no_grad():
export_safetensors(
model=model,
output_dir=args.output_quantized_hf_path,
custom_mode="quark",
weight_format=args.export_weight_format,
pack_method=args.pack_method,
)
tokenizer.save_pretrained(args.output_quantized_hf_path)
logger.info("Export completed.")
logger.info("========== Quantization Completed Successfully ==========")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Offline quantization for Step-3.5-Flash with MoE layer replacement"
)
parser.add_argument("--model_dir", dest="model_dir", type=str, default=DEFAULT_INPUT_MODEL_PATH)
parser.add_argument("--output_dir", dest="output_quantized_hf_path", type=str, default=DEFAULT_OUTPUT_MODEL_PATH)
# Model loading
parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "cpu"])
parser.add_argument("--multi_gpu", dest="multi_gpu", action="store_true")
parser.add_argument("--multi_device", dest="multi_device", action="store_true")
parser.add_argument(
"--model_attn_implementation",
dest="model_attn_implementation",
type=str,
default="eager",
choices=["eager", "sdpa", "flash_attention_2"],
)
parser.add_argument(
"--data_type",
dest="data_type",
type=str,
default="auto",
choices=["auto", "float16", "bfloat16", "float32"],
)
# Calibration
parser.add_argument(
"--dataset",
dest="dataset",
type=str,
default="pileval",
help="Calibration dataset name. Default is 'pileval'.",
)
parser.add_argument("--seq_len", dest="seq_len", type=int, default=512)
parser.add_argument("--batch_size", dest="batch_size", type=int, default=1)
parser.add_argument("--num_calib_data", dest="num_calib_data", type=int, default=128)
# Quantization
parser.add_argument(
"--preset",
dest="preset",
type=str,
choices=sorted(PRESETS.keys()),
default="mxfp4_moe_only_no_kvcache",
help="Convenience preset for quantization settings.",
)
parser.add_argument(
"--quant_algo",
dest="quant_algo",
type=str,
default=None,
help="Optional quantization algorithm(s) to apply.",
)
parser.add_argument(
"--exclude_layers",
type=str,
nargs="*",
default=None,
help="Layer wildcard patterns to exclude from quantization.",
)
# Export
parser.add_argument("--pack_method", dest="pack_method", type=str, default="reorder", choices=["order", "reorder"])
parser.add_argument(
"--export_weight_format",
dest="export_weight_format",
type=str,
default="real_quantized",
choices=["fake_quantized", "real_quantized"],
)
group = parser.add_mutually_exclusive_group()
group.add_argument(
"--trust_remote_code",
action="store_true",
dest="trust_remote_code",
help="Enable execution of custom model code from the Hub (use only with repositories you fully trust).",
)
group.add_argument(
"--no_trust_remote_code",
action="store_false",
dest="trust_remote_code",
help="Disable execution of custom model code from the Hub (safer, recommended if unsure).",
)
parser.set_defaults(trust_remote_code=True)
main(parser.parse_args())