| |
| |
| |
| |
| |
| |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import os |
| import re |
| import shutil |
| from pathlib import Path |
| from types import MethodType |
| from typing import Any |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from quark.torch import LLMTemplate, ModelQuantizer, export_safetensors |
| from quark.torch.utils.llm import ( |
| get_calib_dataloader, |
| get_model, |
| get_tokenizer, |
| ) |
| from quark.common.utils.log import ScreenLogger |
|
|
| try: |
| |
| from accelerate.hooks import AlignDevicesHook, add_hook_to_module |
| from accelerate.utils import PrefixedDataset |
|
|
| _ACCELERATE_AVAILABLE = True |
| except Exception: |
| AlignDevicesHook = None |
| add_hook_to_module = None |
| PrefixedDataset = None |
| _ACCELERATE_AVAILABLE = False |
|
|
|
|
| DEFAULT_INPUT_MODEL_PATH = "stepfun-ai/Step-3.5-Flash" |
| DEFAULT_OUTPUT_MODEL_PATH = "quantized_models/Step-3.5-Flash-MXFP4" |
|
|
| logger = ScreenLogger(__name__) |
|
|
| def _step35_template_exclude_layers() -> list[str]: |
| return [ |
| |
| "model.embed_tokens*", |
| "*embed_tokens*", |
| "*lm_head*", |
| "*layernorm*", |
| "*norm*", |
| |
| "*moe.gate", |
| "*moe.router_bias*", |
| |
| "model.layers.0.mlp.*", |
| "model.layers.1.mlp.*", |
| "model.layers.2.mlp.*", |
| |
| "*share_expert*", |
| "*self_attn*", |
| ] |
|
|
| PRESETS: dict[str, dict[str, Any]] = { |
|
|
| "mxfp4_moe_only_no_kvcache": { |
| "quant_scheme": "mxfp4", |
| "exclude_layers": _step35_template_exclude_layers(), |
| }, |
| } |
|
|
|
|
| def _copy_non_weight_files(src_dir: str, dst_dir: str) -> None: |
| """ |
| Copy non-weight files from an HF model directory (json/jinja/tokenizer, etc.), |
| while skipping *.safetensors and model.safetensors.index.json. |
| |
| Note: `export_safetensors` exports the essential HF weights and config, but the |
| original model directory may contain extra assets (e.g. chat_template.jinja). |
| We do a conservative copy here so offline inference keeps those auxiliary files. |
| """ |
| src = Path(src_dir) |
| dst = Path(dst_dir) |
| dst.mkdir(parents=True, exist_ok=True) |
|
|
| for p in src.iterdir(): |
| if p.is_dir(): |
| continue |
| name = p.name |
| if name.endswith(".safetensors"): |
| continue |
| if name == "model.safetensors.index.json": |
| continue |
| |
| |
| shutil.copy2(p, dst / name) |
|
|
|
|
| def _register_step35_flash_template() -> None: |
| """ |
| Register a Quark LLMTemplate for Step-3.5-Flash (config.model_type = step3p5). |
| """ |
| model_type = "step3p5" |
| if model_type in LLMTemplate.list_available(): |
| return |
|
|
|
|
| step35_flash_template = LLMTemplate( |
| model_type=model_type, |
| kv_layers_name=["*k_proj", "*v_proj"], |
| q_layer_name="*q_proj", |
| exclude_layers_name=_step35_template_exclude_layers(), |
| ) |
| LLMTemplate.register_template(step35_flash_template) |
| logger.info("Registered LLMTemplate: %s", model_type) |
|
|
|
|
| @torch.no_grad() |
| def replace_step35_moelinear_with_linear(moe_module: Any) -> None: |
| """ |
| Convert Step3p5MoEMLP's MoELinear modules into separate Linear layers per expert. |
| """ |
| if getattr(moe_module, "_step35_replaced", False): |
| return |
|
|
| logger.debug("Converting Step3p5MoEMLP experts to separate gate/up/down Linear layers...") |
|
|
| |
| num_experts: int = int(getattr(moe_module, "moe_num_experts", 288)) |
| hidden_size: int = int(getattr(moe_module, "hidden_size", 4096)) |
| moe_intermediate_size: int = int(getattr(moe_module, "moe_intermediate_size", 1280)) |
| |
| |
| original_device = moe_module.gate_proj.weight.device |
| original_dtype = moe_module.gate_proj.weight.dtype |
| |
| |
| moe_module.hidden_size = hidden_size |
| moe_module.expert_dim = moe_intermediate_size |
| moe_module.num_experts = num_experts |
|
|
| is_meta: bool = original_device == torch.device("meta") |
| target_device_for_new = original_device if not is_meta else torch.device("meta") |
|
|
| |
| for expert_index in range(num_experts): |
| expert_module = nn.Module() |
| expert_module.gate_proj = nn.Linear( |
| hidden_size, moe_intermediate_size, bias=False, device=target_device_for_new, dtype=original_dtype |
| ) |
| expert_module.up_proj = nn.Linear( |
| hidden_size, moe_intermediate_size, bias=False, device=target_device_for_new, dtype=original_dtype |
| ) |
| expert_module.down_proj = nn.Linear( |
| moe_intermediate_size, hidden_size, bias=False, device=target_device_for_new, dtype=original_dtype |
| ) |
| setattr(moe_module, str(expert_index), expert_module) |
|
|
|
|
| |
| weights_synced = _step35_sync_weights_to_linear(moe_module) |
| |
| |
| moe_module.forward = MethodType(_step35_moe_forward, moe_module) |
|
|
| if weights_synced: |
| _step35_cleanup_fused(moe_module) |
|
|
| moe_module._step35_replaced = True |
|
|
|
|
| @torch.no_grad() |
| def _step35_sync_weights_to_linear(module: Any) -> bool: |
| """ |
| Split MoELinear weights and copy into per-expert Linear layers. |
| Returns True if synced; returns False if fused weights are still on 'meta' (not materialized). |
| MoELinear tensors in Step3p5MoEMLP are expected to be: |
| - gate_proj.weight: [num_experts, moe_intermediate_size, hidden_size] |
| - up_proj.weight: [num_experts, moe_intermediate_size, hidden_size] |
| - down_proj.weight: [num_experts, hidden_size, moe_intermediate_size] |
| """ |
| if getattr(module, "_weights_synced", False): |
| return True |
|
|
| W_gate = getattr(module, "gate_proj", None) |
| W_up = getattr(module, "up_proj", None) |
| W_down = getattr(module, "down_proj", None) |
| |
| if W_gate is None or W_up is None or W_down is None: |
| return False |
|
|
| is_offload = getattr(W_gate.weight, "is_meta", False) or W_gate.weight.device == torch.device("meta") |
| if is_offload: |
| |
| if not _ACCELERATE_AVAILABLE: |
| raise RuntimeError( |
| "Model appears to be loaded with accelerate offload (meta tensors), but accelerate is not available." |
| ) |
| if not hasattr(module, "_hf_hook"): |
| return False |
| W_gate = module._hf_hook.weights_map["gate_proj.weight"] |
| W_up = module._hf_hook.weights_map["up_proj.weight"] |
| W_down = module._hf_hook.weights_map["down_proj.weight"] |
|
|
| try: |
| for expert_index in range(int(module.num_experts)): |
| expert_module = getattr(module, str(expert_index)) |
|
|
| W_gate_current = W_gate.weight[expert_index] |
| W_up_current = W_up.weight[expert_index] |
| W_down_current = W_down.weight[expert_index] |
|
|
| if is_offload: |
| hook = module._hf_hook |
| dataset = hook.weights_map.dataset |
| layer_value = [W_gate_current, W_up_current, W_down_current] |
| for idx, layer_name in enumerate(["gate_proj", "up_proj", "down_proj"]): |
| prefix = f"{hook.weights_map.prefix}{expert_index}.{layer_name}." |
| prefixed_weights_map = PrefixedDataset(dataset, prefix) |
| full_name = f"{prefix}weight" |
| dataset.all_keys.append(full_name) |
| dataset.state_dict[full_name] = layer_value[idx] |
|
|
| quark_hook = AlignDevicesHook( |
| execution_device=hook.execution_device, |
| offload=hook.offload, |
| io_same_device=hook.io_same_device, |
| weights_map=prefixed_weights_map, |
| offload_buffers=hook.offload_buffers, |
| place_submodules=hook.place_submodules, |
| skip_keys=hook.skip_keys, |
| tied_params_map=hook.tied_params_map, |
| ) |
| linear_module = getattr(expert_module, layer_name) |
| add_hook_to_module(linear_module, quark_hook) |
| else: |
| |
| expert_module.gate_proj.weight.data.copy_(W_gate_current.to(W_gate.weight.device)) |
| expert_module.up_proj.weight.data.copy_(W_up_current.to(W_up.weight.device)) |
| expert_module.down_proj.weight.data.copy_(W_down_current.to(W_down.weight.device)) |
|
|
| if is_offload: |
| prefix = module._hf_hook.weights_map.prefix |
| del module._hf_hook.weights_map.dataset.state_dict[f"{prefix}gate_proj.weight"] |
| del module._hf_hook.weights_map.dataset.state_dict[f"{prefix}up_proj.weight"] |
| del module._hf_hook.weights_map.dataset.state_dict[f"{prefix}down_proj.weight"] |
| module._hf_hook.weights_map.dataset.all_keys.remove(f"{prefix}gate_proj.weight") |
| module._hf_hook.weights_map.dataset.all_keys.remove(f"{prefix}up_proj.weight") |
| module._hf_hook.weights_map.dataset.all_keys.remove(f"{prefix}down_proj.weight") |
|
|
| module._weights_synced = True |
| return True |
| except Exception as e: |
| logger.warning("Failed to sync Step3.5 MoE weights: %s", e) |
| return False |
|
|
|
|
|
|
| @torch.no_grad() |
| def _step35_cleanup_fused(module: Any) -> None: |
| """Optionally remove fused MoELinear modules after replacement.""" |
| |
| |
| for proj_name in ["gate_proj", "up_proj", "down_proj"]: |
| |
| if hasattr(module, proj_name): |
| delattr(module, proj_name) |
| |
| torch.cuda.empty_cache() |
| logger.debug(f"Cleaned up original MoELinear modules") |
|
|
|
|
| def _step35_moe_forward(self: Any, hidden_states: torch.Tensor) -> torch.Tensor: |
| """ |
| Forward using per-expert gate_proj, up_proj, down_proj (nn.Linear), |
| matching the original Step3p5MoEMLP.forward semantics but without MoELinear. |
| """ |
| synced = _step35_sync_weights_to_linear(self) |
| if not synced: |
| raise RuntimeError( |
| "Step3p5MoEMLP weights are on 'meta' (not materialized). " |
| "Move fused parameters to a real device first, then call forward." |
| ) |
| |
| batch_size, sequence_length, hidden_dim = hidden_states.shape |
| hidden_states = hidden_states.view(-1, hidden_dim) |
| |
| |
| if self.need_fp32_gate: |
| router_logits = torch.matmul(hidden_states.to(torch.float32), self.gate.weight.t().to(torch.float32)) |
| else: |
| |
| router_logits = self.gate(hidden_states) |
| |
| |
| if hasattr(self, 'custom_routing_function') and self.custom_routing_function: |
| routing_weights, selected_experts = self.custom_routing_function( |
| router_logits, self.top_k, renormalize=True) |
| else: |
| routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float) |
| routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) |
| |
| |
| routing_weights = routing_weights * self.routed_scaling_factor |
| |
| |
| final_hidden_states = torch.zeros( |
| (batch_size * sequence_length, hidden_dim), |
| dtype=hidden_states.dtype, |
| device=hidden_states.device) |
| |
| |
| |
| expert_mask = torch.nn.functional.one_hot( |
| selected_experts, num_classes=self.num_experts).permute(2, 1, 0) |
| |
| limit = getattr(self, 'limit', None) |
| |
| |
| for expert_idx in range(self.num_experts): |
| idx, top_x = torch.where(expert_mask[expert_idx]) |
|
|
| |
| |
| |
| current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) |
| |
| expert_module = getattr(self, str(expert_idx)) |
|
|
| up = expert_module.up_proj(current_state) |
| gate = self.act_fn(expert_module.gate_proj(current_state)) |
| |
| if limit is not None: |
| gate = gate.clamp(min=None, max=limit) |
| up = up.clamp(min=-limit, max=limit) |
| |
| current_hidden_states = expert_module.down_proj(gate * up) * routing_weights[top_x, idx, None] |
| |
| |
| |
| final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) |
| |
| final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) |
| return final_hidden_states |
|
|
| @torch.no_grad() |
| def patch_step35_moe(model: nn.Module) -> int: |
| """ |
| Apply Step-3.5-Flash MoE replacement to all Step3p5MoEMLP modules in the model. |
| """ |
| patched = 0 |
| for name, module in model.named_modules(remove_duplicate=False): |
| if module.__class__.__name__ == "Step3p5MoEMLP": |
| replace_step35_moelinear_with_linear(module) |
| patched += 1 |
| logger.debug(f"Patched MoE module: {name}") |
|
|
| if patched > 0: |
| logger.info("Patched %d Step3p5MoEMLP module(s) for quantization.", patched) |
| return patched |
|
|
|
|
| def _resolve_calib_device(device: str, model: nn.Module) -> str: |
| """ |
| Resolve a torch-compatible device string for calibration inputs. |
| """ |
| if device != "auto": |
| return str(device) |
|
|
| hf_map = getattr(model, "hf_device_map", None) |
| if isinstance(hf_map, dict): |
| cuda_ids: list[int] = [] |
| for v in hf_map.values(): |
| m = re.match(r"^cuda:(\d+)$", str(v)) |
| if m: |
| cuda_ids.append(int(m.group(1))) |
| if cuda_ids: |
| return f"cuda:{min(cuda_ids)}" |
|
|
| if torch.cuda.is_available(): |
| return "cuda:0" |
| return "cpu" |
|
|
|
|
| def main(args: argparse.Namespace) -> None: |
| os.makedirs(args.output_quantized_hf_path, exist_ok=True) |
|
|
| _register_step35_flash_template() |
|
|
| if getattr(args, "preset", None): |
| preset_cfg = PRESETS[args.preset] |
| args.quant_scheme = preset_cfg["quant_scheme"] |
| if getattr(args, "quant_algo", None) is None and "quant_algo" in preset_cfg: |
| args.quant_algo = preset_cfg["quant_algo"] |
| logger.info("Using preset: %s", args.preset) |
|
|
| logger.info("Input model: %s", args.model_dir) |
| logger.info("Output dir: %s", args.output_quantized_hf_path) |
|
|
| logger.info("Step 1/4: Loading model and tokenizer ...") |
| model, _ = get_model( |
| args.model_dir, |
| data_type=args.data_type, |
| device=args.device, |
| multi_gpu=args.multi_gpu, |
| multi_device=args.multi_device, |
| attn_implementation=args.model_attn_implementation, |
| trust_remote_code=args.trust_remote_code, |
| ) |
| |
| patch_step35_moe(model) |
| |
| model_type = model.config.model_type if hasattr(model.config, "model_type") else model.config.architectures[0] |
| tokenizer = get_tokenizer( |
| args.model_dir, max_seq_len=args.seq_len, model_type=model_type, trust_remote_code=args.trust_remote_code |
| ) |
|
|
| logger.info("Step 2/4: Building calibration dataloader ...") |
| base_device = str(model.device) if (args.multi_gpu or args.multi_device) else str(args.device) |
| main_device = _resolve_calib_device(base_device, model) |
| logger.info("Calibration dataset: %s", args.dataset) |
| calib_dataloader = get_calib_dataloader( |
| dataset_name=args.dataset, |
| tokenizer=tokenizer, |
| batch_size=args.batch_size, |
| num_calib_data=args.num_calib_data, |
| seqlen=args.seq_len, |
| device=main_device, |
| ) |
|
|
| logger.info("Step 3/4: Quantizing ...") |
| template = LLMTemplate.get(model_type) |
| if args.exclude_layers is not None: |
| logger.warning( |
| "Ignoring --exclude_layers (%s). This script always uses " |
| "_register_step35_flash_template excludes for Step-3.5-Flash.", |
| args.exclude_layers, |
| ) |
| exclude_layers = _step35_template_exclude_layers() |
| logger.info("Exclude layers (template): %s", exclude_layers) |
| if getattr(args, "quant_algo", None): |
| logger.info("Quantization algorithm(s): %s", args.quant_algo) |
| |
| quant_config = template.get_config( |
| scheme=args.quant_scheme, |
| algorithm=args.quant_algo, |
| kv_cache_scheme=None, |
| min_kv_scale=0.0, |
| layer_config={}, |
| attention_scheme=None, |
| exclude_layers=exclude_layers, |
| algo_configs=None, |
| ) |
| |
| quantizer = ModelQuantizer(quant_config, args.multi_device) |
| model = quantizer.quantize_model(model, calib_dataloader) |
| |
| model = quantizer.freeze(model) |
|
|
| logger.info("Step 4/4: Exporting HF safetensors ...") |
| _copy_non_weight_files(args.model_dir, args.output_quantized_hf_path) |
| with torch.no_grad(): |
| export_safetensors( |
| model=model, |
| output_dir=args.output_quantized_hf_path, |
| custom_mode="quark", |
| weight_format=args.export_weight_format, |
| pack_method=args.pack_method, |
| ) |
| tokenizer.save_pretrained(args.output_quantized_hf_path) |
|
|
| logger.info("Export completed.") |
| logger.info("========== Quantization Completed Successfully ==========") |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser( |
| description="Offline quantization for Step-3.5-Flash with MoE layer replacement" |
| ) |
| parser.add_argument("--model_dir", dest="model_dir", type=str, default=DEFAULT_INPUT_MODEL_PATH) |
| parser.add_argument("--output_dir", dest="output_quantized_hf_path", type=str, default=DEFAULT_OUTPUT_MODEL_PATH) |
|
|
| |
| parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "cpu"]) |
| parser.add_argument("--multi_gpu", dest="multi_gpu", action="store_true") |
| parser.add_argument("--multi_device", dest="multi_device", action="store_true") |
| parser.add_argument( |
| "--model_attn_implementation", |
| dest="model_attn_implementation", |
| type=str, |
| default="eager", |
| choices=["eager", "sdpa", "flash_attention_2"], |
| ) |
| parser.add_argument( |
| "--data_type", |
| dest="data_type", |
| type=str, |
| default="auto", |
| choices=["auto", "float16", "bfloat16", "float32"], |
| ) |
|
|
| |
| parser.add_argument( |
| "--dataset", |
| dest="dataset", |
| type=str, |
| default="pileval", |
| help="Calibration dataset name. Default is 'pileval'.", |
| ) |
| parser.add_argument("--seq_len", dest="seq_len", type=int, default=512) |
| parser.add_argument("--batch_size", dest="batch_size", type=int, default=1) |
| parser.add_argument("--num_calib_data", dest="num_calib_data", type=int, default=128) |
|
|
| |
| parser.add_argument( |
| "--preset", |
| dest="preset", |
| type=str, |
| choices=sorted(PRESETS.keys()), |
| default="mxfp4_moe_only_no_kvcache", |
| help="Convenience preset for quantization settings.", |
| ) |
| parser.add_argument( |
| "--quant_algo", |
| dest="quant_algo", |
| type=str, |
| default=None, |
| help="Optional quantization algorithm(s) to apply.", |
| ) |
| parser.add_argument( |
| "--exclude_layers", |
| type=str, |
| nargs="*", |
| default=None, |
| help="Layer wildcard patterns to exclude from quantization.", |
| ) |
|
|
| |
| parser.add_argument("--pack_method", dest="pack_method", type=str, default="reorder", choices=["order", "reorder"]) |
| parser.add_argument( |
| "--export_weight_format", |
| dest="export_weight_format", |
| type=str, |
| default="real_quantized", |
| choices=["fake_quantized", "real_quantized"], |
| ) |
| group = parser.add_mutually_exclusive_group() |
| group.add_argument( |
| "--trust_remote_code", |
| action="store_true", |
| dest="trust_remote_code", |
| help="Enable execution of custom model code from the Hub (use only with repositories you fully trust).", |
| ) |
| group.add_argument( |
| "--no_trust_remote_code", |
| action="store_false", |
| dest="trust_remote_code", |
| help="Disable execution of custom model code from the Hub (safer, recommended if unsure).", |
| ) |
| parser.set_defaults(trust_remote_code=True) |
|
|
| main(parser.parse_args()) |
|
|
|
|