| """Load balancing module""" |
| import random |
| from typing import Optional |
| from ..core.models import Token |
| from ..core.config import config |
| from .token_manager import TokenManager |
| from .token_lock import TokenLock |
|
|
| class LoadBalancer: |
| """Token load balancer with random selection and image generation lock""" |
|
|
| def __init__(self, token_manager: TokenManager): |
| self.token_manager = token_manager |
| |
| self.token_lock = TokenLock(lock_timeout=config.image_timeout) |
|
|
| async def select_token(self, for_image_generation: bool = False, for_video_generation: bool = False) -> Optional[Token]: |
| """ |
| Select a token using random load balancing |
| |
| Args: |
| for_image_generation: If True, only select tokens that are not locked for image generation and have image_enabled=True |
| for_video_generation: If True, filter out tokens with Sora2 quota exhausted (sora2_cooldown_until not expired), tokens that don't support Sora2, and tokens with video_enabled=False |
| |
| Returns: |
| Selected token or None if no available tokens |
| """ |
| |
| if config.at_auto_refresh_enabled: |
| all_tokens = await self.token_manager.get_all_tokens() |
| for token in all_tokens: |
| if token.is_active and token.expiry_time: |
| from datetime import datetime |
| time_until_expiry = token.expiry_time - datetime.now() |
| hours_until_expiry = time_until_expiry.total_seconds() / 3600 |
| |
| if hours_until_expiry <= 24: |
| await self.token_manager.auto_refresh_expiring_token(token.id) |
|
|
| active_tokens = await self.token_manager.get_active_tokens() |
|
|
| if not active_tokens: |
| return None |
|
|
| |
| if for_video_generation: |
| from datetime import datetime |
| available_tokens = [] |
| for token in active_tokens: |
| |
| if not token.video_enabled: |
| continue |
|
|
| |
| if not token.sora2_supported: |
| continue |
|
|
| |
| if token.sora2_cooldown_until and token.sora2_cooldown_until <= datetime.now(): |
| await self.token_manager.refresh_sora2_remaining_if_cooldown_expired(token.id) |
| |
| token = await self.token_manager.db.get_token(token.id) |
|
|
| |
| if token and token.sora2_cooldown_until and token.sora2_cooldown_until > datetime.now(): |
| continue |
|
|
| if token: |
| available_tokens.append(token) |
|
|
| if not available_tokens: |
| return None |
|
|
| active_tokens = available_tokens |
|
|
| |
| if for_image_generation: |
| available_tokens = [] |
| for token in active_tokens: |
| |
| if not token.image_enabled: |
| continue |
|
|
| if not await self.token_lock.is_locked(token.id): |
| available_tokens.append(token) |
|
|
| if not available_tokens: |
| return None |
|
|
| |
| return random.choice(available_tokens) |
| else: |
| |
| return random.choice(active_tokens) |
|
|