Spaces:
Configuration error
Configuration error
| import { streamText } from 'ai' | |
| import { openai } from '@ai-sdk/openai' | |
| import { anthropic } from '@ai-sdk/anthropic' | |
| import { google } from '@ai-sdk/google' | |
| import { AI_MODELS, ModelId, DEFAULT_MODEL, isOllamaModel } from '@/types/prompt' | |
| import { checkRateLimit, getClientIdentifier, getRateLimitHeaders } from '@/lib/rate-limit' | |
| import { getCachedResponse } from '@/lib/cache' | |
| import prisma from '@/lib/prisma' | |
| import { getAuthUser } from '@/lib/auth' | |
| // ββ SSRF Protection ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| const ALLOWED_OLLAMA_HOSTS = new Set([ | |
| 'localhost', | |
| '127.0.0.1', | |
| '0.0.0.0', | |
| '::1', | |
| ]) | |
| /** | |
| * Validate that an Ollama URL points to a safe host (localhost only). | |
| * Prevents SSRF attacks where a malicious user could probe internal services. | |
| */ | |
| function isAllowedOllamaUrl(urlStr: string): boolean { | |
| try { | |
| const url = new URL(urlStr) | |
| const hostname = url.hostname.toLowerCase() | |
| // Allow localhost variants only | |
| if (ALLOWED_OLLAMA_HOSTS.has(hostname)) return true | |
| // Allow custom hosts from environment whitelist | |
| const extraHosts = process.env.ALLOWED_OLLAMA_HOSTS?.split(',').map(h => h.trim().toLowerCase()) || [] | |
| if (extraHosts.includes(hostname)) return true | |
| return false | |
| } catch { | |
| return false | |
| } | |
| } | |
| // ββ Standardized error response helper βββββββββββββββββββββββββββββββββββββ | |
| function errorResponse(message: string, status: number, headers?: Headers): Response { | |
| const responseHeaders: Record<string, string> = { 'Content-Type': 'application/json' } | |
| if (headers) { | |
| headers.forEach((value, key) => { responseHeaders[key] = value }) | |
| } | |
| return new Response( | |
| JSON.stringify({ error: message }), | |
| { status, headers: responseHeaders } | |
| ) | |
| } | |
| // ββ Ollama streaming helper ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async function streamFromOllama(prompt: string, model: string, ollamaUrl: string) { | |
| const response = await fetch(`${ollamaUrl}/api/generate`, { | |
| method: 'POST', | |
| headers: { 'Content-Type': 'application/json' }, | |
| body: JSON.stringify({ | |
| model, | |
| prompt, | |
| stream: true, | |
| }), | |
| }) | |
| if (!response.ok) { | |
| if (response.status === 404) { | |
| throw new Error(`Ollama model "${model}" not found. Please run: ollama pull ${model}`) | |
| } | |
| const errorText = await response.text().catch(() => '') | |
| throw new Error(`Ollama error: ${response.status}${errorText ? ` - ${errorText}` : ''}`) | |
| } | |
| const encoder = new TextEncoder() | |
| const decoder = new TextDecoder() | |
| const transformStream = new TransformStream({ | |
| async transform(chunk, controller) { | |
| const text = decoder.decode(chunk) | |
| const lines = text.split('\n').filter(line => line.trim()) | |
| for (const line of lines) { | |
| try { | |
| const json = JSON.parse(line) | |
| if (json.response) { | |
| controller.enqueue(encoder.encode(json.response)) | |
| } | |
| } catch { | |
| // Skip malformed JSON | |
| } | |
| } | |
| }, | |
| }) | |
| return response.body?.pipeThrough(transformStream) | |
| } | |
| // ββ Turnstile verification βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async function verifyTurnstile(token: string): Promise<boolean> { | |
| const secretKey = process.env.TURNSTILE_SECRET_KEY | |
| if (!secretKey) { | |
| if (process.env.NODE_ENV === 'production') { | |
| console.error('Turnstile secret key not configured in production') | |
| return false | |
| } | |
| console.warn('Turnstile not configured, skipping verification in development') | |
| return true | |
| } | |
| try { | |
| const response = await fetch('https://challenges.cloudflare.com/turnstile/v0/siteverify', { | |
| method: 'POST', | |
| headers: { 'Content-Type': 'application/json' }, | |
| body: JSON.stringify({ | |
| secret: secretKey, | |
| response: token, | |
| }), | |
| }) | |
| const data = await response.json() | |
| return data.success === true | |
| } catch (error) { | |
| console.error('Turnstile verification error:', error) | |
| return false | |
| } | |
| } | |
| // ββ Main POST handler ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| export async function POST(req: Request) { | |
| try { | |
| const body = await req.json() | |
| // Zod validation | |
| const { runSchema } = await import('@/lib/validations') | |
| const parsed = runSchema.safeParse(body) | |
| if (!parsed.success) { | |
| return errorResponse('Validation failed', 400) | |
| } | |
| const { prompt, promptId, model, variables, turnstileToken, ollamaUrl } = parsed.data | |
| // Get authenticated user server-side | |
| const authUser = await getAuthUser() | |
| const userId = authUser?.id ?? null | |
| // 1. Rate Limiting | |
| const identifier = getClientIdentifier(req, userId ?? undefined) | |
| const rateLimit = await checkRateLimit(identifier, !!userId) | |
| if (!rateLimit.success) { | |
| const headers = getRateLimitHeaders(rateLimit) | |
| return errorResponse(rateLimit.error || 'Rate limit exceeded', 429, headers) | |
| } | |
| // 2. Check Cache | |
| const modelId = (model as ModelId) || DEFAULT_MODEL | |
| // 2.5 Validate model against prompt's allowed list | |
| if (promptId && !isOllamaModel(modelId)) { | |
| const promptRecord = await prisma.prompt.findUnique({ | |
| where: { id: promptId }, | |
| select: { modelAllowed: true }, | |
| }) | |
| if (promptRecord?.modelAllowed?.length && !promptRecord.modelAllowed.includes(modelId)) { | |
| return errorResponse( | |
| `Model "${modelId}" is not allowed for this prompt. Allowed: ${promptRecord.modelAllowed.join(', ')}`, | |
| 400 | |
| ) | |
| } | |
| } | |
| if (promptId && variables) { | |
| const cached = await getCachedResponse(promptId, variables, modelId) | |
| if (cached) { | |
| const headers = getRateLimitHeaders(rateLimit) | |
| headers.set('X-Cache', 'HIT') | |
| return new Response(cached, { | |
| headers: { | |
| 'Content-Type': 'text/plain', | |
| ...Object.fromEntries(headers.entries()), | |
| }, | |
| }) | |
| } | |
| } | |
| // 3. Verify Turnstile (bot protection) β require for unauthenticated users | |
| if (!userId) { | |
| if (!turnstileToken) { | |
| return errorResponse('Bot verification required', 403) | |
| } | |
| const isValid = await verifyTurnstile(turnstileToken) | |
| if (!isValid) { | |
| return errorResponse('Bot verification failed', 403) | |
| } | |
| } | |
| // 4. Check if this is an Ollama model | |
| const isOllama = isOllamaModel(modelId) | |
| if (isOllama) { | |
| const resolvedOllamaUrl = ollamaUrl || "http://localhost:11434" | |
| // SSRF protection: validate the URL before making any request | |
| if (!isAllowedOllamaUrl(resolvedOllamaUrl)) { | |
| return errorResponse( | |
| 'Invalid Ollama URL. Only localhost connections are allowed for security.', | |
| 400 | |
| ) | |
| } | |
| try { | |
| const stream = await streamFromOllama(prompt, modelId, resolvedOllamaUrl) | |
| if (!stream) { | |
| return errorResponse('Failed to get Ollama stream', 500) | |
| } | |
| // Track run in database (non-blocking) | |
| if (promptId) { | |
| trackRun(promptId, userId, modelId, identifier).catch( | |
| (err: unknown) => console.error('Failed to track run:', err) | |
| ) | |
| } | |
| const headers = getRateLimitHeaders(rateLimit) | |
| headers.set('X-Cache', 'MISS') | |
| headers.set('X-Model-Type', 'ollama') | |
| return new Response(stream, { | |
| headers: { | |
| 'Content-Type': 'text/plain; charset=utf-8', | |
| ...Object.fromEntries(headers.entries()), | |
| }, | |
| }) | |
| } catch (error) { | |
| console.error('Ollama error:', error) | |
| return errorResponse( | |
| `Ollama error: ${error instanceof Error ? error.message : 'Connection failed'}. Make sure Ollama is running.`, | |
| 500 | |
| ) | |
| } | |
| } | |
| // 5. Validate cloud model | |
| const modelConfig = AI_MODELS[modelId as ModelId] | |
| if (!modelConfig) { | |
| return errorResponse('Invalid model', 400) | |
| } | |
| // 6. Select the appropriate AI provider | |
| let aiModel | |
| switch (modelConfig.provider) { | |
| case 'openai': | |
| aiModel = openai(modelId) | |
| break | |
| case 'anthropic': | |
| aiModel = anthropic(modelId) | |
| break | |
| case 'google': | |
| aiModel = google(modelId) | |
| break | |
| default: | |
| aiModel = google(DEFAULT_MODEL) | |
| } | |
| // 7. Stream the response | |
| const result = await streamText({ | |
| model: aiModel, | |
| prompt, | |
| }) | |
| // 8. Track run in database (non-blocking) | |
| if (promptId) { | |
| trackRun(promptId, userId, modelId, identifier).catch( | |
| (err: unknown) => console.error('Failed to track run:', err) | |
| ) | |
| } | |
| // 9. Return streaming response with rate limit headers | |
| const headers = getRateLimitHeaders(rateLimit) | |
| headers.set('X-Cache', 'MISS') | |
| headers.set('X-Model-Type', 'cloud') | |
| return result.toTextStreamResponse({ | |
| headers: Object.fromEntries(headers.entries()), | |
| }) | |
| } catch (error) { | |
| console.error('AI Run Error:', error) | |
| if (error instanceof Response) { | |
| return error | |
| } | |
| if (error instanceof Error) { | |
| // Never leak API key values in error messages | |
| if (error.message.includes('API key') || error.message.includes('api_key')) { | |
| return errorResponse('API key not configured. Please add your API keys in settings.', 500) | |
| } | |
| return errorResponse(error.message, 500) | |
| } | |
| return errorResponse('An unexpected error occurred', 500) | |
| } | |
| } | |
| /** | |
| * Track a run in the database and increment the prompt's totalRuns counter. | |
| */ | |
| async function trackRun(promptId: string, userId: string | null, model: string, ipHash: string) { | |
| await prisma.$transaction([ | |
| prisma.run.create({ | |
| data: { | |
| promptId, | |
| userId: userId || null, | |
| model, | |
| cached: false, | |
| ipHash, | |
| }, | |
| }), | |
| prisma.prompt.update({ | |
| where: { id: promptId }, | |
| data: { totalRuns: { increment: 1 } }, | |
| }), | |
| ]) | |
| } | |