File size: 10,357 Bytes
bb0c63f
 
 
 
 
 
 
 
 
 
 
 
 
1c6736a
bb0c63f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c6736a
bb0c63f
1c6736a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f094f6
 
 
 
 
 
 
 
 
 
 
 
1c6736a
 
bb0c63f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
"""
API routes for the Enterprise AI Gateway
"""

import time
from typing import List
from fastapi import APIRouter, HTTPException, Depends, Request, status
from fastapi.responses import HTMLResponse
from slowapi import Limiter
from slowapi.util import get_remote_address
from pydantic import BaseModel

from ..models import QueryRequest, QueryResponse, HealthResponse
from ..security import validate_api_key, detect_pii, detect_prompt_injection, detect_toxicity, detect_hate_speech
from ..llm.client import llm_client
from ..config import RATE_LIMIT, SERVICE_API_KEY
from ..metrics import metrics
from ..providers import PROVIDER_CONFIG, estimate_cost


# --- Request Models for Batch Endpoints ---
class BatchRequest(BaseModel):
    prompts: List[str]

# --- Router Setup ---
router = APIRouter()
limiter = Limiter(key_func=get_remote_address, default_limits=[RATE_LIMIT])

@router.get("/", include_in_schema=False)
async def read_root():
    """Serves the Interactive Gateway Demo Dashboard"""
    import os
    from fastapi.responses import FileResponse

    # Path to static HTML file
    static_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "static", "index.html")

    # Read and inject API key for demo experience
    with open(static_path, "r") as f:
        html_content = f.read()

    # Inject the actual service API key for the demo
    html_with_key = html_content.replace('value="secure-demo-ak7x9..."', f'value="{SERVICE_API_KEY}"')
    return HTMLResponse(content=html_with_key, media_type="text/html")

@router.get("/health", response_model=HealthResponse)
async def health_check(request: Request):
    """Health check endpoint"""
    active_provider = None
    if llm_client.providers:
        active_provider = llm_client.providers[0]["name"]
    return HealthResponse(
        status="healthy",
        provider=active_provider,
        timestamp=time.time()
    )

@router.post("/query", response_model=QueryResponse)
@limiter.limit(RATE_LIMIT)
async def query_llm(request: Request, query: QueryRequest, api_key: str = Depends(validate_api_key)):
    """Query LLM with security and fallback protocols"""

    # ========== LAYER 1: Regex-based pre-screening ==========

    # 1a. Prompt injection check (already done in Pydantic model, but double-check)
    if detect_prompt_injection(query.prompt):
        metrics.record_request(blocked=True)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail="Security Alert: Prompt injection pattern detected"
        )

    # 1b. PII detection
    pii_result = detect_pii(query.prompt)
    if pii_result["has_pii"]:
        metrics.record_request(blocked=True)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=f"Security Alert: PII detected ({', '.join(pii_result['pii_types'])})"
        )

    # 1c. Hate speech pre-screening
    hate_result = detect_hate_speech(query.prompt)
    if hate_result["is_hate_speech"]:
        metrics.record_request(blocked=True)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail="Security Alert: Hate speech detected"
        )

    # ========== LAYER 2: AI-based safety check (Lakera/Gemini) ==========
    # Only runs if regex layer passes
    # Skip for educational content (to avoid false positives on questions about hate/prejudice)
    is_educational = hate_result.get("is_educational", False)

    if not is_educational:
        toxicity_result = detect_toxicity(query.prompt)
        if toxicity_result["is_toxic"]:
            categories = ", ".join(toxicity_result["blocked_categories"]) or "harmful content"
            metrics.record_request(blocked=True)
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=f"Security Alert: Content flagged by AI safety ({categories})"
            )

    # ========== LAYER 3: LLM Execution ==========
    response_content, provider_used, latency_ms, error_message, cascade_path = await llm_client.query_llm_cascade(
        prompt=query.prompt,
        max_tokens=query.max_tokens,
        temperature=query.temperature
    )

    if response_content:
        # Estimate cost (rough estimate based on max_tokens)
        cost_estimate = None
        if provider_used:
            for provider in llm_client.providers:
                if provider["name"] == provider_used:
                    cost_estimate = estimate_cost(
                        provider_used,
                        provider["model"],
                        len(query.prompt.split()) * 2,  # rough input token estimate
                        query.max_tokens // 2  # assume half of max used
                    )
                    break

        # Record metrics
        metrics.record_request(
            provider=provider_used,
            latency_ms=latency_ms,
            blocked=False
        )

        return QueryResponse(
            response=response_content,
            provider=provider_used,
            latency_ms=latency_ms,
            status="success",
            error=None,
            cascade_path=cascade_path,
            cost_estimate_usd=cost_estimate
        )
    else:
        # Record failed request
        metrics.record_request(cascade_failed=True)

        # Fallback failure
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=error_message or "All LLM providers failed."
        )


@router.get("/metrics")
async def get_metrics():
    """Return current gateway metrics"""
    return metrics.to_dict()


@router.get("/providers")
async def get_providers():
    """Return available providers with pricing info"""
    active_providers = [p["name"] for p in llm_client.providers]
    return {
        "providers": PROVIDER_CONFIG,
        "active_providers": active_providers,
        "active_models": {p["name"]: p["model"] for p in llm_client.providers}
    }


@router.post("/batch/resilience")
async def batch_resilience_test(
    request: Request,
    batch: BatchRequest,
    api_key: str = Depends(validate_api_key)
):
    """Run multiple prompts through the cascade, return aggregate metrics"""
    results = []
    total_failures = 0
    total_latency = 0

    # Limit to 10 prompts for PoC
    prompts = batch.prompts[:10]

    for prompt in prompts:
        try:
            response, provider, latency, error, cascade_path = await llm_client.query_llm_cascade(
                prompt=prompt,
                max_tokens=256,
                temperature=0.7
            )

            failures_in_cascade = sum(1 for step in cascade_path if step["status"] == "failed")
            total_failures += failures_in_cascade

            if response:
                total_latency += latency
                metrics.record_request(provider=provider, latency_ms=latency)
            else:
                metrics.record_request(cascade_failed=True)

            results.append({
                "prompt": prompt[:50] + "..." if len(prompt) > 50 else prompt,
                "success": response is not None,
                "provider": provider,
                "latency_ms": latency,
                "cascade_path": cascade_path,
                "failures_in_cascade": failures_in_cascade
            })
        except Exception as e:
            results.append({
                "prompt": prompt[:50] + "..." if len(prompt) > 50 else prompt,
                "success": False,
                "error": str(e)
            })

    successful = sum(1 for r in results if r.get("success"))
    avg_latency = total_latency / successful if successful > 0 else 0

    return {
        "total": len(results),
        "successful": successful,
        "failed": len(results) - successful,
        "total_cascade_failures": total_failures,
        "average_latency_ms": round(avg_latency, 2),
        "downtime_prevented_minutes": round(total_failures * 4, 1),  # 4 min per failure
        "results": results
    }


@router.post("/batch/security")
async def batch_security_test(batch: BatchRequest):
    """Test prompts for security issues without executing LLM calls"""
    results = []
    total_blocked = 0
    pii_leaks = 0
    injection_attempts = 0

    for prompt in batch.prompts[:20]:  # Limit to 20
        pii_result = detect_pii(prompt)
        injection_detected = detect_prompt_injection(prompt)

        blocked = pii_result["has_pii"] or injection_detected

        if blocked:
            total_blocked += 1
        if pii_result["has_pii"]:
            pii_leaks += len(pii_result["pii_types"])
            metrics.record_request(blocked=True, pii_detected=True)
        if injection_detected:
            injection_attempts += 1
            metrics.record_request(blocked=True, injection_detected=True)

        results.append({
            "prompt": prompt[:50] + "..." if len(prompt) > 50 else prompt,
            "blocked": blocked,
            "pii_detected": pii_result["pii_types"] if pii_result["has_pii"] else [],
            "pii_matches": pii_result["matches"] if pii_result["has_pii"] else {},
            "injection_detected": injection_detected
        })

    # Calculate compliance fines avoided (GDPR ~$50K + CCPA ~$7.5K avg = $28K per violation)
    compliance_fines_avoided = pii_leaks * 28000

    return {
        "total": len(results),
        "blocked": total_blocked,
        "passed": len(results) - total_blocked,
        "pii_leaks_prevented": pii_leaks,
        "injection_attempts_blocked": injection_attempts,
        "compliance_fines_avoided_usd": compliance_fines_avoided,
        "results": results
    }


class ToxicityRequest(BaseModel):
    text: str


@router.post("/check-toxicity")
async def check_toxicity(request: ToxicityRequest):
    """
    Check text for toxic content using AI safety classification.
    Returns toxicity scores and blocked categories.
    """
    result = detect_toxicity(request.text)
    # Sanitize error - don't expose internal details to users
    has_error = result["error"] is not None
    return {
        "is_toxic": result["is_toxic"],
        "scores": result["scores"],
        "blocked_categories": result["blocked_categories"],
        "error": "Safety check encountered an issue" if has_error else None
    }