Spaces:
Configuration error
Configuration error
File size: 11,784 Bytes
bcce530 d70eff5 bcce530 d70eff5 bcce530 d70eff5 bcce530 d70eff5 bcce530 d70eff5 bcce530 d70eff5 bcce530 d70eff5 bcce530 d70eff5 bcce530 d70eff5 bcce530 d70eff5 bcce530 d70eff5 bcce530 d70eff5 bcce530 d70eff5 bcce530 d70eff5 bcce530 d70eff5 bcce530 d70eff5 bcce530 d70eff5 bcce530 d70eff5 bcce530 d70eff5 bcce530 d70eff5 bcce530 d70eff5 bcce530 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 | import { streamText } from 'ai'
import { openai } from '@ai-sdk/openai'
import { anthropic } from '@ai-sdk/anthropic'
import { google } from '@ai-sdk/google'
import { AI_MODELS, ModelId, DEFAULT_MODEL, isOllamaModel } from '@/types/prompt'
import { checkRateLimit, getClientIdentifier, getRateLimitHeaders } from '@/lib/rate-limit'
import { getCachedResponse } from '@/lib/cache'
import prisma from '@/lib/prisma'
import { getAuthUser } from '@/lib/auth'
// ββ SSRF Protection ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
const ALLOWED_OLLAMA_HOSTS = new Set([
'localhost',
'127.0.0.1',
'0.0.0.0',
'::1',
])
/**
* Validate that an Ollama URL points to a safe host (localhost only).
* Prevents SSRF attacks where a malicious user could probe internal services.
*/
function isAllowedOllamaUrl(urlStr: string): boolean {
try {
const url = new URL(urlStr)
const hostname = url.hostname.toLowerCase()
// Allow localhost variants only
if (ALLOWED_OLLAMA_HOSTS.has(hostname)) return true
// Allow custom hosts from environment whitelist
const extraHosts = process.env.ALLOWED_OLLAMA_HOSTS?.split(',').map(h => h.trim().toLowerCase()) || []
if (extraHosts.includes(hostname)) return true
return false
} catch {
return false
}
}
// ββ Standardized error response helper βββββββββββββββββββββββββββββββββββββ
function errorResponse(message: string, status: number, headers?: Headers): Response {
const responseHeaders: Record<string, string> = { 'Content-Type': 'application/json' }
if (headers) {
headers.forEach((value, key) => { responseHeaders[key] = value })
}
return new Response(
JSON.stringify({ error: message }),
{ status, headers: responseHeaders }
)
}
// ββ Ollama streaming helper ββββββββββββββββββββββββββββββββββββββββββββββββ
async function streamFromOllama(prompt: string, model: string, ollamaUrl: string) {
const response = await fetch(`${ollamaUrl}/api/generate`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
model,
prompt,
stream: true,
}),
})
if (!response.ok) {
if (response.status === 404) {
throw new Error(`Ollama model "${model}" not found. Please run: ollama pull ${model}`)
}
const errorText = await response.text().catch(() => '')
throw new Error(`Ollama error: ${response.status}${errorText ? ` - ${errorText}` : ''}`)
}
const encoder = new TextEncoder()
const decoder = new TextDecoder()
const transformStream = new TransformStream({
async transform(chunk, controller) {
const text = decoder.decode(chunk)
const lines = text.split('\n').filter(line => line.trim())
for (const line of lines) {
try {
const json = JSON.parse(line)
if (json.response) {
controller.enqueue(encoder.encode(json.response))
}
} catch {
// Skip malformed JSON
}
}
},
})
return response.body?.pipeThrough(transformStream)
}
// ββ Turnstile verification βββββββββββββββββββββββββββββββββββββββββββββββββ
async function verifyTurnstile(token: string): Promise<boolean> {
const secretKey = process.env.TURNSTILE_SECRET_KEY
if (!secretKey) {
if (process.env.NODE_ENV === 'production') {
console.error('Turnstile secret key not configured in production')
return false
}
console.warn('Turnstile not configured, skipping verification in development')
return true
}
try {
const response = await fetch('https://challenges.cloudflare.com/turnstile/v0/siteverify', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
secret: secretKey,
response: token,
}),
})
const data = await response.json()
return data.success === true
} catch (error) {
console.error('Turnstile verification error:', error)
return false
}
}
// ββ Main POST handler ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
export async function POST(req: Request) {
try {
const body = await req.json()
// Zod validation
const { runSchema } = await import('@/lib/validations')
const parsed = runSchema.safeParse(body)
if (!parsed.success) {
return errorResponse('Validation failed', 400)
}
const { prompt, promptId, model, variables, turnstileToken, ollamaUrl } = parsed.data
// Get authenticated user server-side
const authUser = await getAuthUser()
const userId = authUser?.id ?? null
// 1. Rate Limiting
const identifier = getClientIdentifier(req, userId ?? undefined)
const rateLimit = await checkRateLimit(identifier, !!userId)
if (!rateLimit.success) {
const headers = getRateLimitHeaders(rateLimit)
return errorResponse(rateLimit.error || 'Rate limit exceeded', 429, headers)
}
// 2. Check Cache
const modelId = (model as ModelId) || DEFAULT_MODEL
// 2.5 Validate model against prompt's allowed list
if (promptId && !isOllamaModel(modelId)) {
const promptRecord = await prisma.prompt.findUnique({
where: { id: promptId },
select: { modelAllowed: true },
})
if (promptRecord?.modelAllowed?.length && !promptRecord.modelAllowed.includes(modelId)) {
return errorResponse(
`Model "${modelId}" is not allowed for this prompt. Allowed: ${promptRecord.modelAllowed.join(', ')}`,
400
)
}
}
if (promptId && variables) {
const cached = await getCachedResponse(promptId, variables, modelId)
if (cached) {
const headers = getRateLimitHeaders(rateLimit)
headers.set('X-Cache', 'HIT')
return new Response(cached, {
headers: {
'Content-Type': 'text/plain',
...Object.fromEntries(headers.entries()),
},
})
}
}
// 3. Verify Turnstile (bot protection) β require for unauthenticated users
if (!userId) {
if (!turnstileToken) {
return errorResponse('Bot verification required', 403)
}
const isValid = await verifyTurnstile(turnstileToken)
if (!isValid) {
return errorResponse('Bot verification failed', 403)
}
}
// 4. Check if this is an Ollama model
const isOllama = isOllamaModel(modelId)
if (isOllama) {
const resolvedOllamaUrl = ollamaUrl || "http://localhost:11434"
// SSRF protection: validate the URL before making any request
if (!isAllowedOllamaUrl(resolvedOllamaUrl)) {
return errorResponse(
'Invalid Ollama URL. Only localhost connections are allowed for security.',
400
)
}
try {
const stream = await streamFromOllama(prompt, modelId, resolvedOllamaUrl)
if (!stream) {
return errorResponse('Failed to get Ollama stream', 500)
}
// Track run in database (non-blocking)
if (promptId) {
trackRun(promptId, userId, modelId, identifier).catch(
(err: unknown) => console.error('Failed to track run:', err)
)
}
const headers = getRateLimitHeaders(rateLimit)
headers.set('X-Cache', 'MISS')
headers.set('X-Model-Type', 'ollama')
return new Response(stream, {
headers: {
'Content-Type': 'text/plain; charset=utf-8',
...Object.fromEntries(headers.entries()),
},
})
} catch (error) {
console.error('Ollama error:', error)
return errorResponse(
`Ollama error: ${error instanceof Error ? error.message : 'Connection failed'}. Make sure Ollama is running.`,
500
)
}
}
// 5. Validate cloud model
const modelConfig = AI_MODELS[modelId as ModelId]
if (!modelConfig) {
return errorResponse('Invalid model', 400)
}
// 6. Select the appropriate AI provider
let aiModel
switch (modelConfig.provider) {
case 'openai':
aiModel = openai(modelId)
break
case 'anthropic':
aiModel = anthropic(modelId)
break
case 'google':
aiModel = google(modelId)
break
default:
aiModel = google(DEFAULT_MODEL)
}
// 7. Stream the response
const result = await streamText({
model: aiModel,
prompt,
})
// 8. Track run in database (non-blocking)
if (promptId) {
trackRun(promptId, userId, modelId, identifier).catch(
(err: unknown) => console.error('Failed to track run:', err)
)
}
// 9. Return streaming response with rate limit headers
const headers = getRateLimitHeaders(rateLimit)
headers.set('X-Cache', 'MISS')
headers.set('X-Model-Type', 'cloud')
return result.toTextStreamResponse({
headers: Object.fromEntries(headers.entries()),
})
} catch (error) {
console.error('AI Run Error:', error)
if (error instanceof Response) {
return error
}
if (error instanceof Error) {
// Never leak API key values in error messages
if (error.message.includes('API key') || error.message.includes('api_key')) {
return errorResponse('API key not configured. Please add your API keys in settings.', 500)
}
return errorResponse(error.message, 500)
}
return errorResponse('An unexpected error occurred', 500)
}
}
/**
* Track a run in the database and increment the prompt's totalRuns counter.
*/
async function trackRun(promptId: string, userId: string | null, model: string, ipHash: string) {
await prisma.$transaction([
prisma.run.create({
data: {
promptId,
userId: userId || null,
model,
cached: false,
ipHash,
},
}),
prisma.prompt.update({
where: { id: promptId },
data: { totalRuns: { increment: 1 } },
}),
])
}
|