""" API Key Manager - 多Key随机管理器 支持线程安全的Key随机选择,提供良好的负载分散 """ import os import random import threading from typing import List, Optional from dataclasses import dataclass, field from collections import defaultdict import logging logger = logging.getLogger(__name__) @dataclass class KeyStats: """单个Key的统计信息""" key_hash: str # Key的hash(用于日志,不泄露真实Key) total_requests: int = 0 success_count: int = 0 error_count: int = 0 last_used: Optional[float] = None last_error: Optional[str] = None @property def success_rate(self) -> float: """成功率""" if self.total_requests == 0: return 1.0 return self.success_count / self.total_requests class ApiKeyManager: """API Key 管理器 - 支持多Key随机选择""" def __init__(self, api_keys: Optional[str | List[str]] = None): """ 初始化API Key管理器 Args: api_keys: 可以是单个key、逗号分隔的字符串,或key列表 """ self._keys: List[str] = [] self._lock = threading.Lock() # 线程安全锁 self._stats: dict[str, KeyStats] = {} # key_hash -> KeyStats self._key_to_hash: dict[str, str] = {} # key -> key_hash # 加载keys if api_keys: self._load_keys(api_keys) else: # 从环境变量加载 env_keys = os.getenv("OPENAI_API_KEY", "") if env_keys: self._load_keys(env_keys) if not self._keys: logger.warning("未配置任何API Key") logger.info(f"API Key管理器初始化完成,共加载 {len(self._keys)} 个Key") def _load_keys(self, keys: str | List[str]): """加载API Keys""" if isinstance(keys, str): # 支持逗号分隔的多个key keys_list = [k.strip() for k in keys.split(",") if k.strip()] else: keys_list = keys # 去重 seen = set() unique_keys = [] for key in keys_list: if key and key not in seen: seen.add(key) unique_keys.append(key) self._keys = unique_keys # 初始化统计信息 import time import hashlib for key in self._keys: key_hash = hashlib.sha256(key.encode()).hexdigest()[:8] self._key_to_hash[key] = key_hash self._stats[key_hash] = KeyStats(key_hash=key_hash) @property def key_count(self) -> int: """获取Key数量""" return len(self._keys) @property def has_keys(self) -> bool: """是否有可用的Key""" return len(self._keys) > 0 def get_key(self) -> Optional[str]: """ 获取下一个可用的API Key(随机选择) Returns: API Key字符串,如果没有可用Key返回None """ if not self._keys: return None with self._lock: # 随机选择一个 key key = random.choice(self._keys) # 更新统计 import time key_hash = self._key_to_hash[key] stats = self._stats[key_hash] stats.total_requests += 1 stats.last_used = time.time() return key def record_success(self, key: str): """记录请求成功""" if key not in self._key_to_hash: return with self._lock: key_hash = self._key_to_hash[key] self._stats[key_hash].success_count += 1 def record_error(self, key: str, error: str): """记录请求失败""" if key not in self._key_to_hash: return with self._lock: key_hash = self._key_to_hash[key] stats = self._stats[key_hash] stats.error_count += 1 stats.last_error = error[:200] # 限制错误信息长度 def get_stats(self) -> dict: """获取所有Key的统计信息""" with self._lock: return { "total_keys": len(self._keys), "keys": [ { "key_hash": stats.key_hash, "total_requests": stats.total_requests, "success_count": stats.success_count, "error_count": stats.error_count, "success_rate": f"{stats.success_rate:.2%}", "last_used": stats.last_used, "last_error": stats.last_error, } for stats in self._stats.values() ], } def reset_stats(self): """重置统计信息""" with self._lock: for stats in self._stats.values(): stats.total_requests = 0 stats.success_count = 0 stats.error_count = 0 stats.last_used = None stats.last_error = None # 全局单例实例 _global_manager: Optional[ApiKeyManager] = None _global_lock = threading.Lock() def get_global_manager() -> ApiKeyManager: """获取全局单例""" if _global_manager is None: with _global_lock: if _global_manager is None: _global_manager = ApiKeyManager() return _global_manager def init_manager(api_keys: Optional[str | List[str]] = None) -> ApiKeyManager: """初始化全局管理器""" global _global_manager with _global_lock: _global_manager = ApiKeyManager(api_keys) return _global_manager