File size: 1,451 Bytes
028c6f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
827d8df
028c6f9
 
827d8df
 
028c6f9
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import psutil
import torch
from typing import Optional

AVAILABLE_RAM = 18 * 1024**3
TARGET_MAX_RAM = int(AVAILABLE_RAM * 0.95)  # 95%

def get_available_ram():
    return psutil.virtual_memory().available

def estimate_model_memory(model_path: str, dtype: str = "fp32") -> int:
    """
    Rough estimate by computing total tensor sizes from index file (pytorch_model.bin.index.json or model.safetensors.index.json)
    """
    import json, os, glob
    total_bytes = 0
    shard_files = glob.glob(os.path.join(model_path, "*.safetensors"))
    if shard_files:
        from safetensors import safe_open
        for shard in shard_files:
            with safe_open(shard, framework="pt", device="cpu") as f:
                for key in f.keys():
                    tensor = f.get_tensor(key)
                    total_bytes += tensor.numel() * tensor.element_size()
        return total_bytes
    return 5 * 1024**3  # fallback 5GB

def check_memory_safe(model_a_id: str, model_b_id: str, evo_params: Optional[dict] = None) -> bool:
    # 安全のため常にTrueを返し、チャンク処理に任せる
    return True

def adaptive_chunk_size(tensor_shape, dtype_bytes, safety_factor=0.8):
    avail = get_available_ram()
    max_bytes = int(avail * safety_factor)
    numel_per_row = tensor_shape[-1] if len(tensor_shape) >= 2 else 1
    chunk_rows = max(1, max_bytes // (numel_per_row * dtype_bytes))
    return min(chunk_rows, tensor_shape[0])