Spaces:
Sleeping
Sleeping
| """ | |
| ClauseGuard β FastAPI Backend v4.1 | |
| ββββββββββββββββββββββββββββββββββ | |
| Fixes in v4.1: | |
| β’ FIX: Rate limiter uses sliding window with proper IP extraction (X-Forwarded-For) | |
| β’ FIX: RAG sessions have TTL-based expiry (1 hour) instead of just count-based | |
| β’ FIX: Input text size validation (max 200KB) | |
| β’ FIX: Proper error handling for all endpoints | |
| """ | |
| import os | |
| import re | |
| import json | |
| import time | |
| import uuid | |
| import tempfile | |
| from contextlib import asynccontextmanager | |
| from typing import Optional | |
| from collections import defaultdict | |
| from datetime import datetime | |
| import httpx | |
| import numpy as np | |
| from fastapi import FastAPI, HTTPException, Depends, Body, Request, UploadFile, File as FastAPIFile | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel, Field | |
| from auth import get_current_user, require_auth | |
| # ββ Import shared modules ββ | |
| import sys | |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| try: | |
| from app import ( | |
| split_clauses, classify_cuad, extract_entities, | |
| detect_contradictions, compute_risk_score, analyze_contract, | |
| CUAD_LABELS, RISK_MAP, DESC_MAP, _model_status, | |
| cuad_model, cuad_tokenizer | |
| ) | |
| from obligations import extract_obligations | |
| from compliance import check_compliance | |
| from compare import compare_contracts | |
| from redlining import generate_redlines | |
| from chatbot import index_contract, chat_respond | |
| from ocr_engine import parse_pdf_smart, get_ocr_status | |
| _SHARED_MODULES = True | |
| except ImportError as e: | |
| _SHARED_MODULES = False | |
| print(f"[API] WARNING: Could not import shared modules: {e}") | |
| # βββ Config βββ | |
| SUPABASE_URL = os.environ.get("SUPABASE_URL", "") | |
| SUPABASE_SERVICE_KEY = os.environ.get("SUPABASE_SERVICE_ROLE_KEY", "") | |
| HF_API_TOKEN = os.environ.get("HF_API_TOKEN", "") | |
| SAULLM_ENDPOINT = os.environ.get("SAULLM_ENDPOINT", "") | |
| MAX_TEXT_LENGTH = int(os.environ.get("MAX_TEXT_LENGTH", "200000")) | |
| # βββ FIX v4.2: Improved sliding window rate limiter with periodic cleanup βββ | |
| _rate_limits: dict[str, list[float]] = {} | |
| _rate_limits_last_cleanup: float = 0.0 | |
| RATE_LIMIT_REQUESTS = 30 | |
| RATE_LIMIT_WINDOW = 60 # seconds | |
| def _get_client_ip(request: Request) -> str: | |
| """Extract real client IP, handling reverse proxies.""" | |
| forwarded = request.headers.get("x-forwarded-for", "") | |
| if forwarded: | |
| return forwarded.split(",")[0].strip() | |
| return request.client.host if request.client else "unknown" | |
| def _check_rate_limit(client_ip: str) -> bool: | |
| """Sliding window rate limiter with periodic stale-IP cleanup.""" | |
| global _rate_limits_last_cleanup | |
| now = time.time() | |
| # FIX v4.2: Periodic cleanup every 60s regardless of dict size | |
| if now - _rate_limits_last_cleanup > 60: | |
| stale = [ip for ip, ts in _rate_limits.items() if not ts or now - ts[-1] > RATE_LIMIT_WINDOW * 2] | |
| for ip in stale: | |
| del _rate_limits[ip] | |
| _rate_limits_last_cleanup = now | |
| if client_ip not in _rate_limits: | |
| _rate_limits[client_ip] = [] | |
| # Remove expired timestamps | |
| _rate_limits[client_ip] = [ | |
| t for t in _rate_limits[client_ip] if now - t < RATE_LIMIT_WINDOW | |
| ] | |
| if len(_rate_limits[client_ip]) >= RATE_LIMIT_REQUESTS: | |
| return False | |
| _rate_limits[client_ip].append(now) | |
| return True | |
| # βββ Supabase helper βββ | |
| async def supabase_insert(table: str, data: dict): | |
| if not SUPABASE_URL or not SUPABASE_SERVICE_KEY: | |
| return | |
| try: | |
| async with httpx.AsyncClient() as client: | |
| await client.post( | |
| f"{SUPABASE_URL}/rest/v1/{table}", | |
| json=data, | |
| headers={ | |
| "apikey": SUPABASE_SERVICE_KEY, | |
| "Authorization": f"Bearer {SUPABASE_SERVICE_KEY}", | |
| "Content-Type": "application/json", | |
| "Prefer": "return=minimal", | |
| }, | |
| timeout=10.0, | |
| ) | |
| except Exception: | |
| pass | |
| async def supabase_query(table: str, params: dict, headers_extra: dict = {}): | |
| if not SUPABASE_URL or not SUPABASE_SERVICE_KEY: | |
| return [] | |
| try: | |
| async with httpx.AsyncClient() as client: | |
| resp = await client.get( | |
| f"{SUPABASE_URL}/rest/v1/{table}", | |
| params=params, | |
| headers={ | |
| "apikey": SUPABASE_SERVICE_KEY, | |
| "Authorization": f"Bearer {SUPABASE_SERVICE_KEY}", | |
| **headers_extra, | |
| }, | |
| timeout=10.0, | |
| ) | |
| return resp.json() if resp.status_code == 200 else [] | |
| except Exception: | |
| return [] | |
| # βββ FIX v4.1: RAG sessions with TTL-based expiry βββ | |
| _rag_sessions: dict[str, dict] = {} | |
| _RAG_SESSION_MAX = 100 | |
| _RAG_SESSION_TTL = 3600 # 1 hour | |
| def _cleanup_rag_sessions(): | |
| """Remove expired RAG sessions.""" | |
| now = time.time() | |
| expired = [sid for sid, s in _rag_sessions.items() if now - s.get("created_at", 0) > _RAG_SESSION_TTL] | |
| for sid in expired: | |
| del _rag_sessions[sid] | |
| def _store_rag_session(session_id: str, data: dict): | |
| """Store a RAG session with TTL tracking.""" | |
| _cleanup_rag_sessions() | |
| if len(_rag_sessions) >= _RAG_SESSION_MAX: | |
| # Remove oldest session | |
| oldest = min(_rag_sessions, key=lambda k: _rag_sessions[k].get("created_at", 0)) | |
| del _rag_sessions[oldest] | |
| data["created_at"] = time.time() | |
| _rag_sessions[session_id] = data | |
| # βββ Request/Response Models βββ | |
| class AnalyzeRequest(BaseModel): | |
| text: Optional[str] = Field(None, min_length=50) | |
| clauses: Optional[list] = None | |
| source_url: Optional[str] = None | |
| class CompareRequest(BaseModel): | |
| text_a: str = Field(..., min_length=50) | |
| text_b: str = Field(..., min_length=50) | |
| class ExplainRequest(BaseModel): | |
| clause: str = Field(..., min_length=10, max_length=2000) | |
| category: str | |
| class ExplainResponse(BaseModel): | |
| clause: str | |
| category: str | |
| explanation: str | |
| legal_basis: str | |
| recommendation: str | |
| class ChatRequest(BaseModel): | |
| message: str = Field(..., min_length=1, max_length=2000) | |
| session_id: str | |
| history: Optional[list[dict]] = None | |
| class RedlineRequest(BaseModel): | |
| session_id: Optional[str] = None | |
| text: Optional[str] = None | |
| use_llm: bool = True | |
| # βββ App βββ | |
| async def lifespan(app: FastAPI): | |
| yield | |
| app = FastAPI(title="ClauseGuard API", version="4.1.0", lifespan=lifespan) | |
| # FIX v4.2: CORS origins configurable via env var; localhost only in dev | |
| _extra_origins = os.environ.get("CORS_EXTRA_ORIGINS", "").split(",") | |
| ALLOWED_ORIGINS = [ | |
| "https://clauseguardweb.netlify.app", | |
| ] | |
| # Only add localhost origins if explicitly enabled via env | |
| if os.environ.get("CORS_ALLOW_LOCALHOST", "").lower() == "true": | |
| ALLOWED_ORIGINS.extend(["http://localhost:3000", "http://localhost:3001"]) | |
| ALLOWED_ORIGINS.extend([o.strip() for o in _extra_origins if o.strip()]) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=ALLOWED_ORIGINS, | |
| allow_origin_regex=r"^chrome-extension://.*$", | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def health(): | |
| model_status = "ml" if _SHARED_MODULES and cuad_model else "regex" | |
| ocr_status = get_ocr_status() if _SHARED_MODULES else "unavailable" | |
| return { | |
| "status": "ok", | |
| "model": model_status, | |
| "version": "4.1.0", | |
| "shared_modules": _SHARED_MODULES, | |
| "ocr": ocr_status, | |
| "features": ["analyze", "compare", "redline", "chat", "ocr"], | |
| "rag_sessions_active": len(_rag_sessions), | |
| } | |
| async def analyze(req: AnalyzeRequest, request: Request, user: Optional[dict] = Depends(get_current_user)): | |
| client_ip = _get_client_ip(request) | |
| if not _check_rate_limit(client_ip): | |
| raise HTTPException(status_code=429, detail="Rate limit exceeded. Please wait 60 seconds.") | |
| text = req.text | |
| if not text and req.clauses: | |
| text = "\n\n".join(req.clauses) if isinstance(req.clauses, list) else str(req.clauses) | |
| if not text or len(text.strip()) < 50: | |
| raise HTTPException(status_code=400, detail="Text too short (minimum 50 characters)") | |
| # FIX v4.1: Input size validation | |
| if len(text) > MAX_TEXT_LENGTH: | |
| raise HTTPException(status_code=400, detail=f"Text too long (max {MAX_TEXT_LENGTH // 1000}KB)") | |
| start = time.time() | |
| clauses = split_clauses(text) | |
| if not clauses: | |
| raise HTTPException(status_code=400, detail="No clauses detected") | |
| clause_results = [] | |
| for clause in clauses: | |
| predictions = classify_cuad(clause) | |
| if predictions: | |
| for pred in predictions: | |
| clause_results.append({ | |
| "text": clause, | |
| "label": pred["label"], | |
| "confidence": pred["confidence"], | |
| "risk": pred["risk"], | |
| "description": pred["description"], | |
| "source": pred.get("source", "unknown"), | |
| }) | |
| entities = extract_entities(text) | |
| contradictions = detect_contradictions(clause_results, text) | |
| risk, grade, sev_counts = compute_risk_score(clause_results, len(clauses)) | |
| obligations = extract_obligations(text) | |
| compliance = check_compliance(text) | |
| # v4.0: Redlining | |
| analysis_for_redline = {"clauses": clause_results} | |
| redlines = [] | |
| try: | |
| redlines = generate_redlines(analysis_for_redline, use_llm=True) | |
| except Exception as e: | |
| print(f"[API] Redlining error: {e}") | |
| latency = int((time.time() - start) * 1000) | |
| results_for_db = [] | |
| for cr in clause_results: | |
| results_for_db.append({ | |
| "text": cr["text"], | |
| "categories": [{ | |
| "name": cr["label"], | |
| "severity": cr["risk"], | |
| "confidence": cr["confidence"], | |
| "description": cr["description"], | |
| }], | |
| }) | |
| # RAG indexing with TTL-managed sessions | |
| session_id = None | |
| try: | |
| chunks, embeddings, _status = index_contract(text) | |
| if chunks and embeddings is not None: | |
| session_id = uuid.uuid4().hex[:12] | |
| _store_rag_session(session_id, { | |
| "chunks": chunks, | |
| "embeddings": embeddings, | |
| "analysis": { | |
| "risk": {"score": risk, "grade": grade, "breakdown": sev_counts}, | |
| "metadata": {"total_clauses": len(clauses), "flagged_clauses": len(clause_results)}, | |
| "clauses": clause_results[:30], | |
| "entities": entities[:30], | |
| "contradictions": contradictions, | |
| }, | |
| }) | |
| except Exception as e: | |
| print(f"[API] RAG indexing error: {e}") | |
| if user: | |
| await supabase_insert("analyses", { | |
| "user_id": user["id"], | |
| "source_url": req.source_url, | |
| "total_clauses": len(clauses), | |
| "flagged_count": len(set(cr["text"] for cr in clause_results)), | |
| "risk_score": risk, | |
| "grade": grade, | |
| "clauses": results_for_db, | |
| "entities": entities, | |
| "contradictions": contradictions, | |
| "obligations": obligations, | |
| "compliance": compliance, | |
| }) | |
| return { | |
| "risk_score": risk, | |
| "grade": grade, | |
| "total_clauses": len(clauses), | |
| "flagged_count": len(set(cr["text"] for cr in clause_results)), | |
| "results": results_for_db, | |
| "entities": entities, | |
| "contradictions": contradictions, | |
| "obligations": obligations, | |
| "compliance": compliance, | |
| "redlines": redlines, | |
| "model": "ml" if cuad_model else "regex", | |
| "latency_ms": latency, | |
| "session_id": session_id, | |
| } | |
| async def compare(req: CompareRequest, request: Request): | |
| client_ip = _get_client_ip(request) | |
| if not _check_rate_limit(client_ip): | |
| raise HTTPException(status_code=429, detail="Rate limit exceeded.") | |
| # FIX v4.1: Input size validation for comparison | |
| if len(req.text_a) > MAX_TEXT_LENGTH or len(req.text_b) > MAX_TEXT_LENGTH: | |
| raise HTTPException(status_code=400, detail=f"Text too long (max {MAX_TEXT_LENGTH // 1000}KB per contract)") | |
| return compare_contracts(req.text_a, req.text_b) | |
| async def redline(req: RedlineRequest, request: Request): | |
| client_ip = _get_client_ip(request) | |
| if not _check_rate_limit(client_ip): | |
| raise HTTPException(status_code=429, detail="Rate limit exceeded.") | |
| if req.session_id and req.session_id in _rag_sessions: | |
| analysis = _rag_sessions[req.session_id]["analysis"] | |
| elif req.text: | |
| if len(req.text) > MAX_TEXT_LENGTH: | |
| raise HTTPException(status_code=400, detail="Text too long") | |
| result, error = analyze_contract(req.text) | |
| if error: | |
| raise HTTPException(status_code=400, detail=error) | |
| analysis = result | |
| else: | |
| raise HTTPException(status_code=400, detail="Provide session_id or text") | |
| redlines = generate_redlines(analysis, use_llm=req.use_llm) | |
| return {"redlines": redlines, "count": len(redlines)} | |
| async def chat(req: ChatRequest, request: Request): | |
| client_ip = _get_client_ip(request) | |
| if not _check_rate_limit(client_ip): | |
| raise HTTPException(status_code=429, detail="Rate limit exceeded.") | |
| # FIX v4.1: Clean up expired sessions before checking | |
| _cleanup_rag_sessions() | |
| if req.session_id not in _rag_sessions: | |
| raise HTTPException(status_code=404, detail="Session expired or not found. Please analyze a contract first.") | |
| session = _rag_sessions[req.session_id] | |
| response_text = "" | |
| for partial in chat_respond(req.message, req.history or [], | |
| session["chunks"], session["embeddings"], session["analysis"]): | |
| response_text = partial | |
| return {"response": response_text, "session_id": req.session_id} | |
| async def chat_stream(req: ChatRequest, request: Request): | |
| client_ip = _get_client_ip(request) | |
| if not _check_rate_limit(client_ip): | |
| raise HTTPException(status_code=429, detail="Rate limit exceeded.") | |
| _cleanup_rag_sessions() | |
| if req.session_id not in _rag_sessions: | |
| raise HTTPException(status_code=404, detail="Session expired or not found.") | |
| session = _rag_sessions[req.session_id] | |
| async def generate(): | |
| last = "" | |
| for partial in chat_respond( | |
| req.message, req.history or [], | |
| session["chunks"], session["embeddings"], session["analysis"] | |
| ): | |
| delta = partial[len(last):] | |
| last = partial | |
| if delta: | |
| yield f"data: {json.dumps({'delta': delta})}\n\n" | |
| yield "data: [DONE]\n\n" | |
| return StreamingResponse(generate(), media_type="text/event-stream") | |
| async def ocr_endpoint(file: UploadFile = FastAPIFile(...)): | |
| if not file.filename or not file.filename.lower().endswith(".pdf"): | |
| raise HTTPException(status_code=400, detail="Only PDF files supported") | |
| # FIX v4.1: Limit upload size (20MB) | |
| content = await file.read() | |
| if len(content) > 20 * 1024 * 1024: | |
| raise HTTPException(status_code=400, detail="File too large (max 20MB)") | |
| with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp: | |
| tmp.write(content) | |
| tmp_path = tmp.name | |
| try: | |
| text, error, method = parse_pdf_smart(tmp_path) | |
| if error: | |
| raise HTTPException(status_code=400, detail=error) | |
| return {"text": text, "method": method, "chars": len(text) if text else 0, "filename": file.filename} | |
| finally: | |
| os.unlink(tmp_path) | |
| async def explain(req: ExplainRequest, user: dict = Depends(require_auth)): | |
| desc = DESC_MAP.get(req.category, "Unknown category.") | |
| legal = "Consult local consumer protection laws." | |
| recommendation = "Review this clause carefully." | |
| if SAULLM_ENDPOINT and HF_API_TOKEN: | |
| try: | |
| prompt = ( | |
| f"Analyze this contract clause and explain why it may be risky.\n\n" | |
| f"Clause: \"{req.clause}\"\nCategory: {req.category}\n\n" | |
| f"Provide: 1) Plain-English explanation 2) Legal basis 3) Recommendation" | |
| ) | |
| async with httpx.AsyncClient(timeout=30.0) as client: | |
| resp = await client.post( | |
| SAULLM_ENDPOINT, | |
| json={"inputs": prompt, "parameters": {"max_new_tokens": 300, "temperature": 0.3}}, | |
| headers={"Authorization": f"Bearer {HF_API_TOKEN}"}, | |
| ) | |
| if resp.status_code == 200: | |
| output = resp.json() | |
| generated = output[0]["generated_text"] if isinstance(output, list) else output.get("generated_text", "") | |
| if generated and len(generated) > 50: | |
| parts = generated.split("\n\n") | |
| desc = parts[0] if len(parts) > 0 else desc | |
| legal = parts[1] if len(parts) > 1 else legal | |
| recommendation = parts[2] if len(parts) > 2 else recommendation | |
| except Exception: | |
| pass | |
| return ExplainResponse(clause=req.clause, category=req.category, | |
| explanation=desc, legal_basis=legal, recommendation=recommendation) | |
| async def history(user: dict = Depends(require_auth), limit: int = 20, offset: int = 0): | |
| limit = min(limit, 100) | |
| data = await supabase_query("analyses", { | |
| "user_id": f"eq.{user['id']}", "select": "*", | |
| "order": "created_at.desc", "limit": str(limit), "offset": str(offset), | |
| }) | |
| return {"analyses": data, "limit": limit, "offset": offset} | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |