vn6295337's picture
Add schema normalizers for MCP output consistency
26c5c2f
"""
Rate Limiter - Token bucket and sliding window implementations for API rate limiting.
Prevents self-DoS during stress testing by enforcing per-API rate limits.
"""
import time
import asyncio
from typing import Dict, Optional
from dataclasses import dataclass, field
from collections import deque
import threading
@dataclass
class TokenBucket:
"""Token bucket rate limiter.
Allows bursting up to capacity, refills at steady rate.
Thread-safe implementation.
"""
rate: float # Tokens per second
capacity: int # Maximum tokens (burst capacity)
tokens: float = field(init=False)
last_update: float = field(init=False)
_lock: threading.Lock = field(default_factory=threading.Lock, repr=False)
def __post_init__(self):
self.tokens = float(self.capacity)
self.last_update = time.monotonic()
def _refill(self):
"""Refill tokens based on elapsed time."""
now = time.monotonic()
elapsed = now - self.last_update
self.tokens = min(self.capacity, self.tokens + elapsed * self.rate)
self.last_update = now
def acquire(self, tokens: int = 1) -> bool:
"""Try to acquire tokens. Returns True if successful."""
with self._lock:
self._refill()
if self.tokens >= tokens:
self.tokens -= tokens
return True
return False
async def acquire_async(self, tokens: int = 1, timeout: float = 30.0) -> bool:
"""Async version - waits until tokens available or timeout."""
start = time.monotonic()
while time.monotonic() - start < timeout:
if self.acquire(tokens):
return True
# Wait for estimated refill time
wait_time = min(0.1, (tokens - self.tokens) / self.rate)
await asyncio.sleep(max(0.01, wait_time))
return False
def tokens_available(self) -> float:
"""Get current available tokens (without modifying state)."""
with self._lock:
self._refill()
return self.tokens
@dataclass
class SlidingWindowLimiter:
"""Sliding window rate limiter.
Tracks requests in a time window, more accurate than token bucket
for strict rate limits.
"""
max_requests: int # Maximum requests in window
window_seconds: float # Window duration
_requests: deque = field(default_factory=deque, repr=False)
_lock: threading.Lock = field(default_factory=threading.Lock, repr=False)
def _cleanup(self):
"""Remove expired timestamps from window."""
cutoff = time.monotonic() - self.window_seconds
while self._requests and self._requests[0] < cutoff:
self._requests.popleft()
def acquire(self) -> bool:
"""Try to acquire a request slot. Returns True if allowed."""
with self._lock:
self._cleanup()
if len(self._requests) < self.max_requests:
self._requests.append(time.monotonic())
return True
return False
async def acquire_async(self, timeout: float = 30.0) -> bool:
"""Async version - waits until slot available or timeout."""
start = time.monotonic()
while time.monotonic() - start < timeout:
if self.acquire():
return True
# Estimate wait time until oldest request expires
with self._lock:
if self._requests:
oldest = self._requests[0]
wait_time = max(0.01, oldest + self.window_seconds - time.monotonic())
else:
wait_time = 0.01
await asyncio.sleep(min(0.5, wait_time))
return False
def requests_in_window(self) -> int:
"""Get current request count in window."""
with self._lock:
self._cleanup()
return len(self._requests)
class DailyQuotaTracker:
"""Tracks daily API quota usage.
For APIs with daily limits (NYT: 500/day, NewsAPI: 100/day).
"""
def __init__(self, daily_limit: int, name: str = "api"):
self.daily_limit = daily_limit
self.name = name
self.used = 0
self.reset_date = self._current_date()
self._lock = threading.Lock()
def _current_date(self) -> str:
return time.strftime("%Y-%m-%d")
def _check_reset(self):
"""Reset counter if day changed."""
today = self._current_date()
if today != self.reset_date:
self.used = 0
self.reset_date = today
def acquire(self, count: int = 1) -> bool:
"""Try to use quota. Returns True if within limit."""
with self._lock:
self._check_reset()
if self.used + count <= self.daily_limit:
self.used += count
return True
return False
def remaining(self) -> int:
"""Get remaining quota for today."""
with self._lock:
self._check_reset()
return max(0, self.daily_limit - self.used)
class RateLimiterRegistry:
"""Registry of rate limiters for different APIs.
Centralizes rate limit configuration and provides unified access.
"""
def __init__(self):
self.limiters: Dict[str, TokenBucket | SlidingWindowLimiter] = {}
self.quotas: Dict[str, DailyQuotaTracker] = {}
self._setup_defaults()
def _setup_defaults(self):
"""Configure default rate limiters based on known API limits."""
# Token bucket limiters (burst-friendly)
self.limiters["sec_edgar"] = TokenBucket(rate=10, capacity=10)
self.limiters["yahoo_finance"] = TokenBucket(rate=5, capacity=20)
self.limiters["finnhub"] = TokenBucket(rate=1, capacity=5)
# Sliding window limiters (strict limits)
self.limiters["fred"] = SlidingWindowLimiter(max_requests=120, window_seconds=60)
self.limiters["reddit"] = SlidingWindowLimiter(max_requests=100, window_seconds=60)
# Daily quota trackers
self.quotas["nyt"] = DailyQuotaTracker(daily_limit=500, name="NYT")
self.quotas["newsapi"] = DailyQuotaTracker(daily_limit=100, name="NewsAPI")
self.quotas["tavily"] = DailyQuotaTracker(daily_limit=33, name="Tavily") # ~1000/month
def get_limiter(self, api: str) -> Optional[TokenBucket | SlidingWindowLimiter]:
"""Get rate limiter for an API."""
return self.limiters.get(api.lower())
def get_quota(self, api: str) -> Optional[DailyQuotaTracker]:
"""Get quota tracker for an API."""
return self.quotas.get(api.lower())
async def acquire(self, api: str, timeout: float = 30.0) -> bool:
"""Acquire rate limit and quota for an API.
Returns True if both rate limit and quota allow the request.
"""
api_lower = api.lower()
# Check daily quota first (faster to reject)
quota = self.quotas.get(api_lower)
if quota and not quota.acquire():
return False
# Then check rate limiter
limiter = self.limiters.get(api_lower)
if limiter:
return await limiter.acquire_async(timeout=timeout)
# No limiter configured - allow by default
return True
def status(self) -> Dict:
"""Get status of all rate limiters and quotas."""
status = {"limiters": {}, "quotas": {}}
for name, limiter in self.limiters.items():
if isinstance(limiter, TokenBucket):
status["limiters"][name] = {
"type": "token_bucket",
"available": limiter.tokens_available(),
"capacity": limiter.capacity
}
elif isinstance(limiter, SlidingWindowLimiter):
status["limiters"][name] = {
"type": "sliding_window",
"used": limiter.requests_in_window(),
"max": limiter.max_requests
}
for name, quota in self.quotas.items():
status["quotas"][name] = {
"remaining": quota.remaining(),
"daily_limit": quota.daily_limit
}
return status
# Global registry instance
_registry: Optional[RateLimiterRegistry] = None
def get_rate_limiter_registry() -> RateLimiterRegistry:
"""Get the global rate limiter registry."""
global _registry
if _registry is None:
_registry = RateLimiterRegistry()
return _registry
if __name__ == "__main__":
import asyncio
async def demo():
registry = get_rate_limiter_registry()
print("Initial status:", registry.status())
# Simulate some API calls
for api in ["sec_edgar", "fred", "nyt"]:
result = await registry.acquire(api)
print(f"{api}: {'allowed' if result else 'blocked'}")
print("\nAfter requests:", registry.status())
asyncio.run(demo())