| import os |
| import json |
| from safetensors import safe_open |
| from safetensors.torch import save_file |
|
|
|
|
| def merge_safetensors(input_dir, output_file, config_file): |
| |
| merged_tensors = {} |
|
|
| |
| with open(config_file, 'r') as f: |
| config = json.load(f) |
|
|
| |
| metadata = { |
| "format": "pt", |
| "total_size": "", |
| "_diffusers_version": config.get("_diffusers_version", ""), |
| "_class_name": config.get("_class_name", ""), |
| |
| } |
|
|
| total_size = 0 |
|
|
| |
| for filename in os.listdir(input_dir): |
| if filename.endswith('.safetensors'): |
| file_path = os.path.join(input_dir, filename) |
|
|
| |
| with safe_open(file_path, framework="pt", device="cpu") as f: |
| file_metadata = f.metadata() |
| if file_metadata and "__metadata__" in file_metadata: |
| total_size += int(file_metadata["__metadata__"].get("total_size", 0)) |
|
|
| for key in f.keys(): |
| tensor = f.get_tensor(key) |
| merged_tensors[key] = tensor |
|
|
| |
| metadata["total_size"] = str(total_size) |
|
|
| |
| save_file(merged_tensors, output_file, metadata) |
|
|
|
|
| input_directory = './10_1' |
| output_file = './10_1/flux1-merge-S10_D1.safetensors' |
| config_file = './10_1/config.json' |
| merge_safetensors(input_directory, output_file, config_file) |