| import random |
| from fastapi import HTTPException, Request |
| import time |
| import re |
| from datetime import datetime, timedelta |
| from apscheduler.schedulers.background import BackgroundScheduler |
| import os |
| import requests |
| import httpx |
| from threading import Lock |
| import logging |
| import sys |
|
|
| DEBUG = os.environ.get("DEBUG", "false").lower() == "true" |
| LOG_FORMAT_DEBUG = '%(asctime)s - %(levelname)s - [%(key)s]-%(request_type)s-[%(model)s]-%(status_code)s: %(message)s - %(error_message)s' |
| LOG_FORMAT_NORMAL = '[%(key)s]-%(request_type)s-[%(model)s]-%(status_code)s: %(message)s' |
|
|
| |
| logger = logging.getLogger("my_logger") |
| logger.setLevel(logging.DEBUG) |
|
|
| handler = logging.StreamHandler() |
| |
| |
| logger.addHandler(handler) |
|
|
| def format_log_message(level, message, extra=None): |
| extra = extra or {} |
| log_values = { |
| 'asctime': datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
| 'levelname': level, |
| 'key': extra.get('key', 'N/A'), |
| 'request_type': extra.get('request_type', 'N/A'), |
| 'model': extra.get('model', 'N/A'), |
| 'status_code': extra.get('status_code', 'N/A'), |
| 'error_message': extra.get('error_message', ''), |
| 'message': message |
| } |
| log_format = LOG_FORMAT_DEBUG if DEBUG else LOG_FORMAT_NORMAL |
| return log_format % log_values |
|
|
|
|
| class APIKeyManager: |
| def __init__(self): |
| self.api_keys = re.findall( |
| r"AIzaSy[a-zA-Z0-9_-]{33}", os.environ.get('GEMINI_API_KEYS', "")) |
| self.key_stack = [] |
| self._reset_key_stack() |
| |
| |
| self.scheduler = BackgroundScheduler() |
| self.scheduler.start() |
| self.tried_keys_for_request = set() |
|
|
| def _reset_key_stack(self): |
| """创建并随机化密钥栈""" |
| shuffled_keys = self.api_keys[:] |
| random.shuffle(shuffled_keys) |
| self.key_stack = shuffled_keys |
|
|
|
|
| def get_available_key(self): |
| """从栈顶获取密钥,栈空时重新生成 (修改后)""" |
| while self.key_stack: |
| key = self.key_stack.pop() |
| |
| if key not in self.tried_keys_for_request: |
| self.tried_keys_for_request.add(key) |
| return key |
|
|
| if not self.api_keys: |
| log_msg = format_log_message('ERROR', "没有配置任何 API 密钥!") |
| logger.error(log_msg) |
| return None |
|
|
| self._reset_key_stack() |
|
|
| |
| while self.key_stack: |
| key = self.key_stack.pop() |
| |
| if key not in self.tried_keys_for_request: |
| self.tried_keys_for_request.add(key) |
| return key |
|
|
| return None |
|
|
|
|
| def show_all_keys(self): |
| log_msg = format_log_message('INFO', f"当前可用API key个数: {len(self.api_keys)} ") |
| logger.info(log_msg) |
| for i, api_key in enumerate(self.api_keys): |
| log_msg = format_log_message('INFO', f"API Key{i}: {api_key[:8]}...{api_key[-3:]}") |
| logger.info(log_msg) |
|
|
| |
| |
| |
| |
| |
| |
|
|
| def reset_tried_keys_for_request(self): |
| """在新的请求尝试时重置已尝试的 key 集合""" |
| self.tried_keys_for_request = set() |
|
|
|
|
| def handle_gemini_error(error, current_api_key, key_manager) -> str: |
| if isinstance(error, requests.exceptions.HTTPError): |
| status_code = error.response.status_code |
| if status_code == 400: |
| try: |
| error_data = error.response.json() |
| if 'error' in error_data: |
| if error_data['error'].get('code') == "invalid_argument": |
| error_message = "无效的 API 密钥" |
| extra_log_invalid_key = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message} |
| log_msg = format_log_message('ERROR', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 无效,可能已过期或被删除", extra=extra_log_invalid_key) |
| logger.error(log_msg) |
| |
| |
| return error_message |
| error_message = error_data['error'].get( |
| 'message', 'Bad Request') |
| extra_log_400 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message} |
| log_msg = format_log_message('WARNING', f"400 错误请求: {error_message}", extra=extra_log_400) |
| logger.warning(log_msg) |
| return f"400 错误请求: {error_message}" |
| except ValueError: |
| error_message = "400 错误请求:响应不是有效的JSON格式" |
| extra_log_400_json = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message} |
| log_msg = format_log_message('WARNING', error_message, extra=extra_log_400_json) |
| logger.warning(log_msg) |
| return error_message |
|
|
| elif status_code == 429: |
| error_message = "API 密钥配额已用尽或其他原因" |
| extra_log_429 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message} |
| log_msg = format_log_message('WARNING', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 429 官方资源耗尽或其他原因", extra=extra_log_429) |
| logger.warning(log_msg) |
| |
| |
| return error_message |
|
|
| elif status_code == 403: |
| error_message = "权限被拒绝" |
| extra_log_403 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message} |
| log_msg = format_log_message('ERROR', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 403 权限被拒绝", extra=extra_log_403) |
| logger.error(log_msg) |
| |
| |
| return error_message |
| elif status_code == 500: |
| error_message = "服务器内部错误" |
| extra_log_500 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message} |
| log_msg = format_log_message('WARNING', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 500 服务器内部错误", extra=extra_log_500) |
| logger.warning(log_msg) |
| |
| return "Gemini API 内部错误" |
|
|
| elif status_code == 503: |
| error_message = "服务不可用" |
| extra_log_503 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message} |
| log_msg = format_log_message('WARNING', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 503 服务不可用", extra=extra_log_503) |
| logger.warning(log_msg) |
| |
| return "Gemini API 服务不可用" |
| else: |
| error_message = f"未知错误: {status_code}" |
| extra_log_other = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message} |
| log_msg = format_log_message('WARNING', f"{current_api_key[:8]} ... {current_api_key[-3:]} → {status_code} 未知错误", extra=extra_log_other) |
| logger.warning(log_msg) |
| |
| return f"未知错误/模型不可用: {status_code}" |
|
|
| elif isinstance(error, requests.exceptions.ConnectionError): |
| error_message = "连接错误" |
| log_msg = format_log_message('WARNING', error_message, extra={'error_message': error_message}) |
| logger.warning(log_msg) |
| return error_message |
|
|
| elif isinstance(error, requests.exceptions.Timeout): |
| error_message = "请求超时" |
| log_msg = format_log_message('WARNING', error_message, extra={'error_message': error_message}) |
| logger.warning(log_msg) |
| return error_message |
| else: |
| error_message = f"发生未知错误: {error}" |
| log_msg = format_log_message('ERROR', error_message, extra={'error_message': error_message}) |
| logger.error(log_msg) |
| return error_message |
|
|
|
|
| async def test_api_key(api_key: str) -> bool: |
| """ |
| 测试 API 密钥是否有效。 |
| """ |
| try: |
| url = "https://generativelanguage.googleapis.com/v1beta/models?key={}".format(api_key) |
| async with httpx.AsyncClient() as client: |
| response = await client.get(url) |
| response.raise_for_status() |
| return True |
| except Exception: |
| return False |
|
|
|
|
| rate_limit_data = {} |
| rate_limit_lock = Lock() |
|
|
|
|
| def protect_from_abuse(request: Request, max_requests_per_minute: int = 30, max_requests_per_day_per_ip: int = 600): |
| now = int(time.time()) |
| minute = now // 60 |
| day = now // (60 * 60 * 24) |
|
|
| minute_key = f"{request.url.path}:{minute}" |
| day_key = f"{request.client.host}:{day}" |
|
|
| with rate_limit_lock: |
| minute_count, minute_timestamp = rate_limit_data.get( |
| minute_key, (0, now)) |
| if now - minute_timestamp >= 60: |
| minute_count = 0 |
| minute_timestamp = now |
| minute_count += 1 |
| rate_limit_data[minute_key] = (minute_count, minute_timestamp) |
|
|
| day_count, day_timestamp = rate_limit_data.get(day_key, (0, now)) |
| if now - day_timestamp >= 86400: |
| day_count = 0 |
| day_timestamp = now |
| day_count += 1 |
| rate_limit_data[day_key] = (day_count, day_timestamp) |
|
|
| if minute_count > max_requests_per_minute: |
| raise HTTPException(status_code=429, detail={ |
| "message": "Too many requests per minute", "limit": max_requests_per_minute}) |
| if day_count > max_requests_per_day_per_ip: |
| raise HTTPException(status_code=429, detail={"message": "Too many requests per day from this IP", "limit": max_requests_per_day_per_ip}) |