| import os |
| import sys |
| from pathlib import Path |
| from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM |
| from huggingface_hub import login |
| from fastapi import HTTPException |
| from pydantic import BaseModel |
|
|
|
|
| class DownloadRequest(BaseModel): |
| model: str |
|
|
|
|
| def check_model(model_name): |
| """ |
| 检查模型是否存在 |
| 参数: model_name - 从 request 传递过来的模型名称 |
| 返回: (model_name, cache_dir, success) |
| """ |
| cache_dir = "./my_model_cache" |
| |
| |
| model_path = Path(cache_dir) / f"models--{model_name.replace('/', '--')}" |
| snapshot_path = model_path / "snapshots" |
| |
| if snapshot_path.exists() and any(snapshot_path.iterdir()): |
| print(f"✓ 模型 {model_name} 已存在于缓存中") |
| try: |
| tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) |
| return model_name, cache_dir, True |
| except Exception as e: |
| print(f"⚠ 加载现有模型失败: {e}") |
| return model_name, cache_dir, False |
| else: |
| raise HTTPException(status_code=404, detail=f"模型 `{model_name}` 不存在,请先下载") |
|
|
|
|
| def download_model(model_name): |
| """ |
| 下载指定的模型 |
| 参数: model_name - 要下载的模型名称 |
| 返回: (success, message) |
| """ |
| cache_dir = "./my_model_cache" |
| |
| print(f"开始下载模型: {model_name}") |
| print(f"缓存目录: {cache_dir}") |
| |
| |
| token = os.getenv("HUGGINGFACE_TOKEN") |
| if token: |
| try: |
| print("登录 Hugging Face...") |
| login(token=token) |
| print("✓ HuggingFace 登录成功!") |
| except Exception as e: |
| print(f"⚠ 登录失败: {e}") |
| print("继续使用公开模型") |
| else: |
| print("ℹ 未设置 HUGGINGFACE_TOKEN - 仅使用公开模型") |
| |
| try: |
| |
| print("正在下载 tokenizer...") |
| tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) |
| print("✓ Tokenizer 下载成功!") |
| |
| |
| print("正在下载模型...") |
| model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir) |
| print("✓ 模型下载成功!") |
| |
| print(f"✓ 模型和 tokenizer 已成功下载到 {cache_dir}") |
| return True, f"模型 {model_name} 下载成功" |
| |
| except Exception as e: |
| print(f"✗ 下载模型时出错: {e}") |
| return False, f"下载失败: {str(e)}" |
|
|
|
|
| def initialize_pipeline(model_name): |
| """ |
| 使用模型初始化 pipeline |
| 参数: model_name - 从 request 传递过来的模型名称 |
| 返回: (pipe, tokenizer, success) |
| """ |
| model_name, cache_dir, success = check_model(model_name) |
| |
| if not success: |
| return None, None, False |
| |
| try: |
| |
| tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) |
| |
| print(f"使用 {model_name} 初始化 pipeline...") |
| |
| pipe = pipeline("text-generation", model=model_name, tokenizer=tokenizer) |
| print("✓ Pipeline 初始化成功!") |
| |
| return pipe, tokenizer, True |
| |
| except Exception as e: |
| print(f"✗ Pipeline 初始化失败: {e}") |
| return None, None, False |
|
|