| from safetensors import safe_open |
| from collections import defaultdict |
| import os |
|
|
| def inspect_checkpoint(checkpoint_path, detailed=False): |
| """ |
| Inspect the structure of a safetensors checkpoint file. |
| |
| Args: |
| checkpoint_path: Path to the .safetensors file |
| detailed: If True, shows more detailed information |
| """ |
| |
| if not os.path.exists(checkpoint_path): |
| print(f"โ File not found: {checkpoint_path}") |
| return |
| |
| print("=" * 80) |
| print(f"INSPECTING: {os.path.basename(checkpoint_path)}") |
| print("=" * 80) |
| |
| |
| size_bytes = os.path.getsize(checkpoint_path) |
| size_gb = size_bytes / (1024**3) |
| print(f"\n๐ฆ File Size: {size_gb:.2f} GB ({size_bytes:,} bytes)") |
| |
| with safe_open(checkpoint_path, framework="pt") as f: |
| keys = list(f.keys()) |
| |
| print(f"\n๐ Total Parameters: {len(keys):,}") |
| |
| |
| print("\n" + "=" * 80) |
| print("COMPONENT BREAKDOWN") |
| print("=" * 80) |
| |
| categories = defaultdict(list) |
| |
| for key in keys: |
| |
| if any(x in key.lower() for x in ['vae', 'first_stage', 'decoder', 'encoder', 'quant_conv', 'post_quant']): |
| categories['VAE'].append(key) |
| elif any(x in key.lower() for x in ['text_encoder', 'cond_stage', 'clip', 'transformer.text_model']): |
| categories['Text Encoder'].append(key) |
| elif any(x in key.lower() for x in ['model.diffusion', 'diffusion_model', 'transformer', 'double_blocks', 'single_blocks']): |
| categories['UNet/Transformer'].append(key) |
| else: |
| categories['Other'].append(key) |
| |
| |
| for category, cat_keys in sorted(categories.items()): |
| print(f"\n{category}: {len(cat_keys)} parameters") |
| |
| |
| print("\n" + "=" * 80) |
| print("KEY PATTERNS") |
| print("=" * 80) |
| |
| |
| prefix_groups = defaultdict(int) |
| for key in keys: |
| prefix = key.split('.')[0] if '.' in key else key |
| prefix_groups[prefix] += 1 |
| |
| print("\nTop-level prefixes:") |
| for prefix, count in sorted(prefix_groups.items(), key=lambda x: -x[1]): |
| print(f" {prefix}: {count} parameters") |
| |
| |
| print("\n" + "=" * 80) |
| print("SAMPLE KEYS FROM EACH COMPONENT") |
| print("=" * 80) |
| |
| for category, cat_keys in sorted(categories.items()): |
| if cat_keys: |
| print(f"\n{category} (showing first 5):") |
| for key in cat_keys[:5]: |
| tensor = f.get_tensor(key) |
| print(f" {key}") |
| print(f" โโ shape: {tuple(tensor.shape)}, dtype: {tensor.dtype}") |
| |
| if detailed: |
| print("\n" + "=" * 80) |
| print("ALL KEYS (DETAILED)") |
| print("=" * 80) |
| |
| for i, key in enumerate(keys, 1): |
| tensor = f.get_tensor(key) |
| print(f"\n{i}. {key}") |
| print(f" Shape: {tuple(tensor.shape)}") |
| print(f" Dtype: {tensor.dtype}") |
| print(f" Size: {tensor.numel():,} elements") |
| |
| |
| print("\n" + "=" * 80) |
| print("MODEL TYPE DETECTION") |
| print("=" * 80) |
| |
| has_flux_blocks = any('double_blocks' in k or 'single_blocks' in k for k in keys) |
| has_sd_unet = any('model.diffusion_model' in k for k in keys) |
| has_vae = any('vae' in k.lower() or 'first_stage' in k for k in keys) |
| has_text_encoder = any('text_encoder' in k.lower() or 'cond_stage' in k for k in keys) |
| |
| print(f"\nโ FLUX-style blocks: {'โ
YES' if has_flux_blocks else 'โ NO'}") |
| print(f"โ SD-style UNet: {'โ
YES' if has_sd_unet else 'โ NO'}") |
| print(f"โ VAE included: {'โ
YES' if has_vae else 'โ NO'}") |
| print(f"โ Text Encoder included: {'โ
YES' if has_text_encoder else 'โ NO'}") |
| |
| if has_flux_blocks: |
| print("\n๐ Likely model type: FLUX") |
| elif has_sd_unet: |
| print("\n๐ Likely model type: Stable Diffusion") |
| else: |
| print("\nโ ๏ธ Could not determine model type") |
| |
| |
| print("\n" + "=" * 80) |
| print("CHECKPOINT COMPLETENESS") |
| print("=" * 80) |
| |
| if has_vae and has_text_encoder: |
| print("\nโ
This appears to be a COMPLETE checkpoint") |
| print(" (Contains UNet/Transformer + VAE + Text Encoder)") |
| else: |
| print("\nโ ๏ธ This appears to be a PARTIAL checkpoint") |
| if not has_vae: |
| print(" Missing: VAE") |
| if not has_text_encoder: |
| print(" Missing: Text Encoder") |
| |
| print("\n" + "=" * 80) |
| print("INSPECTION COMPLETE") |
| print("=" * 80) |
|
|
|
|
| def compare_checkpoints(working_checkpoint, broken_checkpoint): |
| """ |
| Compare two checkpoints to see the differences. |
| |
| Args: |
| working_checkpoint: Path to checkpoint that works |
| broken_checkpoint: Path to checkpoint that doesn't work |
| """ |
| |
| print("=" * 80) |
| print("COMPARING CHECKPOINTS") |
| print("=" * 80) |
| |
| with safe_open(working_checkpoint, framework="pt") as f1: |
| keys1 = set(f1.keys()) |
| |
| with safe_open(broken_checkpoint, framework="pt") as f2: |
| keys2 = set(f2.keys()) |
| |
| print(f"\nWorking checkpoint: {len(keys1)} keys") |
| print(f"Broken checkpoint: {len(keys2)} keys") |
| |
| only_in_working = keys1 - keys2 |
| only_in_broken = keys2 - keys1 |
| common = keys1 & keys2 |
| |
| print(f"\nCommon keys: {len(common)}") |
| print(f"Only in working: {len(only_in_working)}") |
| print(f"Only in broken: {len(only_in_broken)}") |
| |
| if only_in_working: |
| print("\n๐ Keys present in WORKING but missing in BROKEN (first 20):") |
| for key in sorted(only_in_working)[:20]: |
| print(f" - {key}") |
| |
| if only_in_broken: |
| print("\n๐ Keys present in BROKEN but missing in WORKING (first 20):") |
| for key in sorted(only_in_broken)[:20]: |
| print(f" + {key}") |
| |
| |
| print("\n" + "=" * 80) |
| print("KEY PATTERN COMPARISON") |
| print("=" * 80) |
| |
| def get_prefixes(keys): |
| prefixes = defaultdict(int) |
| for key in keys: |
| prefix = key.split('.')[0] |
| prefixes[prefix] += 1 |
| return prefixes |
| |
| prefixes1 = get_prefixes(keys1) |
| prefixes2 = get_prefixes(keys2) |
| |
| all_prefixes = set(prefixes1.keys()) | set(prefixes2.keys()) |
| |
| print(f"\n{'Prefix':<30} {'Working':<15} {'Broken':<15}") |
| print("-" * 60) |
| for prefix in sorted(all_prefixes): |
| count1 = prefixes1.get(prefix, 0) |
| count2 = prefixes2.get(prefix, 0) |
| status = "โ
" if count1 == count2 else "โ ๏ธ " |
| print(f"{status} {prefix:<28} {count1:<15} {count2:<15}") |
|
|
|
|
| |
| if __name__ == "__main__": |
| |
| print("OPTION 1: Inspect your working checkpoint") |
| print("-" * 80) |
| inspect_checkpoint( |
| "../flux1-depth-dev_ComfyMerged.safetensors", |
| detailed=False |
| ) |
| |
| print("\n\n") |
| |
| |
| |
| |
| |
| |
| |
| |
|
|