File size: 1,467 Bytes
fed4003
21a6bd7
 
aea3cf9
fed4003
 
 
21a6bd7
 
 
fed4003
aea3cf9
fed4003
 
21a6bd7
 
 
 
 
 
 
 
 
 
 
fed4003
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import os
import glob
from huggingface_hub import snapshot_download
from civitai import download_civitai_model

def prepare_model(source: str, model_id: str, civitai_key: str = None) -> str:
    if source == "hf":
        # 簡易バリデーション(念のため)
        if model_id.startswith("http"):
            raise ValueError("HuggingFace model_id should be 'repo_name' or 'namespace/repo_name', not a URL.")
        try:
            path = snapshot_download(model_id, cache_dir="/app/data/hf_cache")
        except Exception as e:
            raise RuntimeError(f"Failed to download HF model {model_id}: {e}")

        # safetensorsが存在するかチェック(GGUFなどのみのリポジトリ対策)
        safetensors_files = glob.glob(os.path.join(path, "**", "*.safetensors"), recursive=True)
        bin_files = glob.glob(os.path.join(path, "**", "*.bin"), recursive=True)  # pytorch_model.bin も一応許容
        if not safetensors_files and not bin_files:
            raise RuntimeError(
                f"No safetensors or pytorch_model.bin files found in {model_id}. "
                "This repository may only contain GGUF or other non-mergeable formats."
            )
        return path

    elif source == "civitai":
        if not civitai_key:
            raise ValueError("Civitai API key required")
        return download_civitai_model(model_id, civitai_key)
    else:
        raise ValueError(f"Unknown source: {source}")