astrbot_help / src /api_key_manager.py
qa1145's picture
Upload 28 files
d347708 verified
"""
API Key Manager - 多Key随机管理器
支持线程安全的Key随机选择,提供良好的负载分散
"""
import os
import random
import threading
from typing import List, Optional
from dataclasses import dataclass, field
from collections import defaultdict
import logging
logger = logging.getLogger(__name__)
@dataclass
class KeyStats:
"""单个Key的统计信息"""
key_hash: str # Key的hash(用于日志,不泄露真实Key)
total_requests: int = 0
success_count: int = 0
error_count: int = 0
last_used: Optional[float] = None
last_error: Optional[str] = None
@property
def success_rate(self) -> float:
"""成功率"""
if self.total_requests == 0:
return 1.0
return self.success_count / self.total_requests
class ApiKeyManager:
"""API Key 管理器 - 支持多Key随机选择"""
def __init__(self, api_keys: Optional[str | List[str]] = None):
"""
初始化API Key管理器
Args:
api_keys: 可以是单个key、逗号分隔的字符串,或key列表
"""
self._keys: List[str] = []
self._lock = threading.Lock() # 线程安全锁
self._stats: dict[str, KeyStats] = {} # key_hash -> KeyStats
self._key_to_hash: dict[str, str] = {} # key -> key_hash
# 加载keys
if api_keys:
self._load_keys(api_keys)
else:
# 从环境变量加载
env_keys = os.getenv("OPENAI_API_KEY", "")
if env_keys:
self._load_keys(env_keys)
if not self._keys:
logger.warning("未配置任何API Key")
logger.info(f"API Key管理器初始化完成,共加载 {len(self._keys)} 个Key")
def _load_keys(self, keys: str | List[str]):
"""加载API Keys"""
if isinstance(keys, str):
# 支持逗号分隔的多个key
keys_list = [k.strip() for k in keys.split(",") if k.strip()]
else:
keys_list = keys
# 去重
seen = set()
unique_keys = []
for key in keys_list:
if key and key not in seen:
seen.add(key)
unique_keys.append(key)
self._keys = unique_keys
# 初始化统计信息
import time
import hashlib
for key in self._keys:
key_hash = hashlib.sha256(key.encode()).hexdigest()[:8]
self._key_to_hash[key] = key_hash
self._stats[key_hash] = KeyStats(key_hash=key_hash)
@property
def key_count(self) -> int:
"""获取Key数量"""
return len(self._keys)
@property
def has_keys(self) -> bool:
"""是否有可用的Key"""
return len(self._keys) > 0
def get_key(self) -> Optional[str]:
"""
获取下一个可用的API Key(随机选择)
Returns:
API Key字符串,如果没有可用Key返回None
"""
if not self._keys:
return None
with self._lock:
# 随机选择一个 key
key = random.choice(self._keys)
# 更新统计
import time
key_hash = self._key_to_hash[key]
stats = self._stats[key_hash]
stats.total_requests += 1
stats.last_used = time.time()
return key
def record_success(self, key: str):
"""记录请求成功"""
if key not in self._key_to_hash:
return
with self._lock:
key_hash = self._key_to_hash[key]
self._stats[key_hash].success_count += 1
def record_error(self, key: str, error: str):
"""记录请求失败"""
if key not in self._key_to_hash:
return
with self._lock:
key_hash = self._key_to_hash[key]
stats = self._stats[key_hash]
stats.error_count += 1
stats.last_error = error[:200] # 限制错误信息长度
def get_stats(self) -> dict:
"""获取所有Key的统计信息"""
with self._lock:
return {
"total_keys": len(self._keys),
"keys": [
{
"key_hash": stats.key_hash,
"total_requests": stats.total_requests,
"success_count": stats.success_count,
"error_count": stats.error_count,
"success_rate": f"{stats.success_rate:.2%}",
"last_used": stats.last_used,
"last_error": stats.last_error,
}
for stats in self._stats.values()
],
}
def reset_stats(self):
"""重置统计信息"""
with self._lock:
for stats in self._stats.values():
stats.total_requests = 0
stats.success_count = 0
stats.error_count = 0
stats.last_used = None
stats.last_error = None
# 全局单例实例
_global_manager: Optional[ApiKeyManager] = None
_global_lock = threading.Lock()
def get_global_manager() -> ApiKeyManager:
"""获取全局单例"""
if _global_manager is None:
with _global_lock:
if _global_manager is None:
_global_manager = ApiKeyManager()
return _global_manager
def init_manager(api_keys: Optional[str | List[str]] = None) -> ApiKeyManager:
"""初始化全局管理器"""
global _global_manager
with _global_lock:
_global_manager = ApiKeyManager(api_keys)
return _global_manager