| from safetensors.torch import load_file, save_file |
| import torch |
| from typing import List, Dict, Optional |
| import logging |
| from tqdm import tqdm |
| import os |
| import hashlib |
| from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
| |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
| logger = logging.getLogger(__name__) |
|
|
| def calculate_checksum(file_path: str) -> str: |
| """ |
| Calculate the SHA-256 checksum of a file. |
| |
| Args: |
| file_path (str): Path to the file. |
| |
| Returns: |
| str: SHA-256 checksum of the file. |
| """ |
| sha256 = hashlib.sha256() |
| with open(file_path, "rb") as f: |
| for chunk in iter(lambda: f.read(4096), b""): |
| sha256.update(chunk) |
| return sha256.hexdigest() |
|
|
| def verify_checksums(model_parts: List[str], expected_checksums: List[str]) -> None: |
| """ |
| Verify the checksums of model part files. |
| |
| Args: |
| model_parts (list): List of model part file paths. |
| expected_checksums (list): List of expected checksums for each part. |
| |
| Raises: |
| RuntimeError: If any checksum does not match. |
| """ |
| for part, expected_checksum in zip(model_parts, expected_checksums): |
| actual_checksum = calculate_checksum(part) |
| if actual_checksum != expected_checksum: |
| raise RuntimeError(f"Checksum mismatch for {part}: expected {expected_checksum}, got {actual_checksum}") |
|
|
| def load_part(part: str) -> Dict[str, torch.Tensor]: |
| """ |
| Load a single model part. |
| |
| Args: |
| part (str): Path to the model part file. |
| |
| Returns: |
| dict: State dictionary of the model part. |
| """ |
| return load_file(part) |
|
|
| def load_charm_model(model_parts: List[str], expected_checksums: Optional[List[str]] = None) -> Dict[str, torch.Tensor]: |
| """ |
| Load and merge multiple .safetensors model files. |
| |
| Args: |
| model_parts (list): List of model part file paths (e.g., ["model-1-of-10.safetensors", ...]). |
| expected_checksums (list, optional): List of expected checksums for each part. |
| |
| Returns: |
| dict: Merged model state dictionary. |
| |
| Raises: |
| FileNotFoundError: If any model part file is missing. |
| RuntimeError: If there is an issue loading or merging the model parts. |
| """ |
| merged_state_dict = {} |
|
|
| |
| for part in model_parts: |
| if not os.path.exists(part): |
| raise FileNotFoundError(f"Model part not found: {part}") |
|
|
| |
| if expected_checksums: |
| logger.info("Verifying checksums...") |
| verify_checksums(model_parts, expected_checksums) |
| logger.info("Checksums verified successfully.") |
|
|
| |
| try: |
| logger.info("Loading and merging model parts...") |
| with ThreadPoolExecutor() as executor: |
| futures = {executor.submit(load_part, part): part for part in model_parts} |
| for future in tqdm(as_completed(futures), total=len(futures), desc="Loading model parts"): |
| part = futures[future] |
| try: |
| state_dict = future.result() |
| merged_state_dict.update(state_dict) |
| logger.debug(f"Loaded part: {part}") |
| except Exception as e: |
| logger.error(f"Error loading part {part}: {e}") |
| raise RuntimeError(f"Failed to load part: {part}") |
|
|
| logger.info("Model parts loaded and merged successfully.") |
| return merged_state_dict |
| except Exception as e: |
| logger.error(f"Error loading or merging model parts: {e}") |
| raise RuntimeError("Failed to load or merge model parts.") |
|
|
| |
| if __name__ == "__main__": |
| try: |
| |
| model_files = [f"model-{i}-of-10.safetensors" for i in range(1, 11)] |
|
|
| |
| expected_checksums = [ |
| "checksum_for_model-1-of-10.safetensors", |
| "checksum_for_model-2-of-10.safetensors", |
| |
| ] |
|
|
| |
| charm_model = load_charm_model(model_files, expected_checksums) |
|
|
| |
| output_file = "merged_model.safetensors" |
| save_file(charm_model, output_file) |
| logger.info(f"Merged model saved as '{output_file}'.") |
| except Exception as e: |
| logger.error(f"An error occurred: {e}") |