blog2code-api / codes /rate_limiter.py
srishtichugh's picture
initial commit
2fd8593
raw
history blame
3.46 kB
"""
Rate Limiter for OpenAI API to avoid hitting TPM (tokens per minute) limits.
"""
import time
from typing import List, Tuple
class RateLimiter:
"""Smart rate limiter that tracks token usage and sleeps only when necessary."""
def __init__(self, max_tokens_per_minute: int = 95000, buffer: int = 5000):
"""
Initialize the rate limiter.
Args:
max_tokens_per_minute: Maximum tokens allowed per minute (default: 95K for safety)
buffer: Safety buffer to stay under limit (default: 5K)
"""
self.max_tokens = max_tokens_per_minute - buffer
self.tokens_used: List[Tuple[float, int]] = [] # [(timestamp, tokens), ...]
self.total_waits = 0
self.total_wait_time = 0.0
def wait_if_needed(self, tokens_needed: int) -> None:
"""
Check if we need to wait before making the next API call.
Args:
tokens_needed: Estimated tokens for the next API call
"""
now = time.time()
# Remove tokens older than 60 seconds (sliding window)
self.tokens_used = [
(ts, tok) for ts, tok in self.tokens_used
if now - ts < 60
]
# Calculate tokens used in last 60 seconds
tokens_in_window = sum(tok for _, tok in self.tokens_used)
# If adding new request would exceed limit, wait
if tokens_in_window + tokens_needed > self.max_tokens:
# Calculate how long to wait
oldest_timestamp = self.tokens_used[0][0]
wait_time = 60 - (now - oldest_timestamp) + 1 # +1 for safety
print(f"⏰ Rate limit approaching ({tokens_in_window + tokens_needed}/{self.max_tokens} tokens)")
print(f" Waiting {wait_time:.1f}s for rate limit window to reset...")
time.sleep(wait_time)
self.total_waits += 1
self.total_wait_time += wait_time
# Clear old tokens after waiting
now = time.time()
self.tokens_used = [
(ts, tok) for ts, tok in self.tokens_used
if now - ts < 60
]
# Record this request
self.tokens_used.append((now, tokens_needed))
def get_stats(self) -> dict:
"""Get statistics about rate limiting."""
return {
'total_waits': self.total_waits,
'total_wait_time': self.total_wait_time,
'current_window_tokens': sum(tok for _, tok in self.tokens_used)
}
def print_stats(self) -> None:
"""Print rate limiting statistics."""
stats = self.get_stats()
print("\n" + "="*50)
print("πŸ“Š Rate Limiter Statistics")
print("="*50)
print(f"Total waits: {stats['total_waits']}")
print(f"Total wait time: {stats['total_wait_time']:.1f}s")
print(f"Current window usage: {stats['current_window_tokens']} tokens")
print("="*50 + "\n")
def estimate_tokens(text: str, overhead: int = 800) -> int:
"""
Estimate tokens for a text string.
Args:
text: Input text
overhead: Additional tokens for system prompts, formatting, etc.
Returns:
Estimated token count
"""
# Rough estimation: 1 token β‰ˆ 4 characters
content_tokens = len(str(text)) // 4
return content_tokens + overhead