| |
| """ |
| Extract vision tower + projector weights from moonshotai/Kimi-K2.5. |
| |
| Downloads only the 2 shards that contain vision weights (~900MB total), |
| extracts the relevant keys, and saves as a single compact file. |
| |
| Usage: |
| python extract_vision_weights.py [--output-dir ./weights] |
| |
| Requirements: |
| pip install safetensors huggingface_hub |
| """ |
|
|
| import argparse |
| import os |
| from pathlib import Path |
|
|
| from huggingface_hub import hf_hub_download |
| from safetensors import safe_open |
| from safetensors.torch import save_file |
|
|
|
|
| REPO_ID = "moonshotai/Kimi-K2.5" |
| VISION_SHARDS = [ |
| "model-00063-of-000064.safetensors", |
| "model-00064-of-000064.safetensors", |
| ] |
|
|
| |
| INDEX_FILE = "model.safetensors.index.json" |
|
|
|
|
| def is_vision_key(key: str) -> bool: |
| """Check if a weight key belongs to vision tower or projector.""" |
| return "vision_tower" in key or "mm_projector" in key |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Extract Kimi-K2.5 vision weights") |
| parser.add_argument( |
| "--output-dir", |
| type=str, |
| default=str(Path(__file__).parent.parent / "weights"), |
| help="Directory to save extracted weights", |
| ) |
| parser.add_argument( |
| "--cache-dir", |
| type=str, |
| default=None, |
| help="HuggingFace cache directory for downloaded shards", |
| ) |
| parser.add_argument( |
| "--dry-run", |
| action="store_true", |
| help="Just show what would be downloaded, don't actually download", |
| ) |
| args = parser.parse_args() |
|
|
| output_dir = Path(args.output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
| output_path = output_dir / "kimi_k25_vision.safetensors" |
|
|
| if output_path.exists(): |
| print(f"Output already exists: {output_path}") |
| print(f" Size: {output_path.stat().st_size / (1024**2):.1f} MB") |
| print("Delete it to re-extract.") |
| return |
|
|
| |
| print(f"Downloading safetensors index from {REPO_ID}...") |
| index_path = hf_hub_download( |
| REPO_ID, INDEX_FILE, cache_dir=args.cache_dir |
| ) |
|
|
| import json |
| with open(index_path) as f: |
| index = json.load(f) |
|
|
| weight_map = index["weight_map"] |
| vision_keys = {k: v for k, v in weight_map.items() if is_vision_key(k)} |
| needed_shards = sorted(set(vision_keys.values())) |
|
|
| print(f"Found {len(vision_keys)} vision/projector keys across {len(needed_shards)} shards:") |
| for shard in needed_shards: |
| n = sum(1 for v in vision_keys.values() if v == shard) |
| print(f" {shard}: {n} keys") |
|
|
| assert set(needed_shards) == set(VISION_SHARDS), ( |
| f"Expected shards {VISION_SHARDS}, got {needed_shards}. " |
| "The model may have been resharded — update VISION_SHARDS." |
| ) |
|
|
| if args.dry_run: |
| print("\n[DRY RUN] Would download:") |
| for shard in needed_shards: |
| print(f" {REPO_ID}/{shard}") |
| print(f"\nWould extract {len(vision_keys)} keys to {output_path}") |
| print("\nSample keys:") |
| for k in sorted(vision_keys.keys())[:10]: |
| print(f" {k}") |
| print(f" ... and {len(vision_keys)-10} more") |
| return |
|
|
| |
| all_tensors = {} |
|
|
| for shard_name in needed_shards: |
| print(f"\nDownloading {shard_name}...") |
| shard_path = hf_hub_download( |
| REPO_ID, shard_name, cache_dir=args.cache_dir |
| ) |
|
|
| print(f" Extracting vision keys...") |
| with safe_open(shard_path, framework="pt", device="cpu") as f: |
| for key in f.keys(): |
| if is_vision_key(key): |
| all_tensors[key] = f.get_tensor(key) |
|
|
| print(f"\nExtracted {len(all_tensors)} tensors") |
|
|
| |
| vt_keys = [k for k in all_tensors if "vision_tower" in k] |
| proj_keys = [k for k in all_tensors if "mm_projector" in k] |
| print(f" vision_tower: {len(vt_keys)} keys") |
| print(f" mm_projector: {len(proj_keys)} keys") |
|
|
| |
| print("\nProjector weight shapes:") |
| for k in sorted(proj_keys): |
| print(f" {k}: {list(all_tensors[k].shape)}") |
|
|
| |
| print("\nSample vision tower shapes:") |
| for k in sorted(vt_keys)[:5]: |
| print(f" {k}: {list(all_tensors[k].shape)}") |
|
|
| |
| print(f"\nSaving to {output_path}...") |
| save_file(all_tensors, str(output_path)) |
|
|
| size_mb = output_path.stat().st_size / (1024**2) |
| print(f"Done! Saved {size_mb:.1f} MB ({len(all_tensors)} tensors)") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|