Spaces:
Sleeping
Sleeping
| """ | |
| API Key Manager - 多Key随机管理器 | |
| 支持线程安全的Key随机选择,提供良好的负载分散 | |
| """ | |
| import os | |
| import random | |
| import threading | |
| from typing import List, Optional | |
| from dataclasses import dataclass, field | |
| from collections import defaultdict | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class KeyStats: | |
| """单个Key的统计信息""" | |
| key_hash: str # Key的hash(用于日志,不泄露真实Key) | |
| total_requests: int = 0 | |
| success_count: int = 0 | |
| error_count: int = 0 | |
| last_used: Optional[float] = None | |
| last_error: Optional[str] = None | |
| def success_rate(self) -> float: | |
| """成功率""" | |
| if self.total_requests == 0: | |
| return 1.0 | |
| return self.success_count / self.total_requests | |
| class ApiKeyManager: | |
| """API Key 管理器 - 支持多Key随机选择""" | |
| def __init__(self, api_keys: Optional[str | List[str]] = None): | |
| """ | |
| 初始化API Key管理器 | |
| Args: | |
| api_keys: 可以是单个key、逗号分隔的字符串,或key列表 | |
| """ | |
| self._keys: List[str] = [] | |
| self._lock = threading.Lock() # 线程安全锁 | |
| self._stats: dict[str, KeyStats] = {} # key_hash -> KeyStats | |
| self._key_to_hash: dict[str, str] = {} # key -> key_hash | |
| # 加载keys | |
| if api_keys: | |
| self._load_keys(api_keys) | |
| else: | |
| # 从环境变量加载 | |
| env_keys = os.getenv("OPENAI_API_KEY", "") | |
| if env_keys: | |
| self._load_keys(env_keys) | |
| if not self._keys: | |
| logger.warning("未配置任何API Key") | |
| logger.info(f"API Key管理器初始化完成,共加载 {len(self._keys)} 个Key") | |
| def _load_keys(self, keys: str | List[str]): | |
| """加载API Keys""" | |
| if isinstance(keys, str): | |
| # 支持逗号分隔的多个key | |
| keys_list = [k.strip() for k in keys.split(",") if k.strip()] | |
| else: | |
| keys_list = keys | |
| # 去重 | |
| seen = set() | |
| unique_keys = [] | |
| for key in keys_list: | |
| if key and key not in seen: | |
| seen.add(key) | |
| unique_keys.append(key) | |
| self._keys = unique_keys | |
| # 初始化统计信息 | |
| import time | |
| import hashlib | |
| for key in self._keys: | |
| key_hash = hashlib.sha256(key.encode()).hexdigest()[:8] | |
| self._key_to_hash[key] = key_hash | |
| self._stats[key_hash] = KeyStats(key_hash=key_hash) | |
| def key_count(self) -> int: | |
| """获取Key数量""" | |
| return len(self._keys) | |
| def has_keys(self) -> bool: | |
| """是否有可用的Key""" | |
| return len(self._keys) > 0 | |
| def get_key(self) -> Optional[str]: | |
| """ | |
| 获取下一个可用的API Key(随机选择) | |
| Returns: | |
| API Key字符串,如果没有可用Key返回None | |
| """ | |
| if not self._keys: | |
| return None | |
| with self._lock: | |
| # 随机选择一个 key | |
| key = random.choice(self._keys) | |
| # 更新统计 | |
| import time | |
| key_hash = self._key_to_hash[key] | |
| stats = self._stats[key_hash] | |
| stats.total_requests += 1 | |
| stats.last_used = time.time() | |
| return key | |
| def record_success(self, key: str): | |
| """记录请求成功""" | |
| if key not in self._key_to_hash: | |
| return | |
| with self._lock: | |
| key_hash = self._key_to_hash[key] | |
| self._stats[key_hash].success_count += 1 | |
| def record_error(self, key: str, error: str): | |
| """记录请求失败""" | |
| if key not in self._key_to_hash: | |
| return | |
| with self._lock: | |
| key_hash = self._key_to_hash[key] | |
| stats = self._stats[key_hash] | |
| stats.error_count += 1 | |
| stats.last_error = error[:200] # 限制错误信息长度 | |
| def get_stats(self) -> dict: | |
| """获取所有Key的统计信息""" | |
| with self._lock: | |
| return { | |
| "total_keys": len(self._keys), | |
| "keys": [ | |
| { | |
| "key_hash": stats.key_hash, | |
| "total_requests": stats.total_requests, | |
| "success_count": stats.success_count, | |
| "error_count": stats.error_count, | |
| "success_rate": f"{stats.success_rate:.2%}", | |
| "last_used": stats.last_used, | |
| "last_error": stats.last_error, | |
| } | |
| for stats in self._stats.values() | |
| ], | |
| } | |
| def reset_stats(self): | |
| """重置统计信息""" | |
| with self._lock: | |
| for stats in self._stats.values(): | |
| stats.total_requests = 0 | |
| stats.success_count = 0 | |
| stats.error_count = 0 | |
| stats.last_used = None | |
| stats.last_error = None | |
| # 全局单例实例 | |
| _global_manager: Optional[ApiKeyManager] = None | |
| _global_lock = threading.Lock() | |
| def get_global_manager() -> ApiKeyManager: | |
| """获取全局单例""" | |
| if _global_manager is None: | |
| with _global_lock: | |
| if _global_manager is None: | |
| _global_manager = ApiKeyManager() | |
| return _global_manager | |
| def init_manager(api_keys: Optional[str | List[str]] = None) -> ApiKeyManager: | |
| """初始化全局管理器""" | |
| global _global_manager | |
| with _global_lock: | |
| _global_manager = ApiKeyManager(api_keys) | |
| return _global_manager | |