open-prompt / src /app /api /run /route.ts
anky2002's picture
fix: SSRF protection for Ollama URL + standardize error responses in /api/run
d70eff5 verified
import { streamText } from 'ai'
import { openai } from '@ai-sdk/openai'
import { anthropic } from '@ai-sdk/anthropic'
import { google } from '@ai-sdk/google'
import { AI_MODELS, ModelId, DEFAULT_MODEL, isOllamaModel } from '@/types/prompt'
import { checkRateLimit, getClientIdentifier, getRateLimitHeaders } from '@/lib/rate-limit'
import { getCachedResponse } from '@/lib/cache'
import prisma from '@/lib/prisma'
import { getAuthUser } from '@/lib/auth'
// ── SSRF Protection ────────────────────────────────────────────────────────
const ALLOWED_OLLAMA_HOSTS = new Set([
'localhost',
'127.0.0.1',
'0.0.0.0',
'::1',
])
/**
* Validate that an Ollama URL points to a safe host (localhost only).
* Prevents SSRF attacks where a malicious user could probe internal services.
*/
function isAllowedOllamaUrl(urlStr: string): boolean {
try {
const url = new URL(urlStr)
const hostname = url.hostname.toLowerCase()
// Allow localhost variants only
if (ALLOWED_OLLAMA_HOSTS.has(hostname)) return true
// Allow custom hosts from environment whitelist
const extraHosts = process.env.ALLOWED_OLLAMA_HOSTS?.split(',').map(h => h.trim().toLowerCase()) || []
if (extraHosts.includes(hostname)) return true
return false
} catch {
return false
}
}
// ── Standardized error response helper ─────────────────────────────────────
function errorResponse(message: string, status: number, headers?: Headers): Response {
const responseHeaders: Record<string, string> = { 'Content-Type': 'application/json' }
if (headers) {
headers.forEach((value, key) => { responseHeaders[key] = value })
}
return new Response(
JSON.stringify({ error: message }),
{ status, headers: responseHeaders }
)
}
// ── Ollama streaming helper ────────────────────────────────────────────────
async function streamFromOllama(prompt: string, model: string, ollamaUrl: string) {
const response = await fetch(`${ollamaUrl}/api/generate`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
model,
prompt,
stream: true,
}),
})
if (!response.ok) {
if (response.status === 404) {
throw new Error(`Ollama model "${model}" not found. Please run: ollama pull ${model}`)
}
const errorText = await response.text().catch(() => '')
throw new Error(`Ollama error: ${response.status}${errorText ? ` - ${errorText}` : ''}`)
}
const encoder = new TextEncoder()
const decoder = new TextDecoder()
const transformStream = new TransformStream({
async transform(chunk, controller) {
const text = decoder.decode(chunk)
const lines = text.split('\n').filter(line => line.trim())
for (const line of lines) {
try {
const json = JSON.parse(line)
if (json.response) {
controller.enqueue(encoder.encode(json.response))
}
} catch {
// Skip malformed JSON
}
}
},
})
return response.body?.pipeThrough(transformStream)
}
// ── Turnstile verification ─────────────────────────────────────────────────
async function verifyTurnstile(token: string): Promise<boolean> {
const secretKey = process.env.TURNSTILE_SECRET_KEY
if (!secretKey) {
if (process.env.NODE_ENV === 'production') {
console.error('Turnstile secret key not configured in production')
return false
}
console.warn('Turnstile not configured, skipping verification in development')
return true
}
try {
const response = await fetch('https://challenges.cloudflare.com/turnstile/v0/siteverify', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
secret: secretKey,
response: token,
}),
})
const data = await response.json()
return data.success === true
} catch (error) {
console.error('Turnstile verification error:', error)
return false
}
}
// ── Main POST handler ──────────────────────────────────────────────────────
export async function POST(req: Request) {
try {
const body = await req.json()
// Zod validation
const { runSchema } = await import('@/lib/validations')
const parsed = runSchema.safeParse(body)
if (!parsed.success) {
return errorResponse('Validation failed', 400)
}
const { prompt, promptId, model, variables, turnstileToken, ollamaUrl } = parsed.data
// Get authenticated user server-side
const authUser = await getAuthUser()
const userId = authUser?.id ?? null
// 1. Rate Limiting
const identifier = getClientIdentifier(req, userId ?? undefined)
const rateLimit = await checkRateLimit(identifier, !!userId)
if (!rateLimit.success) {
const headers = getRateLimitHeaders(rateLimit)
return errorResponse(rateLimit.error || 'Rate limit exceeded', 429, headers)
}
// 2. Check Cache
const modelId = (model as ModelId) || DEFAULT_MODEL
// 2.5 Validate model against prompt's allowed list
if (promptId && !isOllamaModel(modelId)) {
const promptRecord = await prisma.prompt.findUnique({
where: { id: promptId },
select: { modelAllowed: true },
})
if (promptRecord?.modelAllowed?.length && !promptRecord.modelAllowed.includes(modelId)) {
return errorResponse(
`Model "${modelId}" is not allowed for this prompt. Allowed: ${promptRecord.modelAllowed.join(', ')}`,
400
)
}
}
if (promptId && variables) {
const cached = await getCachedResponse(promptId, variables, modelId)
if (cached) {
const headers = getRateLimitHeaders(rateLimit)
headers.set('X-Cache', 'HIT')
return new Response(cached, {
headers: {
'Content-Type': 'text/plain',
...Object.fromEntries(headers.entries()),
},
})
}
}
// 3. Verify Turnstile (bot protection) β€” require for unauthenticated users
if (!userId) {
if (!turnstileToken) {
return errorResponse('Bot verification required', 403)
}
const isValid = await verifyTurnstile(turnstileToken)
if (!isValid) {
return errorResponse('Bot verification failed', 403)
}
}
// 4. Check if this is an Ollama model
const isOllama = isOllamaModel(modelId)
if (isOllama) {
const resolvedOllamaUrl = ollamaUrl || "http://localhost:11434"
// SSRF protection: validate the URL before making any request
if (!isAllowedOllamaUrl(resolvedOllamaUrl)) {
return errorResponse(
'Invalid Ollama URL. Only localhost connections are allowed for security.',
400
)
}
try {
const stream = await streamFromOllama(prompt, modelId, resolvedOllamaUrl)
if (!stream) {
return errorResponse('Failed to get Ollama stream', 500)
}
// Track run in database (non-blocking)
if (promptId) {
trackRun(promptId, userId, modelId, identifier).catch(
(err: unknown) => console.error('Failed to track run:', err)
)
}
const headers = getRateLimitHeaders(rateLimit)
headers.set('X-Cache', 'MISS')
headers.set('X-Model-Type', 'ollama')
return new Response(stream, {
headers: {
'Content-Type': 'text/plain; charset=utf-8',
...Object.fromEntries(headers.entries()),
},
})
} catch (error) {
console.error('Ollama error:', error)
return errorResponse(
`Ollama error: ${error instanceof Error ? error.message : 'Connection failed'}. Make sure Ollama is running.`,
500
)
}
}
// 5. Validate cloud model
const modelConfig = AI_MODELS[modelId as ModelId]
if (!modelConfig) {
return errorResponse('Invalid model', 400)
}
// 6. Select the appropriate AI provider
let aiModel
switch (modelConfig.provider) {
case 'openai':
aiModel = openai(modelId)
break
case 'anthropic':
aiModel = anthropic(modelId)
break
case 'google':
aiModel = google(modelId)
break
default:
aiModel = google(DEFAULT_MODEL)
}
// 7. Stream the response
const result = await streamText({
model: aiModel,
prompt,
})
// 8. Track run in database (non-blocking)
if (promptId) {
trackRun(promptId, userId, modelId, identifier).catch(
(err: unknown) => console.error('Failed to track run:', err)
)
}
// 9. Return streaming response with rate limit headers
const headers = getRateLimitHeaders(rateLimit)
headers.set('X-Cache', 'MISS')
headers.set('X-Model-Type', 'cloud')
return result.toTextStreamResponse({
headers: Object.fromEntries(headers.entries()),
})
} catch (error) {
console.error('AI Run Error:', error)
if (error instanceof Response) {
return error
}
if (error instanceof Error) {
// Never leak API key values in error messages
if (error.message.includes('API key') || error.message.includes('api_key')) {
return errorResponse('API key not configured. Please add your API keys in settings.', 500)
}
return errorResponse(error.message, 500)
}
return errorResponse('An unexpected error occurred', 500)
}
}
/**
* Track a run in the database and increment the prompt's totalRuns counter.
*/
async function trackRun(promptId: string, userId: string | null, model: string, ipHash: string) {
await prisma.$transaction([
prisma.run.create({
data: {
promptId,
userId: userId || null,
model,
cached: false,
ipHash,
},
}),
prisma.prompt.update({
where: { id: promptId },
data: { totalRuns: { increment: 1 } },
}),
])
}