"""Convert MTP expert weights from FP8 to INT4 compressed-tensors (Marlin format). Key fix: pack_factor=4 (4 INT4 values per INT32), matching K2.5 base model format. """ import torch from safetensors import safe_open from safetensors.torch import save_file from collections import OrderedDict MTP_PATH = "/data/models/Kimi-K2.5-MTP/mtp_fp8_orig.safetensors" OUTPUT_PATH = "/data/models/Kimi-K2.5-MTP/mtp.safetensors" GROUP_SIZE = 32 FP8_BLOCK_SIZE = 8 PACK_FACTOR = 4 # 4 INT4 values per INT32 (matching base model Marlin format) def dequantize_fp8_block(weight_u8, weight_scale_fp8, weight_scale_2): out_f, in_f = weight_u8.shape block_in = in_f // FP8_BLOCK_SIZE w = weight_u8.to(torch.float32).reshape(out_f, block_in, FP8_BLOCK_SIZE) w = w - 128.0 s = weight_scale_fp8.to(torch.float32).unsqueeze(-1) s2 = weight_scale_2.item() if weight_scale_2.numel() == 1 else 1.0 return (w * s * s2).reshape(out_f, in_f).to(torch.bfloat16) def quantize_int4_marlin(weight_bf16, group_size=32): out_f, in_f = weight_bf16.shape w = weight_bf16.to(torch.float32) pad = (group_size - in_f % group_size) % group_size if pad > 0: w = torch.nn.functional.pad(w, (0, pad)) in_padded = w.shape[1] w_grouped = w.reshape(out_f, -1, group_size) scales = w_grouped.abs().amax(dim=-1) / 7.0 scales = scales.clamp(min=1e-10) w_int = torch.round(w_grouped / scales.unsqueeze(-1)).clamp(-8, 7).to(torch.int8) w_int = w_int.reshape(out_f, in_padded) # Pack with PACK_FACTOR=4 (4 INT4 values per INT32) assert in_padded % PACK_FACTOR == 0, f"in_padded={in_padded} not divisible by {PACK_FACTOR}" w_unsigned = (w_int + 8).to(torch.int32) # [0, 15] w_r = w_unsigned.reshape(out_f, -1, PACK_FACTOR) packed = torch.zeros(out_f, w_r.shape[1], dtype=torch.int32) for i in range(PACK_FACTOR): packed |= (w_r[:, :, i] & 0xF) << (i * 4) shape = torch.tensor([out_f, in_f], dtype=torch.int32) return packed, scales.to(torch.bfloat16), shape print("Loading MTP FP8 weights...") new_tensors = OrderedDict() converted_expert = 0 converted_shared = 0 passed = 0 with safe_open(MTP_PATH, framework="pt", device="cpu") as f: all_keys = sorted(f.keys()) fp8_bases = set() for k in all_keys: if k.endswith(".weight") and f"{k[:-7]}.weight_scale" in all_keys: fp8_bases.add(k[:-7]) print(f"FP8 projections: {len(fp8_bases)}") processed = set() for k in all_keys: if k in processed: continue base = None for fb in fp8_bases: if k.startswith(fb + "."): base = fb break if base is not None: if k == f"{base}.weight": w_u8 = f.get_tensor(k) w_scale = f.get_tensor(f"{base}.weight_scale") w_scale2 = f.get_tensor(f"{base}.weight_scale_2") w_bf16 = dequantize_fp8_block(w_u8, w_scale, w_scale2) if ".mlp.experts." in base: packed, scales, shape = quantize_int4_marlin(w_bf16, GROUP_SIZE) new_tensors[f"{base}.weight_packed"] = packed new_tensors[f"{base}.weight_scale"] = scales new_tensors[f"{base}.weight_shape"] = shape converted_expert += 1 else: new_tensors[f"{base}.weight"] = w_bf16 converted_shared += 1 processed.update([k, f"{base}.weight_scale", f"{base}.weight_scale_2", f"{base}.input_scale"]) continue new_tensors[k] = f.get_tensor(k) passed += 1 print(f"Expert→INT4: {converted_expert}, Shared→BF16: {converted_shared}, Passthrough: {passed}") print(f"Total: {len(new_tensors)}") # Verify pack format matches base sample = "model.layers.61.mlp.experts.0.gate_proj.weight_packed" if sample in new_tensors: print(f"\nVerify: {sample} shape={list(new_tensors[sample].shape)}") print(f"Expected: [2048, 896] (3584/4=896)") save_file(new_tensors, OUTPUT_PATH) import os print(f"Saved: {os.path.getsize(OUTPUT_PATH)/1024/1024:.1f} MB")