Kimi-K2.5-vision / extract_vision_weights.py
davehind's picture
Add extraction script and update README with reproduction steps
3fe5217
#!/usr/bin/env python3
"""
Extract vision tower + projector weights from moonshotai/Kimi-K2.5.
Downloads only the 2 shards that contain vision weights (~900MB total),
extracts the relevant keys, and saves as a single compact file.
Usage:
python extract_vision_weights.py [--output-dir ./weights]
Requirements:
pip install safetensors huggingface_hub
"""
import argparse
import os
from pathlib import Path
from huggingface_hub import hf_hub_download
from safetensors import safe_open
from safetensors.torch import save_file
REPO_ID = "moonshotai/Kimi-K2.5"
VISION_SHARDS = [
"model-00063-of-000064.safetensors", # mm_projector keys (~103MB)
"model-00064-of-000064.safetensors", # vision_tower keys (~795MB)
]
# We also need the safetensors index to verify which keys are in which shard
INDEX_FILE = "model.safetensors.index.json"
def is_vision_key(key: str) -> bool:
"""Check if a weight key belongs to vision tower or projector."""
return "vision_tower" in key or "mm_projector" in key
def main():
parser = argparse.ArgumentParser(description="Extract Kimi-K2.5 vision weights")
parser.add_argument(
"--output-dir",
type=str,
default=str(Path(__file__).parent.parent / "weights"),
help="Directory to save extracted weights",
)
parser.add_argument(
"--cache-dir",
type=str,
default=None,
help="HuggingFace cache directory for downloaded shards",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Just show what would be downloaded, don't actually download",
)
args = parser.parse_args()
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
output_path = output_dir / "kimi_k25_vision.safetensors"
if output_path.exists():
print(f"Output already exists: {output_path}")
print(f" Size: {output_path.stat().st_size / (1024**2):.1f} MB")
print("Delete it to re-extract.")
return
# Step 1: Download the index to verify our shard list
print(f"Downloading safetensors index from {REPO_ID}...")
index_path = hf_hub_download(
REPO_ID, INDEX_FILE, cache_dir=args.cache_dir
)
import json
with open(index_path) as f:
index = json.load(f)
weight_map = index["weight_map"]
vision_keys = {k: v for k, v in weight_map.items() if is_vision_key(k)}
needed_shards = sorted(set(vision_keys.values()))
print(f"Found {len(vision_keys)} vision/projector keys across {len(needed_shards)} shards:")
for shard in needed_shards:
n = sum(1 for v in vision_keys.values() if v == shard)
print(f" {shard}: {n} keys")
assert set(needed_shards) == set(VISION_SHARDS), (
f"Expected shards {VISION_SHARDS}, got {needed_shards}. "
"The model may have been resharded — update VISION_SHARDS."
)
if args.dry_run:
print("\n[DRY RUN] Would download:")
for shard in needed_shards:
print(f" {REPO_ID}/{shard}")
print(f"\nWould extract {len(vision_keys)} keys to {output_path}")
print("\nSample keys:")
for k in sorted(vision_keys.keys())[:10]:
print(f" {k}")
print(f" ... and {len(vision_keys)-10} more")
return
# Step 2: Download shards and extract vision weights
all_tensors = {}
for shard_name in needed_shards:
print(f"\nDownloading {shard_name}...")
shard_path = hf_hub_download(
REPO_ID, shard_name, cache_dir=args.cache_dir
)
print(f" Extracting vision keys...")
with safe_open(shard_path, framework="pt", device="cpu") as f:
for key in f.keys():
if is_vision_key(key):
all_tensors[key] = f.get_tensor(key)
print(f"\nExtracted {len(all_tensors)} tensors")
# Print summary of what we got
vt_keys = [k for k in all_tensors if "vision_tower" in k]
proj_keys = [k for k in all_tensors if "mm_projector" in k]
print(f" vision_tower: {len(vt_keys)} keys")
print(f" mm_projector: {len(proj_keys)} keys")
# Print shapes of projector weights to verify dimensions
print("\nProjector weight shapes:")
for k in sorted(proj_keys):
print(f" {k}: {list(all_tensors[k].shape)}")
# Print a few vision tower weight shapes
print("\nSample vision tower shapes:")
for k in sorted(vt_keys)[:5]:
print(f" {k}: {list(all_tensors[k].shape)}")
# Step 3: Save as single file
print(f"\nSaving to {output_path}...")
save_file(all_tensors, str(output_path))
size_mb = output_path.stat().st_size / (1024**2)
print(f"Done! Saved {size_mb:.1f} MB ({len(all_tensors)} tensors)")
if __name__ == "__main__":
main()