File size: 1,451 Bytes
028c6f9 827d8df 028c6f9 827d8df 028c6f9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 | import psutil
import torch
from typing import Optional
AVAILABLE_RAM = 18 * 1024**3
TARGET_MAX_RAM = int(AVAILABLE_RAM * 0.95) # 95%
def get_available_ram():
return psutil.virtual_memory().available
def estimate_model_memory(model_path: str, dtype: str = "fp32") -> int:
"""
Rough estimate by computing total tensor sizes from index file (pytorch_model.bin.index.json or model.safetensors.index.json)
"""
import json, os, glob
total_bytes = 0
shard_files = glob.glob(os.path.join(model_path, "*.safetensors"))
if shard_files:
from safetensors import safe_open
for shard in shard_files:
with safe_open(shard, framework="pt", device="cpu") as f:
for key in f.keys():
tensor = f.get_tensor(key)
total_bytes += tensor.numel() * tensor.element_size()
return total_bytes
return 5 * 1024**3 # fallback 5GB
def check_memory_safe(model_a_id: str, model_b_id: str, evo_params: Optional[dict] = None) -> bool:
# 安全のため常にTrueを返し、チャンク処理に任せる
return True
def adaptive_chunk_size(tensor_shape, dtype_bytes, safety_factor=0.8):
avail = get_available_ram()
max_bytes = int(avail * safety_factor)
numel_per_row = tensor_shape[-1] if len(tensor_shape) >= 2 else 1
chunk_rows = max(1, max_bytes // (numel_per_row * dtype_bytes))
return min(chunk_rows, tensor_shape[0]) |