| import argparse |
| import torch |
| from collections import Counter |
| import os |
| import math |
|
|
| |
| TORCH_DTYPE_TO_BYTES = { |
| |
| torch.bool: 1, |
| |
| torch.float16: 2, |
| torch.half: 2, |
| torch.bfloat16: 2, |
| torch.float32: 4, |
| torch.float: 4, |
| torch.float64: 8, |
| torch.double: 8, |
| |
| torch.complex64: 8, |
| torch.complex128: 16, |
| torch.cfloat: 8, |
| torch.cdouble: 16, |
| |
| torch.int8: 1, |
| torch.int16: 2, |
| torch.short: 2, |
| torch.int32: 4, |
| torch.int: 4, |
| torch.int64: 8, |
| torch.long: 8, |
| |
| torch.uint8: 1, |
| torch.uint16: 2, |
| torch.uint32: 4, |
| torch.uint64: 8, |
| |
| torch.qint8: 1, |
| torch.quint8: 1, |
| torch.qint32: 4, |
| torch.quint4x2: 1, |
| } |
|
|
| def get_bytes_per_element(dtype): |
| """Returns the number of bytes for a given PyTorch dtype.""" |
| return TORCH_DTYPE_TO_BYTES.get(dtype, None) |
|
|
| def get_dtype_name(dtype): |
| """Returns a readable string for a PyTorch dtype.""" |
| return str(dtype).replace('torch.', '') |
|
|
| def calculate_num_elements(shape): |
| """Calculates the total number of elements from a tensor shape tuple.""" |
| if not shape: |
| return 1 |
| if 0 in shape: |
| return 0 |
| num_elements = 1 |
| for dim_size in shape: |
| num_elements *= dim_size |
| return num_elements |
|
|
| def extract_tensors_from_obj(obj, prefix=""): |
| """ |
| Recursively extracts tensors from nested dictionaries/objects. |
| Returns a dictionary of {key: tensor} pairs. |
| """ |
| tensors = {} |
| |
| if isinstance(obj, torch.Tensor): |
| return {prefix or "tensor": obj} |
| |
| elif isinstance(obj, dict): |
| for key, value in obj.items(): |
| new_prefix = f"{prefix}.{key}" if prefix else key |
| tensors.update(extract_tensors_from_obj(value, new_prefix)) |
| |
| elif hasattr(obj, 'state_dict') and callable(getattr(obj, 'state_dict')): |
| |
| state_dict = obj.state_dict() |
| new_prefix = f"{prefix}.state_dict" if prefix else "state_dict" |
| tensors.update(extract_tensors_from_obj(state_dict, new_prefix)) |
| |
| elif hasattr(obj, '__dict__'): |
| |
| for key, value in obj.__dict__.items(): |
| if isinstance(value, torch.Tensor): |
| new_prefix = f"{prefix}.{key}" if prefix else key |
| tensors[new_prefix] = value |
| |
| return tensors |
|
|
| def inspect_pth_precision_and_size(filepath): |
| """ |
| Reads a .pth file, extracts tensors from it, |
| and reports the precision (dtype), actual size, and theoretical FP32 size. |
| """ |
| if not os.path.exists(filepath): |
| print(f"Error: File not found at '{filepath}'") |
| return |
|
|
| try: |
| print(f"Loading PyTorch file: {filepath}") |
| |
| |
| try: |
| obj = torch.load(filepath, map_location="cpu", weights_only=True) |
| print("(Loaded with weights_only=True for security)\n") |
| except TypeError: |
| |
| obj = torch.load(filepath, map_location="cpu") |
| print("(Warning: Loaded without weights_only=True - older PyTorch version)\n") |
| |
| |
| tensors = extract_tensors_from_obj(obj) |
| |
| if not tensors: |
| print("No tensors found in the file.") |
| return |
|
|
| tensor_info_list = [] |
| dtype_counts = Counter() |
| total_actual_mb = 0.0 |
| total_fp32_equiv_mb = 0.0 |
|
|
| max_key_len = max(len("Tensor Name"), max(len(k) for k in tensors.keys())) |
|
|
| header = ( |
| f"{'Tensor Name':<{max_key_len}} | " |
| f"{'Precision (dtype)':<17} | " |
| f"{'Shape':<20} | " |
| f"{'Actual Size (MB)':>16} | " |
| f"{'FP32 Equiv. (MB)':>18}" |
| ) |
| print(header) |
| print( |
| f"{'-' * max_key_len}-|-------------------|{'-' * 20}|------------------|-------------------" |
| ) |
|
|
| for key, tensor in tensors.items(): |
| dtype = tensor.dtype |
| dtype_name = get_dtype_name(dtype) |
| shape = tuple(tensor.shape) |
| shape_str = str(shape) |
| |
| num_elements = tensor.numel() |
| bytes_per_el_actual = get_bytes_per_element(dtype) |
|
|
| actual_size_mb_str = "N/A" |
| fp32_equiv_size_mb_str = "N/A" |
| actual_size_mb_val = 0.0 |
|
|
| if bytes_per_el_actual is not None: |
| actual_bytes = num_elements * bytes_per_el_actual |
| actual_size_mb_val = actual_bytes / (1024 * 1024) |
| total_actual_mb += actual_size_mb_val |
| actual_size_mb_str = f"{actual_size_mb_val:.3f}" |
|
|
| |
| fp32_equiv_bytes = num_elements * 4 |
| fp32_equiv_size_mb_val = fp32_equiv_bytes / (1024 * 1024) |
| total_fp32_equiv_mb += fp32_equiv_size_mb_val |
| fp32_equiv_size_mb_str = f"{fp32_equiv_size_mb_val:.3f}" |
| else: |
| print(f"Warning: Unknown dtype '{dtype}' for tensor '{key}'. Cannot calculate size.") |
|
|
| |
| if len(shape_str) > 18: |
| shape_str = shape_str[:15] + "..." |
|
|
| print( |
| f"{key:<{max_key_len}} | " |
| f"{dtype_name:<17} | " |
| f"{shape_str:<20} | " |
| f"{actual_size_mb_str:>16} | " |
| f"{fp32_equiv_size_mb_str:>18}" |
| ) |
| dtype_counts[dtype_name] += 1 |
|
|
| print("\n--- Summary ---") |
| print(f"Total tensors found: {len(tensors)}") |
| if dtype_counts: |
| print("Precision distribution:") |
| for dtype, count in dtype_counts.most_common(): |
| print(f" - {dtype:<12}: {count} tensor(s)") |
| else: |
| print("No dtypes to summarize.") |
|
|
| print(f"\nTotal actual size of all tensors: {total_actual_mb:.3f} MB") |
| print(f"Total theoretical FP32 size of all tensors: {total_fp32_equiv_mb:.3f} MB") |
|
|
| if total_fp32_equiv_mb > 0.00001: |
| savings_percentage = (1 - (total_actual_mb / total_fp32_equiv_mb)) * 100 |
| print(f"Overall size reduction compared to full FP32: {savings_percentage:.2f}%") |
| else: |
| print("Overall size reduction cannot be calculated (no FP32 equivalent data or zero size).") |
|
|
| |
| non_tensor_keys = [] |
| if isinstance(obj, dict): |
| for key, value in obj.items(): |
| if key not in [k.split('.')[0] for k in tensors.keys()]: |
| non_tensor_keys.append(f"{key}: {type(value).__name__}") |
| |
| if non_tensor_keys: |
| print(f"\nNon-tensor content found:") |
| for item in non_tensor_keys[:5]: |
| print(f" - {item}") |
| if len(non_tensor_keys) > 5: |
| print(f" ... and {len(non_tensor_keys) - 5} more items") |
|
|
| except Exception as e: |
| print(f"An error occurred while processing '{filepath}':") |
| print(f" {e}") |
| print("Please ensure it's a valid PyTorch .pth file and PyTorch is installed correctly.") |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser( |
| description="Inspect tensor precision (dtype) and size in a PyTorch .pth file." |
| ) |
| parser.add_argument( |
| "filepath", |
| help="Path to the .pth file to inspect." |
| ) |
| args = parser.parse_args() |
|
|
| inspect_pth_precision_and_size(args.filepath) |