| import os |
| import json |
| import time |
| from collections import defaultdict |
| import threading |
| from typing import Dict, List, Any, Optional, Union, get_type_hints |
| from datetime import datetime, timedelta |
| from utils import logger, load_config |
|
|
|
|
| class Config: |
| """配置管理类,用于存储和管理所有配置""" |
| |
| |
| _defaults = { |
| "ondemand_session_timeout_minutes": 30, |
| "session_timeout_minutes": 3600, |
| "max_retries": 5, |
| "retry_delay": 3, |
| "request_timeout": 45, |
| "stream_timeout": 180, |
| "rate_limit": 30, |
| "account_cooldown_seconds": 300, |
| "debug_mode": False, |
| "api_access_token": "sk-2api-ondemand-access-token-2025", |
| "stats_file_path": "stats_data.json", |
| "stats_backup_path": "stats_data_backup.json", |
| "stats_save_interval": 300, |
| "max_history_items": 1000, |
| "default_endpoint_id": "predefined-claude-3.7-sonnet" |
| } |
| |
| |
| _model_mapping = { |
| "gpt-3.5-turbo": "predefined-openai-gpto3-mini", |
| "gpto3-mini": "predefined-openai-gpto3-mini", |
| "gpt-4o": "predefined-openai-gpt4o", |
| "gpt-4o-mini": "predefined-openai-gpt4o-mini", |
| "gpt-4-turbo": "predefined-openai-gpt4.1", |
| "gpt-4.1": "predefined-openai-gpt4.1", |
| "gpt-4.1-mini": "predefined-openai-gpt4.1-mini", |
| "gpt-4.1-nano": "predefined-openai-gpt4.1-nano", |
| "deepseek-v3": "predefined-deepseek-v3", |
| "deepseek-r1": "predefined-deepseek-r1", |
| "claude-3.5-sonnet": "predefined-claude-3.5-sonnet", |
| "claude-3.7-sonnet": "predefined-claude-3.7-sonnet", |
| "claude-3-opus": "predefined-claude-3-opus", |
| "claude-3-haiku": "predefined-claude-3-haiku", |
| "gemini-1.5-pro": "predefined-gemini-2.0-flash", |
| "gemini-2.0-flash": "predefined-gemini-2.0-flash", |
| |
| } |
| |
| def __init__(self): |
| """初始化配置对象""" |
| |
| self._config = self._defaults.copy() |
| |
| |
| self.usage_stats = { |
| "total_requests": 0, |
| "successful_requests": 0, |
| "failed_requests": 0, |
| "model_usage": defaultdict(int), |
| "account_usage": defaultdict(int), |
| "daily_usage": defaultdict(int), |
| "hourly_usage": defaultdict(int), |
| "request_history": [], |
| "total_prompt_tokens": 0, |
| "total_completion_tokens": 0, |
| "total_tokens": 0, |
| "model_tokens": defaultdict(int), |
| "daily_tokens": defaultdict(int), |
| "hourly_tokens": defaultdict(int), |
| "last_saved": datetime.now().isoformat() |
| } |
| |
| |
| self.usage_stats_lock = threading.Lock() |
| self.account_index_lock = threading.Lock() |
| self.client_sessions_lock = threading.Lock() |
| |
| |
| self.current_account_index = 0 |
| |
| |
| |
| |
| self.client_sessions = {} |
| |
| |
| self.accounts = [] |
| |
| |
| |
| self.account_cooldowns = {} |
| |
| def get(self, key: str, default: Any = None) -> Any: |
| """获取配置值""" |
| return self._config.get(key, default) |
| |
| def set(self, key: str, value: Any) -> None: |
| """设置配置值""" |
| self._config[key] = value |
| |
| def update(self, config_dict: Dict[str, Any]) -> None: |
| """批量更新配置值""" |
| self._config.update(config_dict) |
| |
| def get_model_endpoint(self, model_name: str) -> str: |
| """获取模型对应的端点ID""" |
| return self._model_mapping.get(model_name, self.get("default_endpoint_id")) |
| |
| def load_from_file(self) -> bool: |
| """从配置文件加载配置""" |
| try: |
| |
| config_data = load_config() |
| if config_data: |
| |
| for key, value in config_data.items(): |
| if key != "accounts": |
| self.set(key, value) |
| |
| |
| if "accounts" in config_data: |
| self.accounts = config_data["accounts"] |
| |
| logger.info("已从配置文件加载配置") |
| return True |
| return False |
| except Exception as e: |
| logger.error(f"加载配置文件时出错: {e}") |
| return False |
| |
| def load_from_env(self) -> None: |
| """从环境变量加载配置""" |
| |
| if not self.accounts: |
| accounts_env = os.getenv("ONDEMAND_ACCOUNTS", "") |
| if accounts_env: |
| try: |
| self.accounts = json.loads(accounts_env).get('accounts', []) |
| logger.info("已从环境变量加载账户信息") |
| except json.JSONDecodeError: |
| logger.error("解码 ONDEMAND_ACCOUNTS 环境变量失败") |
| |
| |
| env_mappings = { |
| "ondemand_session_timeout_minutes": "ONDEMAND_SESSION_TIMEOUT_MINUTES", |
| "session_timeout_minutes": "SESSION_TIMEOUT_MINUTES", |
| "max_retries": "MAX_RETRIES", |
| "retry_delay": "RETRY_DELAY", |
| "request_timeout": "REQUEST_TIMEOUT", |
| "stream_timeout": "STREAM_TIMEOUT", |
| "rate_limit": "RATE_LIMIT", |
| "debug_mode": "DEBUG_MODE", |
| "api_access_token": "API_ACCESS_TOKEN" |
| } |
| |
| for config_key, env_key in env_mappings.items(): |
| env_value = os.getenv(env_key) |
| if env_value is not None: |
| |
| default_value = self.get(config_key) |
| if isinstance(default_value, bool): |
| self.set(config_key, env_value.lower() == 'true') |
| elif isinstance(default_value, int): |
| self.set(config_key, int(env_value)) |
| elif isinstance(default_value, float): |
| self.set(config_key, float(env_value)) |
| else: |
| self.set(config_key, env_value) |
|
|
| def save_stats_to_file(self): |
| """将统计数据保存到文件中""" |
| try: |
| with self.usage_stats_lock: |
| |
| stats_copy = { |
| "total_requests": self.usage_stats["total_requests"], |
| "successful_requests": self.usage_stats["successful_requests"], |
| "failed_requests": self.usage_stats["failed_requests"], |
| "model_usage": dict(self.usage_stats["model_usage"]), |
| "account_usage": dict(self.usage_stats["account_usage"]), |
| "daily_usage": dict(self.usage_stats["daily_usage"]), |
| "hourly_usage": dict(self.usage_stats["hourly_usage"]), |
| "request_history": list(self.usage_stats["request_history"]), |
| "total_prompt_tokens": self.usage_stats["total_prompt_tokens"], |
| "total_completion_tokens": self.usage_stats["total_completion_tokens"], |
| "total_tokens": self.usage_stats["total_tokens"], |
| "model_tokens": dict(self.usage_stats["model_tokens"]), |
| "daily_tokens": dict(self.usage_stats["daily_tokens"]), |
| "hourly_tokens": dict(self.usage_stats["hourly_tokens"]), |
| "last_saved": datetime.now().isoformat() |
| } |
| |
| stats_file_path = self.get("stats_file_path") |
| stats_backup_path = self.get("stats_backup_path") |
| |
| |
| with open(stats_backup_path, 'w', encoding='utf-8') as f: |
| json.dump(stats_copy, f, ensure_ascii=False, indent=2) |
| |
| |
| if os.path.exists(stats_file_path): |
| os.remove(stats_file_path) |
| |
| |
| os.rename(stats_backup_path, stats_file_path) |
| |
| logger.info(f"统计数据已保存到 {stats_file_path}") |
| self.usage_stats["last_saved"] = datetime.now().isoformat() |
| except Exception as e: |
| logger.error(f"保存统计数据时出错: {e}") |
|
|
| def load_stats_from_file(self): |
| """从文件中加载统计数据""" |
| try: |
| stats_file_path = self.get("stats_file_path") |
| if os.path.exists(stats_file_path): |
| with open(stats_file_path, 'r', encoding='utf-8') as f: |
| saved_stats = json.load(f) |
| |
| with self.usage_stats_lock: |
| |
| self.usage_stats["total_requests"] = saved_stats.get("total_requests", 0) |
| self.usage_stats["successful_requests"] = saved_stats.get("successful_requests", 0) |
| self.usage_stats["failed_requests"] = saved_stats.get("failed_requests", 0) |
| self.usage_stats["total_prompt_tokens"] = saved_stats.get("total_prompt_tokens", 0) |
| self.usage_stats["total_completion_tokens"] = saved_stats.get("total_completion_tokens", 0) |
| self.usage_stats["total_tokens"] = saved_stats.get("total_tokens", 0) |
| |
| |
| for model, count in saved_stats.get("model_usage", {}).items(): |
| self.usage_stats["model_usage"][model] = count |
| |
| for account, count in saved_stats.get("account_usage", {}).items(): |
| self.usage_stats["account_usage"][account] = count |
| |
| for day, count in saved_stats.get("daily_usage", {}).items(): |
| self.usage_stats["daily_usage"][day] = count |
| |
| for hour, count in saved_stats.get("hourly_usage", {}).items(): |
| self.usage_stats["hourly_usage"][hour] = count |
| |
| for model, tokens in saved_stats.get("model_tokens", {}).items(): |
| self.usage_stats["model_tokens"][model] = tokens |
| |
| for day, tokens in saved_stats.get("daily_tokens", {}).items(): |
| self.usage_stats["daily_tokens"][day] = tokens |
| |
| for hour, tokens in saved_stats.get("hourly_tokens", {}).items(): |
| self.usage_stats["hourly_tokens"][hour] = tokens |
| |
| |
| self.usage_stats["request_history"] = saved_stats.get("request_history", []) |
| |
| |
| max_history_items = self.get("max_history_items") |
| if len(self.usage_stats["request_history"]) > max_history_items: |
| self.usage_stats["request_history"] = self.usage_stats["request_history"][-max_history_items:] |
| |
| logger.info(f"已从 {stats_file_path} 加载统计数据") |
| return True |
| else: |
| logger.info(f"未找到统计数据文件 {stats_file_path},将使用默认值") |
| return False |
| except Exception as e: |
| logger.error(f"加载统计数据时出错: {e}") |
| return False |
|
|
| def start_stats_save_thread(self): |
| """启动定期保存统计数据的线程""" |
| def save_stats_periodically(): |
| while True: |
| time.sleep(self.get("stats_save_interval")) |
| self.save_stats_to_file() |
| |
| save_thread = threading.Thread(target=save_stats_periodically, daemon=True) |
| save_thread.start() |
| logger.info(f"统计数据保存线程已启动,每 {self.get('stats_save_interval')} 秒保存一次") |
|
|
| def init(self): |
| """初始化配置,从配置文件或环境变量加载设置""" |
| |
| self.load_from_file() |
| |
| |
| self.load_from_env() |
| |
| |
| if not self.accounts: |
| error_msg = "在 config.json 或环境变量 ONDEMAND_ACCOUNTS 中未找到账户信息" |
| logger.critical(error_msg) |
| |
| logger.warning("将继续运行,但没有账户信息,可能会导致功能受限") |
| |
| logger.info("已加载API访问Token") |
| |
| |
| self.load_stats_from_file() |
| |
| |
| self.start_stats_save_thread() |
|
|
| def get_next_ondemand_account_details(self): |
| """获取下一个 OnDemand 账户的邮箱和密码,用于轮询。 |
| 会跳过处于冷却期的账户。""" |
| with self.account_index_lock: |
| current_time = datetime.now() |
| |
| |
| expired_cooldowns = [email for email, end_time in self.account_cooldowns.items() |
| if end_time < current_time] |
| for email in expired_cooldowns: |
| del self.account_cooldowns[email] |
| logger.info(f"账户 {email} 的冷却期已结束,现在可用") |
| |
| |
| for _ in range(len(self.accounts)): |
| account_details = self.accounts[self.current_account_index] |
| email = account_details.get('email') |
| |
| |
| self.current_account_index = (self.current_account_index + 1) % len(self.accounts) |
| |
| |
| if email in self.account_cooldowns: |
| cooldown_end = self.account_cooldowns[email] |
| remaining_seconds = (cooldown_end - current_time).total_seconds() |
| logger.warning(f"账户 {email} 仍在冷却期中,还剩 {remaining_seconds:.1f} 秒") |
| continue |
| |
| |
| logger.info(f"[系统] 新会话将使用账户: {email}") |
| return email, account_details.get('password') |
| |
| |
| logger.warning("所有账户都在冷却期!使用第一个账户,尽管它可能会触发速率限制") |
| account_details = self.accounts[0] |
| return account_details.get('email'), account_details.get('password') |
|
|
|
|
| |
| config_instance = Config() |
|
|
| def init_config(): |
| """初始化配置的兼容函数,用于向后兼容""" |
| config_instance.init() |
|
|
|
|
| def get_config_value(name: str, default: Any = None) -> Any: |
| """ |
| 获取当前配置变量的最新值。 |
| 推荐外部通过 config.get_config_value('变量名') 获取配置。 |
| 对于 accounts, model_mapping, usage_stats, client_sessions,请使用新增的专用getter函数。 |
| """ |
| return config_instance.get(name, default) |
|
|
| |
| def get_accounts() -> List[Dict[str, str]]: |
| """获取账户信息列表""" |
| return config_instance.accounts |
|
|
| def get_model_mapping() -> Dict[str, str]: |
| """获取模型名称到端点ID的映射""" |
| return config_instance._model_mapping |
|
|
| def get_usage_stats() -> Dict[str, Any]: |
| """获取用量统计数据""" |
| return config_instance.usage_stats |
|
|
| def get_client_sessions() -> Dict[str, Any]: |
| """获取客户端会话信息""" |
| return config_instance.client_sessions |
|
|
| def get_next_ondemand_account_details(): |
| """获取下一个账户的兼容函数""" |
| return config_instance.get_next_ondemand_account_details() |
|
|
| def set_account_cooldown(email, cooldown_seconds=None): |
| """设置账户冷却期 |
| |
| Args: |
| email: 账户邮箱 |
| cooldown_seconds: 冷却时间(秒),如果为None则使用默认配置 |
| """ |
| if cooldown_seconds is None: |
| cooldown_seconds = config_instance.get('account_cooldown_seconds') |
| |
| cooldown_end = datetime.now() + timedelta(seconds=cooldown_seconds) |
| with config_instance.account_index_lock: |
| config_instance.account_cooldowns[email] = cooldown_end |
| logger.warning(f"账户 {email} 已设置冷却期 {cooldown_seconds} 秒,将于 {cooldown_end.strftime('%Y-%m-%d %H:%M:%S')} 结束") |
|
|
|
|
| |
| |
| |