| from safetensors.torch import save_file, load_file |
| import torch |
| import os |
|
|
| def inspect_keys(file_path, max_keys=10): |
| """Helper function to inspect the structure of a safetensors file.""" |
| state = load_file(file_path) |
| keys = list(state.keys()) |
| print(f"\n{os.path.basename(file_path)} - Total keys: {len(keys)}") |
| print(f"First {max_keys} keys:") |
| for k in keys[:max_keys]: |
| print(f" {k}") |
| return keys |
|
|
| def merge_for_comfyui( |
| unet_path, |
| vae_path, |
| text_encoder_path, |
| output_path, |
| model_type="flux" |
| ): |
| """ |
| Merge components into ComfyUI-compatible safetensors checkpoint. |
| |
| Args: |
| unet_path: Path to the main model/transformer safetensors |
| vae_path: Path to the VAE safetensors |
| text_encoder_path: Path to the text encoder/CLIP safetensors |
| output_path: Path for the merged checkpoint |
| model_type: Type of model (flux, sd15, sdxl) |
| """ |
| |
| print("=" * 60) |
| print("STEP 1: Inspecting input files...") |
| print("=" * 60) |
| |
| |
| unet_keys = inspect_keys(unet_path) |
| vae_keys = inspect_keys(vae_path) |
| text_encoder_keys = inspect_keys(text_encoder_path) |
| |
| print("\n" + "=" * 60) |
| print("STEP 2: Loading weights...") |
| print("=" * 60) |
| |
| unet_state = load_file(unet_path) |
| vae_state = load_file(vae_path) |
| text_encoder_state = load_file(text_encoder_path) |
| |
| print("\n" + "=" * 60) |
| print("STEP 3: Merging with proper key structure...") |
| print("=" * 60) |
| |
| merged_state = {} |
| |
| |
| sample_unet_key = unet_keys[0] |
| sample_vae_key = vae_keys[0] |
| sample_te_key = text_encoder_keys[0] |
| |
| print(f"\nDetected key patterns:") |
| print(f" UNet: {sample_unet_key}") |
| print(f" VAE: {sample_vae_key}") |
| print(f" Text Encoder: {sample_te_key}") |
| |
| |
| for key, value in unet_state.items(): |
| |
| if key.startswith('model.') or key.startswith('diffusion_model.'): |
| merged_state[key] = value |
| else: |
| |
| merged_state[f'model.diffusion_model.{key}'] = value |
| |
| |
| for key, value in vae_state.items(): |
| if key.startswith('first_stage_model.') or key.startswith('vae.'): |
| merged_state[key] = value |
| elif key.startswith('decoder.') or key.startswith('encoder.'): |
| merged_state[f'first_stage_model.{key}'] = value |
| else: |
| merged_state[f'first_stage_model.decoder.{key}'] = value |
| |
| |
| for key, value in text_encoder_state.items(): |
| if key.startswith('cond_stage_model.') or key.startswith('text_encoder.'): |
| merged_state[key] = value |
| else: |
| |
| if model_type.lower() == "flux": |
| merged_state[f'text_encoders.{key}'] = value |
| else: |
| merged_state[f'cond_stage_model.transformer.{key}'] = value |
| |
| print(f"\nMerged state contains {len(merged_state)} parameters") |
| |
| |
| print("\n" + "=" * 60) |
| print("STEP 4: Saving merged checkpoint...") |
| print("=" * 60) |
| |
| save_file(merged_state, output_path) |
| |
| print("\n✅ Merge complete!") |
| print(f"File saved to: {output_path}") |
| |
| size_gb = os.path.getsize(output_path) / (1024**3) |
| print(f"File size: {size_gb:.2f} GB") |
| |
| |
| print("\n" + "=" * 60) |
| print("STEP 5: Verifying merged file...") |
| print("=" * 60) |
| inspect_keys(output_path, max_keys=20) |
|
|
|
|
| def simple_merge_keep_structure( |
| unet_path, |
| vae_path, |
| text_encoder_path, |
| output_path |
| ): |
| """ |
| Simple merge that preserves original key structure. |
| Use this if the files already have proper ComfyUI keys. |
| """ |
| print("Loading all components...") |
| |
| unet_state = load_file(unet_path) |
| vae_state = load_file(vae_path) |
| text_encoder_state = load_file(text_encoder_path) |
| |
| print("Merging...") |
| merged_state = {} |
| merged_state.update(unet_state) |
| merged_state.update(vae_state) |
| merged_state.update(text_encoder_state) |
| |
| print(f"Saving {len(merged_state)} parameters...") |
| save_file(merged_state, output_path) |
| |
| size_gb = os.path.getsize(output_path) / (1024**3) |
| print(f"✅ Done! File size: {size_gb:.2f} GB") |
|
|
|
|
| |
| if __name__ == "__main__": |
| |
| merge_for_comfyui( |
| unet_path="../flux1-depth-dev.safetensors", |
| vae_path="../vae/diffusion_pytorch_model.safetensors", |
| text_encoder_path="../text_encoder/model.safetensors", |
| output_path="../flux1-depth-dev_merged_model.safetensors", |
| model_type="flux" |
| ) |
| |
| |
| |
| |
| |
| |
| |
| |
|
|