| import os |
| import requests |
| from safetensors.torch import load_file, save_file |
| import torch |
| torch.cuda.empty_cache() |
| import torch.nn.functional as F |
| from tqdm import tqdm |
|
|
| def download_file(url, dest_path): |
| print(f"Downloading {url} to {dest_path}") |
| response = requests.get(url, stream=True) |
| if response.status_code == 200: |
| with open(dest_path, 'wb') as f: |
| for chunk in response.iter_content(1024): |
| f.write(chunk) |
| else: |
| raise Exception(f"Failed to download file from {url}") |
|
|
| def load_model(file_path): |
| return load_file(file_path) |
|
|
| def save_model(merged_model, output_file): |
| print(f"Saving merged model to {output_file}") |
| save_file(merged_model, output_file) |
|
|
| def resize_tensor_shapes(tensor1, tensor2): |
| if tensor1.size() == tensor2.size(): |
| return tensor1, tensor2 |
|
|
| |
| max_shape = [max(s1, s2) for s1, s2 in zip(tensor1.shape, tensor2.shape)] |
| tensor1_resized = F.pad(tensor1, (0, max_shape[-1] - tensor1.size(-1))) |
| tensor2_resized = F.pad(tensor2, (0, max_shape[-1] - tensor2.size(-1))) |
|
|
| return tensor1_resized, tensor2_resized |
|
|
| def merge_checkpoints(ckpt1, ckpt2, blend_ratio=0.6): |
| print(f"Merging checkpoints with blend ratio: {blend_ratio}") |
| merged = {} |
| all_keys = set(ckpt1.keys()).union(set(ckpt2.keys())) |
|
|
| for key in tqdm(all_keys, desc="Merging Checkpoints", unit="layer"): |
| t1, t2 = ckpt1.get(key), ckpt2.get(key) |
| if t1 is not None and t2 is not None: |
| t1, t2 = resize_tensor_shapes(t1, t2) |
| merged[key] = blend_ratio * t1 + (1 - blend_ratio) * t2 |
| elif t1 is not None: |
| merged[key] = t1 |
| else: |
| merged[key] = t2 |
|
|
| |
| control_output_size(merged, target_size_gb=26) |
|
|
| return merged |
|
|
| def control_output_size(merged, target_size_gb): |
| |
| target_size_bytes = target_size_gb * 1024**3 |
| current_size_bytes = sum(tensor.numel() * tensor.element_size() for tensor in merged.values()) |
|
|
| |
| if current_size_bytes > target_size_bytes: |
| excess_size = current_size_bytes - target_size_bytes |
| print(f"Current size exceeds target by {excess_size / (1024**2):.2f} MB. Adjusting...") |
| |
| |
| elements_to_reduce = excess_size // 4 |
| total_elements = sum(tensor.numel() for tensor in merged.values()) |
|
|
| |
| for key in merged.keys(): |
| tensor = merged[key] |
| num_elements = tensor.numel() |
| |
| reduction = min(elements_to_reduce, num_elements) |
| merged[key] = tensor.flatten()[:num_elements - reduction].view(tensor.shape) |
| elements_to_reduce -= reduction |
| if elements_to_reduce <= 0: |
| break |
|
|
| def cleanup_files(*file_paths): |
| for file_path in file_paths: |
| if os.path.exists(file_path): |
| os.remove(file_path) |
| print(f"Deleted {file_path}") |
|
|
| if __name__ == "__main__": |
| try: |
| model1_path = "mangledMergeFlux_v0Bfloat16Dev.safetensors" |
| model2_path = "output_checkpoint.safetensors" |
| blend_ratio = 0.6 |
| output_file = "output_checkpoint.safetensors" |
|
|
| |
| model1 = load_model(model1_path) |
| model2 = load_model(model2_path) |
|
|
| |
| merged_model = merge_checkpoints(model1, model2, blend_ratio) |
|
|
| |
| save_model(merged_model, output_file) |
|
|
| |
| cleanup_files(model1_path) |
| |
| except Exception as e: |
| print(f"An error occurred: {e}") |