| """ |
| Convert MTP expert weights from NVFP4 packed format to INT4 compressed-tensors. |
| |
| FP4 E2M1 format: 2 values packed per U8 byte |
| weight: [out, in/2] U8 (2 FP4 per byte, block_size_fp4=16 fp4 values = 8 bytes) |
| weight_scale: [out, in/2/8] F8E4M3 (one scale per 8 bytes = per 16 fp4 values) |
| weight_scale_2: scalar F32 (global scale) |
| input_scale: scalar F32 (activation scale, ignored for weight loading) |
| """ |
| import torch |
| import numpy as np |
| from safetensors import safe_open |
| from safetensors.torch import save_file |
| from collections import OrderedDict |
|
|
| MTP_PATH = "/data/models/Kimi-K2.5-MTP/mtp_fp8_orig.safetensors" |
| OUTPUT_PATH = "/data/models/Kimi-K2.5-MTP/mtp.safetensors" |
| GROUP_SIZE = 32 |
| PACK_FACTOR = 8 |
|
|
| |
| |
| FP4_TABLE = torch.tensor([ |
| 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, |
| -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0 |
| ], dtype=torch.float32) |
|
|
| def dequant_fp4_block(weight_u8, weight_scale_fp8e4m3, weight_scale_2_f32): |
| """Dequantize FP4-packed weight to BF16. |
| |
| weight_u8: [out, in/2] — 2 FP4 values per U8 byte |
| weight_scale: [out, in/2/8] F8E4M3 — scale per 8 bytes (16 fp4 values) |
| weight_scale_2: scalar F32 — global scale |
| |
| Returns: [out, in] BF16 |
| """ |
| out_f, in_packed = weight_u8.shape |
| in_fp4 = in_packed * 2 |
| |
| |
| w_u8 = weight_u8.to(torch.int32) |
| low_nibble = w_u8 & 0x0F |
| high_nibble = (w_u8 >> 4) & 0x0F |
| |
| |
| unpacked = torch.stack([low_nibble, high_nibble], dim=-1) |
| unpacked = unpacked.reshape(out_f, in_fp4) |
| |
| |
| decoded = FP4_TABLE[unpacked.cpu()].to(torch.float32) |
| |
| |
| |
| scale = weight_scale_fp8e4m3.to(torch.float32) |
| |
| scale_expanded = scale.repeat_interleave(16, dim=-1) |
| |
| |
| global_scale = weight_scale_2_f32.item() if weight_scale_2_f32.numel() == 1 else 1.0 |
| |
| result = decoded * scale_expanded * global_scale |
| return result.to(torch.bfloat16) |
|
|
| def quantize_int4_gptq(weight_bf16, group_size=32): |
| """Quantize BF16 to INT4 GPTQ format (packed 4 values per INT32).""" |
| out_f, in_f = weight_bf16.shape |
| w = weight_bf16.to(torch.float32) |
| |
| pad = (group_size - in_f % group_size) % group_size |
| if pad > 0: |
| w = torch.nn.functional.pad(w, (0, pad)) |
| in_padded = w.shape[1] |
| |
| w_grouped = w.reshape(out_f, -1, group_size) |
| scales = w_grouped.abs().amax(dim=-1) / 7.0 |
| scales = scales.clamp(min=1e-10) |
| |
| w_int = torch.round(w_grouped / scales.unsqueeze(-1)).clamp(-8, 7).to(torch.int8) |
| w_int = w_int.reshape(out_f, in_padded) |
| |
| |
| w_unsigned = (w_int + 8).to(torch.int32) |
| w_r = w_unsigned.reshape(out_f, -1, PACK_FACTOR) |
| packed = torch.zeros(out_f, w_r.shape[1], dtype=torch.int32) |
| for i in range(PACK_FACTOR): |
| packed |= (w_r[:, :, i] & 0xF) << (i * 4) |
| |
| shape = torch.tensor([out_f, in_f], dtype=torch.int32) |
| return packed, scales.to(torch.bfloat16), shape |
|
|
| print("Loading original FP4-packed MTP weights...") |
| new_tensors = OrderedDict() |
| converted_expert = 0 |
| converted_shared = 0 |
| passed = 0 |
|
|
| with safe_open(MTP_PATH, framework="pt", device="cpu") as f: |
| all_keys = sorted(f.keys()) |
| |
| |
| fp4_bases = set() |
| for k in all_keys: |
| if k.endswith(".weight") and not k.endswith("_scale") and not k.endswith("_scale_2"): |
| t = f.get_tensor(k) |
| if t.dtype == torch.uint8: |
| base = k[:-7] |
| if f"{base}.weight_scale" in all_keys: |
| fp4_bases.add(base) |
| |
| print(f"FP4-packed projections: {len(fp4_bases)}") |
| |
| processed = set() |
| for k in all_keys: |
| if k in processed: |
| continue |
| |
| base = None |
| for fb in fp4_bases: |
| if k.startswith(fb + "."): |
| base = fb |
| break |
| |
| if base is not None: |
| if k == f"{base}.weight": |
| w_u8 = f.get_tensor(k) |
| w_scale = f.get_tensor(f"{base}.weight_scale") |
| w_scale2 = f.get_tensor(f"{base}.weight_scale_2") |
| |
| w_bf16 = dequant_fp4_block(w_u8, w_scale, w_scale2) |
| |
| if ".mlp.experts." in base: |
| packed, scales, shape = quantize_int4_gptq(w_bf16, GROUP_SIZE) |
| new_tensors[f"{base}.weight_packed"] = packed |
| new_tensors[f"{base}.weight_scale"] = scales |
| new_tensors[f"{base}.weight_shape"] = shape |
| converted_expert += 1 |
| if converted_expert == 1: |
| print(f" Sample: {base}.weight_packed: {list(packed.shape)}, scale: {list(scales.shape)}") |
| else: |
| new_tensors[f"{base}.weight"] = w_bf16 |
| converted_shared += 1 |
| |
| processed.update([k, f"{base}.weight_scale", f"{base}.weight_scale_2", f"{base}.input_scale"]) |
| continue |
| |
| new_tensors[k] = f.get_tensor(k) |
| passed += 1 |
|
|
| print(f"Expert→INT4: {converted_expert}, Shared→BF16: {converted_shared}, Passthrough: {passed}") |
| print(f"Total: {len(new_tensors)}") |
| print("Saving...") |
| save_file(new_tensors, OUTPUT_PATH) |
| import os |
| print(f"Saved: {os.path.getsize(OUTPUT_PATH)/1024/1024:.1f} MB") |
|
|