Spaces:
Sleeping
Sleeping
File size: 5,878 Bytes
d347708 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 | """
API Key Manager - 多Key随机管理器
支持线程安全的Key随机选择,提供良好的负载分散
"""
import os
import random
import threading
from typing import List, Optional
from dataclasses import dataclass, field
from collections import defaultdict
import logging
logger = logging.getLogger(__name__)
@dataclass
class KeyStats:
"""单个Key的统计信息"""
key_hash: str # Key的hash(用于日志,不泄露真实Key)
total_requests: int = 0
success_count: int = 0
error_count: int = 0
last_used: Optional[float] = None
last_error: Optional[str] = None
@property
def success_rate(self) -> float:
"""成功率"""
if self.total_requests == 0:
return 1.0
return self.success_count / self.total_requests
class ApiKeyManager:
"""API Key 管理器 - 支持多Key随机选择"""
def __init__(self, api_keys: Optional[str | List[str]] = None):
"""
初始化API Key管理器
Args:
api_keys: 可以是单个key、逗号分隔的字符串,或key列表
"""
self._keys: List[str] = []
self._lock = threading.Lock() # 线程安全锁
self._stats: dict[str, KeyStats] = {} # key_hash -> KeyStats
self._key_to_hash: dict[str, str] = {} # key -> key_hash
# 加载keys
if api_keys:
self._load_keys(api_keys)
else:
# 从环境变量加载
env_keys = os.getenv("OPENAI_API_KEY", "")
if env_keys:
self._load_keys(env_keys)
if not self._keys:
logger.warning("未配置任何API Key")
logger.info(f"API Key管理器初始化完成,共加载 {len(self._keys)} 个Key")
def _load_keys(self, keys: str | List[str]):
"""加载API Keys"""
if isinstance(keys, str):
# 支持逗号分隔的多个key
keys_list = [k.strip() for k in keys.split(",") if k.strip()]
else:
keys_list = keys
# 去重
seen = set()
unique_keys = []
for key in keys_list:
if key and key not in seen:
seen.add(key)
unique_keys.append(key)
self._keys = unique_keys
# 初始化统计信息
import time
import hashlib
for key in self._keys:
key_hash = hashlib.sha256(key.encode()).hexdigest()[:8]
self._key_to_hash[key] = key_hash
self._stats[key_hash] = KeyStats(key_hash=key_hash)
@property
def key_count(self) -> int:
"""获取Key数量"""
return len(self._keys)
@property
def has_keys(self) -> bool:
"""是否有可用的Key"""
return len(self._keys) > 0
def get_key(self) -> Optional[str]:
"""
获取下一个可用的API Key(随机选择)
Returns:
API Key字符串,如果没有可用Key返回None
"""
if not self._keys:
return None
with self._lock:
# 随机选择一个 key
key = random.choice(self._keys)
# 更新统计
import time
key_hash = self._key_to_hash[key]
stats = self._stats[key_hash]
stats.total_requests += 1
stats.last_used = time.time()
return key
def record_success(self, key: str):
"""记录请求成功"""
if key not in self._key_to_hash:
return
with self._lock:
key_hash = self._key_to_hash[key]
self._stats[key_hash].success_count += 1
def record_error(self, key: str, error: str):
"""记录请求失败"""
if key not in self._key_to_hash:
return
with self._lock:
key_hash = self._key_to_hash[key]
stats = self._stats[key_hash]
stats.error_count += 1
stats.last_error = error[:200] # 限制错误信息长度
def get_stats(self) -> dict:
"""获取所有Key的统计信息"""
with self._lock:
return {
"total_keys": len(self._keys),
"keys": [
{
"key_hash": stats.key_hash,
"total_requests": stats.total_requests,
"success_count": stats.success_count,
"error_count": stats.error_count,
"success_rate": f"{stats.success_rate:.2%}",
"last_used": stats.last_used,
"last_error": stats.last_error,
}
for stats in self._stats.values()
],
}
def reset_stats(self):
"""重置统计信息"""
with self._lock:
for stats in self._stats.values():
stats.total_requests = 0
stats.success_count = 0
stats.error_count = 0
stats.last_used = None
stats.last_error = None
# 全局单例实例
_global_manager: Optional[ApiKeyManager] = None
_global_lock = threading.Lock()
def get_global_manager() -> ApiKeyManager:
"""获取全局单例"""
if _global_manager is None:
with _global_lock:
if _global_manager is None:
_global_manager = ApiKeyManager()
return _global_manager
def init_manager(api_keys: Optional[str | List[str]] = None) -> ApiKeyManager:
"""初始化全局管理器"""
global _global_manager
with _global_lock:
_global_manager = ApiKeyManager(api_keys)
return _global_manager
|