| import time |
| import logging |
| import functools |
| import requests |
| from abc import ABC, abstractmethod |
| from typing import Callable, Any, Dict, Optional, Type, Union, TypeVar, cast |
|
|
| |
| import config |
|
|
| |
| T = TypeVar('T') |
|
|
| class RetryStrategy(ABC): |
| """重试策略的抽象基类""" |
| |
| @abstractmethod |
| def should_retry(self, exception: Exception, retry_count: int, max_retries: int) -> bool: |
| """ |
| 判断是否应该重试 |
| |
| Args: |
| exception: 捕获的异常 |
| retry_count: 当前重试次数 |
| max_retries: 最大重试次数 |
| |
| Returns: |
| bool: 是否应该重试 |
| """ |
| pass |
| |
| @abstractmethod |
| def get_retry_delay(self, retry_count: int, base_delay: int) -> float: |
| """ |
| 计算重试延迟时间 |
| |
| Args: |
| retry_count: 当前重试次数 |
| base_delay: 基础延迟时间(秒) |
| |
| Returns: |
| float: 重试延迟时间(秒) |
| """ |
| pass |
| |
| @abstractmethod |
| def log_retry_attempt(self, logger: logging.Logger, exception: Exception, |
| retry_count: int, max_retries: int, delay: float) -> None: |
| """ |
| 记录重试尝试 |
| |
| Args: |
| logger: 日志记录器 |
| exception: 捕获的异常 |
| retry_count: 当前重试次数 |
| max_retries: 最大重试次数 |
| delay: 重试延迟时间 |
| """ |
| pass |
| |
| @abstractmethod |
| def on_retry(self, exception: Exception, retry_count: int) -> None: |
| """ |
| 重试前的回调函数,可以执行额外操作 |
| |
| Args: |
| exception: 捕获的异常 |
| retry_count: 当前重试次数 |
| """ |
| pass |
|
|
|
|
| class ExponentialBackoffStrategy(RetryStrategy): |
| """指数退避重试策略,适用于连接错误""" |
| |
| def should_retry(self, exception: Exception, retry_count: int, max_retries: int) -> bool: |
| return (isinstance(exception, requests.exceptions.ConnectionError) and |
| retry_count < max_retries) |
| |
| def get_retry_delay(self, retry_count: int, base_delay: int) -> float: |
| |
| return base_delay * (2 ** retry_count) |
| |
| def log_retry_attempt(self, logger: logging.Logger, exception: Exception, |
| retry_count: int, max_retries: int, delay: float) -> None: |
| |
| if callable(logger) and not isinstance(logger, logging.Logger): |
| |
| logger(f"连接错误,{delay:.1f}秒后重试 ({retry_count}/{max_retries}): {exception}", "WARNING") |
| else: |
| |
| logger.warning(f"连接错误,{delay:.1f}秒后重试 ({retry_count}/{max_retries}): {exception}") |
| |
| def on_retry(self, exception: Exception, retry_count: int) -> None: |
| |
| pass |
|
|
|
|
| class LinearBackoffStrategy(RetryStrategy): |
| """线性退避重试策略,适用于超时错误""" |
| |
| def should_retry(self, exception: Exception, retry_count: int, max_retries: int) -> bool: |
| return (isinstance(exception, requests.exceptions.Timeout) and |
| retry_count < max_retries) |
| |
| def get_retry_delay(self, retry_count: int, base_delay: int) -> float: |
| |
| return base_delay * retry_count |
| |
| def log_retry_attempt(self, logger: logging.Logger, exception: Exception, |
| retry_count: int, max_retries: int, delay: float) -> None: |
| |
| if callable(logger) and not isinstance(logger, logging.Logger): |
| |
| logger(f"请求超时,{delay:.1f}秒后重试 ({retry_count}/{max_retries}): {exception}", "WARNING") |
| else: |
| |
| logger.warning(f"请求超时,{delay:.1f}秒后重试 ({retry_count}/{max_retries}): {exception}") |
| |
| def on_retry(self, exception: Exception, retry_count: int) -> None: |
| |
| pass |
|
|
|
|
| class ServerErrorStrategy(RetryStrategy): |
| """服务器错误重试策略,适用于5xx错误""" |
| |
| def should_retry(self, exception: Exception, retry_count: int, max_retries: int) -> bool: |
| if not isinstance(exception, requests.exceptions.HTTPError): |
| return False |
| |
| response = getattr(exception, 'response', None) |
| if response is None: |
| return False |
| |
| return (500 <= response.status_code < 600 and retry_count < max_retries) |
| |
| def get_retry_delay(self, retry_count: int, base_delay: int) -> float: |
| |
| return base_delay * retry_count |
| |
| def log_retry_attempt(self, logger: logging.Logger, exception: Exception, |
| retry_count: int, max_retries: int, delay: float) -> None: |
| response = getattr(exception, 'response', None) |
| status_code = response.status_code if response else 'unknown' |
| |
| if callable(logger) and not isinstance(logger, logging.Logger): |
| |
| logger(f"服务器错误 {status_code},{delay:.1f}秒后重试 ({retry_count}/{max_retries})", "WARNING") |
| else: |
| |
| logger.warning(f"服务器错误 {status_code},{delay:.1f}秒后重试 ({retry_count}/{max_retries})") |
| |
| def on_retry(self, exception: Exception, retry_count: int) -> None: |
| |
| pass |
|
|
|
|
| class RateLimitStrategy(RetryStrategy): |
| """速率限制重试策略,适用于429错误,包括账号切换逻辑和延迟重试""" |
| |
| def __init__(self, client=None): |
| """ |
| 初始化速率限制重试策略 |
| |
| Args: |
| client: API客户端实例,用于切换账号 |
| """ |
| self.client = client |
| self.consecutive_429_count = 0 |
| |
| def should_retry(self, exception: Exception, retry_count: int, max_retries: int) -> bool: |
| if not isinstance(exception, requests.exceptions.HTTPError): |
| return False |
| |
| response = getattr(exception, 'response', None) |
| if response is None: |
| return False |
| |
| is_rate_limit = response.status_code == 429 |
| if is_rate_limit: |
| self.consecutive_429_count += 1 |
| else: |
| self.consecutive_429_count = 0 |
| |
| return is_rate_limit |
| |
| def get_retry_delay(self, retry_count: int, base_delay: int) -> float: |
| |
| return 0 |
| |
| def log_retry_attempt(self, logger: logging.Logger, exception: Exception, |
| retry_count: int, max_retries: int, delay: float) -> None: |
| |
| message = "" |
| if self.consecutive_429_count > 1: |
| message = f"连续第{self.consecutive_429_count}次速率限制错误,尝试立即重试" |
| else: |
| message = "速率限制错误,尝试切换账号" |
| |
| if callable(logger) and not isinstance(logger, logging.Logger): |
| |
| logger(message, "WARNING") |
| else: |
| |
| logger.warning(message) |
| |
| def on_retry(self, exception: Exception, retry_count: int) -> None: |
| |
| user_identifier = getattr(self.client, '_associated_user_identifier', None) |
| request_ip = getattr(self.client, '_associated_request_ip', None) |
|
|
| |
| if self.consecutive_429_count == 1 or (self.consecutive_429_count > 0 and self.consecutive_429_count % 3 == 0): |
| if self.client and hasattr(self.client, 'email'): |
| |
| current_email = self.client.email |
| config.set_account_cooldown(current_email) |
| |
| |
| new_email, new_password = config.get_next_ondemand_account_details() |
| if new_email: |
| |
| self.client.email = new_email |
| self.client.password = new_password |
| self.client.token = "" |
| self.client.refresh_token = "" |
| self.client.session_id = "" |
| |
| |
| try: |
| |
| current_context_hash = getattr(self.client, '_current_request_context_hash', None) |
| |
| self.client.sign_in(context=current_context_hash) |
| if self.client.create_session(external_context=current_context_hash): |
| |
| if hasattr(self.client, '_log'): |
| self.client._log(f"成功切换到账号 {new_email} 并使用上下文哈希 '{current_context_hash}' 重新登录和创建新会话。", "INFO") |
| |
| setattr(self.client, '_new_session_requires_full_history', True) |
| if hasattr(self.client, '_log'): |
| self.client._log(f"已设置 _new_session_requires_full_history = True,下次查询应发送完整历史。", "INFO") |
| else: |
| |
| if hasattr(self.client, '_log'): |
| self.client._log(f"切换到账号 {new_email} 后,创建新会话失败。", "WARNING") |
| |
| setattr(self.client, '_new_session_requires_full_history', False) |
|
|
|
|
| |
| if not user_identifier: |
| if hasattr(self.client, '_log'): |
| self.client._log("RateLimitStrategy: _associated_user_identifier not found on client. Cannot update client_sessions.", "ERROR") |
| |
| else: |
| old_email_in_strategy = current_email |
| new_email_in_strategy = self.client.email |
|
|
| with config.config_instance.client_sessions_lock: |
| if user_identifier in config.config_instance.client_sessions: |
| user_specific_sessions = config.config_instance.client_sessions[user_identifier] |
|
|
| |
| |
| |
| |
| if old_email_in_strategy in user_specific_sessions: |
| |
| |
| del user_specific_sessions[old_email_in_strategy] |
| if hasattr(self.client, '_log'): |
| self.client._log(f"RateLimitStrategy: Removed session for old email '{old_email_in_strategy}' for user '{user_identifier}'.", "INFO") |
| |
| |
| |
| |
| |
| |
| |
| |
| ip_to_use = request_ip if request_ip else user_specific_sessions.get(new_email_in_strategy, {}).get("ip", "unknown_ip_in_retry_update") |
| |
| |
| from datetime import datetime |
| |
| |
| |
| active_hash_for_new_session = getattr(self.client, '_current_request_context_hash', None) |
|
|
| user_specific_sessions[new_email_in_strategy] = { |
| "client": self.client, |
| "active_context_hash": active_hash_for_new_session, |
| "last_time": datetime.now(), |
| "ip": ip_to_use |
| } |
| log_message_hash_part = f"set to '{active_hash_for_new_session}' (from client instance's _current_request_context_hash)" if active_hash_for_new_session is not None else "set to None (_current_request_context_hash not found on client instance)" |
| if hasattr(self.client, '_log'): |
| self.client._log(f"RateLimitStrategy: Updated/added session for new email '{new_email_in_strategy}' for user '{user_identifier}'. active_context_hash {log_message_hash_part}.", "INFO") |
| else: |
| if hasattr(self.client, '_log'): |
| self.client._log(f"RateLimitStrategy: User '{user_identifier}' not found in client_sessions during update attempt.", "WARNING") |
| |
|
|
| except Exception as e: |
| |
| |
| if hasattr(self.client, '_log'): |
| self.client._log(f"切换到账号 {new_email} 后登录或创建会话失败: {e}", "WARNING") |
| |
|
|
|
|
| class RetryHandler: |
| """重试处理器,管理多个重试策略""" |
| |
| def __init__(self, client=None, logger=None): |
| """ |
| 初始化重试处理器 |
| |
| Args: |
| client: API客户端实例,用于切换账号 |
| logger: 日志记录器或日志函数 |
| """ |
| self.client = client |
| |
| |
| self.logger = logger or logging.getLogger(__name__) |
| self.strategies = [ |
| ExponentialBackoffStrategy(), |
| LinearBackoffStrategy(), |
| ServerErrorStrategy(), |
| RateLimitStrategy(client) |
| ] |
| |
| def retry_operation(self, operation: Callable[..., T], *args, **kwargs) -> T: |
| """ |
| 使用重试策略执行操作 |
| |
| Args: |
| operation: 要执行的操作 |
| *args: 操作的位置参数 |
| **kwargs: 操作的关键字参数 |
| |
| Returns: |
| 操作的结果 |
| |
| Raises: |
| Exception: 如果所有重试都失败,则抛出最后一个异常 |
| """ |
| max_retries = config.get_config_value('max_retries') |
| base_delay = config.get_config_value('retry_delay') |
| retry_count = 0 |
| last_exception = None |
| |
| while True: |
| try: |
| return operation(*args, **kwargs) |
| except Exception as e: |
| last_exception = e |
| |
| |
| strategy = next((s for s in self.strategies if s.should_retry(e, retry_count, max_retries)), None) |
| |
| if strategy: |
| retry_count += 1 |
| delay = strategy.get_retry_delay(retry_count, base_delay) |
| strategy.log_retry_attempt(self.logger, e, retry_count, max_retries, delay) |
| strategy.on_retry(e, retry_count) |
| |
| if delay > 0: |
| time.sleep(delay) |
| else: |
| |
| raise |
|
|
|
|
| def with_retry(max_retries: Optional[int] = None, retry_delay: Optional[int] = None): |
| """ |
| 重试装饰器,用于装饰需要重试的方法 |
| |
| Args: |
| max_retries: 最大重试次数,如果为None则使用配置值 |
| retry_delay: 基础重试延迟,如果为None则使用配置值 |
| |
| Returns: |
| 装饰后的函数 |
| """ |
| def decorator(func): |
| @functools.wraps(func) |
| def wrapper(self, *args, **kwargs): |
| |
| _max_retries = max_retries or config.get_config_value('max_retries') |
| _retry_delay = retry_delay or config.get_config_value('retry_delay') |
| |
| |
| handler = RetryHandler(client=self, logger=getattr(self, '_log', None)) |
| |
| |
| def operation(): |
| return func(self, *args, **kwargs) |
| |
| |
| return handler.retry_operation(operation) |
| |
| return wrapper |
| |
| return decorator |