| from safetensors.torch import save_file, load_file |
| import torch |
|
|
| def merge_model_components( |
| unet_path, |
| vae_path, |
| text_encoder_path, |
| output_path |
| ): |
| """ |
| Merge UNet, VAE, and text encoder into a single safetensors file. |
| |
| Args: |
| unet_path: Path to the main model/unet safetensors file |
| vae_path: Path to the VAE safetensors file |
| text_encoder_path: Path to the text encoder/CLIP safetensors file |
| output_path: Path where the merged file will be saved |
| """ |
| |
| print("Loading UNet/Model weights...") |
| unet_state = load_file(unet_path) |
| |
| print("Loading VAE weights...") |
| vae_state = load_file(vae_path) |
| |
| print("Loading Text Encoder weights...") |
| text_encoder_state = load_file(text_encoder_path) |
| |
| |
| print("Merging state dictionaries...") |
| merged_state = {} |
| |
| |
| merged_state.update(unet_state) |
| |
| |
| for key, value in vae_state.items(): |
| |
| if not key.startswith('vae.'): |
| merged_state[f'vae.{key}'] = value |
| else: |
| merged_state[key] = value |
| |
| |
| for key, value in text_encoder_state.items(): |
| |
| if not key.startswith('text_encoder.'): |
| merged_state[f'text_encoder.{key}'] = value |
| else: |
| merged_state[key] = value |
| |
| print(f"Total parameters in merged model: {len(merged_state)}") |
| print(f"Saving merged model to {output_path}...") |
| |
| |
| save_file(merged_state, output_path) |
| |
| print("✅ Merge complete!") |
| print(f"File saved to: {output_path}") |
| |
| |
| import os |
| size_gb = os.path.getsize(output_path) / (1024**3) |
| print(f"File size: {size_gb:.2f} GB") |
|
|
|
|
| |
| if __name__ == "__main__": |
| merge_model_components( |
| unet_path="flux1-depth-dev.safetensors", |
| vae_path="vae/diffusion_pytorch_model.safetensors", |
| text_encoder_path="text_encoder/model.safetensors", |
| output_path="flux1-depth-dev_merged_model.safetensors" |
| ) |
|
|