""" Security utilities for the Enterprise AI Gateway """ import os import re import logging import requests from fastapi import HTTPException, Depends, status from fastapi.security import APIKeyHeader from ..config import TOXICITY_THRESHOLD_DEFAULT, TOXICITY_THRESHOLD_HATE logger = logging.getLogger(__name__) # --- Security Configuration --- API_KEY_NAME = "X-API-Key" api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) SERVICE_API_KEY = os.getenv("SERVICE_API_KEY") ENABLE_PROMPT_INJECTION_CHECK = os.getenv("ENABLE_PROMPT_INJECTION_CHECK", "true").lower() == "true" # --- Prompt Injection Detection --- INJECTION_PATTERNS = [ # Existing patterns - direct override attempts r"ignore\s+(all\s+)?(previous|above|prior)\s+instructions?", r"disregard\s+(all\s+)?(previous|above|prior)\s+instructions?", r"you\s+are\s+now", r"system\s*:\s*", # Demonstration/example requests for attacks r"(demonstrate|show|give\s+me|provide).{0,30}(injection|jailbreak|exploit|attack|bypass|hack)", r"(example|sample|demo).{0,20}(of|for).{0,20}(injection|jailbreak|exploit|attack|bypass)", r"(how\s+to|how\s+would|how\s+do\s+i|how\s+do\s+you).{0,30}(inject|jailbreak|exploit|bypass|hack)", r"by\s+demonstrating.{0,20}(injection|jailbreak|exploit|attack|bypass|one)", r"(injection|jailbreak|exploit|attack).{0,20}by\s+demonstrating", # Simulate/role-play attack requests r"(simulate|emulate|replicate|recreate).{0,30}(attack|injection|exploit|breach|hack)", r"(pretend|act\s+as\s+if|imagine).{0,30}(hacked|breached|compromised|no.{0,10}filter)", # Ignore/disable safety patterns (embedded in any context) r"ignore\s+(all\s+)?(safety|filter|restriction|guideline|rule|moderation)", r"(disable|remove|bypass|skip)\s+(all\s+)?(safety|filter|restriction|moderation)", # Completion attacks (asking to complete sensitive phrases) r"complete\s+(this|the)?\s*(sentence|phrase|text).{0,30}(password|secret|credential|key)", # Repeat-after-me attacks r"(repeat\s+after|say\s+after|copy\s+this).{0,30}(no\s+restriction|no\s+filter|no\s+limit|no\s+rule)", r"(repeat\s+after|say\s+after).{0,10}(me|this)", # Role-play without safety r"role.?play.{0,30}(without|no)\s+(any\s+)?(safety|filter|guideline|restriction|rule)", r"(act|behave|respond)\s+as\s+(if|though).{0,20}(no|without).{0,10}(filter|safety|restriction)", r"ai\s+without\s+(any\s+)?(safety|filter|guideline|restriction)", # Hypothetical bypass requests r"(hypothetically|theoretically|in\s+theory).{0,30}(bypass|hack|inject|jailbreak|exploit)", # Encoded instruction attacks r"(base64|encoded|decrypt|decode).{0,20}(instruction|command|message|this)", ] def detect_prompt_injection(prompt: str) -> bool: """Detect potential prompt injection attacks""" if not ENABLE_PROMPT_INJECTION_CHECK: return False prompt_lower = prompt.lower() for pattern in INJECTION_PATTERNS: if re.search(pattern, prompt_lower, re.IGNORECASE): return True return False # --- PII Detection --- PII_PATTERNS = { # Existing patterns "email": r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", "credit_card": r"\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b", "ssn": r"\b\d{3}-\d{2}-\d{4}\b", "tax_id": r"\b\d{2}-\d{7}\b", # API Keys - OpenAI/Anthropic style (sk-proj-xxx, sk-ant-xxx) "api_key_openai": r"\bsk-[a-zA-Z0-9\-_]{20,}\b", # API Keys - AWS access keys "api_key_aws": r"\bAKIA[0-9A-Z]{16}\b", # API Keys - Generic patterns with labels "api_key_labeled": r"(api[_\-]?key|apikey|api[_\-]?secret|secret[_\-]?key|access[_\-]?token|bearer)[\s:=]+\S+", # Passport Numbers - 1-2 letters followed by 6-9 digits "passport": r"\b[A-Z]{1,2}[\s\-]?\d{6,9}\b", # Driver's License - Letter followed by digit groups "drivers_license_dashed": r"\b[A-Z]\d{3}[\-\s]?\d{4}[\-\s]?\d{4}\b", "drivers_license_simple": r"\b[A-Z]\d{6,12}\b", # Medical Records - DOB patterns "medical_dob": r"(DOB|date\s+of\s+birth|d\.o\.b\.?)[\s:]+\d{1,2}[/\-]\d{1,2}[/\-]\d{2,4}", # Medical Records - MRN patterns "medical_mrn": r"(MRN|medical\s+record|patient\s+id|patient\s+number)[\s:#\-]+\d{4,}", # Phone Numbers - US formats "phone_us": r"\b(\+?1[\s\-.]?)?\(?\d{3}\)?[\s\-.]?\d{3}[\s\-.]?\d{4}\b", # Phone Numbers - International with + prefix "phone_intl": r"\+\d{1,3}[\s\-.]?\d{2,4}[\s\-.]?\d{3,4}[\s\-.]?\d{3,4}\b", # IBAN - International Bank Account Number "iban": r"\b[A-Z]{2}\d{2}[\s]?[A-Z0-9]{4}[\s]?[A-Z0-9]{4}[\s]?[A-Z0-9]{4}[\s]?[A-Z0-9]{0,14}\b", # Password - Labeled patterns "password_labeled": r"(password|passwd|pwd|pass)[\s:=]+\S+", # Database connection strings with credentials "db_connection": r"(postgres|mysql|mongodb|redis)://[^:]+:[^@]+@", } def detect_pii(prompt: str) -> dict: """Detect PII in prompt, returns {has_pii: bool, pii_types: list, matches: dict}""" matches = {} pii_types = [] for pii_type, pattern in PII_PATTERNS.items(): found = re.findall(pattern, prompt, re.IGNORECASE) if found: pii_types.append(pii_type) matches[pii_type] = len(found) return { "has_pii": len(pii_types) > 0, "pii_types": pii_types, "matches": matches } # --- Hate Speech Pre-Screening (Approach A) --- # Regex-based pre-screening for hate speech indicators # Runs BEFORE AI safety classifier to catch subtle hate speech HATE_SPEECH_PATTERNS = [ # Hate verbs targeting people/groups who are different r"(hate|despise|loathe|detest|can'?t\s+stand|disgust).{0,30}(people|persons|those|them|everyone|anybody|anyone).{0,30}(who\s+are|who\s+look|who\s+come|different|foreign|other|not\s+like\s+me|unlike\s+me)", r"(people|persons|those|them).{0,20}(who\s+are|who\s+look).{0,20}(different|foreign|other).{0,20}(are\s+)?(disgust|repuls|sicken|hate)", # Dehumanizing language r"(people|they|them|those).{0,20}(are\s+animals|are\s+subhuman|are\s+vermin|are\s+parasites|are\s+cockroaches)", r"(people|they|them|those).{0,20}(don'?t\s+belong|should\s+go\s+back|should\s+be\s+removed|should\s+be\s+deported|have\s+no\s+place)", # Supremacist framing r"(superior|inferior|pure|impure).{0,20}(race|blood|people|kind|breed|stock)", r"(our\s+kind|my\s+kind|our\s+people).{0,20}(better|superior|pure)", # Direct expressions of hatred toward groups r"(i\s+really\s+)?(hate|despise|loathe).{0,20}(people|those|them).{0,20}(different|like\s+them|foreign)", r"(don'?t\s+look\s+like\s+me|different\s+from\s+me).{0,20}(disgust|hate|despise|loathe)", ] # Patterns that indicate EDUCATIONAL context (should NOT block) HATE_SPEECH_EDUCATIONAL_PATTERNS = [ r"(explain|history|overcome|causes|why\s+do|what\s+causes|how\s+can\s+i|how\s+to\s+combat|how\s+to\s+fight|prevent|understand)", r"(civil\s+rights|discrimination|prejudice|bias|racism).{0,20}(movement|history|explained|education)", ] def detect_hate_speech(prompt: str) -> dict: """ Pre-screen for hate speech indicators before AI safety classifier. Returns: {is_hate_speech: bool, matched_pattern: str|None, is_educational: bool} """ prompt_lower = prompt.lower() # First check if this is educational context is_educational = any( re.search(pattern, prompt_lower, re.IGNORECASE) for pattern in HATE_SPEECH_EDUCATIONAL_PATTERNS ) # If educational, don't flag as hate speech if is_educational: return { "is_hate_speech": False, "matched_pattern": None, "is_educational": True } # Check for hate speech patterns for pattern in HATE_SPEECH_PATTERNS: match = re.search(pattern, prompt_lower, re.IGNORECASE) if match: return { "is_hate_speech": True, "matched_pattern": match.group(), "is_educational": False } return { "is_hate_speech": False, "matched_pattern": None, "is_educational": False } # --- API Key Validation --- async def validate_api_key(api_key: str = Depends(api_key_header)): """Validate API key for request authentication""" if not SERVICE_API_KEY: raise HTTPException(status_code=500, detail="Server misconfiguration: API Key missing") if api_key != SERVICE_API_KEY: raise HTTPException(status_code=401, detail="Invalid or missing API key") return api_key # --- Gemini Safety Filter for Toxicity Detection --- # Uses Gemini's built-in content safety that returns HARM_CATEGORY_* ratings # Uses GEMINI_MODEL env var or defaults to gemini-2.5-flash def get_gemini_safety_url(): model = os.getenv("GEMINI_MODEL", "gemini-2.5-flash") return f"https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent" # --- Lakera Guard API (Fallback) --- LAKERA_API_URL = "https://api.lakera.ai/v2/guard" # Gemini harm categories (all available categories) HARM_CATEGORIES = [ "HARM_CATEGORY_SEXUALLY_EXPLICIT", "HARM_CATEGORY_HATE_SPEECH", "HARM_CATEGORY_HARASSMENT", "HARM_CATEGORY_DANGEROUS_CONTENT", "HARM_CATEGORY_CIVIC_INTEGRITY", ] def detect_toxicity(text: str) -> dict: """ Detect toxic content using Gemini's built-in safety filters. Uses GEMINI_API_KEY environment variable for authentication. Gemini 2.5 models handle safety by refusing harmful content. Returns: {is_toxic: bool, scores: dict, blocked_categories: list, error: str|None} """ # Read API key at runtime to pick up HF Spaces secrets api_key = os.getenv("GEMINI_API_KEY") if not api_key: logger.warning("GEMINI_API_KEY not configured, falling back to Lakera") return detect_toxicity_lakera(text) try: # Ask Gemini to classify if the content is harmful classification_prompt = f"""Analyze this text and determine if it requests harmful content. Respond with ONLY one of these categories, or "SAFE" if harmless: - SEXUALLY_EXPLICIT (nude, porn, explicit sexual content) - HATE_SPEECH (racism, discrimination, slurs) - HARASSMENT (threats, bullying, intimidation) - DANGEROUS_CONTENT (weapons, drugs, violence, self-harm) - CIVIC_INTEGRITY (election fraud, voter suppression) Text to analyze: "{text}" Category:""" payload = { "contents": [{"parts": [{"text": classification_prompt}]}], } response = requests.post( f"{get_gemini_safety_url()}?key={api_key}", json=payload, headers={"Content-Type": "application/json"}, timeout=10 ) if response.status_code != 200: error_detail = "" try: error_detail = response.json().get("error", {}).get("message", "") except: pass logger.warning(f"Gemini API error {response.status_code}: {error_detail}, falling back to Lakera") return detect_toxicity_lakera(text) data = response.json() blocked_categories = [] scores = {} # Check if request was blocked at prompt level if "promptFeedback" in data: feedback = data["promptFeedback"] if feedback.get("blockReason"): blocked_categories.append(feedback["blockReason"]) return { "is_toxic": True, "scores": {"BLOCKED": 1.0}, "blocked_categories": blocked_categories, "error": None } # Parse Gemini's classification response if "candidates" in data and data["candidates"]: response_text = "" for part in data["candidates"][0].get("content", {}).get("parts", []): response_text += part.get("text", "") response_text = response_text.strip().upper() # Check for harmful categories harmful_categories = [ "SEXUALLY_EXPLICIT", "HATE_SPEECH", "HARASSMENT", "DANGEROUS_CONTENT", "CIVIC_INTEGRITY" ] for category in harmful_categories: if category in response_text: blocked_categories.append(f"HARM_CATEGORY_{category}") scores[f"HARM_CATEGORY_{category}"] = 0.9 # If Gemini says SAFE or doesn't match categories if not blocked_categories: scores["SAFE"] = 1.0 return { "is_toxic": len(blocked_categories) > 0, "scores": scores, "blocked_categories": blocked_categories, "error": None } except requests.exceptions.Timeout: logger.warning("Gemini API timeout, falling back to Lakera") return detect_toxicity_lakera(text) except Exception as e: logger.warning(f"Gemini API exception: {e}, falling back to Lakera") return detect_toxicity_lakera(text) def detect_toxicity_lakera(text: str) -> dict: """ Fallback toxicity detection using Lakera Guard API. Uses LAKERA_API_KEY environment variable for authentication. Returns: {is_toxic: bool, scores: dict, blocked_categories: list, error: str|None} """ api_key = os.getenv("LAKERA_API_KEY") if not api_key: logger.warning("LAKERA_API_KEY not configured, skipping toxicity check") # Both Gemini and Lakera unavailable - allow request to proceed return { "is_toxic": False, "scores": {}, "blocked_categories": [], "error": None # Don't error, just skip check } try: payload = { "messages": [{"content": text, "role": "user"}] } response = requests.post( LAKERA_API_URL, json=payload, headers={ "Authorization": f"Bearer {api_key}", "Content-Type": "application/json" }, timeout=10 ) if response.status_code != 200: error_detail = "" try: error_detail = response.json().get("error", response.text) except: error_detail = response.text logger.warning(f"Lakera API error {response.status_code}: {error_detail}") # Both APIs failed - allow request to proceed return { "is_toxic": False, "scores": {}, "blocked_categories": [], "error": None # Don't block user, just skip check } data = response.json() blocked_categories = [] scores = {} # Lakera returns categories with flagged status # Check for flagged content in results results = data.get("results", []) for result in results: categories = result.get("categories", {}) for category, flagged in categories.items(): if flagged: blocked_categories.append(f"LAKERA_{category.upper()}") scores[f"LAKERA_{category.upper()}"] = 1.0 else: scores[f"LAKERA_{category.upper()}"] = 0.0 # Also check category_scores for more detail category_scores = result.get("category_scores", {}) for category, score in category_scores.items(): scores[f"LAKERA_{category.upper()}"] = score # Check top-level flagged status is_flagged = data.get("flagged", False) if is_flagged and not blocked_categories: blocked_categories.append("LAKERA_FLAGGED") return { "is_toxic": is_flagged or len(blocked_categories) > 0, "scores": scores, "blocked_categories": blocked_categories, "error": None } except requests.exceptions.Timeout: logger.warning("Lakera API timeout") return { "is_toxic": False, "scores": {}, "blocked_categories": [], "error": None # Don't block user } except Exception as e: logger.warning(f"Lakera API exception: {e}") return { "is_toxic": False, "scores": {}, "blocked_categories": [], "error": None # Don't block user }