Spaces:
Sleeping
Sleeping
| """ | |
| ClauseGuard β FastAPI Backend (Production) | |
| Clause classification + explanations + history + JWT auth. | |
| FastAPI 0.136, Pydantic 2.13, Python 3.12 (April 2026) | |
| """ | |
| import os | |
| import time | |
| import re | |
| from contextlib import asynccontextmanager | |
| from typing import Optional | |
| import httpx | |
| import numpy as np | |
| from fastapi import FastAPI, HTTPException, Depends | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| from auth import get_current_user, require_auth | |
| # βββ Config βββ | |
| MODEL_PATH = os.environ.get("MODEL_PATH", "./clauseguard-model/final") | |
| ONNX_MODEL_PATH = os.environ.get("ONNX_MODEL_PATH", "./clauseguard-model-onnx") | |
| USE_ONNX = os.environ.get("USE_ONNX", "true").lower() == "true" | |
| SUPABASE_URL = os.environ.get("SUPABASE_URL", "") | |
| SUPABASE_SERVICE_KEY = os.environ.get("SUPABASE_SERVICE_ROLE_KEY", "") | |
| HF_API_TOKEN = os.environ.get("HF_API_TOKEN", "") | |
| SAULLM_ENDPOINT = os.environ.get("SAULLM_ENDPOINT", "") | |
| LABEL_NAMES = [ | |
| "Limitation of liability", "Unilateral termination", "Unilateral change", | |
| "Content removal", "Contract by using", "Choice of law", "Jurisdiction", "Arbitration", | |
| ] | |
| LABEL_DESCRIPTIONS = { | |
| "Limitation of liability": "Company limits or excludes liability for losses, data breaches, or service failures.", | |
| "Unilateral termination": "Company can terminate your account at any time without reason.", | |
| "Unilateral change": "Company can change terms at any time without your consent.", | |
| "Content removal": "Company can delete your content without notice or justification.", | |
| "Contract by using": "You are bound to the contract simply by using the service.", | |
| "Choice of law": "Governing law may differ from your country, reducing your legal protections.", | |
| "Jurisdiction": "Disputes must be resolved in a jurisdiction that may disadvantage you.", | |
| "Arbitration": "Forces disputes to arbitration instead of court. You waive your right to sue.", | |
| } | |
| SEVERITY_MAP = { | |
| "Limitation of liability": "HIGH", "Unilateral termination": "HIGH", "Arbitration": "HIGH", | |
| "Unilateral change": "MEDIUM", "Content removal": "MEDIUM", "Choice of law": "MEDIUM", | |
| "Jurisdiction": "MEDIUM", "Contract by using": "LOW", | |
| } | |
| LEGAL_BASIS = { | |
| "Arbitration": "EU Directive 93/13/EEC Art. 3; CFPB arbitration rule (US).", | |
| "Unilateral change": "EU Directive 93/13/EEC Annex 1(j) β unilateral alteration.", | |
| "Content removal": "EU Digital Services Act Art. 17 β statement of reasons required.", | |
| "Jurisdiction": "EU Regulation 1215/2012 Art. 18 β consumer domicile prevails.", | |
| "Choice of law": "EU Regulation 593/2008 Art. 6 β consumer protection of habitual residence.", | |
| "Limitation of liability": "EU Directive 93/13/EEC Annex 1(a) β excluding statutory rights.", | |
| "Unilateral termination": "EU Directive 93/13/EEC Annex 1(f)(g) β termination without notice.", | |
| "Contract by using": "EU Directive 2011/83/EU Art. 8 β active consent required.", | |
| } | |
| # βββ Model βββ | |
| classifier = None | |
| def load_model(): | |
| global classifier | |
| try: | |
| if USE_ONNX and os.path.exists(ONNX_MODEL_PATH): | |
| from optimum.onnxruntime import ORTModelForSequenceClassification | |
| from transformers import AutoTokenizer, pipeline | |
| model = ORTModelForSequenceClassification.from_pretrained(ONNX_MODEL_PATH) | |
| tokenizer = AutoTokenizer.from_pretrained(ONNX_MODEL_PATH) | |
| classifier = pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None) | |
| elif os.path.exists(MODEL_PATH): | |
| from transformers import pipeline | |
| classifier = pipeline("text-classification", model=MODEL_PATH, top_k=None, device=-1) | |
| except Exception as e: | |
| print(f"Model load failed: {e}") | |
| # βββ Regex fallback βββ | |
| PATTERNS = { | |
| 0: [r"not liable", r"shall not be (liable|responsible)", r"in no event.*liable", r"limitation of liability", r"without warranty", r"disclaim"], | |
| 1: [r"terminat.*at any time", r"suspend.*account.*without", r"we may (terminat|suspend|discontinu)", r"right to (terminat|suspend)"], | |
| 2: [r"sole discretion", r"reserves? the right to (modify|change|update|amend)", r"at any time.*without (prior )?notice", r"we may (modify|change|update)"], | |
| 3: [r"remove.*content.*without", r"right to remove", r"we may.*remove"], | |
| 4: [r"by (using|accessing).*you agree", r"continued use.*constitutes? acceptance"], | |
| 5: [r"governed by.*laws? of", r"shall be governed", r"laws of the state of"], | |
| 6: [r"exclusive jurisdiction", r"courts? of.*(california|delaware|new york|ireland|england)", r"submit to.*jurisdiction"], | |
| 7: [r"arbitrat", r"binding arbitration", r"waive.*right.*court", r"class action waiver"], | |
| } | |
| def classify_clause(text: str) -> list[dict]: | |
| if classifier: | |
| try: | |
| preds = classifier(text, truncation=True, max_length=512) | |
| items = preds[0] if isinstance(preds[0], list) else preds | |
| return [ | |
| {"name": p["label"], "severity": SEVERITY_MAP.get(p["label"], "MEDIUM"), | |
| "description": LABEL_DESCRIPTIONS.get(p["label"], ""), "confidence": round(p["score"], 3)} | |
| for p in items if p["score"] > 0.5 and p["label"] in LABEL_DESCRIPTIONS | |
| ] | |
| except Exception: | |
| pass | |
| # Regex fallback | |
| results = [] | |
| text_lower = text.lower() | |
| for lid, pats in PATTERNS.items(): | |
| for p in pats: | |
| if re.search(p, text_lower): | |
| name = LABEL_NAMES[lid] | |
| results.append({"name": name, "severity": SEVERITY_MAP[name], | |
| "description": LABEL_DESCRIPTIONS[name], "confidence": 0.7}) | |
| break | |
| return results | |
| # βββ Supabase helper βββ | |
| async def supabase_insert(table: str, data: dict): | |
| if not SUPABASE_URL or not SUPABASE_SERVICE_KEY: | |
| return | |
| async with httpx.AsyncClient() as client: | |
| await client.post( | |
| f"{SUPABASE_URL}/rest/v1/{table}", | |
| json=data, | |
| headers={"apikey": SUPABASE_SERVICE_KEY, "Authorization": f"Bearer {SUPABASE_SERVICE_KEY}", | |
| "Content-Type": "application/json", "Prefer": "return=minimal"}, | |
| ) | |
| async def supabase_query(table: str, params: dict, headers_extra: dict = {}): | |
| if not SUPABASE_URL or not SUPABASE_SERVICE_KEY: | |
| return [] | |
| async with httpx.AsyncClient() as client: | |
| resp = await client.get( | |
| f"{SUPABASE_URL}/rest/v1/{table}", | |
| params=params, | |
| headers={"apikey": SUPABASE_SERVICE_KEY, "Authorization": f"Bearer {SUPABASE_SERVICE_KEY}", **headers_extra}, | |
| ) | |
| return resp.json() if resp.status_code == 200 else [] | |
| # βββ Models βββ | |
| class AnalyzeRequest(BaseModel): | |
| clauses: list[str] = Field(..., min_length=1, max_length=500) | |
| source_url: Optional[str] = None | |
| class AnalyzeResponse(BaseModel): | |
| risk_score: int | |
| grade: str | |
| total_clauses: int | |
| flagged_count: int | |
| results: list[dict] | |
| model: str | |
| latency_ms: int | |
| class ExplainRequest(BaseModel): | |
| clause: str = Field(..., min_length=10, max_length=2000) | |
| category: str | |
| class ExplainResponse(BaseModel): | |
| clause: str | |
| category: str | |
| explanation: str | |
| legal_basis: str | |
| recommendation: str | |
| # βββ App βββ | |
| async def lifespan(app: FastAPI): | |
| load_model() | |
| yield | |
| app = FastAPI(title="ClauseGuard API", version="1.0.0", lifespan=lifespan) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["https://clauseguardweb.netlify.app", "https://clauseguardweb.netlify.app", "chrome-extension://*", "http://localhost:3000"], | |
| allow_credentials=True, allow_methods=["*"], allow_headers=["*"], | |
| ) | |
| async def health(): | |
| return {"status": "ok", "model": "ml" if classifier else "regex"} | |
| async def analyze(req: AnalyzeRequest, user: Optional[dict] = Depends(get_current_user)): | |
| start = time.time() | |
| results = [{"text": c, "categories": classify_clause(c)} for c in req.clauses] | |
| flagged = [r for r in results if r["categories"]] | |
| sev = {"HIGH": 0, "MEDIUM": 0, "LOW": 0} | |
| for r in flagged: | |
| for c in r["categories"]: | |
| sev[c.get("severity", "LOW")] += 1 | |
| total = len(req.clauses) | |
| risk = min(100, round((sev["HIGH"] * 20 + sev["MEDIUM"] * 10 + sev["LOW"] * 5) / max(1, total) * 100)) | |
| grade = "F" if risk >= 60 else "D" if risk >= 40 else "C" if risk >= 20 else "B" if risk >= 10 else "A" | |
| latency = int((time.time() - start) * 1000) | |
| # Save to DB if authenticated | |
| if user: | |
| await supabase_insert("analyses", { | |
| "user_id": user["id"], "source_url": req.source_url, "total_clauses": total, | |
| "flagged_count": len(flagged), "risk_score": risk, "grade": grade, "clauses": results, | |
| }) | |
| return AnalyzeResponse(risk_score=risk, grade=grade, total_clauses=total, | |
| flagged_count=len(flagged), results=results, | |
| model="ml" if classifier else "regex", latency_ms=latency) | |
| async def explain(req: ExplainRequest, user: dict = Depends(require_auth)): | |
| desc = LABEL_DESCRIPTIONS.get(req.category, "Unknown category.") | |
| legal = LEGAL_BASIS.get(req.category, "Consult local consumer protection laws.") | |
| recommendation = "Review this clause carefully. Consider negotiating or seeking legal advice before agreeing." | |
| # Try SaulLM-7B if endpoint configured | |
| if SAULLM_ENDPOINT and HF_API_TOKEN: | |
| try: | |
| prompt = f"""You are a consumer protection legal analyst. Analyze this clause and explain why it may be unfair. | |
| Clause: "{req.clause}" | |
| Category: {req.category} | |
| Provide: | |
| 1. A plain-English explanation of why this is problematic | |
| 2. The specific legal basis (EU/US consumer protection law) | |
| 3. A practical recommendation for the consumer | |
| Be concise. 3-4 sentences maximum per section.""" | |
| async with httpx.AsyncClient(timeout=30.0) as client: | |
| resp = await client.post( | |
| SAULLM_ENDPOINT, | |
| json={"inputs": prompt, "parameters": {"max_new_tokens": 300, "temperature": 0.3}}, | |
| headers={"Authorization": f"Bearer {HF_API_TOKEN}"}, | |
| ) | |
| if resp.status_code == 200: | |
| output = resp.json() | |
| generated = output[0]["generated_text"] if isinstance(output, list) else output.get("generated_text", "") | |
| if generated and len(generated) > 50: | |
| parts = generated.split("\n\n") | |
| desc = parts[0] if len(parts) > 0 else desc | |
| legal = parts[1] if len(parts) > 1 else legal | |
| recommendation = parts[2] if len(parts) > 2 else recommendation | |
| except Exception: | |
| pass # Fall back to static responses | |
| return ExplainResponse(clause=req.clause, category=req.category, | |
| explanation=desc, legal_basis=legal, recommendation=recommendation) | |
| async def history(user: dict = Depends(require_auth), limit: int = 20, offset: int = 0): | |
| limit = min(limit, 100) | |
| data = await supabase_query("analyses", { | |
| "user_id": f"eq.{user['id']}", "select": "*", | |
| "order": "created_at.desc", "limit": str(limit), "offset": str(offset), | |
| }) | |
| return {"analyses": data, "limit": limit, "offset": offset} | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |