diff --git "a/app/main.py" "b/app/main.py" --- "a/app/main.py" +++ "b/app/main.py" @@ -1,1319 +1,1394 @@ -from __future__ import annotations - -import json -import os -import sqlite3 -import time -import uuid -from contextlib import asynccontextmanager -from datetime import UTC, datetime, timedelta -from pathlib import Path -from typing import Any - -import httpx -from apscheduler.schedulers.asyncio import AsyncIOScheduler -from fastapi import Depends, FastAPI, Header, HTTPException, Request, Response, status -from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse -from fastapi.staticfiles import StaticFiles -from itsdangerous import BadSignature, SignatureExpired, URLSafeTimedSerializer - - -BASE_DIR = Path(__file__).resolve().parent.parent -STATIC_DIR = BASE_DIR / "static" -DB_PATH = Path(os.getenv("DATABASE_PATH", BASE_DIR / "data.sqlite3")) -RAW_NVIDIA_API_BASE = os.getenv("NVIDIA_API_BASE", os.getenv("NIM_BASE_URL", "https://integrate.api.nvidia.com/v1")).rstrip("/") -NVIDIA_API_BASE = RAW_NVIDIA_API_BASE if RAW_NVIDIA_API_BASE.endswith("/v1") else f"{RAW_NVIDIA_API_BASE}/v1" -CHAT_COMPLETIONS_URL = f"{NVIDIA_API_BASE}/chat/completions" -MODELS_URL = f"{NVIDIA_API_BASE}/models" -ADMIN_PASSWORD = os.getenv("PASSWORD") -SESSION_SECRET = os.getenv("SESSION_SECRET") or ADMIN_PASSWORD or "nim-responses-dev-secret" -COOKIE_NAME = os.getenv("COOKIE_NAME", "nim_admin_session") -GATEWAY_API_KEY = os.getenv("GATEWAY_API_KEY") -DEFAULT_ENV_KEY = os.getenv("NVIDIA_NIM_API_KEY") or os.getenv("NVIDIA_API_KEY") -REQUEST_TIMEOUT_SECONDS = float(os.getenv("REQUEST_TIMEOUT_SECONDS", "90")) -DEFAULT_HEALTH_INTERVAL_MINUTES = int(os.getenv("HEALTHCHECK_INTERVAL_MINUTES", "60")) -DEFAULT_HEALTH_PROMPT = os.getenv("HEALTHCHECK_PROMPT", "请只回复 OK。") -PUBLIC_HISTORY_HOURS = int(os.getenv("PUBLIC_HISTORY_HOURS", "48")) - -DEFAULT_MODELS = [ - ("z-ai/glm5", "GLM-5", "Reasoning and general assistant model from Z.ai", 10, 1), - ("minimaxai/minimax-m2.5", "MiniMax M2.5", "Long-context assistant model from MiniMax", 20, 1), - ("moonshotai/kimi-k2.5", "Kimi K2.5", "Kimi family model tuned for tool use and code", 30, 1), - ("deepseek-ai/deepseek-v3.2", "DeepSeek V3.2", "DeepSeek production general-purpose model", 40, 1), - ("google/gemma-4-31b-it", "Gemma 4 31B IT", "Instruction-tuned Gemma model", 50, 0), - ("qwen/qwen3.5-397b-a17b", "Qwen 3.5 397B A17B", "Large-scale Qwen model with broad capabilities", 60, 0), -] - -scheduler = AsyncIOScheduler(timezone="UTC") - - -def utcnow() -> datetime: - return datetime.now(UTC) - - -def utcnow_iso() -> str: - return utcnow().isoformat() - - -def parse_datetime(value: str | None) -> datetime | None: - if not value: - return None - try: - return datetime.fromisoformat(value) - except ValueError: - return None - - -def bool_value(value: Any) -> bool: - if isinstance(value, bool): - return value - if isinstance(value, (int, float)): - return bool(value) - if value is None: - return False - return str(value).strip().lower() in {"1", "true", "yes", "on", "enabled"} - - -def json_dumps(value: Any) -> str: - return json.dumps(value, ensure_ascii=False) - - -def get_db_connection() -> sqlite3.Connection: - conn = sqlite3.connect(DB_PATH, check_same_thread=False) - conn.row_factory = sqlite3.Row - return conn - - -def init_db() -> None: - DB_PATH.parent.mkdir(parents=True, exist_ok=True) - conn = get_db_connection() - try: - conn.executescript( - """ - CREATE TABLE IF NOT EXISTS proxy_models ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - model_id TEXT UNIQUE NOT NULL, - display_name TEXT NOT NULL, - provider TEXT NOT NULL DEFAULT 'nvidia-nim', - description TEXT, - enabled INTEGER NOT NULL DEFAULT 1, - featured INTEGER NOT NULL DEFAULT 0, - sort_order INTEGER NOT NULL DEFAULT 0, - request_count INTEGER NOT NULL DEFAULT 0, - success_count INTEGER NOT NULL DEFAULT 0, - failure_count INTEGER NOT NULL DEFAULT 0, - healthcheck_count INTEGER NOT NULL DEFAULT 0, - healthcheck_success_count INTEGER NOT NULL DEFAULT 0, - last_used_at TEXT, - last_healthcheck_at TEXT, - last_health_status INTEGER, - last_latency_ms REAL, - created_at TEXT NOT NULL, - updated_at TEXT NOT NULL - ); - - CREATE TABLE IF NOT EXISTS api_keys ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT UNIQUE NOT NULL, - api_key TEXT NOT NULL, - enabled INTEGER NOT NULL DEFAULT 1, - request_count INTEGER NOT NULL DEFAULT 0, - success_count INTEGER NOT NULL DEFAULT 0, - failure_count INTEGER NOT NULL DEFAULT 0, - healthcheck_count INTEGER NOT NULL DEFAULT 0, - healthcheck_success_count INTEGER NOT NULL DEFAULT 0, - last_used_at TEXT, - last_tested_at TEXT, - last_latency_ms REAL, - created_at TEXT NOT NULL, - updated_at TEXT NOT NULL - ); - - CREATE TABLE IF NOT EXISTS response_records ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - response_id TEXT UNIQUE NOT NULL, - parent_response_id TEXT, - model_id INTEGER, - api_key_id INTEGER, - request_json TEXT NOT NULL, - input_items_json TEXT NOT NULL, - output_json TEXT NOT NULL, - output_items_json TEXT NOT NULL, - status TEXT NOT NULL, - created_at TEXT NOT NULL - ); - - CREATE TABLE IF NOT EXISTS health_check_records ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - model_id INTEGER NOT NULL, - api_key_id INTEGER, - ok INTEGER NOT NULL, - status_code INTEGER, - latency_ms REAL, - error_message TEXT, - response_excerpt TEXT, - checked_at TEXT NOT NULL - ); - - CREATE TABLE IF NOT EXISTS settings ( - key TEXT PRIMARY KEY, - value TEXT NOT NULL - ); - """ - ) - - now = utcnow_iso() - for model_id, display_name, description, sort_order, featured in DEFAULT_MODELS: - conn.execute( - """ - INSERT OR IGNORE INTO proxy_models ( - model_id, display_name, provider, description, enabled, featured, sort_order, created_at, updated_at - ) VALUES (?, ?, 'nvidia-nim', ?, 1, ?, ?, ?, ?) - """, - (model_id, display_name, description, featured, sort_order, now, now), - ) - - defaults = { - "healthcheck_enabled": "true", - "healthcheck_interval_minutes": str(DEFAULT_HEALTH_INTERVAL_MINUTES), - "healthcheck_prompt": DEFAULT_HEALTH_PROMPT, - "public_history_hours": str(PUBLIC_HISTORY_HOURS), - } - for key, value in defaults.items(): - conn.execute("INSERT OR IGNORE INTO settings (key, value) VALUES (?, ?)", (key, value)) - - if DEFAULT_ENV_KEY: - conn.execute( - """ - INSERT OR IGNORE INTO api_keys (name, api_key, enabled, created_at, updated_at) - VALUES ('env-default', ?, 1, ?, ?) - """, - (DEFAULT_ENV_KEY, now, now), - ) - - conn.commit() - finally: - conn.close() - - -def get_setting(conn: sqlite3.Connection, key: str, default: str) -> str: - row = conn.execute("SELECT value FROM settings WHERE key = ?", (key,)).fetchone() - return row["value"] if row else default - - -def set_setting(conn: sqlite3.Connection, key: str, value: str) -> None: - conn.execute( - """ - INSERT INTO settings (key, value) VALUES (?, ?) - ON CONFLICT(key) DO UPDATE SET value = excluded.value - """, - (key, value), - ) - - -def get_settings_payload(conn: sqlite3.Connection) -> dict[str, Any]: - return { - "healthcheck_enabled": bool_value(get_setting(conn, "healthcheck_enabled", "true")), - "healthcheck_interval_minutes": int(get_setting(conn, "healthcheck_interval_minutes", str(DEFAULT_HEALTH_INTERVAL_MINUTES))), - "healthcheck_prompt": get_setting(conn, "healthcheck_prompt", DEFAULT_HEALTH_PROMPT), - "public_history_hours": int(get_setting(conn, "public_history_hours", str(PUBLIC_HISTORY_HOURS))), - } - - -def mask_secret(secret: str) -> str: - if len(secret) <= 8: - return f"{secret[:2]}***" - return f"{secret[:4]}...{secret[-4:]}" - - -def create_admin_token() -> str: - serializer = URLSafeTimedSerializer(SESSION_SECRET, salt="nim-admin-auth") - return serializer.dumps({"role": "admin"}) - - -def verify_admin_token(token: str) -> bool: - serializer = URLSafeTimedSerializer(SESSION_SECRET, salt="nim-admin-auth") - try: - payload = serializer.loads(token, max_age=60 * 60 * 24 * 7) - except (BadSignature, SignatureExpired): - return False - return payload.get("role") == "admin" - - -def require_admin(request: Request, authorization: str | None = Header(default=None)) -> bool: - token: str | None = None - if authorization and authorization.startswith("Bearer "): - token = authorization.removeprefix("Bearer ").strip() - if not token: - token = request.cookies.get(COOKIE_NAME) - if not token or not verify_admin_token(token): - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="需要管理员登录。") - return True - - -def require_proxy_token_if_configured(authorization: str | None = Header(default=None)) -> bool: - if not GATEWAY_API_KEY: - return True - if not authorization or not authorization.startswith("Bearer "): - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing bearer token.") - token = authorization.removeprefix("Bearer ").strip() - if token != GATEWAY_API_KEY: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid bearer token.") - return True - - -def fetch_model_by_identifier(conn: sqlite3.Connection, identifier: str | int, enabled_only: bool = False) -> sqlite3.Row | None: - clause = "AND enabled = 1" if enabled_only else "" - if isinstance(identifier, int) or (isinstance(identifier, str) and identifier.isdigit()): - row = conn.execute(f"SELECT * FROM proxy_models WHERE id = ? {clause}", (int(identifier),)).fetchone() - if row: - return row - return conn.execute(f"SELECT * FROM proxy_models WHERE model_id = ? {clause}", (str(identifier),)).fetchone() - - -def fetch_key_by_identifier(conn: sqlite3.Connection, identifier: str | int, enabled_only: bool = False) -> sqlite3.Row | None: - clause = "AND enabled = 1" if enabled_only else "" - if isinstance(identifier, int) or (isinstance(identifier, str) and str(identifier).isdigit()): - row = conn.execute(f"SELECT * FROM api_keys WHERE id = ? {clause}", (int(identifier),)).fetchone() - if row: - return row - return conn.execute(f"SELECT * FROM api_keys WHERE name = ? {clause}", (str(identifier),)).fetchone() - - -def select_api_key(conn: sqlite3.Connection, explicit_id: int | None = None) -> sqlite3.Row: - if explicit_id is not None: - row = fetch_key_by_identifier(conn, explicit_id, enabled_only=True) - if row: - return row - row = conn.execute( - """ - SELECT * FROM api_keys - WHERE enabled = 1 - ORDER BY CASE WHEN last_used_at IS NULL THEN 0 ELSE 1 END, last_used_at ASC, id ASC - LIMIT 1 - """ - ).fetchone() - if not row: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="当前没有可用的 NVIDIA NIM Key。") - return row - - -def row_to_model_item(row: sqlite3.Row) -> dict[str, Any]: - status_name = "unknown" - if row["last_health_status"] is not None: - status_name = "healthy" if bool(row["last_health_status"]) else "down" - return { - "id": row["id"], - "model_id": row["model_id"], - "name": row["model_id"], - "display_name": row["display_name"], - "endpoint": "/v1/responses", - "provider": row["provider"], - "description": row["description"], - "enabled": bool(row["enabled"]), - "featured": bool(row["featured"]), - "sort_order": row["sort_order"], - "status": status_name, - "request_count": row["request_count"], - "success_count": row["success_count"], - "failure_count": row["failure_count"], - "healthcheck_count": row["healthcheck_count"], - "healthcheck_success_count": row["healthcheck_success_count"], - "last_used_at": row["last_used_at"], - "last_healthcheck_at": row["last_healthcheck_at"], - "last_health_status": None if row["last_health_status"] is None else bool(row["last_health_status"]), - "last_latency_ms": row["last_latency_ms"], - "created_at": row["created_at"], - "updated_at": row["updated_at"], - } - - -def row_to_key_item(row: sqlite3.Row) -> dict[str, Any]: - total_checks = row["healthcheck_count"] or 0 - ok_checks = row["healthcheck_success_count"] or 0 - success_ratio = (ok_checks / total_checks) if total_checks else None - status_name = "healthy" if success_ratio and success_ratio >= 0.8 else "unknown" - return { - "id": row["id"], - "name": row["name"], - "label": row["name"], - "masked_key": mask_secret(row["api_key"]), - "enabled": bool(row["enabled"]), - "status": status_name, - "request_count": row["request_count"], - "success_count": row["success_count"], - "failure_count": row["failure_count"], - "healthcheck_count": row["healthcheck_count"], - "healthcheck_success_count": row["healthcheck_success_count"], - "last_used_at": row["last_used_at"], - "last_tested": row["last_tested_at"], - "last_tested_at": row["last_tested_at"], - "last_latency_ms": row["last_latency_ms"], - "created_at": row["created_at"], - "updated_at": row["updated_at"], - } - - -def make_error(status_code: int, message: str, error_type: str = "invalid_request_error") -> JSONResponse: - return JSONResponse( - status_code=status_code, - content={"error": {"message": message, "type": error_type, "code": status_code}}, +from __future__ import annotations + +import asyncio +import json +import os +import sqlite3 +import time +import uuid +from contextlib import asynccontextmanager +from datetime import UTC, datetime, timedelta +from pathlib import Path +from typing import Any + +import httpx +from apscheduler.schedulers.asyncio import AsyncIOScheduler +from fastapi import Depends, FastAPI, Header, HTTPException, Request, Response, status +from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse +from fastapi.staticfiles import StaticFiles +from itsdangerous import BadSignature, SignatureExpired, URLSafeTimedSerializer + + +BASE_DIR = Path(__file__).resolve().parent.parent +STATIC_DIR = BASE_DIR / "static" +DB_PATH = Path(os.getenv("DATABASE_PATH", BASE_DIR / "data.sqlite3")) +RAW_NVIDIA_API_BASE = os.getenv("NVIDIA_API_BASE", os.getenv("NIM_BASE_URL", "https://integrate.api.nvidia.com/v1")).rstrip("/") +NVIDIA_API_BASE = RAW_NVIDIA_API_BASE if RAW_NVIDIA_API_BASE.endswith("/v1") else f"{RAW_NVIDIA_API_BASE}/v1" +CHAT_COMPLETIONS_URL = f"{NVIDIA_API_BASE}/chat/completions" +MODELS_URL = f"{NVIDIA_API_BASE}/models" +ADMIN_PASSWORD = os.getenv("PASSWORD") +SESSION_SECRET = os.getenv("SESSION_SECRET") or ADMIN_PASSWORD or "nim-responses-dev-secret" +COOKIE_NAME = os.getenv("COOKIE_NAME", "nim_admin_session") +PASS_API_KEY = os.getenv("PASS_APIKEY") or os.getenv("GATEWAY_API_KEY") +DEFAULT_ENV_KEY = os.getenv("NVIDIA_NIM_API_KEY") or os.getenv("NVIDIA_API_KEY") +REQUEST_TIMEOUT_SECONDS = float(os.getenv("REQUEST_TIMEOUT_SECONDS", "90")) +DEFAULT_HEALTH_INTERVAL_MINUTES = int(os.getenv("HEALTHCHECK_INTERVAL_MINUTES", "60")) +DEFAULT_HEALTH_PROMPT = os.getenv("HEALTHCHECK_PROMPT", "请只回复 OK。") +PUBLIC_HISTORY_HOURS = int(os.getenv("PUBLIC_HISTORY_HOURS", "48")) +MAX_UPSTREAM_CONNECTIONS = int(os.getenv("MAX_UPSTREAM_CONNECTIONS", "256")) +MAX_KEEPALIVE_CONNECTIONS = int(os.getenv("MAX_KEEPALIVE_CONNECTIONS", "64")) + +DEFAULT_MODELS = [ + ("z-ai/glm5", "GLM-5", "Reasoning and general assistant model from Z.ai", 10, 1), + ("minimaxai/minimax-m2.5", "MiniMax M2.5", "Long-context assistant model from MiniMax", 20, 1), + ("moonshotai/kimi-k2.5", "Kimi K2.5", "Kimi family model tuned for tool use and code", 30, 1), + ("deepseek-ai/deepseek-v3.2", "DeepSeek V3.2", "DeepSeek production general-purpose model", 40, 1), + ("google/gemma-4-31b-it", "Gemma 4 31B IT", "Instruction-tuned Gemma model", 50, 0), + ("qwen/qwen3.5-397b-a17b", "Qwen 3.5 397B A17B", "Large-scale Qwen model with broad capabilities", 60, 0), +] + +scheduler = AsyncIOScheduler(timezone="UTC") +http_client: httpx.AsyncClient | None = None +api_key_selection_lock: asyncio.Lock | None = None +api_key_rr_index = 0 + + +def utcnow() -> datetime: + return datetime.now(UTC) + + +def utcnow_iso() -> str: + return utcnow().isoformat() + + +def parse_datetime(value: str | None) -> datetime | None: + if not value: + return None + try: + return datetime.fromisoformat(value) + except ValueError: + return None + + +def bool_value(value: Any) -> bool: + if isinstance(value, bool): + return value + if isinstance(value, (int, float)): + return bool(value) + if value is None: + return False + return str(value).strip().lower() in {"1", "true", "yes", "on", "enabled"} + + +def json_dumps(value: Any) -> str: + return json.dumps(value, ensure_ascii=False) + + +async def get_http_client() -> httpx.AsyncClient: + global http_client + if http_client is None or http_client.is_closed: + limits = httpx.Limits( + max_connections=MAX_UPSTREAM_CONNECTIONS, + max_keepalive_connections=MAX_KEEPALIVE_CONNECTIONS, + ) + http_client = httpx.AsyncClient(timeout=REQUEST_TIMEOUT_SECONDS, limits=limits) + return http_client + + +async def get_api_key_selection_lock() -> asyncio.Lock: + global api_key_selection_lock + if api_key_selection_lock is None: + api_key_selection_lock = asyncio.Lock() + return api_key_selection_lock + + +def get_db_connection() -> sqlite3.Connection: + conn = sqlite3.connect(DB_PATH, check_same_thread=False, timeout=30.0) + conn.row_factory = sqlite3.Row + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA synchronous=NORMAL") + conn.execute("PRAGMA foreign_keys=ON") + conn.execute("PRAGMA busy_timeout=30000") + return conn + + +def init_db() -> None: + DB_PATH.parent.mkdir(parents=True, exist_ok=True) + conn = get_db_connection() + try: + conn.executescript( + """ + CREATE TABLE IF NOT EXISTS proxy_models ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + model_id TEXT UNIQUE NOT NULL, + display_name TEXT NOT NULL, + provider TEXT NOT NULL DEFAULT 'nvidia-nim', + description TEXT, + enabled INTEGER NOT NULL DEFAULT 1, + featured INTEGER NOT NULL DEFAULT 0, + sort_order INTEGER NOT NULL DEFAULT 0, + request_count INTEGER NOT NULL DEFAULT 0, + success_count INTEGER NOT NULL DEFAULT 0, + failure_count INTEGER NOT NULL DEFAULT 0, + healthcheck_count INTEGER NOT NULL DEFAULT 0, + healthcheck_success_count INTEGER NOT NULL DEFAULT 0, + last_used_at TEXT, + last_healthcheck_at TEXT, + last_health_status INTEGER, + last_latency_ms REAL, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL + ); + + CREATE TABLE IF NOT EXISTS api_keys ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT UNIQUE NOT NULL, + api_key TEXT NOT NULL, + enabled INTEGER NOT NULL DEFAULT 1, + request_count INTEGER NOT NULL DEFAULT 0, + success_count INTEGER NOT NULL DEFAULT 0, + failure_count INTEGER NOT NULL DEFAULT 0, + healthcheck_count INTEGER NOT NULL DEFAULT 0, + healthcheck_success_count INTEGER NOT NULL DEFAULT 0, + last_used_at TEXT, + last_tested_at TEXT, + last_latency_ms REAL, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL + ); + + CREATE TABLE IF NOT EXISTS response_records ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + response_id TEXT UNIQUE NOT NULL, + parent_response_id TEXT, + model_id INTEGER, + api_key_id INTEGER, + request_json TEXT NOT NULL, + input_items_json TEXT NOT NULL, + output_json TEXT NOT NULL, + output_items_json TEXT NOT NULL, + status TEXT NOT NULL, + created_at TEXT NOT NULL + ); + + CREATE TABLE IF NOT EXISTS health_check_records ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + model_id INTEGER NOT NULL, + api_key_id INTEGER, + ok INTEGER NOT NULL, + status_code INTEGER, + latency_ms REAL, + error_message TEXT, + response_excerpt TEXT, + checked_at TEXT NOT NULL + ); + + CREATE TABLE IF NOT EXISTS settings ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL + ); + """ + ) + + now = utcnow_iso() + for model_id, display_name, description, sort_order, featured in DEFAULT_MODELS: + conn.execute( + """ + INSERT OR IGNORE INTO proxy_models ( + model_id, display_name, provider, description, enabled, featured, sort_order, created_at, updated_at + ) VALUES (?, ?, 'nvidia-nim', ?, 1, ?, ?, ?, ?) + """, + (model_id, display_name, description, featured, sort_order, now, now), + ) + + defaults = { + "healthcheck_enabled": "true", + "healthcheck_interval_minutes": str(DEFAULT_HEALTH_INTERVAL_MINUTES), + "healthcheck_prompt": DEFAULT_HEALTH_PROMPT, + "public_history_hours": str(PUBLIC_HISTORY_HOURS), + } + for key, value in defaults.items(): + conn.execute("INSERT OR IGNORE INTO settings (key, value) VALUES (?, ?)", (key, value)) + + if DEFAULT_ENV_KEY: + conn.execute( + """ + INSERT OR IGNORE INTO api_keys (name, api_key, enabled, created_at, updated_at) + VALUES ('env-default', ?, 1, ?, ?) + """, + (DEFAULT_ENV_KEY, now, now), + ) + + conn.commit() + finally: + conn.close() + + +def get_setting(conn: sqlite3.Connection, key: str, default: str) -> str: + row = conn.execute("SELECT value FROM settings WHERE key = ?", (key,)).fetchone() + return row["value"] if row else default + + +def set_setting(conn: sqlite3.Connection, key: str, value: str) -> None: + conn.execute( + """ + INSERT INTO settings (key, value) VALUES (?, ?) + ON CONFLICT(key) DO UPDATE SET value = excluded.value + """, + (key, value), + ) + + +def get_settings_payload(conn: sqlite3.Connection) -> dict[str, Any]: + return { + "healthcheck_enabled": bool_value(get_setting(conn, "healthcheck_enabled", "true")), + "healthcheck_interval_minutes": int(get_setting(conn, "healthcheck_interval_minutes", str(DEFAULT_HEALTH_INTERVAL_MINUTES))), + "healthcheck_prompt": get_setting(conn, "healthcheck_prompt", DEFAULT_HEALTH_PROMPT), + "public_history_hours": int(get_setting(conn, "public_history_hours", str(PUBLIC_HISTORY_HOURS))), + } + + +def mask_secret(secret: str) -> str: + if len(secret) <= 8: + return f"{secret[:2]}***" + return f"{secret[:4]}...{secret[-4:]}" + + +def create_admin_token() -> str: + serializer = URLSafeTimedSerializer(SESSION_SECRET, salt="nim-admin-auth") + return serializer.dumps({"role": "admin"}) + + +def verify_admin_token(token: str) -> bool: + serializer = URLSafeTimedSerializer(SESSION_SECRET, salt="nim-admin-auth") + try: + payload = serializer.loads(token, max_age=60 * 60 * 24 * 7) + except (BadSignature, SignatureExpired): + return False + return payload.get("role") == "admin" + + +def require_admin(request: Request, authorization: str | None = Header(default=None)) -> bool: + token: str | None = None + if authorization and authorization.startswith("Bearer "): + token = authorization.removeprefix("Bearer ").strip() + if not token: + token = request.cookies.get(COOKIE_NAME) + if not token or not verify_admin_token(token): + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="需要管理员登录。") + return True + + +def require_proxy_token_if_configured(authorization: str | None = Header(default=None), x_api_key: str | None = Header(default=None)) -> bool: + if not PASS_API_KEY: + return True + token: str | None = None + if authorization and authorization.startswith("Bearer "): + token = authorization.removeprefix("Bearer ").strip() + elif x_api_key: + token = x_api_key.strip() + if not token: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="缺少 API 鉴权信息。") + if token != PASS_API_KEY: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="API 鉴权失败。") + return True + + +def fetch_model_by_identifier(conn: sqlite3.Connection, identifier: str | int, enabled_only: bool = False) -> sqlite3.Row | None: + clause = "AND enabled = 1" if enabled_only else "" + if isinstance(identifier, int) or (isinstance(identifier, str) and identifier.isdigit()): + row = conn.execute(f"SELECT * FROM proxy_models WHERE id = ? {clause}", (int(identifier),)).fetchone() + if row: + return row + return conn.execute(f"SELECT * FROM proxy_models WHERE model_id = ? {clause}", (str(identifier),)).fetchone() + + +def fetch_key_by_identifier(conn: sqlite3.Connection, identifier: str | int, enabled_only: bool = False) -> sqlite3.Row | None: + clause = "AND enabled = 1" if enabled_only else "" + if isinstance(identifier, int) or (isinstance(identifier, str) and str(identifier).isdigit()): + row = conn.execute(f"SELECT * FROM api_keys WHERE id = ? {clause}", (int(identifier),)).fetchone() + if row: + return row + return conn.execute(f"SELECT * FROM api_keys WHERE name = ? {clause}", (str(identifier),)).fetchone() + + +async def select_api_key(conn: sqlite3.Connection, explicit_id: int | None = None) -> sqlite3.Row: + if explicit_id is not None: + row = fetch_key_by_identifier(conn, explicit_id, enabled_only=True) + if row: + return row + + key_rows = conn.execute( + """ + SELECT * FROM api_keys + WHERE enabled = 1 + ORDER BY id ASC + """ + ).fetchall() + if not key_rows: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="??????? NVIDIA NIM Key?") + + global api_key_rr_index + lock = await get_api_key_selection_lock() + async with lock: + selected = key_rows[api_key_rr_index % len(key_rows)] + api_key_rr_index = (api_key_rr_index + 1) % len(key_rows) + return selected + +def row_to_model_item(row: sqlite3.Row) -> dict[str, Any]: + status_name = "unknown" + if row["last_health_status"] is not None: + status_name = "healthy" if bool(row["last_health_status"]) else "down" + return { + "id": row["id"], + "model_id": row["model_id"], + "name": row["model_id"], + "display_name": row["display_name"], + "endpoint": "/v1/responses", + "provider": row["provider"], + "description": row["description"], + "enabled": bool(row["enabled"]), + "featured": bool(row["featured"]), + "sort_order": row["sort_order"], + "status": status_name, + "request_count": row["request_count"], + "success_count": row["success_count"], + "failure_count": row["failure_count"], + "healthcheck_count": row["healthcheck_count"], + "healthcheck_success_count": row["healthcheck_success_count"], + "last_used_at": row["last_used_at"], + "last_healthcheck_at": row["last_healthcheck_at"], + "last_health_status": None if row["last_health_status"] is None else bool(row["last_health_status"]), + "last_latency_ms": row["last_latency_ms"], + "created_at": row["created_at"], + "updated_at": row["updated_at"], + } + + +def row_to_key_item(row: sqlite3.Row) -> dict[str, Any]: + total_checks = row["healthcheck_count"] or 0 + ok_checks = row["healthcheck_success_count"] or 0 + success_ratio = (ok_checks / total_checks) if total_checks else None + status_name = "healthy" if success_ratio and success_ratio >= 0.8 else "unknown" + return { + "id": row["id"], + "name": row["name"], + "label": row["name"], + "masked_key": mask_secret(row["api_key"]), + "enabled": bool(row["enabled"]), + "status": status_name, + "request_count": row["request_count"], + "success_count": row["success_count"], + "failure_count": row["failure_count"], + "healthcheck_count": row["healthcheck_count"], + "healthcheck_success_count": row["healthcheck_success_count"], + "last_used_at": row["last_used_at"], + "last_tested": row["last_tested_at"], + "last_tested_at": row["last_tested_at"], + "last_latency_ms": row["last_latency_ms"], + "created_at": row["created_at"], + "updated_at": row["updated_at"], + } + + +def make_error(status_code: int, message: str, error_type: str = "invalid_request_error") -> JSONResponse: + return JSONResponse( + status_code=status_code, + content={"error": {"message": message, "type": error_type, "code": status_code}}, ) - -def normalize_content(content: Any, role: str) -> list[dict[str, Any]]: - if content is None: - return [] - if isinstance(content, str): - return [{"type": "output_text" if role == "assistant" else "input_text", "text": content}] - if isinstance(content, list): - normalized: list[dict[str, Any]] = [] - for part in content: - if isinstance(part, str): - normalized.append({"type": "output_text" if role == "assistant" else "input_text", "text": part}) - continue - if not isinstance(part, dict): - normalized.append({"type": "input_text", "text": str(part)}) - continue - if part.get("type") in {"input_text", "output_text", "text", "tool_call", "function_call"}: - normalized.append(part) - continue - if "text" in part: - normalized.append({"type": part.get("type", "input_text"), "text": part.get("text", "")}) - return normalized - if isinstance(content, dict): - if "text" in content: - return [{"type": content.get("type", "input_text"), "text": content.get("text", "")}] - return [{"type": "input_text", "text": json_dumps(content)}] - return [{"type": "input_text", "text": str(content)}] - - -def normalize_input_items(value: Any) -> list[dict[str, Any]]: - if value is None: - return [] - if isinstance(value, str): - return [{"type": "message", "role": "user", "content": [{"type": "input_text", "text": value}]}] - if isinstance(value, dict): - value = [value] - - items: list[dict[str, Any]] = [] - for item in value: - if isinstance(item, str): - items.append({"type": "message", "role": "user", "content": [{"type": "input_text", "text": item}]}) - continue - if not isinstance(item, dict): - items.append({"type": "message", "role": "user", "content": [{"type": "input_text", "text": str(item)}]}) - continue - - item_type = item.get("type") - if item_type == "message" or item.get("role"): - role = item.get("role", "user") - items.append({"type": "message", "role": role, "content": normalize_content(item.get("content"), role)}) - continue - if item_type == "function_call_output": - output = item.get("output") - if not isinstance(output, str): - output = json_dumps(output) if output is not None else "" - items.append({"type": "function_call_output", "call_id": item.get("call_id"), "output": output}) - continue - if item_type == "function_call": - arguments = item.get("arguments", "{}") - if not isinstance(arguments, str): - arguments = json_dumps(arguments) - items.append( - { - "type": "function_call", - "call_id": item.get("call_id") or f"call_{uuid.uuid4().hex[:12]}", - "name": item.get("name"), - "arguments": arguments, - } - ) - continue - if item_type in {"input_text", "output_text", "text"}: - items.append({"type": "message", "role": "user", "content": [{"type": "input_text", "text": item.get("text", "")}]}) - continue - items.append({"type": "message", "role": "user", "content": [{"type": "input_text", "text": json_dumps(item)}]}) - return items - - -def extract_text_from_content(content: Any) -> str: - if content is None: - return "" - if isinstance(content, str): - return content - if isinstance(content, dict): - if "text" in content: - return str(content.get("text", "")) - return json_dumps(content) - if isinstance(content, list): - chunks: list[str] = [] - for part in content: - if isinstance(part, str): - chunks.append(part) - elif isinstance(part, dict) and part.get("type") in {"input_text", "output_text", "text"}: - chunks.append(str(part.get("text", ""))) - return "\n".join(filter(None, chunks)) - return str(content) - - -def load_previous_conversation_items(conn: sqlite3.Connection, previous_response_id: str | None) -> list[dict[str, Any]]: - if not previous_response_id: - return [] - records: list[sqlite3.Row] = [] - current = previous_response_id - while current: - row = conn.execute("SELECT * FROM response_records WHERE response_id = ?", (current,)).fetchone() - if not row: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"previous_response_id '{current}' was not found.") - records.append(row) - current = row["parent_response_id"] - - items: list[dict[str, Any]] = [] - for row in reversed(records): - items.extend(json.loads(row["input_items_json"])) - items.extend(json.loads(row["output_items_json"])) - return items - - -def items_to_chat_messages(items: list[dict[str, Any]]) -> list[dict[str, Any]]: - messages: list[dict[str, Any]] = [] - pending_tool_calls: list[dict[str, Any]] = [] - - def flush_pending_tool_calls() -> None: - nonlocal pending_tool_calls - if pending_tool_calls: - messages.append({"role": "assistant", "content": "", "tool_calls": pending_tool_calls}) - pending_tool_calls = [] - - for item in items: - item_type = item.get("type") - if item_type == "function_call": - pending_tool_calls.append( - { - "id": item.get("call_id") or f"call_{uuid.uuid4().hex[:12]}", - "type": "function", - "function": {"name": item.get("name"), "arguments": item.get("arguments", "{}")}, - } - ) - continue - if item_type == "function_call_output": - flush_pending_tool_calls() - messages.append({"role": "tool", "tool_call_id": item.get("call_id"), "content": item.get("output", "")}) - continue - if item_type != "message": - continue - flush_pending_tool_calls() - role = item.get("role", "user") - text_value = extract_text_from_content(item.get("content")) - if role in {"system", "developer"}: - messages.append({"role": "system", "content": text_value}) - elif role == "assistant": - messages.append({"role": "assistant", "content": text_value}) - else: - messages.append({"role": role, "content": text_value}) - - flush_pending_tool_calls() - return [message for message in messages if message.get("content") is not None or message.get("tool_calls")] - - -def response_tools_to_chat_tools(tools: Any) -> list[dict[str, Any]]: - normalized: list[dict[str, Any]] = [] - for tool in tools or []: - if not isinstance(tool, dict) or tool.get("type") != "function": - continue - function_payload = tool.get("function") if isinstance(tool.get("function"), dict) else tool - name = function_payload.get("name") - if not name: - continue - normalized.append( - { - "type": "function", - "function": { - "name": name, - "description": function_payload.get("description"), - "parameters": function_payload.get("parameters") or {"type": "object", "properties": {}}, - }, - } - ) - return normalized - - -def normalize_tool_choice(tool_choice: Any, tools: list[dict[str, Any]]) -> tuple[Any, list[dict[str, Any]]]: - if tool_choice is None: - return None, tools - if isinstance(tool_choice, str): - return tool_choice, tools - if not isinstance(tool_choice, dict): - return None, tools - if tool_choice.get("type") == "function": - function_name = tool_choice.get("name") or (tool_choice.get("function") or {}).get("name") - if function_name: - return {"type": "function", "function": {"name": function_name}}, tools - if tool_choice.get("type") == "allowed_tools": - allowed = tool_choice.get("tools") or [] - allowed_names = { - entry if isinstance(entry, str) else entry.get("name") - for entry in allowed - if entry is not None - } - filtered_tools = [tool for tool in tools if tool["function"]["name"] in allowed_names] - mode = tool_choice.get("mode", "auto") - return mode if isinstance(mode, str) else "auto", filtered_tools - return None, tools - - -def build_chat_payload(body: dict[str, Any], items: list[dict[str, Any]]) -> dict[str, Any]: - tools = response_tools_to_chat_tools(body.get("tools")) - tool_choice, tools = normalize_tool_choice(body.get("tool_choice"), tools) - payload: dict[str, Any] = {"model": body.get("model"), "messages": items_to_chat_messages(items)} - if tools: - payload["tools"] = tools - if tool_choice is not None: - payload["tool_choice"] = tool_choice - if body.get("temperature") is not None: - payload["temperature"] = body.get("temperature") - if body.get("top_p") is not None: - payload["top_p"] = body.get("top_p") - if body.get("parallel_tool_calls") is not None: - payload["parallel_tool_calls"] = body.get("parallel_tool_calls") - if body.get("max_output_tokens") is not None: - payload["max_tokens"] = body.get("max_output_tokens") - if body.get("instructions"): - payload["messages"] = [{"role": "system", "content": body["instructions"]}] + payload["messages"] - text_config = body.get("text") or {} - text_format = text_config.get("format") if isinstance(text_config, dict) else None - if isinstance(text_format, dict): - if text_format.get("type") == "json_object": - payload["response_format"] = {"type": "json_object"} - elif text_format.get("type") == "json_schema": - payload["response_format"] = {"type": "json_schema", "json_schema": text_format.get("json_schema") or {}} - return payload - - -def extract_upstream_message(upstream_json: dict[str, Any]) -> tuple[dict[str, Any], str | None]: - choices = upstream_json.get("choices") or [] - if not choices: - return {}, None - choice = choices[0] or {} - return choice.get("message") or {}, choice.get("finish_reason") - - -def extract_text_and_tool_calls(message: dict[str, Any]) -> tuple[str, list[dict[str, Any]]]: - content = message.get("content") - text_chunks: list[str] = [] - tool_calls: list[dict[str, Any]] = [] - - if isinstance(content, str): - text_chunks.append(content) - elif isinstance(content, list): - for part in content: - if isinstance(part, str): - text_chunks.append(part) - continue - if not isinstance(part, dict): - text_chunks.append(str(part)) - continue - if part.get("type") in {"input_text", "output_text", "text"}: - text_chunks.append(str(part.get("text", ""))) - continue - if part.get("type") in {"tool_call", "function_call"}: - arguments = part.get("arguments") or "{}" - if not isinstance(arguments, str): - arguments = json_dumps(arguments) - tool_calls.append({"id": part.get("id") or part.get("call_id") or f"call_{uuid.uuid4().hex[:12]}", "name": part.get("name"), "arguments": arguments}) - - for tool_call in message.get("tool_calls") or []: - if not isinstance(tool_call, dict): - continue - function_data = tool_call.get("function") or {} - arguments = function_data.get("arguments") or tool_call.get("arguments") or "{}" - if not isinstance(arguments, str): - arguments = json_dumps(arguments) - tool_calls.append({"id": tool_call.get("id") or f"call_{uuid.uuid4().hex[:12]}", "name": function_data.get("name") or tool_call.get("name"), "arguments": arguments}) - - deduped: list[dict[str, Any]] = [] - seen_ids: set[str] = set() - for tool_call in tool_calls: - if tool_call["id"] in seen_ids: - continue - seen_ids.add(tool_call["id"]) - deduped.append(tool_call) - return "\n".join(filter(None, text_chunks)).strip(), deduped - - -def build_choice_alias(output_items: list[dict[str, Any]], finish_reason: str | None) -> list[dict[str, Any]]: - content_parts: list[dict[str, Any]] = [] - for item in output_items: - if item.get("type") == "message": - for part in item.get("content", []): - content_parts.append({"type": part.get("type", "output_text"), "text": part.get("text", "")}) - elif item.get("type") == "function_call": - arguments = item.get("arguments") or "{}" - try: - parsed_arguments = json.loads(arguments) - except Exception: - parsed_arguments = arguments - content_parts.append({"type": "tool_call", "id": item.get("call_id"), "name": item.get("name"), "arguments": parsed_arguments}) - return [{"index": 0, "message": {"role": "assistant", "content": content_parts}, "finish_reason": finish_reason or "stop"}] - - -def chat_completion_to_response(body: dict[str, Any], upstream_json: dict[str, Any], previous_response_id: str | None) -> dict[str, Any]: - upstream_message, finish_reason = extract_upstream_message(upstream_json) - assistant_text, tool_calls = extract_text_and_tool_calls(upstream_message) - response_id = upstream_json.get("id") or f"resp_{uuid.uuid4().hex}" - output_items: list[dict[str, Any]] = [] - if assistant_text: - output_items.append({"id": f"msg_{uuid.uuid4().hex[:24]}", "type": "message", "status": "completed", "role": "assistant", "content": [{"type": "output_text", "text": assistant_text, "annotations": []}]}) - for tool_call in tool_calls: - output_items.append({"id": f"fc_{uuid.uuid4().hex[:24]}", "type": "function_call", "status": "completed", "call_id": tool_call["id"], "name": tool_call.get("name"), "arguments": tool_call.get("arguments", "{}")}) - usage = upstream_json.get("usage") or {} - return { - "id": response_id, - "object": "response", - "created_at": int(time.time()), - "status": "completed", - "model": body.get("model"), - "output": output_items, - "output_text": assistant_text, - "parallel_tool_calls": bool(body.get("parallel_tool_calls", True)), - "previous_response_id": previous_response_id, - "store": True, - "text": body.get("text") or {"format": {"type": "text"}}, - "usage": {"input_tokens": usage.get("prompt_tokens"), "output_tokens": usage.get("completion_tokens"), "total_tokens": usage.get("total_tokens")}, - "choices": build_choice_alias(output_items, finish_reason), - "upstream": {"id": upstream_json.get("id"), "object": upstream_json.get("object", "chat.completion"), "finish_reason": finish_reason or "stop"}, + +def normalize_content(content: Any, role: str) -> list[dict[str, Any]]: + if content is None: + return [] + if isinstance(content, str): + return [{"type": "output_text" if role == "assistant" else "input_text", "text": content}] + if isinstance(content, list): + normalized: list[dict[str, Any]] = [] + for part in content: + if isinstance(part, str): + normalized.append({"type": "output_text" if role == "assistant" else "input_text", "text": part}) + continue + if not isinstance(part, dict): + normalized.append({"type": "input_text", "text": str(part)}) + continue + if part.get("type") in {"input_text", "output_text", "text", "tool_call", "function_call"}: + normalized.append(part) + continue + if "text" in part: + normalized.append({"type": part.get("type", "input_text"), "text": part.get("text", "")}) + return normalized + if isinstance(content, dict): + if "text" in content: + return [{"type": content.get("type", "input_text"), "text": content.get("text", "")}] + return [{"type": "input_text", "text": json_dumps(content)}] + return [{"type": "input_text", "text": str(content)}] + + +def normalize_input_items(value: Any) -> list[dict[str, Any]]: + if value is None: + return [] + if isinstance(value, str): + return [{"type": "message", "role": "user", "content": [{"type": "input_text", "text": value}]}] + if isinstance(value, dict): + value = [value] + + items: list[dict[str, Any]] = [] + for item in value: + if isinstance(item, str): + items.append({"type": "message", "role": "user", "content": [{"type": "input_text", "text": item}]}) + continue + if not isinstance(item, dict): + items.append({"type": "message", "role": "user", "content": [{"type": "input_text", "text": str(item)}]}) + continue + + item_type = item.get("type") + if item_type == "message" or item.get("role"): + role = item.get("role", "user") + items.append({"type": "message", "role": role, "content": normalize_content(item.get("content"), role)}) + continue + if item_type == "function_call_output": + output = item.get("output") + if not isinstance(output, str): + output = json_dumps(output) if output is not None else "" + items.append({"type": "function_call_output", "call_id": item.get("call_id"), "output": output}) + continue + if item_type == "function_call": + arguments = item.get("arguments", "{}") + if not isinstance(arguments, str): + arguments = json_dumps(arguments) + items.append( + { + "type": "function_call", + "call_id": item.get("call_id") or f"call_{uuid.uuid4().hex[:12]}", + "name": item.get("name"), + "arguments": arguments, + } + ) + continue + if item_type in {"input_text", "output_text", "text"}: + items.append({"type": "message", "role": "user", "content": [{"type": "input_text", "text": item.get("text", "")}]}) + continue + items.append({"type": "message", "role": "user", "content": [{"type": "input_text", "text": json_dumps(item)}]}) + return items + + +def extract_text_from_content(content: Any) -> str: + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, dict): + if "text" in content: + return str(content.get("text", "")) + return json_dumps(content) + if isinstance(content, list): + chunks: list[str] = [] + for part in content: + if isinstance(part, str): + chunks.append(part) + elif isinstance(part, dict) and part.get("type") in {"input_text", "output_text", "text"}: + chunks.append(str(part.get("text", ""))) + return "\n".join(filter(None, chunks)) + return str(content) + + +def load_previous_conversation_items(conn: sqlite3.Connection, previous_response_id: str | None) -> list[dict[str, Any]]: + if not previous_response_id: + return [] + records: list[sqlite3.Row] = [] + current = previous_response_id + while current: + row = conn.execute("SELECT * FROM response_records WHERE response_id = ?", (current,)).fetchone() + if not row: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"previous_response_id '{current}' was not found.") + records.append(row) + current = row["parent_response_id"] + + items: list[dict[str, Any]] = [] + for row in reversed(records): + items.extend(json.loads(row["input_items_json"])) + items.extend(json.loads(row["output_items_json"])) + return items + + +def items_to_chat_messages(items: list[dict[str, Any]]) -> list[dict[str, Any]]: + messages: list[dict[str, Any]] = [] + pending_tool_calls: list[dict[str, Any]] = [] + + def flush_pending_tool_calls() -> None: + nonlocal pending_tool_calls + if pending_tool_calls: + messages.append({"role": "assistant", "content": "", "tool_calls": pending_tool_calls}) + pending_tool_calls = [] + + for item in items: + item_type = item.get("type") + if item_type == "function_call": + pending_tool_calls.append( + { + "id": item.get("call_id") or f"call_{uuid.uuid4().hex[:12]}", + "type": "function", + "function": {"name": item.get("name"), "arguments": item.get("arguments", "{}")}, + } + ) + continue + if item_type == "function_call_output": + flush_pending_tool_calls() + messages.append({"role": "tool", "tool_call_id": item.get("call_id"), "content": item.get("output", "")}) + continue + if item_type != "message": + continue + flush_pending_tool_calls() + role = item.get("role", "user") + text_value = extract_text_from_content(item.get("content")) + if role in {"system", "developer"}: + messages.append({"role": "system", "content": text_value}) + elif role == "assistant": + messages.append({"role": "assistant", "content": text_value}) + else: + messages.append({"role": role, "content": text_value}) + + flush_pending_tool_calls() + return [message for message in messages if message.get("content") is not None or message.get("tool_calls")] + + +def response_tools_to_chat_tools(tools: Any) -> list[dict[str, Any]]: + normalized: list[dict[str, Any]] = [] + for tool in tools or []: + if not isinstance(tool, dict) or tool.get("type") != "function": + continue + function_payload = tool.get("function") if isinstance(tool.get("function"), dict) else tool + name = function_payload.get("name") + if not name: + continue + normalized.append( + { + "type": "function", + "function": { + "name": name, + "description": function_payload.get("description"), + "parameters": function_payload.get("parameters") or {"type": "object", "properties": {}}, + }, + } + ) + return normalized + + +def normalize_tool_choice(tool_choice: Any, tools: list[dict[str, Any]]) -> tuple[Any, list[dict[str, Any]]]: + if tool_choice is None: + return None, tools + if isinstance(tool_choice, str): + return tool_choice, tools + if not isinstance(tool_choice, dict): + return None, tools + if tool_choice.get("type") == "function": + function_name = tool_choice.get("name") or (tool_choice.get("function") or {}).get("name") + if function_name: + return {"type": "function", "function": {"name": function_name}}, tools + if tool_choice.get("type") == "allowed_tools": + allowed = tool_choice.get("tools") or [] + allowed_names = { + entry if isinstance(entry, str) else entry.get("name") + for entry in allowed + if entry is not None + } + filtered_tools = [tool for tool in tools if tool["function"]["name"] in allowed_names] + mode = tool_choice.get("mode", "auto") + return mode if isinstance(mode, str) else "auto", filtered_tools + return None, tools + + +def build_chat_payload(body: dict[str, Any], items: list[dict[str, Any]]) -> dict[str, Any]: + tools = response_tools_to_chat_tools(body.get("tools")) + tool_choice, tools = normalize_tool_choice(body.get("tool_choice"), tools) + payload: dict[str, Any] = {"model": body.get("model"), "messages": items_to_chat_messages(items)} + if tools: + payload["tools"] = tools + if tool_choice is not None: + payload["tool_choice"] = tool_choice + if body.get("temperature") is not None: + payload["temperature"] = body.get("temperature") + if body.get("top_p") is not None: + payload["top_p"] = body.get("top_p") + if body.get("parallel_tool_calls") is not None: + payload["parallel_tool_calls"] = body.get("parallel_tool_calls") + if body.get("max_output_tokens") is not None: + payload["max_tokens"] = body.get("max_output_tokens") + if body.get("instructions"): + payload["messages"] = [{"role": "system", "content": body["instructions"]}] + payload["messages"] + text_config = body.get("text") or {} + text_format = text_config.get("format") if isinstance(text_config, dict) else None + if isinstance(text_format, dict): + if text_format.get("type") == "json_object": + payload["response_format"] = {"type": "json_object"} + elif text_format.get("type") == "json_schema": + payload["response_format"] = {"type": "json_schema", "json_schema": text_format.get("json_schema") or {}} + return payload + + +def extract_upstream_message(upstream_json: dict[str, Any]) -> tuple[dict[str, Any], str | None]: + choices = upstream_json.get("choices") or [] + if not choices: + return {}, None + choice = choices[0] or {} + return choice.get("message") or {}, choice.get("finish_reason") + + +def extract_text_and_tool_calls(message: dict[str, Any]) -> tuple[str, list[dict[str, Any]]]: + content = message.get("content") + text_chunks: list[str] = [] + tool_calls: list[dict[str, Any]] = [] + + if isinstance(content, str): + text_chunks.append(content) + elif isinstance(content, list): + for part in content: + if isinstance(part, str): + text_chunks.append(part) + continue + if not isinstance(part, dict): + text_chunks.append(str(part)) + continue + if part.get("type") in {"input_text", "output_text", "text"}: + text_chunks.append(str(part.get("text", ""))) + continue + if part.get("type") in {"tool_call", "function_call"}: + arguments = part.get("arguments") or "{}" + if not isinstance(arguments, str): + arguments = json_dumps(arguments) + tool_calls.append({"id": part.get("id") or part.get("call_id") or f"call_{uuid.uuid4().hex[:12]}", "name": part.get("name"), "arguments": arguments}) + + for tool_call in message.get("tool_calls") or []: + if not isinstance(tool_call, dict): + continue + function_data = tool_call.get("function") or {} + arguments = function_data.get("arguments") or tool_call.get("arguments") or "{}" + if not isinstance(arguments, str): + arguments = json_dumps(arguments) + tool_calls.append({"id": tool_call.get("id") or f"call_{uuid.uuid4().hex[:12]}", "name": function_data.get("name") or tool_call.get("name"), "arguments": arguments}) + + deduped: list[dict[str, Any]] = [] + seen_ids: set[str] = set() + for tool_call in tool_calls: + if tool_call["id"] in seen_ids: + continue + seen_ids.add(tool_call["id"]) + deduped.append(tool_call) + return "\n".join(filter(None, text_chunks)).strip(), deduped + + +def build_choice_alias(output_items: list[dict[str, Any]], finish_reason: str | None) -> list[dict[str, Any]]: + content_parts: list[dict[str, Any]] = [] + for item in output_items: + if item.get("type") == "message": + for part in item.get("content", []): + content_parts.append({"type": part.get("type", "output_text"), "text": part.get("text", "")}) + elif item.get("type") == "function_call": + arguments = item.get("arguments") or "{}" + try: + parsed_arguments = json.loads(arguments) + except Exception: + parsed_arguments = arguments + content_parts.append({"type": "tool_call", "id": item.get("call_id"), "name": item.get("name"), "arguments": parsed_arguments}) + return [{"index": 0, "message": {"role": "assistant", "content": content_parts}, "finish_reason": finish_reason or "stop"}] + + +def chat_completion_to_response(body: dict[str, Any], upstream_json: dict[str, Any], previous_response_id: str | None) -> dict[str, Any]: + upstream_message, finish_reason = extract_upstream_message(upstream_json) + assistant_text, tool_calls = extract_text_and_tool_calls(upstream_message) + response_id = upstream_json.get("id") or f"resp_{uuid.uuid4().hex}" + output_items: list[dict[str, Any]] = [] + if assistant_text: + output_items.append({"id": f"msg_{uuid.uuid4().hex[:24]}", "type": "message", "status": "completed", "role": "assistant", "content": [{"type": "output_text", "text": assistant_text, "annotations": []}]}) + for tool_call in tool_calls: + output_items.append({"id": f"fc_{uuid.uuid4().hex[:24]}", "type": "function_call", "status": "completed", "call_id": tool_call["id"], "name": tool_call.get("name"), "arguments": tool_call.get("arguments", "{}")}) + usage = upstream_json.get("usage") or {} + return { + "id": response_id, + "object": "response", + "created_at": int(time.time()), + "status": "completed", + "model": body.get("model"), + "output": output_items, + "output_text": assistant_text, + "parallel_tool_calls": bool(body.get("parallel_tool_calls", True)), + "previous_response_id": previous_response_id, + "store": True, + "text": body.get("text") or {"format": {"type": "text"}}, + "usage": {"input_tokens": usage.get("prompt_tokens"), "output_tokens": usage.get("completion_tokens"), "total_tokens": usage.get("total_tokens")}, + "choices": build_choice_alias(output_items, finish_reason), + "upstream": {"id": upstream_json.get("id"), "object": upstream_json.get("object", "chat.completion"), "finish_reason": finish_reason or "stop"}, } - -def store_response_record(conn: sqlite3.Connection, response_payload: dict[str, Any], request_body: dict[str, Any], input_items: list[dict[str, Any]], model_row: sqlite3.Row, api_key_row: sqlite3.Row) -> None: - conn.execute( - """ - INSERT OR REPLACE INTO response_records ( - response_id, parent_response_id, model_id, api_key_id, request_json, - input_items_json, output_json, output_items_json, status, created_at - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - response_payload["id"], - request_body.get("previous_response_id"), - model_row["id"], - api_key_row["id"], - json_dumps(request_body), - json_dumps(input_items), - json_dumps(response_payload), - json_dumps(response_payload.get("output") or []), - response_payload.get("status", "completed"), - utcnow_iso(), - ), - ) - - -def update_usage_stats(conn: sqlite3.Connection, model_row: sqlite3.Row, api_key_row: sqlite3.Row, *, ok: bool, latency_ms: float | None, is_healthcheck: bool) -> None: - now = utcnow_iso() - if is_healthcheck: - conn.execute( - """ - UPDATE proxy_models - SET healthcheck_count = healthcheck_count + 1, - healthcheck_success_count = healthcheck_success_count + ?, - last_healthcheck_at = ?, - last_health_status = ?, - last_latency_ms = ?, - updated_at = ? - WHERE id = ? - """, - (1 if ok else 0, now, 1 if ok else 0, latency_ms, now, model_row["id"]), - ) - conn.execute( - """ - UPDATE api_keys - SET healthcheck_count = healthcheck_count + 1, - healthcheck_success_count = healthcheck_success_count + ?, - last_tested_at = ?, - last_latency_ms = ?, - updated_at = ? - WHERE id = ? - """, - (1 if ok else 0, now, latency_ms, now, api_key_row["id"]), - ) - return - conn.execute( - """ - UPDATE proxy_models - SET request_count = request_count + 1, - success_count = success_count + ?, - failure_count = failure_count + ?, - last_used_at = ?, - last_latency_ms = ?, - updated_at = ? - WHERE id = ? - """, - (1 if ok else 0, 0 if ok else 1, now, latency_ms, now, model_row["id"]), - ) - conn.execute( - """ - UPDATE api_keys - SET request_count = request_count + 1, - success_count = success_count + ?, - failure_count = failure_count + ?, - last_used_at = ?, - last_latency_ms = ?, - updated_at = ? - WHERE id = ? - """, - (1 if ok else 0, 0 if ok else 1, now, latency_ms, now, api_key_row["id"]), - ) - - -def insert_health_record(conn: sqlite3.Connection, model_row: sqlite3.Row, api_key_row: sqlite3.Row, *, ok: bool, status_code: int | None, latency_ms: float | None, error_message: str | None, response_excerpt: str | None) -> None: - conn.execute( - """ - INSERT INTO health_check_records ( - model_id, api_key_id, ok, status_code, latency_ms, error_message, response_excerpt, checked_at - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) - """, - (model_row["id"], api_key_row["id"], 1 if ok else 0, status_code, latency_ms, error_message, response_excerpt, utcnow_iso()), - ) - - -async def post_nvidia_chat_completion(api_key: str, payload: dict[str, Any]) -> tuple[dict[str, Any], float]: - started = time.perf_counter() - async with httpx.AsyncClient(timeout=REQUEST_TIMEOUT_SECONDS) as client: - response = await client.post(CHAT_COMPLETIONS_URL, headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}, json=payload) - latency_ms = round((time.perf_counter() - started) * 1000, 2) - if response.status_code >= 400: - try: - error_payload = response.json() - detail = error_payload.get("error", {}).get("message") or json_dumps(error_payload) - except Exception: - detail = response.text - raise HTTPException(status_code=response.status_code, detail=f"NVIDIA NIM request failed: {detail}") - return response.json(), latency_ms - - -async def perform_healthcheck(conn: sqlite3.Connection, model_row: sqlite3.Row, api_key_row: sqlite3.Row, prompt: str) -> dict[str, Any]: - payload = {"model": model_row["model_id"], "messages": [{"role": "user", "content": prompt}], "max_tokens": 32, "temperature": 0} - try: - upstream_json, latency_ms = await post_nvidia_chat_completion(api_key_row["api_key"], payload) - message, _finish_reason = extract_upstream_message(upstream_json) - assistant_text, _tool_calls = extract_text_and_tool_calls(message) - ok = True - detail = assistant_text or "Model responded successfully." - status_code = 200 - error_message = None - response_excerpt = detail[:200] - except HTTPException as exc: - ok = False - latency_ms = None - detail = exc.detail - status_code = exc.status_code - error_message = exc.detail - response_excerpt = None - update_usage_stats(conn, model_row, api_key_row, ok=ok, latency_ms=latency_ms, is_healthcheck=True) - insert_health_record(conn, model_row, api_key_row, ok=ok, status_code=status_code, latency_ms=latency_ms, error_message=error_message, response_excerpt=response_excerpt) - conn.commit() - return {"model": model_row["model_id"], "display_name": model_row["display_name"], "api_key": api_key_row["name"], "status": "healthy" if ok else "down", "ok": ok, "latency": latency_ms, "status_code": status_code, "detail": detail, "checked_at": utcnow_iso()} - - -async def run_healthchecks(model_identifier: str | int | None = None, api_key_identifier: str | int | None = None, prompt: str | None = None) -> list[dict[str, Any]]: - conn = get_db_connection() - try: - settings_payload = get_settings_payload(conn) - effective_prompt = prompt or settings_payload["healthcheck_prompt"] - if api_key_identifier is not None: - api_key_row = fetch_key_by_identifier(conn, api_key_identifier, enabled_only=True) - if not api_key_row: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="未找到 API Key。") - key_rows = [api_key_row] - else: - key_rows = conn.execute("SELECT * FROM api_keys WHERE enabled = 1 ORDER BY id ASC").fetchall() - if not key_rows: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="No enabled NVIDIA NIM API keys are configured.") - if model_identifier is not None: - model_row = fetch_model_by_identifier(conn, model_identifier, enabled_only=True) - if not model_row: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="未找到模型。") - model_rows = [model_row] - else: - model_rows = conn.execute("SELECT * FROM proxy_models WHERE enabled = 1 ORDER BY sort_order ASC, model_id ASC").fetchall() - results: list[dict[str, Any]] = [] - for index, model_row in enumerate(model_rows): - api_key_row = key_rows[index % len(key_rows)] - results.append(await perform_healthcheck(conn, model_row, api_key_row, effective_prompt)) - return results - finally: - conn.close() - - -def build_public_health_payload(hours: int | None = None) -> dict[str, Any]: - conn = get_db_connection() - try: - settings_payload = get_settings_payload(conn) - effective_hours = hours or settings_payload["public_history_hours"] - since = utcnow() - timedelta(hours=effective_hours) - models = conn.execute("SELECT * FROM proxy_models WHERE enabled = 1 ORDER BY sort_order ASC, model_id ASC").fetchall() - result_models: list[dict[str, Any]] = [] - last_updated: str | None = None - for model in models: - rows = conn.execute("SELECT * FROM health_check_records WHERE model_id = ? AND checked_at >= ? ORDER BY checked_at ASC", (model["id"], since.isoformat())).fetchall() - hourly = [] - ok_count = 0 - for row in rows: - status_name = "healthy" if row["ok"] else "down" - hourly.append({"time": row["checked_at"], "status": status_name, "latency": row["latency_ms"]}) - ok_count += 1 if row["ok"] else 0 - last_updated = row["checked_at"] - total = len(rows) - success_rate = round((ok_count / total) * 100, 1) if total else 0.0 - model_status = "unknown" if model["last_health_status"] is None else ("healthy" if model["last_health_status"] else "down") - result_models.append({"id": model["id"], "model_id": model["model_id"], "name": model["display_name"], "display_name": model["display_name"], "endpoint": "/v1/responses", "status": model_status, "beat": f"{success_rate}%", "hourly": hourly, "last_health_status": None if model["last_health_status"] is None else bool(model["last_health_status"]), "last_healthcheck_at": model["last_healthcheck_at"], "success_rate": success_rate, "points": [{"hour": entry["time"], "label": parse_datetime(entry["time"]).strftime("%H:%M") if parse_datetime(entry["time"]) else entry["time"], "ok": entry["status"] == "healthy", "latency_ms": entry["latency"]} for entry in hourly]}) - return {"generated_at": utcnow_iso(), "last_updated": last_updated, "hours": effective_hours, "models": result_models} - finally: - conn.close() - - -def schedule_healthchecks() -> None: - conn = get_db_connection() - try: - settings_payload = get_settings_payload(conn) - finally: - conn.close() - interval = max(5, int(settings_payload["healthcheck_interval_minutes"])) - enabled = bool(settings_payload["healthcheck_enabled"]) - if scheduler.get_job("nim-hourly-healthcheck"): - scheduler.remove_job("nim-hourly-healthcheck") - if enabled: - scheduler.add_job(run_healthchecks, "interval", minutes=interval, id="nim-hourly-healthcheck", replace_existing=True, next_run_time=utcnow() + timedelta(seconds=10)) - - + +def store_response_record(conn: sqlite3.Connection, response_payload: dict[str, Any], request_body: dict[str, Any], input_items: list[dict[str, Any]], model_row: sqlite3.Row, api_key_row: sqlite3.Row) -> None: + conn.execute( + """ + INSERT OR REPLACE INTO response_records ( + response_id, parent_response_id, model_id, api_key_id, request_json, + input_items_json, output_json, output_items_json, status, created_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + response_payload["id"], + request_body.get("previous_response_id"), + model_row["id"], + api_key_row["id"], + json_dumps(request_body), + json_dumps(input_items), + json_dumps(response_payload), + json_dumps(response_payload.get("output") or []), + response_payload.get("status", "completed"), + utcnow_iso(), + ), + ) + + +def update_usage_stats(conn: sqlite3.Connection, model_row: sqlite3.Row, api_key_row: sqlite3.Row, *, ok: bool, latency_ms: float | None, is_healthcheck: bool) -> None: + now = utcnow_iso() + if is_healthcheck: + conn.execute( + """ + UPDATE proxy_models + SET healthcheck_count = healthcheck_count + 1, + healthcheck_success_count = healthcheck_success_count + ?, + last_healthcheck_at = ?, + last_health_status = ?, + last_latency_ms = ?, + updated_at = ? + WHERE id = ? + """, + (1 if ok else 0, now, 1 if ok else 0, latency_ms, now, model_row["id"]), + ) + conn.execute( + """ + UPDATE api_keys + SET healthcheck_count = healthcheck_count + 1, + healthcheck_success_count = healthcheck_success_count + ?, + last_tested_at = ?, + last_latency_ms = ?, + updated_at = ? + WHERE id = ? + """, + (1 if ok else 0, now, latency_ms, now, api_key_row["id"]), + ) + return + conn.execute( + """ + UPDATE proxy_models + SET request_count = request_count + 1, + success_count = success_count + ?, + failure_count = failure_count + ?, + last_used_at = ?, + last_latency_ms = ?, + updated_at = ? + WHERE id = ? + """, + (1 if ok else 0, 0 if ok else 1, now, latency_ms, now, model_row["id"]), + ) + conn.execute( + """ + UPDATE api_keys + SET request_count = request_count + 1, + success_count = success_count + ?, + failure_count = failure_count + ?, + last_used_at = ?, + last_latency_ms = ?, + updated_at = ? + WHERE id = ? + """, + (1 if ok else 0, 0 if ok else 1, now, latency_ms, now, api_key_row["id"]), + ) + + +def insert_health_record(conn: sqlite3.Connection, model_row: sqlite3.Row, api_key_row: sqlite3.Row, *, ok: bool, status_code: int | None, latency_ms: float | None, error_message: str | None, response_excerpt: str | None) -> None: + conn.execute( + """ + INSERT INTO health_check_records ( + model_id, api_key_id, ok, status_code, latency_ms, error_message, response_excerpt, checked_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + (model_row["id"], api_key_row["id"], 1 if ok else 0, status_code, latency_ms, error_message, response_excerpt, utcnow_iso()), + ) + + +async def post_nvidia_chat_completion(api_key: str, payload: dict[str, Any]) -> tuple[dict[str, Any], float]: + started = time.perf_counter() + client = await get_http_client() + response = await client.post( + CHAT_COMPLETIONS_URL, + headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}, + json=payload, + ) + latency_ms = round((time.perf_counter() - started) * 1000, 2) + if response.status_code >= 400: + try: + error_payload = response.json() + detail = error_payload.get("error", {}).get("message") or json_dumps(error_payload) + except Exception: + detail = response.text + raise HTTPException(status_code=response.status_code, detail=f"NVIDIA NIM 请求失败:{detail}") + return response.json(), latency_ms + + +async def perform_healthcheck(conn: sqlite3.Connection, model_row: sqlite3.Row, api_key_row: sqlite3.Row, prompt: str) -> dict[str, Any]: + payload = {"model": model_row["model_id"], "messages": [{"role": "user", "content": prompt}], "max_tokens": 32, "temperature": 0} + try: + upstream_json, latency_ms = await post_nvidia_chat_completion(api_key_row["api_key"], payload) + message, _finish_reason = extract_upstream_message(upstream_json) + assistant_text, _tool_calls = extract_text_and_tool_calls(message) + ok = True + detail = assistant_text or "模型响应正常。" + status_code = 200 + error_message = None + response_excerpt = detail[:200] + except HTTPException as exc: + ok = False + latency_ms = None + detail = exc.detail + status_code = exc.status_code + error_message = exc.detail + response_excerpt = None + update_usage_stats(conn, model_row, api_key_row, ok=ok, latency_ms=latency_ms, is_healthcheck=True) + insert_health_record(conn, model_row, api_key_row, ok=ok, status_code=status_code, latency_ms=latency_ms, error_message=error_message, response_excerpt=response_excerpt) + conn.commit() + return {"model": model_row["model_id"], "display_name": model_row["display_name"], "api_key": api_key_row["name"], "status": "healthy" if ok else "down", "ok": ok, "latency": latency_ms, "status_code": status_code, "detail": detail, "checked_at": utcnow_iso()} + + +async def run_healthchecks(model_identifier: str | int | None = None, api_key_identifier: str | int | None = None, prompt: str | None = None) -> list[dict[str, Any]]: + conn = get_db_connection() + try: + settings_payload = get_settings_payload(conn) + effective_prompt = prompt or settings_payload["healthcheck_prompt"] + if api_key_identifier is not None: + api_key_row = fetch_key_by_identifier(conn, api_key_identifier, enabled_only=True) + if not api_key_row: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="未找到 API Key。") + key_rows = [api_key_row] + else: + key_rows = conn.execute("SELECT * FROM api_keys WHERE enabled = 1 ORDER BY id ASC").fetchall() + if not key_rows: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="??????? NVIDIA NIM Key?") + if model_identifier is not None: + model_row = fetch_model_by_identifier(conn, model_identifier, enabled_only=True) + if not model_row: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="未找到模型。") + model_rows = [model_row] + else: + model_rows = conn.execute("SELECT * FROM proxy_models WHERE enabled = 1 ORDER BY sort_order ASC, model_id ASC").fetchall() + results: list[dict[str, Any]] = [] + for index, model_row in enumerate(model_rows): + api_key_row = key_rows[index % len(key_rows)] + results.append(await perform_healthcheck(conn, model_row, api_key_row, effective_prompt)) + return results + finally: + conn.close() + + +def build_public_health_payload(hours: int | None = None) -> dict[str, Any]: + conn = get_db_connection() + try: + settings_payload = get_settings_payload(conn) + effective_hours = hours or settings_payload["public_history_hours"] + since = utcnow() - timedelta(hours=effective_hours) + models = conn.execute("SELECT * FROM proxy_models WHERE enabled = 1 ORDER BY sort_order ASC, model_id ASC").fetchall() + result_models: list[dict[str, Any]] = [] + last_updated: str | None = None + for model in models: + rows = conn.execute("SELECT * FROM health_check_records WHERE model_id = ? AND checked_at >= ? ORDER BY checked_at ASC", (model["id"], since.isoformat())).fetchall() + hourly = [] + ok_count = 0 + for row in rows: + status_name = "healthy" if row["ok"] else "down" + hourly.append({"time": row["checked_at"], "status": status_name, "latency": row["latency_ms"]}) + ok_count += 1 if row["ok"] else 0 + last_updated = row["checked_at"] + total = len(rows) + success_rate = round((ok_count / total) * 100, 1) if total else 0.0 + model_status = "unknown" if model["last_health_status"] is None else ("healthy" if model["last_health_status"] else "down") + result_models.append({"id": model["id"], "model_id": model["model_id"], "name": model["display_name"], "display_name": model["display_name"], "endpoint": "/v1/responses", "status": model_status, "beat": f"{success_rate}%", "hourly": hourly, "last_health_status": None if model["last_health_status"] is None else bool(model["last_health_status"]), "last_healthcheck_at": model["last_healthcheck_at"], "success_rate": success_rate, "points": [{"hour": entry["time"], "label": parse_datetime(entry["time"]).strftime("%H:%M") if parse_datetime(entry["time"]) else entry["time"], "ok": entry["status"] == "healthy", "latency_ms": entry["latency"]} for entry in hourly]}) + return {"generated_at": utcnow_iso(), "last_updated": last_updated, "hours": effective_hours, "models": result_models} + finally: + conn.close() + + +def schedule_healthchecks() -> None: + conn = get_db_connection() + try: + settings_payload = get_settings_payload(conn) + finally: + conn.close() + interval = max(5, int(settings_payload["healthcheck_interval_minutes"])) + enabled = bool(settings_payload["healthcheck_enabled"]) + if scheduler.get_job("nim-hourly-healthcheck"): + scheduler.remove_job("nim-hourly-healthcheck") + if enabled: + scheduler.add_job(run_healthchecks, "interval", minutes=interval, id="nim-hourly-healthcheck", replace_existing=True, next_run_time=utcnow() + timedelta(seconds=10)) + + init_db() -@asynccontextmanager -async def lifespan(_app: FastAPI): - init_db() - if not scheduler.running: - scheduler.start() - schedule_healthchecks() - try: - yield - finally: - if scheduler.running: - scheduler.shutdown(wait=False) - - -app = FastAPI(title="NIM 响应网关", lifespan=lifespan) -app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static") - - -def render_html(filename: str) -> HTMLResponse: - content = (STATIC_DIR / filename).read_text(encoding="utf-8") - return HTMLResponse(content=content, media_type="text/html; charset=utf-8") - - -@app.get("/") -async def public_dashboard() -> HTMLResponse: - return render_html("index.html") - - -@app.get("/admin") -async def admin_dashboard() -> HTMLResponse: - return render_html("admin.html") - - -@app.get("/api/health/public") -async def public_health(hours: int | None = None) -> dict[str, Any]: +@asynccontextmanager +async def lifespan(_app: FastAPI): + global http_client, api_key_selection_lock, api_key_rr_index + init_db() + api_key_selection_lock = asyncio.Lock() + api_key_rr_index = 0 + http_client = await get_http_client() + if not scheduler.running: + scheduler.start() + schedule_healthchecks() + try: + yield + finally: + if scheduler.running: + scheduler.shutdown(wait=False) + if http_client is not None and not http_client.is_closed: + await http_client.aclose() + http_client = None + api_key_selection_lock = None + + +app = FastAPI(title="NIM 响应网关", lifespan=lifespan) +app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static") + + +def render_html(filename: str) -> HTMLResponse: + content = (STATIC_DIR / filename).read_text(encoding="utf-8") + return HTMLResponse(content=content, media_type="text/html; charset=utf-8") + + +@app.get("/") +async def public_dashboard() -> HTMLResponse: + return render_html("index.html") + + +@app.get("/admin") +async def admin_dashboard() -> HTMLResponse: + return render_html("admin.html") + + +@app.get("/api/health/public") +async def public_health(hours: int | None = None) -> dict[str, Any]: return build_public_health_payload(hours) - -@app.get("/v1/models") -async def list_models(_: bool = Depends(require_proxy_token_if_configured)) -> dict[str, Any]: - conn = get_db_connection() - try: - rows = conn.execute("SELECT * FROM proxy_models WHERE enabled = 1 ORDER BY sort_order ASC, model_id ASC").fetchall() - data = [{"id": row["model_id"], "object": "model", "created": 0, "owned_by": "nvidia-nim", "display_name": row["display_name"], "status": ("unknown" if row["last_health_status"] is None else ("healthy" if row["last_health_status"] else "down"))} for row in rows] - return {"object": "list", "data": data, "models": data} - finally: - conn.close() - - -@app.get("/v1/responses/{response_id}") -async def get_response(response_id: str, _: bool = Depends(require_proxy_token_if_configured)): - conn = get_db_connection() - try: - row = conn.execute("SELECT output_json FROM response_records WHERE response_id = ?", (response_id,)).fetchone() - if not row: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Response not found.") - return json.loads(row["output_json"]) - finally: - conn.close() - - -@app.post("/v1/responses") -async def create_response(request: Request, _: bool = Depends(require_proxy_token_if_configured)): - body = await request.json() - if not isinstance(body, dict): - return make_error(status.HTTP_400_BAD_REQUEST, "Request body must be a JSON object.") - if not body.get("model"): - return make_error(status.HTTP_400_BAD_REQUEST, "The 'model' field is required.") - if body.get("input") is None: - return make_error(status.HTTP_400_BAD_REQUEST, "The 'input' field is required.") - - conn = get_db_connection() - try: - model_row = fetch_model_by_identifier(conn, body["model"], enabled_only=True) - if not model_row: - return make_error(status.HTTP_404_NOT_FOUND, f"Model '{body['model']}' is not configured or is disabled.") - api_key_row = select_api_key(conn) - previous_items = load_previous_conversation_items(conn, body.get("previous_response_id")) - input_items = normalize_input_items(body.get("input")) - merged_items = previous_items + input_items - chat_payload = build_chat_payload(body, merged_items) - try: - upstream_json, latency_ms = await post_nvidia_chat_completion(api_key_row["api_key"], chat_payload) - except HTTPException as exc: - update_usage_stats(conn, model_row, api_key_row, ok=False, latency_ms=None, is_healthcheck=False) - conn.commit() - raise exc - response_payload = chat_completion_to_response(body, upstream_json, body.get("previous_response_id")) - update_usage_stats(conn, model_row, api_key_row, ok=True, latency_ms=latency_ms, is_healthcheck=False) - store_response_record(conn, response_payload, body, input_items, model_row, api_key_row) - conn.commit() - - if body.get("stream"): - async def event_stream() -> Any: - yield f"event: response.created\ndata: {json_dumps({'type': 'response.created', 'response': {'id': response_payload['id'], 'model': response_payload['model'], 'status': 'in_progress'}})}\n\n" - for index, item in enumerate(response_payload.get("output") or []): - yield f"event: response.output_item.added\ndata: {json_dumps({'type': 'response.output_item.added', 'output_index': index, 'item': item})}\n\n" - if item.get("type") == "message": - text_value = extract_text_from_content(item.get("content")) - if text_value: - yield f"event: response.output_text.delta\ndata: {json_dumps({'type': 'response.output_text.delta', 'output_index': index, 'delta': text_value})}\n\n" - yield f"event: response.output_text.done\ndata: {json_dumps({'type': 'response.output_text.done', 'output_index': index, 'text': text_value})}\n\n" - if item.get("type") == "function_call": - yield f"event: response.function_call_arguments.done\ndata: {json_dumps({'type': 'response.function_call_arguments.done', 'output_index': index, 'arguments': item.get('arguments', '{}'), 'call_id': item.get('call_id')})}\n\n" - yield f"event: response.output_item.done\ndata: {json_dumps({'type': 'response.output_item.done', 'output_index': index, 'item': item})}\n\n" - yield f"event: response.completed\ndata: {json_dumps({'type': 'response.completed', 'response': response_payload})}\n\n" - return StreamingResponse(event_stream(), media_type="text/event-stream") - return response_payload - finally: + +@app.get("/v1/models") +async def list_models() -> dict[str, Any]: + conn = get_db_connection() + try: + rows = conn.execute("SELECT * FROM proxy_models WHERE enabled = 1 ORDER BY sort_order ASC, model_id ASC").fetchall() + data = [{"id": row["model_id"], "object": "model", "created": 0, "owned_by": "nvidia-nim", "display_name": row["display_name"], "status": ("unknown" if row["last_health_status"] is None else ("healthy" if row["last_health_status"] else "down"))} for row in rows] + return {"object": "list", "data": data, "models": data} + finally: + conn.close() + + +@app.get("/v1/responses/{response_id}") +async def get_response(response_id: str, _: bool = Depends(require_proxy_token_if_configured)): + conn = get_db_connection() + try: + row = conn.execute("SELECT output_json FROM response_records WHERE response_id = ?", (response_id,)).fetchone() + if not row: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Response not found.") + return json.loads(row["output_json"]) + finally: + conn.close() + + +@app.post("/v1/responses") +async def create_response(request: Request, _: bool = Depends(require_proxy_token_if_configured)): + body = await request.json() + if not isinstance(body, dict): + return make_error(status.HTTP_400_BAD_REQUEST, "?????? JSON ???") + if not body.get("model"): + return make_error(status.HTTP_400_BAD_REQUEST, "?? model ???") + if body.get("input") is None: + return make_error(status.HTTP_400_BAD_REQUEST, "?? input ???") + + conn = get_db_connection() + try: + model_row = fetch_model_by_identifier(conn, body["model"], enabled_only=True) + if not model_row: + return make_error(status.HTTP_404_NOT_FOUND, f"?? {body['model']} ????????") + api_key_row = await select_api_key(conn) + previous_items = load_previous_conversation_items(conn, body.get("previous_response_id")) + input_items = normalize_input_items(body.get("input")) + merged_items = previous_items + input_items + chat_payload = build_chat_payload(body, merged_items) + try: + upstream_json, latency_ms = await post_nvidia_chat_completion(api_key_row["api_key"], chat_payload) + except HTTPException as exc: + update_usage_stats(conn, model_row, api_key_row, ok=False, latency_ms=None, is_healthcheck=False) + conn.commit() + raise exc + response_payload = chat_completion_to_response(body, upstream_json, body.get("previous_response_id")) + update_usage_stats(conn, model_row, api_key_row, ok=True, latency_ms=latency_ms, is_healthcheck=False) + store_response_record(conn, response_payload, body, input_items, model_row, api_key_row) + conn.commit() + + if body.get("stream"): + async def event_stream() -> Any: + yield f"event: response.created\ndata: {json_dumps({'type': 'response.created', 'response': {'id': response_payload['id'], 'model': response_payload['model'], 'status': 'in_progress'}})}\n\n" + for index, item in enumerate(response_payload.get("output") or []): + yield f"event: response.output_item.added\ndata: {json_dumps({'type': 'response.output_item.added', 'output_index': index, 'item': item})}\n\n" + if item.get("type") == "message": + text_value = extract_text_from_content(item.get("content")) + if text_value: + yield f"event: response.output_text.delta\ndata: {json_dumps({'type': 'response.output_text.delta', 'output_index': index, 'delta': text_value})}\n\n" + yield f"event: response.output_text.done\ndata: {json_dumps({'type': 'response.output_text.done', 'output_index': index, 'text': text_value})}\n\n" + if item.get("type") == "function_call": + yield f"event: response.function_call_arguments.done\ndata: {json_dumps({'type': 'response.function_call_arguments.done', 'output_index': index, 'arguments': item.get('arguments', '{}'), 'call_id': item.get('call_id')})}\n\n" + yield f"event: response.output_item.done\ndata: {json_dumps({'type': 'response.output_item.done', 'output_index': index, 'item': item})}\n\n" + yield f"event: response.completed\ndata: {json_dumps({'type': 'response.completed', 'response': response_payload})}\n\n" + return StreamingResponse(event_stream(), media_type="text/event-stream") + return response_payload + finally: + conn.close() + +@app.post("/admin/api/login") +async def admin_login(request: Request, response: Response): + if not ADMIN_PASSWORD: + raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="尚未配置 PASSWORD 环境变量。") + body = await request.json() + password = body.get("password") if isinstance(body, dict) else None + if password != ADMIN_PASSWORD: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="密码错误。") + token = create_admin_token() + response.set_cookie(COOKIE_NAME, token, httponly=True, samesite="lax", secure=False, max_age=60 * 60 * 24 * 7) + return {"token": token, "access_token": token, "token_type": "bearer"} + + +@app.post("/admin/api/logout") +async def admin_logout(response: Response, _: bool = Depends(require_admin)): + response.delete_cookie(COOKIE_NAME) + return {"message": "已退出登录。"} + + +@app.get("/admin/api/session") +async def admin_session(_: bool = Depends(require_admin)): + return {"ok": True} + + +@app.get("/admin/api/overview") +async def admin_overview(_: bool = Depends(require_admin)): + conn = get_db_connection() + try: + total_models = conn.execute("SELECT COUNT(*) AS count FROM proxy_models").fetchone()["count"] + enabled_models = conn.execute("SELECT COUNT(*) AS count FROM proxy_models WHERE enabled = 1").fetchone()["count"] + total_keys = conn.execute("SELECT COUNT(*) AS count FROM api_keys").fetchone()["count"] + enabled_keys = conn.execute("SELECT COUNT(*) AS count FROM api_keys WHERE enabled = 1").fetchone()["count"] + usage = conn.execute("SELECT COALESCE(SUM(request_count), 0) AS total_requests, COALESCE(SUM(success_count), 0) AS total_success, COALESCE(SUM(failure_count), 0) AS total_failures FROM proxy_models").fetchone() + recent_rows = conn.execute("SELECT h.checked_at, h.ok, h.latency_ms, m.model_id FROM health_check_records h JOIN proxy_models m ON m.id = h.model_id ORDER BY h.checked_at DESC LIMIT 8").fetchall() + return { + "metrics": [ + {"label": "Enabled Models", "value": enabled_models}, + {"label": "Enabled Keys", "value": enabled_keys}, + {"label": "Proxy Requests", "value": usage["total_requests"]}, + {"label": "Failures", "value": usage["total_failures"]}, + ], + "recent_checks": [{"time": row["checked_at"], "model": row["model_id"], "status": "healthy" if row["ok"] else "down", "latency": row["latency_ms"]} for row in recent_rows], + "totals": { + "total_models": total_models, + "enabled_models": enabled_models, + "total_keys": total_keys, + "enabled_keys": enabled_keys, + "total_requests": usage["total_requests"], + "total_success": usage["total_success"], + "total_failures": usage["total_failures"], + }, + } + finally: + conn.close() + + +@app.get("/admin/api/models") +async def admin_models(_: bool = Depends(require_admin)): + conn = get_db_connection() + try: + rows = conn.execute("SELECT * FROM proxy_models ORDER BY sort_order ASC, model_id ASC").fetchall() + return {"items": [row_to_model_item(row) for row in rows]} + finally: + conn.close() + + +@app.get("/admin/api/models/usage") +async def admin_models_usage(_: bool = Depends(require_admin)): + conn = get_db_connection() + try: + rows = conn.execute("SELECT * FROM proxy_models ORDER BY request_count DESC, model_id ASC").fetchall() + return {"items": [row_to_model_item(row) for row in rows]} + finally: + conn.close() + + +@app.post("/admin/api/models") +async def admin_add_model(request: Request, _: bool = Depends(require_admin)): + body = await request.json() + model_id = (body.get("model_id") or body.get("name") or "").strip() + display_name = (body.get("display_name") or model_id).strip() + if not model_id: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="model_id is required.") + conn = get_db_connection() + try: + now = utcnow_iso() + conn.execute( + """ + INSERT INTO proxy_models (model_id, display_name, provider, description, enabled, featured, sort_order, created_at, updated_at) + VALUES (?, ?, 'nvidia-nim', ?, ?, ?, ?, ?, ?) + ON CONFLICT(model_id) DO UPDATE SET + display_name = excluded.display_name, + description = excluded.description, + enabled = excluded.enabled, + featured = excluded.featured, + sort_order = excluded.sort_order, + updated_at = excluded.updated_at + """, + (model_id, display_name, body.get("description"), 1 if body.get("enabled", True) else 0, 1 if body.get("featured", False) else 0, int(body.get("sort_order", 0)), now, now), + ) + conn.commit() + row = fetch_model_by_identifier(conn, model_id) + return {"item": row_to_model_item(row)} + finally: conn.close() - -@app.post("/admin/api/login") -async def admin_login(request: Request, response: Response): - if not ADMIN_PASSWORD: - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="尚未配置 PASSWORD 环境变量。") - body = await request.json() - password = body.get("password") if isinstance(body, dict) else None - if password != ADMIN_PASSWORD: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="密码错误。") - token = create_admin_token() - response.set_cookie(COOKIE_NAME, token, httponly=True, samesite="lax", secure=False, max_age=60 * 60 * 24 * 7) - return {"token": token, "access_token": token, "token_type": "bearer"} - - -@app.post("/admin/api/logout") -async def admin_logout(response: Response, _: bool = Depends(require_admin)): - response.delete_cookie(COOKIE_NAME) - return {"message": "已退出登录。"} - - -@app.get("/admin/api/session") -async def admin_session(_: bool = Depends(require_admin)): - return {"ok": True} - - -@app.get("/admin/api/overview") -async def admin_overview(_: bool = Depends(require_admin)): - conn = get_db_connection() - try: - total_models = conn.execute("SELECT COUNT(*) AS count FROM proxy_models").fetchone()["count"] - enabled_models = conn.execute("SELECT COUNT(*) AS count FROM proxy_models WHERE enabled = 1").fetchone()["count"] - total_keys = conn.execute("SELECT COUNT(*) AS count FROM api_keys").fetchone()["count"] - enabled_keys = conn.execute("SELECT COUNT(*) AS count FROM api_keys WHERE enabled = 1").fetchone()["count"] - usage = conn.execute("SELECT COALESCE(SUM(request_count), 0) AS total_requests, COALESCE(SUM(success_count), 0) AS total_success, COALESCE(SUM(failure_count), 0) AS total_failures FROM proxy_models").fetchone() - recent_rows = conn.execute("SELECT h.checked_at, h.ok, h.latency_ms, m.model_id FROM health_check_records h JOIN proxy_models m ON m.id = h.model_id ORDER BY h.checked_at DESC LIMIT 8").fetchall() - return { - "metrics": [ - {"label": "Enabled Models", "value": enabled_models}, - {"label": "Enabled Keys", "value": enabled_keys}, - {"label": "Proxy Requests", "value": usage["total_requests"]}, - {"label": "Failures", "value": usage["total_failures"]}, - ], - "recent_checks": [{"time": row["checked_at"], "model": row["model_id"], "status": "healthy" if row["ok"] else "down", "latency": row["latency_ms"]} for row in recent_rows], - "totals": { - "total_models": total_models, - "enabled_models": enabled_models, - "total_keys": total_keys, - "enabled_keys": enabled_keys, - "total_requests": usage["total_requests"], - "total_success": usage["total_success"], - "total_failures": usage["total_failures"], - }, - } - finally: - conn.close() - - -@app.get("/admin/api/models") -async def admin_models(_: bool = Depends(require_admin)): - conn = get_db_connection() - try: - rows = conn.execute("SELECT * FROM proxy_models ORDER BY sort_order ASC, model_id ASC").fetchall() - return {"items": [row_to_model_item(row) for row in rows]} - finally: - conn.close() - - -@app.get("/admin/api/models/usage") -async def admin_models_usage(_: bool = Depends(require_admin)): - conn = get_db_connection() - try: - rows = conn.execute("SELECT * FROM proxy_models ORDER BY request_count DESC, model_id ASC").fetchall() - return {"items": [row_to_model_item(row) for row in rows]} - finally: - conn.close() - - -@app.post("/admin/api/models") -async def admin_add_model(request: Request, _: bool = Depends(require_admin)): - body = await request.json() - model_id = (body.get("model_id") or body.get("name") or "").strip() - display_name = (body.get("display_name") or model_id).strip() - if not model_id: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="model_id is required.") - conn = get_db_connection() - try: - now = utcnow_iso() - conn.execute( - """ - INSERT INTO proxy_models (model_id, display_name, provider, description, enabled, featured, sort_order, created_at, updated_at) - VALUES (?, ?, 'nvidia-nim', ?, ?, ?, ?, ?, ?) - ON CONFLICT(model_id) DO UPDATE SET - display_name = excluded.display_name, - description = excluded.description, - enabled = excluded.enabled, - featured = excluded.featured, - sort_order = excluded.sort_order, - updated_at = excluded.updated_at - """, - (model_id, display_name, body.get("description"), 1 if body.get("enabled", True) else 0, 1 if body.get("featured", False) else 0, int(body.get("sort_order", 0)), now, now), - ) - conn.commit() - row = fetch_model_by_identifier(conn, model_id) - return {"item": row_to_model_item(row)} - finally: - conn.close() - - -def delete_model_internal(model_identifier: str) -> dict[str, Any]: - conn = get_db_connection() - try: - row = fetch_model_by_identifier(conn, model_identifier) - if not row: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="未找到模型。") - conn.execute("DELETE FROM proxy_models WHERE id = ?", (row["id"],)) - conn.commit() - return {"message": "Model deleted."} - finally: - conn.close() - - -@app.delete("/admin/api/models/{model_identifier}") -async def admin_delete_model(model_identifier: str, _: bool = Depends(require_admin)): - return delete_model_internal(model_identifier) - - -@app.post("/admin/api/models/remove") -async def admin_remove_model_alias(request: Request, _: bool = Depends(require_admin)): - body = await request.json() - value = body.get("value") if isinstance(body, dict) else None - if not value: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="value is required.") - return delete_model_internal(str(value)) - - -async def test_model_internal(model_identifier: str, payload: dict[str, Any] | None = None) -> dict[str, Any]: - conn = get_db_connection() - try: - row = fetch_model_by_identifier(conn, model_identifier, enabled_only=True) - if not row: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="未找到模型。") - api_key_row = select_api_key(conn, payload.get("api_key_id") if payload else None) - return await perform_healthcheck(conn, row, api_key_row, (payload or {}).get("prompt") or DEFAULT_HEALTH_PROMPT) - finally: - conn.close() - - -@app.post("/admin/api/models/test") -async def admin_test_model_alias(request: Request, _: bool = Depends(require_admin)): - body = await request.json() - identifier = body.get("value") or body.get("model_id") - if not identifier: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="value is required.") - return await test_model_internal(str(identifier), body) - - -@app.post("/admin/api/models/{model_identifier}/test") -async def admin_test_model(model_identifier: str, request: Request, _: bool = Depends(require_admin)): - body = await request.json() if request.method == "POST" else {} + + +def delete_model_internal(model_identifier: str) -> dict[str, Any]: + conn = get_db_connection() + try: + row = fetch_model_by_identifier(conn, model_identifier) + if not row: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="未找到模型。") + conn.execute("DELETE FROM proxy_models WHERE id = ?", (row["id"],)) + conn.commit() + return {"message": "??????"} + finally: + conn.close() + + +@app.delete("/admin/api/models/{model_identifier}") +async def admin_delete_model(model_identifier: str, _: bool = Depends(require_admin)): + return delete_model_internal(model_identifier) + + +@app.post("/admin/api/models/remove") +async def admin_remove_model_alias(request: Request, _: bool = Depends(require_admin)): + body = await request.json() + value = body.get("value") if isinstance(body, dict) else None + if not value: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="value is required.") + return delete_model_internal(str(value)) + + +async def test_model_internal(model_identifier: str, payload: dict[str, Any] | None = None) -> dict[str, Any]: + conn = get_db_connection() + try: + row = fetch_model_by_identifier(conn, model_identifier, enabled_only=True) + if not row: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="未找到模型。") + api_key_row = await select_api_key(conn, payload.get("api_key_id") if payload else None) + return await perform_healthcheck(conn, row, api_key_row, (payload or {}).get("prompt") or DEFAULT_HEALTH_PROMPT) + finally: + conn.close() + + +@app.post("/admin/api/models/test") +async def admin_test_model_alias(request: Request, _: bool = Depends(require_admin)): + body = await request.json() + identifier = body.get("value") or body.get("model_id") + if not identifier: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="value is required.") + return await test_model_internal(str(identifier), body) + + +@app.post("/admin/api/models/{model_identifier}/test") +async def admin_test_model(model_identifier: str, request: Request, _: bool = Depends(require_admin)): + body = await request.json() if request.method == "POST" else {} return await test_model_internal(model_identifier, body) - -@app.get("/admin/api/keys") -async def admin_keys(_: bool = Depends(require_admin)): - conn = get_db_connection() - try: - rows = conn.execute("SELECT * FROM api_keys ORDER BY id ASC").fetchall() - return {"items": [row_to_key_item(row) for row in rows]} - finally: - conn.close() - - -@app.get("/admin/api/keys/usage") -async def admin_keys_usage(_: bool = Depends(require_admin)): - conn = get_db_connection() - try: - rows = conn.execute("SELECT * FROM api_keys ORDER BY request_count DESC, id ASC").fetchall() - return {"items": [row_to_key_item(row) for row in rows]} - finally: - conn.close() - - -@app.post("/admin/api/keys") -async def admin_add_key(request: Request, _: bool = Depends(require_admin)): - body = await request.json() - name = (body.get("name") or body.get("label") or "").strip() - api_key = (body.get("api_key") or body.get("key") or "").strip() - if not name or not api_key: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Both name and api_key are required.") - conn = get_db_connection() - try: - now = utcnow_iso() - conn.execute( - """ - INSERT INTO api_keys (name, api_key, enabled, created_at, updated_at) - VALUES (?, ?, ?, ?, ?) - ON CONFLICT(name) DO UPDATE SET api_key = excluded.api_key, enabled = excluded.enabled, updated_at = excluded.updated_at - """, - (name, api_key, 1 if body.get("enabled", True) else 0, now, now), - ) - conn.commit() - row = fetch_key_by_identifier(conn, name) - return {"item": row_to_key_item(row)} - finally: - conn.close() - - -def delete_key_internal(key_identifier: str) -> dict[str, Any]: - conn = get_db_connection() - try: - row = fetch_key_by_identifier(conn, key_identifier) - if not row: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="未找到 API Key。") - conn.execute("DELETE FROM api_keys WHERE id = ?", (row["id"],)) - conn.commit() - return {"message": "API key deleted."} - finally: - conn.close() - - -@app.delete("/admin/api/keys/{key_identifier}") -async def admin_delete_key(key_identifier: str, _: bool = Depends(require_admin)): - return delete_key_internal(key_identifier) - - -@app.post("/admin/api/keys/remove") -async def admin_remove_key_alias(request: Request, _: bool = Depends(require_admin)): - body = await request.json() - value = body.get("value") if isinstance(body, dict) else None - if not value: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="value is required.") - return delete_key_internal(str(value)) - - -async def test_key_internal(key_identifier: str, payload: dict[str, Any] | None = None) -> dict[str, Any]: - conn = get_db_connection() - try: - key_row = fetch_key_by_identifier(conn, key_identifier, enabled_only=True) - if not key_row: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="未找到 API Key。") - model_identifier = (payload or {}).get("model_id") or DEFAULT_MODELS[0][0] - model_row = fetch_model_by_identifier(conn, model_identifier, enabled_only=True) - if not model_row: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="未找到模型。") - return await perform_healthcheck(conn, model_row, key_row, (payload or {}).get("prompt") or DEFAULT_HEALTH_PROMPT) - finally: - conn.close() - - -@app.post("/admin/api/keys/test") -async def admin_test_key_alias(request: Request, _: bool = Depends(require_admin)): - body = await request.json() - identifier = body.get("value") or body.get("name") or body.get("label") - if not identifier: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="value is required.") - return await test_key_internal(str(identifier), body) - - -@app.post("/admin/api/keys/{key_identifier}/test") -async def admin_test_key(key_identifier: str, request: Request, _: bool = Depends(require_admin)): - body = await request.json() if request.method == "POST" else {} - return await test_key_internal(key_identifier, body) - - -@app.get("/admin/api/healthchecks") -async def admin_healthchecks(hours: int = 48, _: bool = Depends(require_admin)): - conn = get_db_connection() - try: - since = utcnow() - timedelta(hours=hours) - rows = conn.execute( - """ - SELECT h.*, m.model_id, m.display_name, k.name AS key_name - FROM health_check_records h - JOIN proxy_models m ON m.id = h.model_id - LEFT JOIN api_keys k ON k.id = h.api_key_id - WHERE h.checked_at >= ? - ORDER BY h.checked_at DESC - LIMIT 200 - """, - (since.isoformat(),), - ).fetchall() - items = [{"id": row["id"], "model": row["display_name"], "model_id": row["model_id"], "api_key": row["key_name"], "status": "healthy" if row["ok"] else "down", "detail": row["response_excerpt"] or row["error_message"] or "暂无详情。", "latency": row["latency_ms"], "status_code": row["status_code"], "checked_at": row["checked_at"]} for row in rows] - return {"items": items} - finally: - conn.close() - - -@app.post("/admin/api/healthchecks/run") -async def admin_run_healthchecks(request: Request, _: bool = Depends(require_admin)): - body = await request.json() if request.method == "POST" else {} - results = await run_healthchecks(model_identifier=body.get("model_id") or body.get("model"), api_key_identifier=body.get("api_key_id") or body.get("key_id"), prompt=body.get("prompt")) - return {"items": results, "results": results} - - -@app.get("/admin/api/settings") -async def admin_settings(_: bool = Depends(require_admin)): - conn = get_db_connection() - try: - return get_settings_payload(conn) - finally: - conn.close() - - -@app.put("/admin/api/settings") -async def admin_update_settings(request: Request, _: bool = Depends(require_admin)): - body = await request.json() - conn = get_db_connection() - try: - set_setting(conn, "healthcheck_enabled", "true" if body.get("healthcheck_enabled", True) else "false") - set_setting(conn, "healthcheck_interval_minutes", str(max(5, int(body.get("healthcheck_interval_minutes", DEFAULT_HEALTH_INTERVAL_MINUTES))))) - set_setting(conn, "healthcheck_prompt", body.get("healthcheck_prompt") or DEFAULT_HEALTH_PROMPT) - if body.get("public_history_hours"): - set_setting(conn, "public_history_hours", str(max(1, int(body.get("public_history_hours"))))) - conn.commit() - finally: - conn.close() - schedule_healthchecks() - conn = get_db_connection() - try: - return get_settings_payload(conn) - finally: + +@app.get("/admin/api/keys") +async def admin_keys(_: bool = Depends(require_admin)): + conn = get_db_connection() + try: + rows = conn.execute("SELECT * FROM api_keys ORDER BY id ASC").fetchall() + return {"items": [row_to_key_item(row) for row in rows]} + finally: + conn.close() + + +@app.get("/admin/api/keys/usage") +async def admin_keys_usage(_: bool = Depends(require_admin)): + conn = get_db_connection() + try: + rows = conn.execute("SELECT * FROM api_keys ORDER BY request_count DESC, id ASC").fetchall() + return {"items": [row_to_key_item(row) for row in rows]} + finally: + conn.close() + + +@app.post("/admin/api/keys") +async def admin_add_key(request: Request, _: bool = Depends(require_admin)): + body = await request.json() + name = (body.get("name") or body.get("label") or "").strip() + api_key = (body.get("api_key") or body.get("key") or "").strip() + if not name or not api_key: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Both name and api_key are required.") + conn = get_db_connection() + try: + now = utcnow_iso() + conn.execute( + """ + INSERT INTO api_keys (name, api_key, enabled, created_at, updated_at) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT(name) DO UPDATE SET api_key = excluded.api_key, enabled = excluded.enabled, updated_at = excluded.updated_at + """, + (name, api_key, 1 if body.get("enabled", True) else 0, now, now), + ) + conn.commit() + row = fetch_key_by_identifier(conn, name) + return {"item": row_to_key_item(row)} + finally: + conn.close() + + +def delete_key_internal(key_identifier: str) -> dict[str, Any]: + conn = get_db_connection() + try: + row = fetch_key_by_identifier(conn, key_identifier) + if not row: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="??? API Key?") + conn.execute("DELETE FROM api_keys WHERE id = ?", (row["id"],)) + conn.commit() + return {"message": "API Key ????"} + finally: + conn.close() + + +@app.delete("/admin/api/keys/{key_identifier}") +async def admin_delete_key(key_identifier: str, _: bool = Depends(require_admin)): + return delete_key_internal(key_identifier) + + +@app.post("/admin/api/keys/remove") +async def admin_remove_key_alias(request: Request, _: bool = Depends(require_admin)): + body = await request.json() + value = body.get("value") if isinstance(body, dict) else None + if not value: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="value is required.") + return delete_key_internal(str(value)) + + +async def test_key_internal(key_identifier: str, payload: dict[str, Any] | None = None) -> dict[str, Any]: + conn = get_db_connection() + try: + key_row = fetch_key_by_identifier(conn, key_identifier, enabled_only=True) + if not key_row: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="??? API Key?") + model_identifier = (payload or {}).get("model_id") or DEFAULT_MODELS[0][0] + model_row = fetch_model_by_identifier(conn, model_identifier, enabled_only=True) + if not model_row: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="??????") + return await perform_healthcheck(conn, model_row, key_row, (payload or {}).get("prompt") or DEFAULT_HEALTH_PROMPT) + finally: + conn.close() + + +async def test_all_keys_internal(payload: dict[str, Any] | None = None) -> list[dict[str, Any]]: + conn = get_db_connection() + try: + key_rows = conn.execute("SELECT * FROM api_keys WHERE enabled = 1 ORDER BY id ASC").fetchall() + if not key_rows: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="??????? API Key?") + model_identifier = (payload or {}).get("model_id") or DEFAULT_MODELS[0][0] + model_row = fetch_model_by_identifier(conn, model_identifier, enabled_only=True) + if not model_row: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="??????") + prompt = (payload or {}).get("prompt") or DEFAULT_HEALTH_PROMPT + results: list[dict[str, Any]] = [] + for key_row in key_rows: + results.append(await perform_healthcheck(conn, model_row, key_row, prompt)) + return results + finally: + conn.close() + + +@app.post("/admin/api/keys/test") +async def admin_test_key_alias(request: Request, _: bool = Depends(require_admin)): + body = await request.json() + identifier = body.get("value") or body.get("name") or body.get("label") + if not identifier: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="value is required.") + return await test_key_internal(str(identifier), body) + + +@app.post("/admin/api/keys/test-all") +async def admin_test_all_keys(request: Request, _: bool = Depends(require_admin)): + body = await request.json() if request.method == "POST" else {} + results = await test_all_keys_internal(body) + return {"items": results, "results": results} + + +@app.post("/admin/api/keys/{key_identifier}/test") +async def admin_test_key(key_identifier: str, request: Request, _: bool = Depends(require_admin)): + body = await request.json() if request.method == "POST" else {} + return await test_key_internal(key_identifier, body) + + +@app.get("/admin/api/healthchecks") +async def admin_healthchecks(hours: int = 48, _: bool = Depends(require_admin)): + conn = get_db_connection() + try: + since = utcnow() - timedelta(hours=hours) + rows = conn.execute( + """ + SELECT h.*, m.model_id, m.display_name, k.name AS key_name + FROM health_check_records h + JOIN proxy_models m ON m.id = h.model_id + LEFT JOIN api_keys k ON k.id = h.api_key_id + WHERE h.checked_at >= ? + ORDER BY h.checked_at DESC + LIMIT 200 + """, + (since.isoformat(),), + ).fetchall() + items = [{"id": row["id"], "model": row["display_name"], "model_id": row["model_id"], "api_key": row["key_name"], "status": "healthy" if row["ok"] else "down", "detail": row["response_excerpt"] or row["error_message"] or "暂无详情。", "latency": row["latency_ms"], "status_code": row["status_code"], "checked_at": row["checked_at"]} for row in rows] + return {"items": items} + finally: + conn.close() + + +@app.post("/admin/api/healthchecks/run") +async def admin_run_healthchecks(request: Request, _: bool = Depends(require_admin)): + body = await request.json() if request.method == "POST" else {} + results = await run_healthchecks(model_identifier=body.get("model_id") or body.get("model"), api_key_identifier=body.get("api_key_id") or body.get("key_id"), prompt=body.get("prompt")) + return {"items": results, "results": results} + + +@app.get("/admin/api/settings") +async def admin_settings(_: bool = Depends(require_admin)): + conn = get_db_connection() + try: + return get_settings_payload(conn) + finally: + conn.close() + + +@app.put("/admin/api/settings") +async def admin_update_settings(request: Request, _: bool = Depends(require_admin)): + body = await request.json() + conn = get_db_connection() + try: + set_setting(conn, "healthcheck_enabled", "true" if body.get("healthcheck_enabled", True) else "false") + set_setting(conn, "healthcheck_interval_minutes", str(max(5, int(body.get("healthcheck_interval_minutes", DEFAULT_HEALTH_INTERVAL_MINUTES))))) + set_setting(conn, "healthcheck_prompt", body.get("healthcheck_prompt") or DEFAULT_HEALTH_PROMPT) + if body.get("public_history_hours"): + set_setting(conn, "public_history_hours", str(max(1, int(body.get("public_history_hours"))))) + conn.commit() + finally: + conn.close() + schedule_healthchecks() + conn = get_db_connection() + try: + return get_settings_payload(conn) + finally: conn.close()