| import argparse |
| from safetensors import safe_open |
| from collections import Counter |
| import os |
| import math |
|
|
| |
| |
| |
| |
| |
| DTYPE_TO_BYTES = { |
| "BOOL": 1, |
| |
| "F8_E5M2": 1, |
| "F8E5M2": 1, |
| "F8_E4M3FN": 1, |
| "F8E4M3FN": 1, |
| "F8_E4M3": 1, |
| "F8_E5M2FNUZ": 1, |
| "F8E5M2FNUZ": 1, |
| "F8_E4M3FNUZ": 1, |
| "F8E4M3FNUZ": 1, |
| |
| "F16": 2, |
| "BF16": 2, |
| "F32": 4, |
| "F64": 8, |
| |
| "I8": 1, |
| "I16": 2, |
| "I32": 4, |
| "I64": 8, |
| |
| "U8": 1, |
| "U16": 2, |
| "U32": 4, |
| "U64": 8, |
| } |
|
|
| def get_bytes_per_element(dtype_str): |
| """Returns the number of bytes for a given safetensors dtype string.""" |
| return DTYPE_TO_BYTES.get(dtype_str.upper(), None) |
|
|
| def calculate_num_elements(shape): |
| """Calculates the total number of elements from a tensor shape tuple.""" |
| if not shape: |
| return 1 |
| if 0 in shape: |
| return 0 |
| |
| |
| num_elements = 1 |
| for dim_size in shape: |
| num_elements *= dim_size |
| return num_elements |
|
|
| def inspect_safetensors_precision_and_size(filepath): |
| """ |
| Reads a .safetensors file, iterates through its tensors, |
| and reports the precision (dtype), actual size, and theoretical FP32 size. |
| """ |
| if not os.path.exists(filepath): |
| print(f"Error: File not found at '{filepath}'") |
| return |
|
|
| if not filepath.lower().endswith(".safetensors"): |
| print(f"Warning: File '{filepath}' does not have a .safetensors extension. Attempting to read anyway.") |
|
|
| tensor_info_list = [] |
| dtype_counts = Counter() |
| total_actual_mb = 0.0 |
| total_fp32_equiv_mb = 0.0 |
|
|
| try: |
| print(f"Inspecting tensors in: {filepath}\n") |
| with safe_open(filepath, framework="pt", device="cpu") as f: |
| tensor_keys = list(f.keys()) |
| if not tensor_keys: |
| print("No tensors found in the file.") |
| return |
|
|
| max_key_len = len("Tensor Name") |
| if tensor_keys: |
| max_key_len = max(max_key_len, max(len(k) for k in tensor_keys)) |
|
|
| header = ( |
| f"{'Tensor Name':<{max_key_len}} | " |
| f"{'Precision (dtype)':<17} | " |
| f"{'Actual Size (MB)':>16} | " |
| f"{'FP32 Equiv. (MB)':>18}" |
| ) |
| print(header) |
| print( |
| f"{'-' * max_key_len}-|-------------------|------------------|-------------------" |
| ) |
|
|
| for key in tensor_keys: |
| tensor_slice = f.get_slice(key) |
| dtype_str = tensor_slice.get_dtype() |
| shape = tensor_slice.get_shape() |
|
|
| num_elements = calculate_num_elements(shape) |
| bytes_per_el_actual = get_bytes_per_element(dtype_str) |
|
|
| actual_size_mb_str = "N/A" |
| fp32_equiv_size_mb_str = "N/A" |
| actual_size_mb_val = 0.0 |
|
|
| if bytes_per_el_actual is not None: |
| actual_bytes = num_elements * bytes_per_el_actual |
| actual_size_mb_val = actual_bytes / (1024 * 1024) |
| total_actual_mb += actual_size_mb_val |
| actual_size_mb_str = f"{actual_size_mb_val:.3f}" |
|
|
| |
| fp32_equiv_bytes = num_elements * 4 |
| fp32_equiv_size_mb_val = fp32_equiv_bytes / (1024 * 1024) |
| total_fp32_equiv_mb += fp32_equiv_size_mb_val |
| fp32_equiv_size_mb_str = f"{fp32_equiv_size_mb_val:.3f}" |
| else: |
| print(f"Warning: Unknown dtype '{dtype_str}' for tensor '{key}'. Cannot calculate size.") |
|
|
| print( |
| f"{key:<{max_key_len}} | " |
| f"{dtype_str:<17} | " |
| f"{actual_size_mb_str:>16} | " |
| f"{fp32_equiv_size_mb_str:>18}" |
| ) |
| dtype_counts[dtype_str] += 1 |
|
|
| print("\n--- Summary ---") |
| print(f"Total tensors found: {len(tensor_keys)}") |
| if dtype_counts: |
| print("Precision distribution:") |
| for dtype, count in dtype_counts.most_common(): |
| print(f" - {dtype:<12}: {count} tensor(s)") |
| else: |
| print("No dtypes to summarize.") |
|
|
| print(f"\nTotal actual size of all tensors: {total_actual_mb:.3f} MB") |
| print(f"Total theoretical FP32 size of all tensors: {total_fp32_equiv_mb:.3f} MB") |
|
|
| if total_fp32_equiv_mb > 0.00001: |
| savings_percentage = (1 - (total_actual_mb / total_fp32_equiv_mb)) * 100 |
| print(f"Overall size reduction compared to full FP32: {savings_percentage:.2f}%") |
| else: |
| print("Overall size reduction cannot be calculated (no FP32 equivalent data or zero size).") |
|
|
| except Exception as e: |
| print(f"An error occurred while processing '{filepath}':") |
| print(f" {e}") |
| print("Please ensure it's a valid .safetensors file and the 'safetensors' (and 'torch') libraries are installed correctly.") |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser( |
| description="Inspect tensor precision (dtype) and size in a .safetensors file." |
| ) |
| parser.add_argument( |
| "filepath", |
| help="Path to the .safetensors file to inspect." |
| ) |
| args = parser.parse_args() |
|
|
| inspect_safetensors_precision_and_size(args.filepath) |