| import os |
| import glob |
| from huggingface_hub import snapshot_download |
| from civitai import download_civitai_model |
|
|
| def prepare_model(source: str, model_id: str, civitai_key: str = None) -> str: |
| if source == "hf": |
| |
| if model_id.startswith("http"): |
| raise ValueError("HuggingFace model_id should be 'repo_name' or 'namespace/repo_name', not a URL.") |
| try: |
| path = snapshot_download(model_id, cache_dir="/app/data/hf_cache") |
| except Exception as e: |
| raise RuntimeError(f"Failed to download HF model {model_id}: {e}") |
|
|
| |
| safetensors_files = glob.glob(os.path.join(path, "**", "*.safetensors"), recursive=True) |
| bin_files = glob.glob(os.path.join(path, "**", "*.bin"), recursive=True) |
| if not safetensors_files and not bin_files: |
| raise RuntimeError( |
| f"No safetensors or pytorch_model.bin files found in {model_id}. " |
| "This repository may only contain GGUF or other non-mergeable formats." |
| ) |
| return path |
|
|
| elif source == "civitai": |
| if not civitai_key: |
| raise ValueError("Civitai API key required") |
| return download_civitai_model(model_id, civitai_key) |
| else: |
| raise ValueError(f"Unknown source: {source}") |