import { redisIncr, redisExpire, redisTTL } from './redis' // Rate limit configurations const RATE_LIMITS = { guest: { max: 10, window: 3600, // 1 hour in seconds }, user: { max: 50, window: 3600, }, pro: { max: 500, window: 3600, }, } // Prefix-specific overrides (for specialized endpoints) const PREFIX_LIMITS: Record = { 'search-semantic': { max: 10, window: 3600 }, // 10 semantic searches/hr 'search-semantic-burst': { max: 3, window: 60 }, // 3 per minute burst 'search': { max: 60, window: 3600 }, // 60 regular searches/hr 'prompts': { max: 30, window: 3600 }, // 30 prompt creations/hr } // In-memory fallback when Redis is unavailable const memoryCounters = new Map() let lastCleanup = Date.now() function checkMemoryLimit(key: string, max: number, window: number): { count: number; allowed: boolean } { const now = Date.now() // Periodic cleanup (every 5 minutes or when map is large) if (now - lastCleanup > 300_000 || memoryCounters.size > 5000) { for (const [k, v] of memoryCounters.entries()) { if (v.resetAt <= now) memoryCounters.delete(k) } lastCleanup = now } const entry = memoryCounters.get(key) if (!entry || entry.resetAt <= now) { memoryCounters.set(key, { count: 1, resetAt: now + window * 1000 }) return { count: 1, allowed: true } } entry.count += 1 return { count: entry.count, allowed: entry.count <= max } } interface RateLimitResult { success: boolean limit: number remaining: number reset: number error?: string } /** * Check rate limit for a given identifier. * Supports prefix-based overrides for specialized endpoints. */ export async function checkRateLimit( identifier: string, isAuthenticated: boolean = false, isPro: boolean = false, ): Promise { // Check if there's a prefix-specific limit let config = isPro ? RATE_LIMITS.pro : isAuthenticated ? RATE_LIMITS.user : RATE_LIMITS.guest // Extract prefix (e.g., "search-semantic:user:abc123" → "search-semantic") const colonIndex = identifier.indexOf(':') if (colonIndex > 0) { const prefix = identifier.substring(0, colonIndex) if (PREFIX_LIMITS[prefix]) { config = PREFIX_LIMITS[prefix] } } const key = `ratelimit:${identifier}` // Try Redis first const currentCount = await redisIncr(key) if (currentCount === null) { // Redis unavailable — use in-memory fallback const { count, allowed } = checkMemoryLimit(key, config.max, config.window) const resetAt = Date.now() + config.window * 1000 if (!allowed) { return { success: false, limit: config.max, remaining: 0, reset: resetAt, error: `Rate limit exceeded. Try again in ${config.window} seconds.`, } } return { success: true, limit: config.max, remaining: Math.max(0, config.max - count), reset: resetAt, } } // Set expiration on first request if (currentCount === 1) { await redisExpire(key, config.window) } // Check if limit exceeded if (currentCount > config.max) { const ttl = await redisTTL(key) const resetTime = Date.now() + (ttl || config.window) * 1000 return { success: false, limit: config.max, remaining: 0, reset: resetTime, error: `Rate limit exceeded. Try again in ${ttl || config.window} seconds.`, } } // Get TTL for reset time const ttl = await redisTTL(key) const resetTime = Date.now() + (ttl || config.window) * 1000 return { success: true, limit: config.max, remaining: config.max - currentCount, reset: resetTime, } } /** * Get client identifier from request (IP address or user ID) */ export function getClientIdentifier(request: Request, userId?: string): string { if (userId) { return `user:${userId}` } const cfConnectingIp = request.headers.get('cf-connecting-ip') const forwarded = request.headers.get('x-forwarded-for') const realIp = request.headers.get('x-real-ip') const ip = cfConnectingIp || forwarded?.split(',')[0]?.trim() || realIp || 'unknown' return `ip:${ip}` } /** * Create rate limit headers for response */ export function getRateLimitHeaders(result: RateLimitResult): Headers { const headers = new Headers() headers.set('X-RateLimit-Limit', result.limit.toString()) headers.set('X-RateLimit-Remaining', Math.max(0, result.remaining).toString()) headers.set('X-RateLimit-Reset', Math.floor(result.reset / 1000).toString()) return headers }