import { streamText } from 'ai' import { openai } from '@ai-sdk/openai' import { anthropic } from '@ai-sdk/anthropic' import { google } from '@ai-sdk/google' import { AI_MODELS, ModelId, DEFAULT_MODEL, isOllamaModel } from '@/types/prompt' import { checkRateLimit, getClientIdentifier, getRateLimitHeaders } from '@/lib/rate-limit' import { getCachedResponse } from '@/lib/cache' import prisma from '@/lib/prisma' import { getAuthUser } from '@/lib/auth' // ── SSRF Protection ──────────────────────────────────────────────────────── const ALLOWED_OLLAMA_HOSTS = new Set([ 'localhost', '127.0.0.1', '0.0.0.0', '::1', ]) /** * Validate that an Ollama URL points to a safe host (localhost only). * Prevents SSRF attacks where a malicious user could probe internal services. */ function isAllowedOllamaUrl(urlStr: string): boolean { try { const url = new URL(urlStr) const hostname = url.hostname.toLowerCase() // Allow localhost variants only if (ALLOWED_OLLAMA_HOSTS.has(hostname)) return true // Allow custom hosts from environment whitelist const extraHosts = process.env.ALLOWED_OLLAMA_HOSTS?.split(',').map(h => h.trim().toLowerCase()) || [] if (extraHosts.includes(hostname)) return true return false } catch { return false } } // ── Standardized error response helper ───────────────────────────────────── function errorResponse(message: string, status: number, headers?: Headers): Response { const responseHeaders: Record = { 'Content-Type': 'application/json' } if (headers) { headers.forEach((value, key) => { responseHeaders[key] = value }) } return new Response( JSON.stringify({ error: message }), { status, headers: responseHeaders } ) } // ── Ollama streaming helper ──────────────────────────────────────────────── async function streamFromOllama(prompt: string, model: string, ollamaUrl: string) { const response = await fetch(`${ollamaUrl}/api/generate`, { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ model, prompt, stream: true, }), }) if (!response.ok) { if (response.status === 404) { throw new Error(`Ollama model "${model}" not found. Please run: ollama pull ${model}`) } const errorText = await response.text().catch(() => '') throw new Error(`Ollama error: ${response.status}${errorText ? ` - ${errorText}` : ''}`) } const encoder = new TextEncoder() const decoder = new TextDecoder() const transformStream = new TransformStream({ async transform(chunk, controller) { const text = decoder.decode(chunk) const lines = text.split('\n').filter(line => line.trim()) for (const line of lines) { try { const json = JSON.parse(line) if (json.response) { controller.enqueue(encoder.encode(json.response)) } } catch { // Skip malformed JSON } } }, }) return response.body?.pipeThrough(transformStream) } // ── Turnstile verification ───────────────────────────────────────────────── async function verifyTurnstile(token: string): Promise { const secretKey = process.env.TURNSTILE_SECRET_KEY if (!secretKey) { if (process.env.NODE_ENV === 'production') { console.error('Turnstile secret key not configured in production') return false } console.warn('Turnstile not configured, skipping verification in development') return true } try { const response = await fetch('https://challenges.cloudflare.com/turnstile/v0/siteverify', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ secret: secretKey, response: token, }), }) const data = await response.json() return data.success === true } catch (error) { console.error('Turnstile verification error:', error) return false } } // ── Main POST handler ────────────────────────────────────────────────────── export async function POST(req: Request) { try { const body = await req.json() // Zod validation const { runSchema } = await import('@/lib/validations') const parsed = runSchema.safeParse(body) if (!parsed.success) { return errorResponse('Validation failed', 400) } const { prompt, promptId, model, variables, turnstileToken, ollamaUrl } = parsed.data // Get authenticated user server-side const authUser = await getAuthUser() const userId = authUser?.id ?? null // 1. Rate Limiting const identifier = getClientIdentifier(req, userId ?? undefined) const rateLimit = await checkRateLimit(identifier, !!userId) if (!rateLimit.success) { const headers = getRateLimitHeaders(rateLimit) return errorResponse(rateLimit.error || 'Rate limit exceeded', 429, headers) } // 2. Check Cache const modelId = (model as ModelId) || DEFAULT_MODEL // 2.5 Validate model against prompt's allowed list if (promptId && !isOllamaModel(modelId)) { const promptRecord = await prisma.prompt.findUnique({ where: { id: promptId }, select: { modelAllowed: true }, }) if (promptRecord?.modelAllowed?.length && !promptRecord.modelAllowed.includes(modelId)) { return errorResponse( `Model "${modelId}" is not allowed for this prompt. Allowed: ${promptRecord.modelAllowed.join(', ')}`, 400 ) } } if (promptId && variables) { const cached = await getCachedResponse(promptId, variables, modelId) if (cached) { const headers = getRateLimitHeaders(rateLimit) headers.set('X-Cache', 'HIT') return new Response(cached, { headers: { 'Content-Type': 'text/plain', ...Object.fromEntries(headers.entries()), }, }) } } // 3. Verify Turnstile (bot protection) — require for unauthenticated users if (!userId) { if (!turnstileToken) { return errorResponse('Bot verification required', 403) } const isValid = await verifyTurnstile(turnstileToken) if (!isValid) { return errorResponse('Bot verification failed', 403) } } // 4. Check if this is an Ollama model const isOllama = isOllamaModel(modelId) if (isOllama) { const resolvedOllamaUrl = ollamaUrl || "http://localhost:11434" // SSRF protection: validate the URL before making any request if (!isAllowedOllamaUrl(resolvedOllamaUrl)) { return errorResponse( 'Invalid Ollama URL. Only localhost connections are allowed for security.', 400 ) } try { const stream = await streamFromOllama(prompt, modelId, resolvedOllamaUrl) if (!stream) { return errorResponse('Failed to get Ollama stream', 500) } // Track run in database (non-blocking) if (promptId) { trackRun(promptId, userId, modelId, identifier).catch( (err: unknown) => console.error('Failed to track run:', err) ) } const headers = getRateLimitHeaders(rateLimit) headers.set('X-Cache', 'MISS') headers.set('X-Model-Type', 'ollama') return new Response(stream, { headers: { 'Content-Type': 'text/plain; charset=utf-8', ...Object.fromEntries(headers.entries()), }, }) } catch (error) { console.error('Ollama error:', error) return errorResponse( `Ollama error: ${error instanceof Error ? error.message : 'Connection failed'}. Make sure Ollama is running.`, 500 ) } } // 5. Validate cloud model const modelConfig = AI_MODELS[modelId as ModelId] if (!modelConfig) { return errorResponse('Invalid model', 400) } // 6. Select the appropriate AI provider let aiModel switch (modelConfig.provider) { case 'openai': aiModel = openai(modelId) break case 'anthropic': aiModel = anthropic(modelId) break case 'google': aiModel = google(modelId) break default: aiModel = google(DEFAULT_MODEL) } // 7. Stream the response const result = await streamText({ model: aiModel, prompt, }) // 8. Track run in database (non-blocking) if (promptId) { trackRun(promptId, userId, modelId, identifier).catch( (err: unknown) => console.error('Failed to track run:', err) ) } // 9. Return streaming response with rate limit headers const headers = getRateLimitHeaders(rateLimit) headers.set('X-Cache', 'MISS') headers.set('X-Model-Type', 'cloud') return result.toTextStreamResponse({ headers: Object.fromEntries(headers.entries()), }) } catch (error) { console.error('AI Run Error:', error) if (error instanceof Response) { return error } if (error instanceof Error) { // Never leak API key values in error messages if (error.message.includes('API key') || error.message.includes('api_key')) { return errorResponse('API key not configured. Please add your API keys in settings.', 500) } return errorResponse(error.message, 500) } return errorResponse('An unexpected error occurred', 500) } } /** * Track a run in the database and increment the prompt's totalRuns counter. */ async function trackRun(promptId: string, userId: string | null, model: string, ipHash: string) { await prisma.$transaction([ prisma.run.create({ data: { promptId, userId: userId || null, model, cached: false, ipHash, }, }), prisma.prompt.update({ where: { id: promptId }, data: { totalRuns: { increment: 1 } }, }), ]) }