Spaces:
Sleeping
Sleeping
| """ | |
| Rate Limiter - Token bucket and sliding window implementations for API rate limiting. | |
| Prevents self-DoS during stress testing by enforcing per-API rate limits. | |
| """ | |
| import time | |
| import asyncio | |
| from typing import Dict, Optional | |
| from dataclasses import dataclass, field | |
| from collections import deque | |
| import threading | |
| class TokenBucket: | |
| """Token bucket rate limiter. | |
| Allows bursting up to capacity, refills at steady rate. | |
| Thread-safe implementation. | |
| """ | |
| rate: float # Tokens per second | |
| capacity: int # Maximum tokens (burst capacity) | |
| tokens: float = field(init=False) | |
| last_update: float = field(init=False) | |
| _lock: threading.Lock = field(default_factory=threading.Lock, repr=False) | |
| def __post_init__(self): | |
| self.tokens = float(self.capacity) | |
| self.last_update = time.monotonic() | |
| def _refill(self): | |
| """Refill tokens based on elapsed time.""" | |
| now = time.monotonic() | |
| elapsed = now - self.last_update | |
| self.tokens = min(self.capacity, self.tokens + elapsed * self.rate) | |
| self.last_update = now | |
| def acquire(self, tokens: int = 1) -> bool: | |
| """Try to acquire tokens. Returns True if successful.""" | |
| with self._lock: | |
| self._refill() | |
| if self.tokens >= tokens: | |
| self.tokens -= tokens | |
| return True | |
| return False | |
| async def acquire_async(self, tokens: int = 1, timeout: float = 30.0) -> bool: | |
| """Async version - waits until tokens available or timeout.""" | |
| start = time.monotonic() | |
| while time.monotonic() - start < timeout: | |
| if self.acquire(tokens): | |
| return True | |
| # Wait for estimated refill time | |
| wait_time = min(0.1, (tokens - self.tokens) / self.rate) | |
| await asyncio.sleep(max(0.01, wait_time)) | |
| return False | |
| def tokens_available(self) -> float: | |
| """Get current available tokens (without modifying state).""" | |
| with self._lock: | |
| self._refill() | |
| return self.tokens | |
| class SlidingWindowLimiter: | |
| """Sliding window rate limiter. | |
| Tracks requests in a time window, more accurate than token bucket | |
| for strict rate limits. | |
| """ | |
| max_requests: int # Maximum requests in window | |
| window_seconds: float # Window duration | |
| _requests: deque = field(default_factory=deque, repr=False) | |
| _lock: threading.Lock = field(default_factory=threading.Lock, repr=False) | |
| def _cleanup(self): | |
| """Remove expired timestamps from window.""" | |
| cutoff = time.monotonic() - self.window_seconds | |
| while self._requests and self._requests[0] < cutoff: | |
| self._requests.popleft() | |
| def acquire(self) -> bool: | |
| """Try to acquire a request slot. Returns True if allowed.""" | |
| with self._lock: | |
| self._cleanup() | |
| if len(self._requests) < self.max_requests: | |
| self._requests.append(time.monotonic()) | |
| return True | |
| return False | |
| async def acquire_async(self, timeout: float = 30.0) -> bool: | |
| """Async version - waits until slot available or timeout.""" | |
| start = time.monotonic() | |
| while time.monotonic() - start < timeout: | |
| if self.acquire(): | |
| return True | |
| # Estimate wait time until oldest request expires | |
| with self._lock: | |
| if self._requests: | |
| oldest = self._requests[0] | |
| wait_time = max(0.01, oldest + self.window_seconds - time.monotonic()) | |
| else: | |
| wait_time = 0.01 | |
| await asyncio.sleep(min(0.5, wait_time)) | |
| return False | |
| def requests_in_window(self) -> int: | |
| """Get current request count in window.""" | |
| with self._lock: | |
| self._cleanup() | |
| return len(self._requests) | |
| class DailyQuotaTracker: | |
| """Tracks daily API quota usage. | |
| For APIs with daily limits (NYT: 500/day, NewsAPI: 100/day). | |
| """ | |
| def __init__(self, daily_limit: int, name: str = "api"): | |
| self.daily_limit = daily_limit | |
| self.name = name | |
| self.used = 0 | |
| self.reset_date = self._current_date() | |
| self._lock = threading.Lock() | |
| def _current_date(self) -> str: | |
| return time.strftime("%Y-%m-%d") | |
| def _check_reset(self): | |
| """Reset counter if day changed.""" | |
| today = self._current_date() | |
| if today != self.reset_date: | |
| self.used = 0 | |
| self.reset_date = today | |
| def acquire(self, count: int = 1) -> bool: | |
| """Try to use quota. Returns True if within limit.""" | |
| with self._lock: | |
| self._check_reset() | |
| if self.used + count <= self.daily_limit: | |
| self.used += count | |
| return True | |
| return False | |
| def remaining(self) -> int: | |
| """Get remaining quota for today.""" | |
| with self._lock: | |
| self._check_reset() | |
| return max(0, self.daily_limit - self.used) | |
| class RateLimiterRegistry: | |
| """Registry of rate limiters for different APIs. | |
| Centralizes rate limit configuration and provides unified access. | |
| """ | |
| def __init__(self): | |
| self.limiters: Dict[str, TokenBucket | SlidingWindowLimiter] = {} | |
| self.quotas: Dict[str, DailyQuotaTracker] = {} | |
| self._setup_defaults() | |
| def _setup_defaults(self): | |
| """Configure default rate limiters based on known API limits.""" | |
| # Token bucket limiters (burst-friendly) | |
| self.limiters["sec_edgar"] = TokenBucket(rate=10, capacity=10) | |
| self.limiters["yahoo_finance"] = TokenBucket(rate=5, capacity=20) | |
| self.limiters["finnhub"] = TokenBucket(rate=1, capacity=5) | |
| # Sliding window limiters (strict limits) | |
| self.limiters["fred"] = SlidingWindowLimiter(max_requests=120, window_seconds=60) | |
| self.limiters["reddit"] = SlidingWindowLimiter(max_requests=100, window_seconds=60) | |
| # Daily quota trackers | |
| self.quotas["nyt"] = DailyQuotaTracker(daily_limit=500, name="NYT") | |
| self.quotas["newsapi"] = DailyQuotaTracker(daily_limit=100, name="NewsAPI") | |
| self.quotas["tavily"] = DailyQuotaTracker(daily_limit=33, name="Tavily") # ~1000/month | |
| def get_limiter(self, api: str) -> Optional[TokenBucket | SlidingWindowLimiter]: | |
| """Get rate limiter for an API.""" | |
| return self.limiters.get(api.lower()) | |
| def get_quota(self, api: str) -> Optional[DailyQuotaTracker]: | |
| """Get quota tracker for an API.""" | |
| return self.quotas.get(api.lower()) | |
| async def acquire(self, api: str, timeout: float = 30.0) -> bool: | |
| """Acquire rate limit and quota for an API. | |
| Returns True if both rate limit and quota allow the request. | |
| """ | |
| api_lower = api.lower() | |
| # Check daily quota first (faster to reject) | |
| quota = self.quotas.get(api_lower) | |
| if quota and not quota.acquire(): | |
| return False | |
| # Then check rate limiter | |
| limiter = self.limiters.get(api_lower) | |
| if limiter: | |
| return await limiter.acquire_async(timeout=timeout) | |
| # No limiter configured - allow by default | |
| return True | |
| def status(self) -> Dict: | |
| """Get status of all rate limiters and quotas.""" | |
| status = {"limiters": {}, "quotas": {}} | |
| for name, limiter in self.limiters.items(): | |
| if isinstance(limiter, TokenBucket): | |
| status["limiters"][name] = { | |
| "type": "token_bucket", | |
| "available": limiter.tokens_available(), | |
| "capacity": limiter.capacity | |
| } | |
| elif isinstance(limiter, SlidingWindowLimiter): | |
| status["limiters"][name] = { | |
| "type": "sliding_window", | |
| "used": limiter.requests_in_window(), | |
| "max": limiter.max_requests | |
| } | |
| for name, quota in self.quotas.items(): | |
| status["quotas"][name] = { | |
| "remaining": quota.remaining(), | |
| "daily_limit": quota.daily_limit | |
| } | |
| return status | |
| # Global registry instance | |
| _registry: Optional[RateLimiterRegistry] = None | |
| def get_rate_limiter_registry() -> RateLimiterRegistry: | |
| """Get the global rate limiter registry.""" | |
| global _registry | |
| if _registry is None: | |
| _registry = RateLimiterRegistry() | |
| return _registry | |
| if __name__ == "__main__": | |
| import asyncio | |
| async def demo(): | |
| registry = get_rate_limiter_registry() | |
| print("Initial status:", registry.status()) | |
| # Simulate some API calls | |
| for api in ["sec_edgar", "fred", "nyt"]: | |
| result = await registry.acquire(api) | |
| print(f"{api}: {'allowed' if result else 'blocked'}") | |
| print("\nAfter requests:", registry.status()) | |
| asyncio.run(demo()) | |