Spaces:
Sleeping
Sleeping
Commit ·
71c1ad2
0
Parent(s):
initial deployment for HF Spaces
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .dockerignore +54 -0
- .env +47 -0
- .env.example +48 -0
- .gitattributes +36 -0
- Dockerfile +34 -0
- README.md +11 -0
- app/__init__.py +2 -0
- app/api/__init__.py +2 -0
- app/api/router.py +27 -0
- app/api/schemas/__init__.py +2 -0
- app/api/schemas/requests.py +43 -0
- app/api/schemas/responses.py +136 -0
- app/api/v1/__init__.py +2 -0
- app/api/v1/alerts.py +151 -0
- app/api/v1/analyze.py +330 -0
- app/api/v1/auth.py +217 -0
- app/api/v1/health.py +51 -0
- app/api/v1/history.py +68 -0
- app/api/v1/scan.py +166 -0
- app/api/v1/users.py +61 -0
- app/config.py +99 -0
- app/core/__init__.py +1 -0
- app/core/dependencies.py +48 -0
- app/core/security.py +79 -0
- app/db/__init__.py +1 -0
- app/db/connection.py +36 -0
- app/db/models/__init__.py +10 -0
- app/db/models/alert.py +75 -0
- app/db/models/scan_result.py +43 -0
- app/db/models/user.py +74 -0
- app/dependencies.py +29 -0
- app/main.py +154 -0
- app/models/__init__.py +2 -0
- app/models/clip_model.py +145 -0
- app/models/image_model.py +180 -0
- app/models/model_registry.py +99 -0
- app/models/onnx_utils.py +120 -0
- app/models/text_model.py +178 -0
- app/observability/__init__.py +2 -0
- app/observability/langsmith.py +59 -0
- app/observability/logging.py +41 -0
- app/pipeline/__init__.py +2 -0
- app/pipeline/decision_engine.py +142 -0
- app/pipeline/deep_analyzer.py +153 -0
- app/pipeline/fast_filter.py +124 -0
- app/pipeline/preprocessor.py +141 -0
- app/pipeline/risk_scorer.py +166 -0
- app/pipeline/workflow.py +327 -0
- app/services/__init__.py +2 -0
- app/services/gemini_service.py +247 -0
.dockerignore
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Environments
|
| 2 |
+
venv/
|
| 3 |
+
.venv/
|
| 4 |
+
env/
|
| 5 |
+
.env
|
| 6 |
+
|
| 7 |
+
# Python
|
| 8 |
+
__pycache__/
|
| 9 |
+
*.pyc
|
| 10 |
+
*.pyo
|
| 11 |
+
*.pyd
|
| 12 |
+
.Python
|
| 13 |
+
build/
|
| 14 |
+
develop-eggs/
|
| 15 |
+
dist/
|
| 16 |
+
downloads/
|
| 17 |
+
eggs/
|
| 18 |
+
.eggs/
|
| 19 |
+
lib/
|
| 20 |
+
lib64/
|
| 21 |
+
parts/
|
| 22 |
+
sdist/
|
| 23 |
+
var/
|
| 24 |
+
wheels/
|
| 25 |
+
*.egg-info/
|
| 26 |
+
.installed.cfg
|
| 27 |
+
*.egg
|
| 28 |
+
|
| 29 |
+
# Data & Models (not baked into image — loaded at runtime)
|
| 30 |
+
model_cache/
|
| 31 |
+
tests/
|
| 32 |
+
.ipynb_checkpoints/
|
| 33 |
+
|
| 34 |
+
# Version Control
|
| 35 |
+
.git/
|
| 36 |
+
.github/
|
| 37 |
+
.gitignore
|
| 38 |
+
|
| 39 |
+
# IDEs
|
| 40 |
+
.vscode/
|
| 41 |
+
.idea/
|
| 42 |
+
|
| 43 |
+
# OS files
|
| 44 |
+
.DS_Store
|
| 45 |
+
Thumbs.db
|
| 46 |
+
|
| 47 |
+
# HF Spaces — README is required at repo root, not in image
|
| 48 |
+
# README.md is NOT excluded so HF can read the frontmatter
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
model_cache/
|
| 52 |
+
__pycache__/
|
| 53 |
+
*.pyc
|
| 54 |
+
.env
|
.env
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Server
|
| 2 |
+
HOST=0.0.0.0
|
| 3 |
+
PORT=8000
|
| 4 |
+
ENV=development
|
| 5 |
+
LOG_LEVEL=INFO
|
| 6 |
+
|
| 7 |
+
# JWT Authentication
|
| 8 |
+
JWT_ACCESS_SECRET=hubble-access-secret-change-in-production-32chars
|
| 9 |
+
JWT_REFRESH_SECRET=hubble-refresh-secret-change-in-production-32chars
|
| 10 |
+
JWT_ACCESS_EXPIRES_MINUTES=15
|
| 11 |
+
JWT_REFRESH_EXPIRES_DAYS=7
|
| 12 |
+
|
| 13 |
+
# Security
|
| 14 |
+
BCRYPT_ROUNDS=12
|
| 15 |
+
CORS_ORIGINS=http://localhost:3000,http://localhost:3001
|
| 16 |
+
|
| 17 |
+
# MongoDB (Atlas Cloud)
|
| 18 |
+
MONGODB_URI=mongodb+srv://sajithjaganathan7_db_user:Winter_bear_07@cluster0.jxdvukx.mongodb.net/hubble?appName=Cluster0
|
| 19 |
+
MONGODB_DB_NAME=hubble
|
| 20 |
+
|
| 21 |
+
# Redis (Redis Cloud)
|
| 22 |
+
REDIS_URL=redis://default:dongH74t41QfBN0TO0e5ylWAVThXZoLR@redis-13470.crce281.ap-south-1-3.ec2.cloud.redislabs.com:13470
|
| 23 |
+
REDIS_CACHE_TTL=300
|
| 24 |
+
|
| 25 |
+
# Gemini API Keys (User should provide real keys for this)
|
| 26 |
+
GEMINI_API_KEYS=AIzaSyBL6onsP6Z-wG32nFgy8Bi7uGDGIopbRDE
|
| 27 |
+
GEMINI_MODEL=gemini-2.5-flash
|
| 28 |
+
|
| 29 |
+
# LangSmith Tracing
|
| 30 |
+
LANGSMITH_API_KEY=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
|
| 31 |
+
LANGSMITH_PROJECT=hubble-moderation
|
| 32 |
+
LANGSMITH_TRACING_V2=true
|
| 33 |
+
|
| 34 |
+
# Model Configuration
|
| 35 |
+
MODEL_CACHE_DIR=./model_cache
|
| 36 |
+
ONNX_ENABLED=false
|
| 37 |
+
TEXT_MODEL_NAME=unitary/toxic-bert
|
| 38 |
+
IMAGE_MODEL_NAME=google/efficientnet-b0
|
| 39 |
+
CLIP_MODEL_NAME=openai/clip-vit-base-patch32
|
| 40 |
+
|
| 41 |
+
# Risk Thresholds
|
| 42 |
+
RISK_LOW_MAX=30
|
| 43 |
+
RISK_MEDIUM_MAX=65
|
| 44 |
+
|
| 45 |
+
# Video Processing
|
| 46 |
+
VIDEO_MAX_FRAMES=10
|
| 47 |
+
VIDEO_FPS_SAMPLE=1
|
.env.example
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================
|
| 2 |
+
# Hubble AI Engine - Environment Configuration
|
| 3 |
+
# ============================================
|
| 4 |
+
# On Hugging Face Spaces, set these as Secrets
|
| 5 |
+
# in the Space settings (Settings > Secrets).
|
| 6 |
+
# They are injected as environment variables at runtime.
|
| 7 |
+
|
| 8 |
+
# Server
|
| 9 |
+
HOST=0.0.0.0
|
| 10 |
+
PORT=7860
|
| 11 |
+
ENV=production
|
| 12 |
+
LOG_LEVEL=INFO
|
| 13 |
+
|
| 14 |
+
# MongoDB (async via motor)
|
| 15 |
+
MONGODB_URI=mongodb+srv://<user>:<password>@<cluster>.mongodb.net/hubble
|
| 16 |
+
MONGODB_DB_NAME=hubble
|
| 17 |
+
|
| 18 |
+
# Redis
|
| 19 |
+
REDIS_URL=redis://default:<password>@<host>:<port>
|
| 20 |
+
REDIS_CACHE_TTL=300
|
| 21 |
+
|
| 22 |
+
# Gemini API Keys (comma-separated for rotation)
|
| 23 |
+
GEMINI_API_KEYS=your-key-1,your-key-2,your-key-3
|
| 24 |
+
GEMINI_MODEL=gemini-2.5-flash
|
| 25 |
+
|
| 26 |
+
# LangSmith Observability (optional)
|
| 27 |
+
LANGSMITH_API_KEY=your-langsmith-api-key
|
| 28 |
+
LANGSMITH_PROJECT=hubble-moderation
|
| 29 |
+
LANGSMITH_TRACING_V2=true
|
| 30 |
+
|
| 31 |
+
# JWT Secrets (use long random strings)
|
| 32 |
+
JWT_ACCESS_SECRET=your-access-secret-at-least-32-chars
|
| 33 |
+
JWT_REFRESH_SECRET=your-refresh-secret-at-least-32-chars
|
| 34 |
+
|
| 35 |
+
# Model Configuration
|
| 36 |
+
MODEL_CACHE_DIR=/tmp/model_cache
|
| 37 |
+
ONNX_ENABLED=false
|
| 38 |
+
TEXT_MODEL_NAME=unitary/toxic-bert
|
| 39 |
+
IMAGE_MODEL_NAME=google/efficientnet-b0
|
| 40 |
+
CLIP_MODEL_NAME=openai/clip-vit-base-patch32
|
| 41 |
+
|
| 42 |
+
# Risk Thresholds
|
| 43 |
+
RISK_LOW_MAX=30
|
| 44 |
+
RISK_MEDIUM_MAX=65
|
| 45 |
+
|
| 46 |
+
# Video Processing
|
| 47 |
+
VIDEO_MAX_FRAMES=10
|
| 48 |
+
VIDEO_FPS_SAMPLE=1
|
.gitattributes
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.onnx.data filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hugging Face Spaces — Hubble AI Engine
|
| 2 |
+
FROM python:3.12-slim
|
| 3 |
+
|
| 4 |
+
# HF Spaces requires port 7860
|
| 5 |
+
ENV PYTHONDONTWRITEBYTECODE=1
|
| 6 |
+
ENV PYTHONUNBUFFERED=1
|
| 7 |
+
ENV HOST=0.0.0.0
|
| 8 |
+
ENV PORT=7860
|
| 9 |
+
|
| 10 |
+
WORKDIR /app
|
| 11 |
+
|
| 12 |
+
# System dependencies for OpenCV and native builds
|
| 13 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 14 |
+
build-essential \
|
| 15 |
+
libgl1 \
|
| 16 |
+
libglib2.0-0 \
|
| 17 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 18 |
+
|
| 19 |
+
# Install Python dependencies
|
| 20 |
+
COPY requirements.txt .
|
| 21 |
+
RUN pip install --no-cache-dir --upgrade pip && \
|
| 22 |
+
pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu && \
|
| 23 |
+
pip install --no-cache-dir -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cpu
|
| 24 |
+
|
| 25 |
+
COPY . .
|
| 26 |
+
|
| 27 |
+
# HF Spaces runs containers as a non-root user (UID 1000)
|
| 28 |
+
RUN useradd -m -u 1000 user
|
| 29 |
+
RUN mkdir -p /tmp/model_cache && chown -R user:user /tmp/model_cache && chown -R user:user /app
|
| 30 |
+
USER user
|
| 31 |
+
|
| 32 |
+
EXPOSE 7860
|
| 33 |
+
|
| 34 |
+
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: SentinelAI
|
| 3 |
+
emoji: 📈
|
| 4 |
+
colorFrom: green
|
| 5 |
+
colorTo: blue
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
license: mit
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/__init__.py
|
| 2 |
+
"""Hubble AI Engine — Production-grade cyberbullying detection pipeline."""
|
app/api/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/api/__init__.py
|
| 2 |
+
"""API layer: routes, schemas, and versioning."""
|
app/api/router.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/api/router.py
|
| 2 |
+
# Root router — aggregates all versioned API routes
|
| 3 |
+
|
| 4 |
+
from fastapi import APIRouter
|
| 5 |
+
from app.api.v1.health import router as health_router
|
| 6 |
+
from app.api.v1.analyze import router as analyze_router
|
| 7 |
+
from app.api.v1.history import router as history_router
|
| 8 |
+
from app.api.v1.auth import router as auth_router
|
| 9 |
+
from app.api.v1.users import router as users_router
|
| 10 |
+
from app.api.v1.scan import router as scan_router
|
| 11 |
+
from app.api.v1.alerts import router as alerts_router
|
| 12 |
+
|
| 13 |
+
# Main API router
|
| 14 |
+
api_router = APIRouter()
|
| 15 |
+
|
| 16 |
+
# ── Unauthenticated: health + raw AI pipeline ──
|
| 17 |
+
api_router.include_router(health_router, prefix="")
|
| 18 |
+
api_router.include_router(analyze_router, prefix="/api/v1")
|
| 19 |
+
api_router.include_router(history_router, prefix="/api/v1")
|
| 20 |
+
|
| 21 |
+
# ── Auth ──
|
| 22 |
+
api_router.include_router(auth_router, prefix="/api/v1")
|
| 23 |
+
|
| 24 |
+
# ── Authenticated business logic ──
|
| 25 |
+
api_router.include_router(users_router, prefix="/api/v1")
|
| 26 |
+
api_router.include_router(scan_router, prefix="/api/v1")
|
| 27 |
+
api_router.include_router(alerts_router, prefix="/api/v1")
|
app/api/schemas/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/api/schemas/__init__.py
|
| 2 |
+
"""Pydantic request/response schemas."""
|
app/api/schemas/requests.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/api/schemas/requests.py
|
| 2 |
+
# Pydantic request models for API endpoints
|
| 3 |
+
|
| 4 |
+
from pydantic import BaseModel, Field
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class TextAnalysisRequest(BaseModel):
|
| 9 |
+
"""Request body for text analysis."""
|
| 10 |
+
text: str = Field(..., min_length=1, max_length=10000, description="Text content to analyze")
|
| 11 |
+
user_id: Optional[str] = Field(None, description="Optional user ID for history tracking")
|
| 12 |
+
source_app: Optional[str] = Field(None, description="Source application (e.g., 'instagram', 'whatsapp')")
|
| 13 |
+
metadata: Optional[dict] = Field(None, description="Additional metadata")
|
| 14 |
+
|
| 15 |
+
model_config = {
|
| 16 |
+
"json_schema_extra": {
|
| 17 |
+
"examples": [
|
| 18 |
+
{
|
| 19 |
+
"text": "You are so ugly nobody likes you",
|
| 20 |
+
"user_id": "user_123",
|
| 21 |
+
"source_app": "instagram",
|
| 22 |
+
}
|
| 23 |
+
]
|
| 24 |
+
}
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class ImageAnalysisRequest(BaseModel):
|
| 29 |
+
"""Metadata for image analysis (file sent as multipart)."""
|
| 30 |
+
user_id: Optional[str] = Field(None, description="Optional user ID for history tracking")
|
| 31 |
+
source_app: Optional[str] = Field(None, description="Source application")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class VideoAnalysisRequest(BaseModel):
|
| 35 |
+
"""Metadata for video analysis (file sent as multipart)."""
|
| 36 |
+
user_id: Optional[str] = Field(None, description="Optional user ID for history tracking")
|
| 37 |
+
source_app: Optional[str] = Field(None, description="Source application")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class HistoryRequest(BaseModel):
|
| 41 |
+
"""Parameters for history queries."""
|
| 42 |
+
limit: int = Field(20, ge=1, le=100, description="Maximum results to return")
|
| 43 |
+
skip: int = Field(0, ge=0, description="Number of results to skip")
|
app/api/schemas/responses.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/api/schemas/responses.py
|
| 2 |
+
# Pydantic response models for API endpoints
|
| 3 |
+
|
| 4 |
+
from pydantic import BaseModel, Field
|
| 5 |
+
from typing import Optional, Literal, Any
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class DecisionDetail(BaseModel):
|
| 10 |
+
"""Details of the moderation decision."""
|
| 11 |
+
action: Literal["ALLOWED", "WARNING", "BLOCKED"]
|
| 12 |
+
reason: str
|
| 13 |
+
severity: str
|
| 14 |
+
should_alert_parent: bool = False
|
| 15 |
+
escalation_notes: Optional[str] = None
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class DeepAnalysisDetail(BaseModel):
|
| 19 |
+
"""Details from the deep analysis stage (only present for HIGH risk)."""
|
| 20 |
+
is_confirmed: bool
|
| 21 |
+
severity: str
|
| 22 |
+
reasoning: str
|
| 23 |
+
categories: list[str] = []
|
| 24 |
+
recommended_action: str
|
| 25 |
+
confidence: float
|
| 26 |
+
clip_scores: dict = {}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class RiskDetail(BaseModel):
|
| 30 |
+
"""Breakdown of risk scoring."""
|
| 31 |
+
score: float
|
| 32 |
+
level: Literal["LOW", "MEDIUM", "HIGH"]
|
| 33 |
+
components: dict = {}
|
| 34 |
+
repeat_offender: bool = False
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class FilterDetail(BaseModel):
|
| 38 |
+
"""Fast filter stage output."""
|
| 39 |
+
is_flagged: bool
|
| 40 |
+
scores: dict[str, float] = {}
|
| 41 |
+
max_score: float
|
| 42 |
+
max_label: str
|
| 43 |
+
categories: list[str] = []
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class AnalysisResponse(BaseModel):
|
| 47 |
+
"""
|
| 48 |
+
Unified response for all /analyze/* endpoints.
|
| 49 |
+
|
| 50 |
+
This is the primary contract between the AI engine
|
| 51 |
+
and the Node.js backend (and any external consumers).
|
| 52 |
+
"""
|
| 53 |
+
request_id: str = Field(..., description="Unique request identifier")
|
| 54 |
+
input_type: Literal["text", "image", "video"]
|
| 55 |
+
status: Literal["ALLOWED", "WARNING", "BLOCKED"]
|
| 56 |
+
risk_level: Literal["LOW", "MEDIUM", "HIGH"]
|
| 57 |
+
risk_score: float = Field(..., ge=0, le=100, description="Composite risk score 0-100")
|
| 58 |
+
categories: list[str] = Field(default_factory=list, description="Detected abuse categories")
|
| 59 |
+
confidence: float = Field(..., ge=0, le=1, description="Overall confidence 0-1")
|
| 60 |
+
|
| 61 |
+
# Detailed breakdowns
|
| 62 |
+
decision: DecisionDetail
|
| 63 |
+
risk_detail: RiskDetail
|
| 64 |
+
filter_detail: FilterDetail
|
| 65 |
+
deep_analysis: Optional[DeepAnalysisDetail] = None
|
| 66 |
+
|
| 67 |
+
# Metadata
|
| 68 |
+
processing_time_ms: int = Field(..., description="Total processing time in milliseconds")
|
| 69 |
+
trace_id: Optional[str] = Field(None, description="LangSmith trace ID for observability")
|
| 70 |
+
cached: bool = Field(False, description="Whether this result was served from cache")
|
| 71 |
+
|
| 72 |
+
model_config = {
|
| 73 |
+
"json_schema_extra": {
|
| 74 |
+
"examples": [
|
| 75 |
+
{
|
| 76 |
+
"request_id": "req_abc123",
|
| 77 |
+
"input_type": "text",
|
| 78 |
+
"status": "WARNING",
|
| 79 |
+
"risk_level": "MEDIUM",
|
| 80 |
+
"risk_score": 45.2,
|
| 81 |
+
"categories": ["insult", "toxic"],
|
| 82 |
+
"confidence": 0.82,
|
| 83 |
+
"decision": {
|
| 84 |
+
"action": "WARNING",
|
| 85 |
+
"reason": "Content flagged as potentially harmful",
|
| 86 |
+
"severity": "medium",
|
| 87 |
+
"should_alert_parent": False,
|
| 88 |
+
},
|
| 89 |
+
"risk_detail": {
|
| 90 |
+
"score": 45.2,
|
| 91 |
+
"level": "MEDIUM",
|
| 92 |
+
"components": {
|
| 93 |
+
"base_score": 42.0,
|
| 94 |
+
"multi_category_penalty": 3.2,
|
| 95 |
+
"repeat_offender_boost": 0.0,
|
| 96 |
+
},
|
| 97 |
+
"repeat_offender": False,
|
| 98 |
+
},
|
| 99 |
+
"filter_detail": {
|
| 100 |
+
"is_flagged": True,
|
| 101 |
+
"scores": {"toxic": 0.78, "insult": 0.65},
|
| 102 |
+
"max_score": 0.78,
|
| 103 |
+
"max_label": "toxic",
|
| 104 |
+
"categories": ["toxic", "insult"],
|
| 105 |
+
},
|
| 106 |
+
"deep_analysis": None,
|
| 107 |
+
"processing_time_ms": 156,
|
| 108 |
+
"trace_id": None,
|
| 109 |
+
"cached": False,
|
| 110 |
+
}
|
| 111 |
+
]
|
| 112 |
+
}
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class HealthResponse(BaseModel):
|
| 117 |
+
"""Health check response."""
|
| 118 |
+
status: str
|
| 119 |
+
version: str
|
| 120 |
+
models: dict[str, bool]
|
| 121 |
+
services: dict[str, bool]
|
| 122 |
+
uptime_seconds: float
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class HistoryResponse(BaseModel):
|
| 126 |
+
"""Moderation history response."""
|
| 127 |
+
user_id: str
|
| 128 |
+
total: int
|
| 129 |
+
results: list[dict[str, Any]]
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class ErrorResponse(BaseModel):
|
| 133 |
+
"""Standard error response."""
|
| 134 |
+
error: str
|
| 135 |
+
detail: str
|
| 136 |
+
request_id: Optional[str] = None
|
app/api/v1/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/api/v1/__init__.py
|
| 2 |
+
"""API v1 endpoints."""
|
app/api/v1/alerts.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/api/v1/alerts.py
|
| 2 |
+
# Alert endpoints for parents and children
|
| 3 |
+
|
| 4 |
+
from fastapi import APIRouter, HTTPException, Depends
|
| 5 |
+
from pydantic import BaseModel
|
| 6 |
+
from typing import Optional
|
| 7 |
+
from app.core.dependencies import get_current_user
|
| 8 |
+
from app.db.models.user import UserDocument, UserRole
|
| 9 |
+
from app.db.models.alert import AlertDocument, AlertStatus
|
| 10 |
+
from app.observability.logging import get_logger
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
|
| 13 |
+
logger = get_logger(__name__)
|
| 14 |
+
router = APIRouter(prefix="/alerts", tags=["Alerts"])
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@router.get("")
|
| 18 |
+
async def list_alerts(
|
| 19 |
+
page: int = 1,
|
| 20 |
+
limit: int = 20,
|
| 21 |
+
status: Optional[str] = None,
|
| 22 |
+
severity: Optional[str] = None,
|
| 23 |
+
user: UserDocument = Depends(get_current_user),
|
| 24 |
+
):
|
| 25 |
+
"""List alerts. Parents see all their children's alerts; children see their own."""
|
| 26 |
+
skip = (page - 1) * limit
|
| 27 |
+
|
| 28 |
+
if user.role == UserRole.PARENT:
|
| 29 |
+
query = AlertDocument.find(AlertDocument.parent_id == str(user.id))
|
| 30 |
+
else:
|
| 31 |
+
query = AlertDocument.find(AlertDocument.child_id == str(user.id))
|
| 32 |
+
|
| 33 |
+
if status:
|
| 34 |
+
query = query.find(AlertDocument.status == AlertStatus(status))
|
| 35 |
+
if severity:
|
| 36 |
+
from app.db.models.alert import AlertSeverity
|
| 37 |
+
query = query.find(AlertDocument.severity == AlertSeverity(severity))
|
| 38 |
+
|
| 39 |
+
alerts = await query.sort(-AlertDocument.created_at).skip(skip).limit(limit).to_list()
|
| 40 |
+
return {
|
| 41 |
+
"success": True,
|
| 42 |
+
"page": page,
|
| 43 |
+
"limit": limit,
|
| 44 |
+
"alerts": [_fmt(a) for a in alerts],
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@router.get("/{alert_id}")
|
| 49 |
+
async def get_alert(
|
| 50 |
+
alert_id: str,
|
| 51 |
+
user: UserDocument = Depends(get_current_user),
|
| 52 |
+
):
|
| 53 |
+
alert = await AlertDocument.get(alert_id)
|
| 54 |
+
if not alert:
|
| 55 |
+
raise HTTPException(status_code=404, detail="Alert not found")
|
| 56 |
+
_check_access(alert, user)
|
| 57 |
+
return {"success": True, "alert": _fmt(alert)}
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class AcknowledgeRequest(BaseModel):
|
| 61 |
+
resolution_notes: Optional[str] = None
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@router.post("/{alert_id}/acknowledge")
|
| 65 |
+
async def acknowledge_alert(
|
| 66 |
+
alert_id: str,
|
| 67 |
+
user: UserDocument = Depends(get_current_user),
|
| 68 |
+
):
|
| 69 |
+
alert = await AlertDocument.get(alert_id)
|
| 70 |
+
if not alert:
|
| 71 |
+
raise HTTPException(status_code=404, detail="Alert not found")
|
| 72 |
+
_check_access(alert, user)
|
| 73 |
+
alert.status = AlertStatus.ACKNOWLEDGED
|
| 74 |
+
alert.acknowledged_at = datetime.utcnow()
|
| 75 |
+
alert.updated_at = datetime.utcnow()
|
| 76 |
+
await alert.save()
|
| 77 |
+
return {"success": True, "alert": _fmt(alert)}
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
@router.post("/{alert_id}/resolve")
|
| 81 |
+
async def resolve_alert(
|
| 82 |
+
alert_id: str,
|
| 83 |
+
body: AcknowledgeRequest,
|
| 84 |
+
user: UserDocument = Depends(get_current_user),
|
| 85 |
+
):
|
| 86 |
+
if user.role != UserRole.PARENT:
|
| 87 |
+
raise HTTPException(status_code=403, detail="Only parents can resolve alerts")
|
| 88 |
+
alert = await AlertDocument.get(alert_id)
|
| 89 |
+
if not alert:
|
| 90 |
+
raise HTTPException(status_code=404, detail="Alert not found")
|
| 91 |
+
if alert.parent_id != str(user.id):
|
| 92 |
+
raise HTTPException(status_code=403, detail="Cannot resolve this alert")
|
| 93 |
+
alert.status = AlertStatus.RESOLVED
|
| 94 |
+
alert.resolved_at = datetime.utcnow()
|
| 95 |
+
alert.resolved_by = str(user.id)
|
| 96 |
+
alert.resolution_notes = body.resolution_notes
|
| 97 |
+
alert.updated_at = datetime.utcnow()
|
| 98 |
+
await alert.save()
|
| 99 |
+
return {"success": True, "alert": _fmt(alert)}
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
@router.get("/stats/summary")
|
| 103 |
+
async def alert_stats(user: UserDocument = Depends(get_current_user)):
|
| 104 |
+
if user.role != UserRole.PARENT:
|
| 105 |
+
raise HTTPException(status_code=403, detail="Only parents can view stats")
|
| 106 |
+
alerts = await AlertDocument.find(
|
| 107 |
+
AlertDocument.parent_id == str(user.id)
|
| 108 |
+
).to_list()
|
| 109 |
+
|
| 110 |
+
by_severity: dict = {}
|
| 111 |
+
by_status: dict = {}
|
| 112 |
+
by_category: dict = {}
|
| 113 |
+
for a in alerts:
|
| 114 |
+
by_severity[a.severity.value] = by_severity.get(a.severity.value, 0) + 1
|
| 115 |
+
by_status[a.status.value] = by_status.get(a.status.value, 0) + 1
|
| 116 |
+
for c in a.categories:
|
| 117 |
+
by_category[c] = by_category.get(c, 0) + 1
|
| 118 |
+
|
| 119 |
+
return {
|
| 120 |
+
"success": True,
|
| 121 |
+
"total": len(alerts),
|
| 122 |
+
"by_severity": by_severity,
|
| 123 |
+
"by_status": by_status,
|
| 124 |
+
"by_category": by_category,
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
# ──────────────────────────────────────────────
|
| 129 |
+
# Helpers
|
| 130 |
+
# ──────────────────────────────────────────────
|
| 131 |
+
|
| 132 |
+
def _check_access(alert: AlertDocument, user: UserDocument):
|
| 133 |
+
if alert.parent_id != str(user.id) and alert.child_id != str(user.id):
|
| 134 |
+
raise HTTPException(status_code=403, detail="Cannot access this alert")
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def _fmt(a: AlertDocument) -> dict:
|
| 138 |
+
return {
|
| 139 |
+
"id": str(a.id),
|
| 140 |
+
"child_id": a.child_id,
|
| 141 |
+
"parent_id": a.parent_id,
|
| 142 |
+
"title": a.title,
|
| 143 |
+
"message": a.message,
|
| 144 |
+
"guidance": a.guidance,
|
| 145 |
+
"severity": a.severity.value,
|
| 146 |
+
"categories": a.categories,
|
| 147 |
+
"status": a.status.value,
|
| 148 |
+
"created_at": a.created_at.isoformat(),
|
| 149 |
+
"acknowledged_at": a.acknowledged_at.isoformat() if a.acknowledged_at else None,
|
| 150 |
+
"resolved_at": a.resolved_at.isoformat() if a.resolved_at else None,
|
| 151 |
+
}
|
app/api/v1/analyze.py
ADDED
|
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/api/v1/analyze.py
|
| 2 |
+
# Core analysis endpoints — text, image, video
|
| 3 |
+
|
| 4 |
+
import time
|
| 5 |
+
import uuid
|
| 6 |
+
from dataclasses import asdict
|
| 7 |
+
from fastapi import APIRouter, UploadFile, File, Form, HTTPException
|
| 8 |
+
|
| 9 |
+
from app.api.schemas.requests import TextAnalysisRequest
|
| 10 |
+
from app.api.schemas.responses import (
|
| 11 |
+
AnalysisResponse,
|
| 12 |
+
DecisionDetail,
|
| 13 |
+
RiskDetail,
|
| 14 |
+
FilterDetail,
|
| 15 |
+
DeepAnalysisDetail,
|
| 16 |
+
ErrorResponse,
|
| 17 |
+
)
|
| 18 |
+
from app.services.mongo_service import mongo_service
|
| 19 |
+
from app.services.redis_service import redis_service
|
| 20 |
+
from app.pipeline.workflow import get_workflow, PipelineState
|
| 21 |
+
from app.observability.logging import get_logger
|
| 22 |
+
from app.config import get_settings
|
| 23 |
+
|
| 24 |
+
logger = get_logger(__name__)
|
| 25 |
+
settings = get_settings()
|
| 26 |
+
|
| 27 |
+
router = APIRouter(tags=["Analysis"])
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# ──────────────────────────────────────────────
|
| 31 |
+
# Helper: Convert pipeline state to API response
|
| 32 |
+
# ──────────────────────────────────────────────
|
| 33 |
+
|
| 34 |
+
def _build_response(
|
| 35 |
+
state: dict,
|
| 36 |
+
input_type: str,
|
| 37 |
+
request_id: str,
|
| 38 |
+
start_time: float,
|
| 39 |
+
cached: bool = False,
|
| 40 |
+
) -> AnalysisResponse:
|
| 41 |
+
"""Convert pipeline output state into the unified AnalysisResponse."""
|
| 42 |
+
|
| 43 |
+
risk = state.get("risk_score")
|
| 44 |
+
decision = state.get("decision")
|
| 45 |
+
filter_result = state.get("filter_result")
|
| 46 |
+
deep_result = state.get("deep_result")
|
| 47 |
+
|
| 48 |
+
# Build nested models
|
| 49 |
+
decision_detail = DecisionDetail(
|
| 50 |
+
action=decision.action if decision else "WARNING",
|
| 51 |
+
reason=decision.reason if decision else "Pipeline incomplete",
|
| 52 |
+
severity=decision.severity if decision else "medium",
|
| 53 |
+
should_alert_parent=decision.should_alert_parent if decision else False,
|
| 54 |
+
escalation_notes=decision.escalation_notes if decision else None,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
risk_detail = RiskDetail(
|
| 58 |
+
score=risk.score if risk else 0.0,
|
| 59 |
+
level=risk.level if risk else "LOW",
|
| 60 |
+
components=risk.components if risk else {},
|
| 61 |
+
repeat_offender=risk.repeat_offender if risk else False,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
filter_detail = FilterDetail(
|
| 65 |
+
is_flagged=filter_result.is_flagged if filter_result else False,
|
| 66 |
+
scores=filter_result.scores if filter_result else {},
|
| 67 |
+
max_score=filter_result.max_score if filter_result else 0.0,
|
| 68 |
+
max_label=filter_result.max_label if filter_result else "",
|
| 69 |
+
categories=filter_result.categories if filter_result else [],
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
deep_analysis = None
|
| 73 |
+
if deep_result:
|
| 74 |
+
deep_analysis = DeepAnalysisDetail(
|
| 75 |
+
is_confirmed=deep_result.is_confirmed,
|
| 76 |
+
severity=deep_result.severity,
|
| 77 |
+
reasoning=deep_result.reasoning,
|
| 78 |
+
categories=deep_result.categories,
|
| 79 |
+
recommended_action=deep_result.recommended_action,
|
| 80 |
+
confidence=deep_result.confidence,
|
| 81 |
+
clip_scores=deep_result.clip_scores,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
processing_time = int((time.time() - start_time) * 1000)
|
| 85 |
+
|
| 86 |
+
return AnalysisResponse(
|
| 87 |
+
request_id=request_id,
|
| 88 |
+
input_type=input_type,
|
| 89 |
+
status=decision_detail.action,
|
| 90 |
+
risk_level=risk_detail.level,
|
| 91 |
+
risk_score=risk_detail.score,
|
| 92 |
+
categories=filter_detail.categories,
|
| 93 |
+
confidence=filter_detail.max_score,
|
| 94 |
+
decision=decision_detail,
|
| 95 |
+
risk_detail=risk_detail,
|
| 96 |
+
filter_detail=filter_detail,
|
| 97 |
+
deep_analysis=deep_analysis,
|
| 98 |
+
processing_time_ms=processing_time,
|
| 99 |
+
trace_id=None, # TODO: capture from LangSmith
|
| 100 |
+
cached=cached,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# ──────────────────────────────────────────────
|
| 105 |
+
# POST /analyze/text
|
| 106 |
+
# ──────────────────────────────────────────────
|
| 107 |
+
|
| 108 |
+
@router.post(
|
| 109 |
+
"/analyze/text",
|
| 110 |
+
response_model=AnalysisResponse,
|
| 111 |
+
responses={400: {"model": ErrorResponse}, 503: {"model": ErrorResponse}},
|
| 112 |
+
summary="Analyze text content for cyberbullying",
|
| 113 |
+
)
|
| 114 |
+
async def analyze_text(request: TextAnalysisRequest):
|
| 115 |
+
"""
|
| 116 |
+
Full moderation pipeline for text content.
|
| 117 |
+
|
| 118 |
+
Pipeline: Preprocess → RoBERTa Filter → Risk Score → [Deep Analysis] → Decision
|
| 119 |
+
"""
|
| 120 |
+
request_id = f"req_{uuid.uuid4().hex[:12]}"
|
| 121 |
+
start_time = time.time()
|
| 122 |
+
|
| 123 |
+
logger.info("analyze_text_started", request_id=request_id, text_length=len(request.text))
|
| 124 |
+
|
| 125 |
+
# Validate text length
|
| 126 |
+
if len(request.text) > settings.max_text_length:
|
| 127 |
+
raise HTTPException(status_code=400, detail=f"Text too long (max {settings.max_text_length} chars)")
|
| 128 |
+
|
| 129 |
+
# Check cache
|
| 130 |
+
cached_result = await redis_service.get_cached_result(request.text, "text")
|
| 131 |
+
if cached_result:
|
| 132 |
+
logger.info("analyze_text_cached", request_id=request_id)
|
| 133 |
+
cached_result["request_id"] = request_id
|
| 134 |
+
cached_result["cached"] = True
|
| 135 |
+
cached_result["processing_time_ms"] = int((time.time() - start_time) * 1000)
|
| 136 |
+
return AnalysisResponse(**cached_result)
|
| 137 |
+
|
| 138 |
+
# Run pipeline
|
| 139 |
+
try:
|
| 140 |
+
workflow = get_workflow()
|
| 141 |
+
initial_state: PipelineState = {
|
| 142 |
+
"input_type": "text",
|
| 143 |
+
"raw_content": request.text,
|
| 144 |
+
"user_id": request.user_id,
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
result_state = await workflow.ainvoke(initial_state)
|
| 148 |
+
|
| 149 |
+
# Check for pipeline errors
|
| 150 |
+
if result_state.get("error"):
|
| 151 |
+
raise HTTPException(status_code=500, detail=result_state["error"])
|
| 152 |
+
|
| 153 |
+
response = _build_response(result_state, "text", request_id, start_time)
|
| 154 |
+
|
| 155 |
+
# Cache the result
|
| 156 |
+
await redis_service.cache_result(
|
| 157 |
+
request.text, "text", response.model_dump()
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
# Log to MongoDB
|
| 161 |
+
await _log_moderation(request_id, "text", request.user_id, response)
|
| 162 |
+
|
| 163 |
+
return response
|
| 164 |
+
|
| 165 |
+
except HTTPException:
|
| 166 |
+
raise
|
| 167 |
+
except Exception as e:
|
| 168 |
+
logger.error("analyze_text_failed", request_id=request_id, error=str(e))
|
| 169 |
+
raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}")
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
# ──────────────────────────────────────────────
|
| 173 |
+
# POST /analyze/image
|
| 174 |
+
# ──────────────────────────────────────────────
|
| 175 |
+
|
| 176 |
+
@router.post(
|
| 177 |
+
"/analyze/image",
|
| 178 |
+
response_model=AnalysisResponse,
|
| 179 |
+
responses={400: {"model": ErrorResponse}, 503: {"model": ErrorResponse}},
|
| 180 |
+
summary="Analyze image content for harmful material",
|
| 181 |
+
)
|
| 182 |
+
async def analyze_image(
|
| 183 |
+
file: UploadFile = File(..., description="Image file to analyze"),
|
| 184 |
+
user_id: str | None = Form(None, description="Optional user ID"),
|
| 185 |
+
source_app: str | None = Form(None, description="Source application"),
|
| 186 |
+
):
|
| 187 |
+
"""
|
| 188 |
+
Full moderation pipeline for image content.
|
| 189 |
+
|
| 190 |
+
Pipeline: Preprocess → EfficientNet Filter → Risk Score → [CLIP + Gemini] → Decision
|
| 191 |
+
"""
|
| 192 |
+
request_id = f"req_{uuid.uuid4().hex[:12]}"
|
| 193 |
+
start_time = time.time()
|
| 194 |
+
|
| 195 |
+
# Validate file type
|
| 196 |
+
if not file.content_type or not file.content_type.startswith("image/"):
|
| 197 |
+
raise HTTPException(status_code=400, detail="File must be an image (JPEG, PNG, etc.)")
|
| 198 |
+
|
| 199 |
+
logger.info("analyze_image_started", request_id=request_id, filename=file.filename)
|
| 200 |
+
|
| 201 |
+
# Validate file size
|
| 202 |
+
if file.size and file.size > settings.max_image_size:
|
| 203 |
+
raise HTTPException(status_code=400, detail=f"Image too large (max {settings.max_image_size / 1024 / 1024}MB)")
|
| 204 |
+
|
| 205 |
+
try:
|
| 206 |
+
image_bytes = await file.read()
|
| 207 |
+
|
| 208 |
+
workflow = get_workflow()
|
| 209 |
+
initial_state: PipelineState = {
|
| 210 |
+
"input_type": "image",
|
| 211 |
+
"raw_content": image_bytes,
|
| 212 |
+
"user_id": user_id,
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
result_state = await workflow.ainvoke(initial_state)
|
| 216 |
+
|
| 217 |
+
if result_state.get("error"):
|
| 218 |
+
raise HTTPException(status_code=500, detail=result_state["error"])
|
| 219 |
+
|
| 220 |
+
response = _build_response(result_state, "image", request_id, start_time)
|
| 221 |
+
|
| 222 |
+
# Log to MongoDB
|
| 223 |
+
await _log_moderation(request_id, "image", user_id, response)
|
| 224 |
+
|
| 225 |
+
return response
|
| 226 |
+
|
| 227 |
+
except HTTPException:
|
| 228 |
+
raise
|
| 229 |
+
except Exception as e:
|
| 230 |
+
logger.error("analyze_image_failed", request_id=request_id, error=str(e))
|
| 231 |
+
raise HTTPException(status_code=500, detail=f"Image analysis failed: {str(e)}")
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
# ──────────────────────────────────────────────
|
| 235 |
+
# POST /analyze/video
|
| 236 |
+
# ──────────────────────────────────────────────
|
| 237 |
+
|
| 238 |
+
@router.post(
|
| 239 |
+
"/analyze/video",
|
| 240 |
+
response_model=AnalysisResponse,
|
| 241 |
+
responses={400: {"model": ErrorResponse}, 503: {"model": ErrorResponse}},
|
| 242 |
+
summary="Analyze video content for harmful material",
|
| 243 |
+
)
|
| 244 |
+
async def analyze_video(
|
| 245 |
+
file: UploadFile = File(..., description="Video file to analyze"),
|
| 246 |
+
user_id: str | None = Form(None, description="Optional user ID"),
|
| 247 |
+
source_app: str | None = Form(None, description="Source application"),
|
| 248 |
+
):
|
| 249 |
+
"""
|
| 250 |
+
Full moderation pipeline for video content.
|
| 251 |
+
|
| 252 |
+
Pipeline: Extract frames → Per-frame EfficientNet → Aggregate risk → [Deep Analysis] → Decision
|
| 253 |
+
"""
|
| 254 |
+
request_id = f"req_{uuid.uuid4().hex[:12]}"
|
| 255 |
+
start_time = time.time()
|
| 256 |
+
|
| 257 |
+
# Validate file type
|
| 258 |
+
if not file.content_type or not file.content_type.startswith("video/"):
|
| 259 |
+
raise HTTPException(status_code=400, detail="File must be a video (MP4, AVI, etc.)")
|
| 260 |
+
|
| 261 |
+
logger.info("analyze_video_started", request_id=request_id, filename=file.filename)
|
| 262 |
+
|
| 263 |
+
# Validate file size
|
| 264 |
+
if file.size and file.size > settings.max_video_size:
|
| 265 |
+
raise HTTPException(status_code=400, detail=f"Video too large (max {settings.max_video_size / 1024 / 1024}MB)")
|
| 266 |
+
|
| 267 |
+
try:
|
| 268 |
+
video_bytes = await file.read()
|
| 269 |
+
|
| 270 |
+
workflow = get_workflow()
|
| 271 |
+
initial_state: PipelineState = {
|
| 272 |
+
"input_type": "video",
|
| 273 |
+
"raw_content": video_bytes,
|
| 274 |
+
"user_id": user_id,
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
result_state = await workflow.ainvoke(initial_state)
|
| 278 |
+
|
| 279 |
+
if result_state.get("error"):
|
| 280 |
+
raise HTTPException(status_code=500, detail=result_state["error"])
|
| 281 |
+
|
| 282 |
+
response = _build_response(result_state, "video", request_id, start_time)
|
| 283 |
+
|
| 284 |
+
# Log to MongoDB
|
| 285 |
+
await _log_moderation(request_id, "video", user_id, response)
|
| 286 |
+
|
| 287 |
+
return response
|
| 288 |
+
|
| 289 |
+
except HTTPException:
|
| 290 |
+
raise
|
| 291 |
+
except Exception as e:
|
| 292 |
+
logger.error("analyze_video_failed", request_id=request_id, error=str(e))
|
| 293 |
+
raise HTTPException(status_code=500, detail=f"Video analysis failed: {str(e)}")
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
# ──────────────────────────────────────────────
|
| 297 |
+
# Helper: Log moderation result to MongoDB
|
| 298 |
+
# ──────────────────────────────────────────────
|
| 299 |
+
|
| 300 |
+
async def _log_moderation(
|
| 301 |
+
request_id: str,
|
| 302 |
+
input_type: str,
|
| 303 |
+
user_id: str | None,
|
| 304 |
+
response: AnalysisResponse,
|
| 305 |
+
) -> None:
|
| 306 |
+
"""Log the moderation result and update user history."""
|
| 307 |
+
try:
|
| 308 |
+
log_entry = {
|
| 309 |
+
"request_id": request_id,
|
| 310 |
+
"input_type": input_type,
|
| 311 |
+
"user_id": user_id,
|
| 312 |
+
"status": response.status,
|
| 313 |
+
"risk_level": response.risk_level,
|
| 314 |
+
"risk_score": response.risk_score,
|
| 315 |
+
"categories": response.categories,
|
| 316 |
+
"processing_time_ms": response.processing_time_ms,
|
| 317 |
+
}
|
| 318 |
+
await mongo_service.log_moderation(log_entry)
|
| 319 |
+
|
| 320 |
+
# Update user history
|
| 321 |
+
if user_id:
|
| 322 |
+
await mongo_service.update_user_history(user_id, {
|
| 323 |
+
"risk_level": response.risk_level,
|
| 324 |
+
"categories": response.categories,
|
| 325 |
+
})
|
| 326 |
+
# Invalidate cached history
|
| 327 |
+
await redis_service.invalidate_user_history(user_id)
|
| 328 |
+
|
| 329 |
+
except Exception as e:
|
| 330 |
+
logger.warning("moderation_logging_failed", error=str(e))
|
app/api/v1/auth.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/api/v1/auth.py
|
| 2 |
+
# Auth endpoints: register (parent), create-child, login, refresh, logout
|
| 3 |
+
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
from fastapi import APIRouter, HTTPException, status, Depends
|
| 6 |
+
from pydantic import BaseModel, EmailStr, field_validator
|
| 7 |
+
from app.db.models.user import UserDocument, UserRole
|
| 8 |
+
from app.core.security import hash_password, verify_password, create_access_token, create_refresh_token, decode_refresh_token
|
| 9 |
+
from app.core.dependencies import get_current_user
|
| 10 |
+
from app.observability.logging import get_logger
|
| 11 |
+
|
| 12 |
+
logger = get_logger(__name__)
|
| 13 |
+
router = APIRouter(prefix="/auth", tags=["Auth"])
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# ──────────────────────────────────────────────
|
| 17 |
+
# Request / Response Schemas
|
| 18 |
+
# ──────────────────────────────────────────────
|
| 19 |
+
|
| 20 |
+
class RegisterParentRequest(BaseModel):
|
| 21 |
+
email: EmailStr
|
| 22 |
+
username: str
|
| 23 |
+
password: str
|
| 24 |
+
first_name: str
|
| 25 |
+
last_name: str
|
| 26 |
+
phone: str | None = None
|
| 27 |
+
consent_given: bool = True
|
| 28 |
+
|
| 29 |
+
@field_validator("username")
|
| 30 |
+
@classmethod
|
| 31 |
+
def username_length(cls, v: str) -> str:
|
| 32 |
+
if len(v) < 3 or len(v) > 30:
|
| 33 |
+
raise ValueError("Username must be 3–30 characters")
|
| 34 |
+
return v.strip()
|
| 35 |
+
|
| 36 |
+
@field_validator("password")
|
| 37 |
+
@classmethod
|
| 38 |
+
def password_length(cls, v: str) -> str:
|
| 39 |
+
if len(v) < 8:
|
| 40 |
+
raise ValueError("Password must be at least 8 characters")
|
| 41 |
+
return v
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class CreateChildRequest(BaseModel):
|
| 45 |
+
username: str
|
| 46 |
+
password: str
|
| 47 |
+
first_name: str
|
| 48 |
+
last_name: str
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class LoginRequest(BaseModel):
|
| 52 |
+
login: str # email or username
|
| 53 |
+
password: str
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class RefreshRequest(BaseModel):
|
| 57 |
+
refresh_token: str
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class TokenResponse(BaseModel):
|
| 61 |
+
access_token: str
|
| 62 |
+
refresh_token: str
|
| 63 |
+
token_type: str = "bearer"
|
| 64 |
+
expires_in: int # seconds
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _token_response(user: UserDocument) -> dict:
|
| 68 |
+
access = create_access_token(str(user.id), user.role.value)
|
| 69 |
+
refresh = create_refresh_token(str(user.id), user.role.value)
|
| 70 |
+
return {
|
| 71 |
+
"access_token": access,
|
| 72 |
+
"refresh_token": refresh,
|
| 73 |
+
"token_type": "bearer",
|
| 74 |
+
"expires_in": 15 * 60,
|
| 75 |
+
"user": user.to_public(),
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# ──────────────────────────────────────────────
|
| 80 |
+
# Endpoints
|
| 81 |
+
# ──────────────────────────────────────────────
|
| 82 |
+
|
| 83 |
+
@router.post("/register", status_code=status.HTTP_201_CREATED)
|
| 84 |
+
async def register_parent(body: RegisterParentRequest):
|
| 85 |
+
"""Register a new parent account."""
|
| 86 |
+
# Email uniqueness
|
| 87 |
+
existing = await UserDocument.find_one(UserDocument.email == body.email)
|
| 88 |
+
if existing:
|
| 89 |
+
raise HTTPException(status_code=409, detail="Email already registered")
|
| 90 |
+
|
| 91 |
+
# Username uniqueness
|
| 92 |
+
existing_u = await UserDocument.find_one(UserDocument.username == body.username)
|
| 93 |
+
if existing_u:
|
| 94 |
+
raise HTTPException(status_code=409, detail="Username already taken")
|
| 95 |
+
|
| 96 |
+
user = UserDocument(
|
| 97 |
+
email=body.email,
|
| 98 |
+
username=body.username,
|
| 99 |
+
password_hash=hash_password(body.password),
|
| 100 |
+
role=UserRole.PARENT,
|
| 101 |
+
first_name=body.first_name,
|
| 102 |
+
last_name=body.last_name,
|
| 103 |
+
phone=body.phone,
|
| 104 |
+
consent_given=body.consent_given,
|
| 105 |
+
consent_date=datetime.utcnow() if body.consent_given else None,
|
| 106 |
+
)
|
| 107 |
+
await user.insert()
|
| 108 |
+
|
| 109 |
+
resp = _token_response(user)
|
| 110 |
+
# Save refresh token
|
| 111 |
+
user.refresh_tokens.append(resp["refresh_token"])
|
| 112 |
+
await user.save()
|
| 113 |
+
|
| 114 |
+
logger.info("parent_registered", user_id=str(user.id))
|
| 115 |
+
return {"success": True, **resp}
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
@router.post("/create-child", status_code=status.HTTP_201_CREATED)
|
| 119 |
+
async def create_child(
|
| 120 |
+
body: CreateChildRequest,
|
| 121 |
+
current_user: UserDocument = Depends(get_current_user),
|
| 122 |
+
):
|
| 123 |
+
"""Parent creates a child account linked to their account."""
|
| 124 |
+
if current_user.role != UserRole.PARENT:
|
| 125 |
+
raise HTTPException(status_code=403, detail="Only parents can create child accounts")
|
| 126 |
+
|
| 127 |
+
existing_u = await UserDocument.find_one(UserDocument.username == body.username)
|
| 128 |
+
if existing_u:
|
| 129 |
+
raise HTTPException(status_code=409, detail="Username already taken")
|
| 130 |
+
|
| 131 |
+
child = UserDocument(
|
| 132 |
+
username=body.username,
|
| 133 |
+
password_hash=hash_password(body.password),
|
| 134 |
+
role=UserRole.CHILD,
|
| 135 |
+
first_name=body.first_name,
|
| 136 |
+
last_name=body.last_name,
|
| 137 |
+
parent_id=str(current_user.id),
|
| 138 |
+
parental_consent=True,
|
| 139 |
+
consent_given=True,
|
| 140 |
+
)
|
| 141 |
+
await child.insert()
|
| 142 |
+
|
| 143 |
+
# Link child to parent
|
| 144 |
+
current_user.children.append(str(child.id))
|
| 145 |
+
await current_user.save()
|
| 146 |
+
|
| 147 |
+
logger.info("child_created", child_id=str(child.id), parent_id=str(current_user.id))
|
| 148 |
+
return {"success": True, "user": child.to_public()}
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
@router.post("/login")
|
| 152 |
+
async def login(body: LoginRequest):
|
| 153 |
+
"""Login with email or username + password."""
|
| 154 |
+
login_val = body.login.strip().lower()
|
| 155 |
+
|
| 156 |
+
# Try email first, then username
|
| 157 |
+
if "@" in login_val:
|
| 158 |
+
user = await UserDocument.find_one(UserDocument.email == login_val)
|
| 159 |
+
else:
|
| 160 |
+
user = await UserDocument.find_one(UserDocument.username == login_val)
|
| 161 |
+
|
| 162 |
+
if not user or not verify_password(body.password, user.password_hash):
|
| 163 |
+
raise HTTPException(status_code=401, detail="Invalid credentials")
|
| 164 |
+
|
| 165 |
+
if not user.is_active:
|
| 166 |
+
raise HTTPException(status_code=403, detail="Account is deactivated")
|
| 167 |
+
|
| 168 |
+
resp = _token_response(user)
|
| 169 |
+
user.refresh_tokens.append(resp["refresh_token"])
|
| 170 |
+
user.last_login_at = datetime.utcnow()
|
| 171 |
+
await user.save()
|
| 172 |
+
|
| 173 |
+
logger.info("user_logged_in", user_id=str(user.id))
|
| 174 |
+
return {"success": True, **resp}
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
@router.post("/refresh")
|
| 178 |
+
async def refresh_token(body: RefreshRequest):
|
| 179 |
+
"""Exchange a valid refresh token for a new token pair."""
|
| 180 |
+
payload = decode_refresh_token(body.refresh_token)
|
| 181 |
+
if not payload:
|
| 182 |
+
raise HTTPException(status_code=401, detail="Invalid or expired refresh token")
|
| 183 |
+
|
| 184 |
+
user = await UserDocument.get(payload["sub"])
|
| 185 |
+
if not user or not user.is_active:
|
| 186 |
+
raise HTTPException(status_code=401, detail="User not found")
|
| 187 |
+
|
| 188 |
+
if body.refresh_token not in user.refresh_tokens:
|
| 189 |
+
raise HTTPException(status_code=401, detail="Refresh token revoked")
|
| 190 |
+
|
| 191 |
+
# Rotate: remove old, add new
|
| 192 |
+
user.refresh_tokens.remove(body.refresh_token)
|
| 193 |
+
resp = _token_response(user)
|
| 194 |
+
user.refresh_tokens.append(resp["refresh_token"])
|
| 195 |
+
await user.save()
|
| 196 |
+
|
| 197 |
+
return {"success": True, **resp}
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
@router.post("/logout")
|
| 201 |
+
async def logout(
|
| 202 |
+
body: RefreshRequest | None = None,
|
| 203 |
+
current_user: UserDocument = Depends(get_current_user),
|
| 204 |
+
):
|
| 205 |
+
"""Revoke refresh token(s). Omit body to logout all devices."""
|
| 206 |
+
if body and body.refresh_token in current_user.refresh_tokens:
|
| 207 |
+
current_user.refresh_tokens.remove(body.refresh_token)
|
| 208 |
+
else:
|
| 209 |
+
current_user.refresh_tokens.clear()
|
| 210 |
+
await current_user.save()
|
| 211 |
+
return {"success": True, "message": "Logged out"}
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
@router.get("/me")
|
| 215 |
+
async def get_me(current_user: UserDocument = Depends(get_current_user)):
|
| 216 |
+
"""Return current authenticated user's profile."""
|
| 217 |
+
return {"success": True, "user": current_user.to_public()}
|
app/api/v1/health.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/api/v1/health.py
|
| 2 |
+
# Health check endpoints
|
| 3 |
+
|
| 4 |
+
import time
|
| 5 |
+
from fastapi import APIRouter
|
| 6 |
+
from app.api.schemas.responses import HealthResponse
|
| 7 |
+
from app.models.model_registry import model_registry
|
| 8 |
+
from app.services.redis_service import redis_service
|
| 9 |
+
from app.services.mongo_service import mongo_service
|
| 10 |
+
from app.services.gemini_service import gemini_service
|
| 11 |
+
|
| 12 |
+
router = APIRouter(tags=["Health"])
|
| 13 |
+
|
| 14 |
+
_startup_time = time.time()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@router.get("/health", response_model=HealthResponse)
|
| 18 |
+
async def health_check():
|
| 19 |
+
"""
|
| 20 |
+
Full health check — reports status of all models and services.
|
| 21 |
+
"""
|
| 22 |
+
model_status = model_registry.get_status()
|
| 23 |
+
all_models_ok = all(model_status.values())
|
| 24 |
+
|
| 25 |
+
service_status = {
|
| 26 |
+
"mongodb": mongo_service.is_connected,
|
| 27 |
+
"redis": redis_service.is_connected,
|
| 28 |
+
"gemini": gemini_service.is_initialized,
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
overall = "healthy" if all_models_ok else "degraded"
|
| 32 |
+
|
| 33 |
+
return HealthResponse(
|
| 34 |
+
status=overall,
|
| 35 |
+
version="4.0.0",
|
| 36 |
+
models=model_status,
|
| 37 |
+
services=service_status,
|
| 38 |
+
uptime_seconds=round(time.time() - _startup_time, 1),
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@router.get("/health/models")
|
| 43 |
+
async def model_health():
|
| 44 |
+
"""Detailed model status."""
|
| 45 |
+
return model_registry.get_status()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@router.get("/health/ping")
|
| 49 |
+
async def ping():
|
| 50 |
+
"""Lightweight liveness probe."""
|
| 51 |
+
return {"status": "ok"}
|
app/api/v1/history.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/api/v1/history.py
|
| 2 |
+
# Moderation history endpoints
|
| 3 |
+
|
| 4 |
+
from fastapi import APIRouter, HTTPException, Query
|
| 5 |
+
from app.api.schemas.responses import HistoryResponse
|
| 6 |
+
from app.services.mongo_service import mongo_service
|
| 7 |
+
from app.observability.logging import get_logger
|
| 8 |
+
|
| 9 |
+
logger = get_logger(__name__)
|
| 10 |
+
|
| 11 |
+
router = APIRouter(tags=["History"])
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@router.get(
|
| 15 |
+
"/history/{user_id}",
|
| 16 |
+
response_model=HistoryResponse,
|
| 17 |
+
summary="Get moderation history for a user",
|
| 18 |
+
)
|
| 19 |
+
async def get_user_history(
|
| 20 |
+
user_id: str,
|
| 21 |
+
limit: int = Query(20, ge=1, le=100, description="Max results"),
|
| 22 |
+
skip: int = Query(0, ge=0, description="Results to skip"),
|
| 23 |
+
):
|
| 24 |
+
"""
|
| 25 |
+
Retrieve moderation history for a specific user.
|
| 26 |
+
|
| 27 |
+
Returns past moderation decisions with timestamps,
|
| 28 |
+
risk scores, and categories.
|
| 29 |
+
"""
|
| 30 |
+
if not mongo_service.is_connected:
|
| 31 |
+
raise HTTPException(
|
| 32 |
+
status_code=503,
|
| 33 |
+
detail="MongoDB not available — history querying disabled",
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
results = await mongo_service.get_moderation_history(user_id, limit=limit, skip=skip)
|
| 37 |
+
|
| 38 |
+
return HistoryResponse(
|
| 39 |
+
user_id=user_id,
|
| 40 |
+
total=len(results),
|
| 41 |
+
results=results,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@router.get(
|
| 46 |
+
"/history/{user_id}/summary",
|
| 47 |
+
summary="Get aggregated user stats",
|
| 48 |
+
)
|
| 49 |
+
async def get_user_summary(user_id: str):
|
| 50 |
+
"""
|
| 51 |
+
Get aggregated moderation statistics for a user.
|
| 52 |
+
|
| 53 |
+
Includes total scans, violations, warnings, and category breakdown.
|
| 54 |
+
"""
|
| 55 |
+
if not mongo_service.is_connected:
|
| 56 |
+
raise HTTPException(status_code=503, detail="MongoDB not available")
|
| 57 |
+
|
| 58 |
+
history = await mongo_service.get_user_history(user_id)
|
| 59 |
+
if not history:
|
| 60 |
+
return {
|
| 61 |
+
"user_id": user_id,
|
| 62 |
+
"total_scans": 0,
|
| 63 |
+
"total_violations": 0,
|
| 64 |
+
"total_warnings": 0,
|
| 65 |
+
"violation_categories": {},
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
return history
|
app/api/v1/scan.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/api/v1/scan.py
|
| 2 |
+
# Authenticated scan endpoints — wraps the AI pipeline and persists results
|
| 3 |
+
|
| 4 |
+
import time
|
| 5 |
+
import uuid
|
| 6 |
+
from fastapi import APIRouter, UploadFile, File, Form, HTTPException, Depends
|
| 7 |
+
from pydantic import BaseModel
|
| 8 |
+
from app.core.dependencies import get_current_user, require_role
|
| 9 |
+
from app.db.models.user import UserDocument, UserRole
|
| 10 |
+
from app.db.models.scan_result import ScanResultDocument, RiskLevel
|
| 11 |
+
from app.db.models.alert import AlertDocument, AlertSeverity, AlertStatus
|
| 12 |
+
from app.pipeline.workflow import get_workflow, PipelineState
|
| 13 |
+
from app.observability.logging import get_logger
|
| 14 |
+
|
| 15 |
+
logger = get_logger(__name__)
|
| 16 |
+
router = APIRouter(prefix="/scan", tags=["Scan"])
|
| 17 |
+
|
| 18 |
+
_SEVERITY_MAP = {
|
| 19 |
+
"LOW": AlertSeverity.LOW,
|
| 20 |
+
"MEDIUM": AlertSeverity.MEDIUM,
|
| 21 |
+
"HIGH": AlertSeverity.HIGH,
|
| 22 |
+
"CRITICAL": AlertSeverity.CRITICAL,
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
async def _run_pipeline_and_persist(
|
| 27 |
+
input_type: str,
|
| 28 |
+
raw_content,
|
| 29 |
+
user: UserDocument,
|
| 30 |
+
) -> dict:
|
| 31 |
+
"""Run the AI pipeline, persist ScanResult and optional Alert, return result dict."""
|
| 32 |
+
start = time.time()
|
| 33 |
+
|
| 34 |
+
workflow = get_workflow()
|
| 35 |
+
state: PipelineState = {
|
| 36 |
+
"input_type": input_type,
|
| 37 |
+
"raw_content": raw_content,
|
| 38 |
+
"user_id": str(user.id),
|
| 39 |
+
}
|
| 40 |
+
result = await workflow.ainvoke(state)
|
| 41 |
+
|
| 42 |
+
if result.get("error"):
|
| 43 |
+
raise HTTPException(status_code=500, detail=result["error"])
|
| 44 |
+
|
| 45 |
+
risk = result.get("risk_score")
|
| 46 |
+
decision = result.get("decision")
|
| 47 |
+
filter_result = result.get("filter_result")
|
| 48 |
+
deep_result = result.get("deep_result")
|
| 49 |
+
|
| 50 |
+
risk_level_str = risk.level if risk else "LOW"
|
| 51 |
+
risk_score = risk.score if risk else 0.0
|
| 52 |
+
action = decision.action if decision else "ALLOWED"
|
| 53 |
+
categories = filter_result.categories if filter_result else []
|
| 54 |
+
is_flagged = filter_result.is_flagged if filter_result else False
|
| 55 |
+
reasoning = deep_result.reasoning if deep_result else None
|
| 56 |
+
processing_ms = int((time.time() - start) * 1000)
|
| 57 |
+
|
| 58 |
+
# Persist ScanResult
|
| 59 |
+
scan_doc = ScanResultDocument(
|
| 60 |
+
user_id=str(user.id),
|
| 61 |
+
input_type=input_type,
|
| 62 |
+
content_preview=(raw_content[:200] if isinstance(raw_content, str) else None),
|
| 63 |
+
risk_level=RiskLevel(risk_level_str),
|
| 64 |
+
risk_score=risk_score,
|
| 65 |
+
categories=categories,
|
| 66 |
+
is_flagged=is_flagged,
|
| 67 |
+
action=action,
|
| 68 |
+
reasoning=reasoning,
|
| 69 |
+
processing_time_ms=processing_ms,
|
| 70 |
+
deep_analysis_used=deep_result is not None,
|
| 71 |
+
)
|
| 72 |
+
await scan_doc.insert()
|
| 73 |
+
|
| 74 |
+
# Create Alert if child account and content is flagged
|
| 75 |
+
if is_flagged and user.role == UserRole.CHILD and user.parent_id:
|
| 76 |
+
severity = _SEVERITY_MAP.get(risk_level_str, AlertSeverity.LOW)
|
| 77 |
+
alert = AlertDocument(
|
| 78 |
+
child_id=str(user.id),
|
| 79 |
+
parent_id=user.parent_id,
|
| 80 |
+
scan_result_id=str(scan_doc.id),
|
| 81 |
+
severity=severity,
|
| 82 |
+
categories=categories,
|
| 83 |
+
severity_score=risk_score,
|
| 84 |
+
)
|
| 85 |
+
alert.generate_content()
|
| 86 |
+
await alert.insert()
|
| 87 |
+
logger.info("alert_created", alert_id=str(alert.id), child_id=str(user.id))
|
| 88 |
+
|
| 89 |
+
return {
|
| 90 |
+
"request_id": f"req_{uuid.uuid4().hex[:12]}",
|
| 91 |
+
"input_type": input_type,
|
| 92 |
+
"scan_id": str(scan_doc.id),
|
| 93 |
+
"status": action,
|
| 94 |
+
"risk_level": risk_level_str,
|
| 95 |
+
"risk_score": risk_score,
|
| 96 |
+
"categories": categories,
|
| 97 |
+
"is_flagged": is_flagged,
|
| 98 |
+
"reasoning": reasoning,
|
| 99 |
+
"processing_time_ms": processing_ms,
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
# ──────────────────────────────────────────────
|
| 104 |
+
# Endpoints
|
| 105 |
+
# ──────────────────────────────────────────────
|
| 106 |
+
|
| 107 |
+
class ScanTextRequest(BaseModel):
|
| 108 |
+
text: str
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
@router.post("/text")
|
| 112 |
+
async def scan_text(
|
| 113 |
+
body: ScanTextRequest,
|
| 114 |
+
user: UserDocument = Depends(get_current_user),
|
| 115 |
+
):
|
| 116 |
+
"""Scan text content. Requires authentication."""
|
| 117 |
+
if not body.text.strip():
|
| 118 |
+
raise HTTPException(status_code=400, detail="Text cannot be empty")
|
| 119 |
+
result = await _run_pipeline_and_persist("text", body.text, user)
|
| 120 |
+
return {"success": True, **result}
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
@router.post("/image")
|
| 124 |
+
async def scan_image(
|
| 125 |
+
file: UploadFile = File(...),
|
| 126 |
+
user: UserDocument = Depends(get_current_user),
|
| 127 |
+
):
|
| 128 |
+
"""Scan image content. Requires authentication."""
|
| 129 |
+
if not file.content_type or not file.content_type.startswith("image/"):
|
| 130 |
+
raise HTTPException(status_code=400, detail="File must be an image")
|
| 131 |
+
|
| 132 |
+
image_bytes = await file.read()
|
| 133 |
+
result = await _run_pipeline_and_persist("image", image_bytes, user)
|
| 134 |
+
return {"success": True, **result}
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
@router.get("/history")
|
| 138 |
+
async def get_scan_history(
|
| 139 |
+
page: int = 1,
|
| 140 |
+
limit: int = 20,
|
| 141 |
+
user: UserDocument = Depends(get_current_user),
|
| 142 |
+
):
|
| 143 |
+
"""Get scan history for the current user."""
|
| 144 |
+
skip = (page - 1) * limit
|
| 145 |
+
scans = await ScanResultDocument.find(
|
| 146 |
+
ScanResultDocument.user_id == str(user.id)
|
| 147 |
+
).sort(-ScanResultDocument.created_at).skip(skip).limit(limit).to_list()
|
| 148 |
+
|
| 149 |
+
return {
|
| 150 |
+
"success": True,
|
| 151 |
+
"page": page,
|
| 152 |
+
"limit": limit,
|
| 153 |
+
"results": [
|
| 154 |
+
{
|
| 155 |
+
"id": str(s.id),
|
| 156 |
+
"input_type": s.input_type,
|
| 157 |
+
"risk_level": s.risk_level.value,
|
| 158 |
+
"risk_score": s.risk_score,
|
| 159 |
+
"action": s.action,
|
| 160 |
+
"is_flagged": s.is_flagged,
|
| 161 |
+
"categories": s.categories,
|
| 162 |
+
"created_at": s.created_at.isoformat(),
|
| 163 |
+
}
|
| 164 |
+
for s in scans
|
| 165 |
+
],
|
| 166 |
+
}
|
app/api/v1/users.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/api/v1/users.py
|
| 2 |
+
# User profile and parent-child management endpoints
|
| 3 |
+
|
| 4 |
+
from fastapi import APIRouter, HTTPException, Depends
|
| 5 |
+
from pydantic import BaseModel
|
| 6 |
+
from typing import Optional
|
| 7 |
+
from app.core.dependencies import get_current_user
|
| 8 |
+
from app.db.models.user import UserDocument, UserRole
|
| 9 |
+
|
| 10 |
+
router = APIRouter(prefix="/users", tags=["Users"])
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@router.get("/me")
|
| 14 |
+
async def get_profile(user: UserDocument = Depends(get_current_user)):
|
| 15 |
+
return {"success": True, "user": user.to_public()}
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class UpdateProfileRequest(BaseModel):
|
| 19 |
+
first_name: Optional[str] = None
|
| 20 |
+
last_name: Optional[str] = None
|
| 21 |
+
phone: Optional[str] = None
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@router.patch("/me")
|
| 25 |
+
async def update_profile(
|
| 26 |
+
body: UpdateProfileRequest,
|
| 27 |
+
user: UserDocument = Depends(get_current_user),
|
| 28 |
+
):
|
| 29 |
+
if body.first_name:
|
| 30 |
+
user.first_name = body.first_name
|
| 31 |
+
if body.last_name:
|
| 32 |
+
user.last_name = body.last_name
|
| 33 |
+
if body.phone is not None:
|
| 34 |
+
user.phone = body.phone
|
| 35 |
+
await user.save()
|
| 36 |
+
return {"success": True, "user": user.to_public()}
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@router.get("/children")
|
| 40 |
+
async def list_children(user: UserDocument = Depends(get_current_user)):
|
| 41 |
+
"""Parent: list all linked child accounts."""
|
| 42 |
+
if user.role != UserRole.PARENT:
|
| 43 |
+
raise HTTPException(status_code=403, detail="Only parents can list children")
|
| 44 |
+
children = await UserDocument.find(
|
| 45 |
+
UserDocument.parent_id == str(user.id)
|
| 46 |
+
).to_list()
|
| 47 |
+
return {"success": True, "children": [c.to_public() for c in children]}
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@router.get("/children/{child_id}")
|
| 51 |
+
async def get_child(
|
| 52 |
+
child_id: str,
|
| 53 |
+
user: UserDocument = Depends(get_current_user),
|
| 54 |
+
):
|
| 55 |
+
"""Parent: get a specific child's profile."""
|
| 56 |
+
if user.role != UserRole.PARENT:
|
| 57 |
+
raise HTTPException(status_code=403, detail="Only parents can view child profiles")
|
| 58 |
+
child = await UserDocument.get(child_id)
|
| 59 |
+
if not child or child.parent_id != str(user.id):
|
| 60 |
+
raise HTTPException(status_code=404, detail="Child not found")
|
| 61 |
+
return {"success": True, "child": child.to_public()}
|
app/config.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/config.py
|
| 2 |
+
# Centralized configuration via Pydantic Settings
|
| 3 |
+
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from functools import lru_cache
|
| 6 |
+
from pydantic_settings import BaseSettings
|
| 7 |
+
from pydantic import Field
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Settings(BaseSettings):
|
| 11 |
+
"""Application settings loaded from environment variables."""
|
| 12 |
+
|
| 13 |
+
# --- Server ---
|
| 14 |
+
host: str = "0.0.0.0"
|
| 15 |
+
port: int = 7860
|
| 16 |
+
env: str = "production"
|
| 17 |
+
log_level: str = "INFO"
|
| 18 |
+
|
| 19 |
+
# --- JWT ---
|
| 20 |
+
jwt_access_secret: str = Field(default="change-me-in-production-at-least-32-chars!!")
|
| 21 |
+
jwt_refresh_secret: str = Field(default="change-me-in-production-at-least-32-chars!!")
|
| 22 |
+
jwt_access_expires_minutes: int = 15
|
| 23 |
+
jwt_refresh_expires_days: int = 7
|
| 24 |
+
|
| 25 |
+
# --- Security ---
|
| 26 |
+
bcrypt_rounds: int = 12
|
| 27 |
+
cors_origins: str = "*"
|
| 28 |
+
|
| 29 |
+
# --- MongoDB ---
|
| 30 |
+
mongodb_uri: str = Field(default="")
|
| 31 |
+
mongodb_db_name: str = "hubble"
|
| 32 |
+
|
| 33 |
+
# --- Redis ---
|
| 34 |
+
redis_url: str = Field(default="")
|
| 35 |
+
redis_cache_ttl: int = 300 # seconds
|
| 36 |
+
|
| 37 |
+
# --- Gemini ---
|
| 38 |
+
gemini_api_keys: str = "" # comma-separated
|
| 39 |
+
gemini_model: str = "gemini-2.5-flash"
|
| 40 |
+
|
| 41 |
+
# --- LangSmith ---
|
| 42 |
+
langsmith_api_key: str = ""
|
| 43 |
+
langsmith_project: str = "hubble-moderation"
|
| 44 |
+
langsmith_tracing_v2: bool = True
|
| 45 |
+
|
| 46 |
+
# --- Models ---
|
| 47 |
+
model_cache_dir: str = "/tmp/model_cache"
|
| 48 |
+
onnx_enabled: bool = False
|
| 49 |
+
text_model_name: str = "unitary/toxic-bert"
|
| 50 |
+
image_model_name: str = "google/efficientnet-b0"
|
| 51 |
+
clip_model_name: str = "openai/clip-vit-base-patch32"
|
| 52 |
+
|
| 53 |
+
# --- Risk Thresholds ---
|
| 54 |
+
risk_low_max: int = 30
|
| 55 |
+
risk_medium_max: int = 65
|
| 56 |
+
|
| 57 |
+
# --- Content Limits ---
|
| 58 |
+
max_text_length: int = 10000
|
| 59 |
+
max_image_size: int = 10 * 1024 * 1024 # 10MB
|
| 60 |
+
max_video_size: int = 50 * 1024 * 1024 # 50MB
|
| 61 |
+
|
| 62 |
+
# --- Video Processing ---
|
| 63 |
+
video_max_frames: int = 10
|
| 64 |
+
video_fps_sample: int = 1
|
| 65 |
+
|
| 66 |
+
model_config = {
|
| 67 |
+
"env_file": ".env",
|
| 68 |
+
"env_file_encoding": "utf-8",
|
| 69 |
+
"extra": "ignore",
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
@property
|
| 73 |
+
def gemini_keys_list(self) -> list[str]:
|
| 74 |
+
"""Parse comma-separated Gemini API keys."""
|
| 75 |
+
if not self.gemini_api_keys:
|
| 76 |
+
return []
|
| 77 |
+
return [k.strip() for k in self.gemini_api_keys.split(",") if k.strip()]
|
| 78 |
+
|
| 79 |
+
@property
|
| 80 |
+
def cors_origins_list(self) -> list[str]:
|
| 81 |
+
"""Parse comma-separated CORS origins."""
|
| 82 |
+
return [o.strip() for o in self.cors_origins.split(",") if o.strip()]
|
| 83 |
+
|
| 84 |
+
@property
|
| 85 |
+
def model_cache_path(self) -> Path:
|
| 86 |
+
"""Resolved path for model cache directory."""
|
| 87 |
+
path = Path(self.model_cache_dir)
|
| 88 |
+
path.mkdir(parents=True, exist_ok=True)
|
| 89 |
+
return path
|
| 90 |
+
|
| 91 |
+
@property
|
| 92 |
+
def is_production(self) -> bool:
|
| 93 |
+
return self.env == "production"
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
@lru_cache()
|
| 97 |
+
def get_settings() -> Settings:
|
| 98 |
+
"""Cached settings singleton."""
|
| 99 |
+
return Settings()
|
app/core/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# app/core/__init__.py
|
app/core/dependencies.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/core/dependencies.py
|
| 2 |
+
# FastAPI dependency injection — auth guards, role checks, shared helpers
|
| 3 |
+
|
| 4 |
+
from fastapi import Depends, HTTPException, status
|
| 5 |
+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 6 |
+
from app.core.security import decode_access_token
|
| 7 |
+
from app.db.models.user import UserDocument, UserRole
|
| 8 |
+
|
| 9 |
+
_bearer = HTTPBearer(auto_error=False)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
async def get_current_user(
|
| 13 |
+
credentials: HTTPAuthorizationCredentials | None = Depends(_bearer),
|
| 14 |
+
) -> UserDocument:
|
| 15 |
+
"""Extract and validate JWT; return the authenticated UserDocument."""
|
| 16 |
+
if not credentials:
|
| 17 |
+
raise HTTPException(
|
| 18 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 19 |
+
detail="Not authenticated",
|
| 20 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
payload = decode_access_token(credentials.credentials)
|
| 24 |
+
if not payload:
|
| 25 |
+
raise HTTPException(
|
| 26 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 27 |
+
detail="Invalid or expired token",
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
user = await UserDocument.get(payload["sub"])
|
| 31 |
+
if not user or not user.is_active:
|
| 32 |
+
raise HTTPException(
|
| 33 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 34 |
+
detail="User not found or deactivated",
|
| 35 |
+
)
|
| 36 |
+
return user
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def require_role(*roles: UserRole):
|
| 40 |
+
"""Factory that returns a dependency requiring one of the given roles."""
|
| 41 |
+
async def _check(user: UserDocument = Depends(get_current_user)) -> UserDocument:
|
| 42 |
+
if user.role not in roles:
|
| 43 |
+
raise HTTPException(
|
| 44 |
+
status_code=status.HTTP_403_FORBIDDEN,
|
| 45 |
+
detail=f"Role '{user.role}' is not permitted for this action",
|
| 46 |
+
)
|
| 47 |
+
return user
|
| 48 |
+
return _check
|
app/core/security.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/core/security.py
|
| 2 |
+
# Password hashing (argon2-cffi) + JWT creation/verification (python-jose)
|
| 3 |
+
|
| 4 |
+
from datetime import datetime, timedelta, timezone
|
| 5 |
+
from typing import Optional
|
| 6 |
+
from argon2 import PasswordHasher
|
| 7 |
+
from argon2.exceptions import VerifyMismatchError, VerificationError, InvalidHashError
|
| 8 |
+
from jose import JWTError, jwt
|
| 9 |
+
from app.config import get_settings
|
| 10 |
+
|
| 11 |
+
settings = get_settings()
|
| 12 |
+
|
| 13 |
+
# ──────────────────────────────────────────────
|
| 14 |
+
# Password hashing — Argon2id (modern, OWASP recommended)
|
| 15 |
+
# ──────────────────────────────────────────────
|
| 16 |
+
|
| 17 |
+
_ph = PasswordHasher()
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def hash_password(plain: str) -> str:
|
| 21 |
+
return _ph.hash(plain)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def verify_password(plain: str, hashed: str) -> bool:
|
| 25 |
+
try:
|
| 26 |
+
return _ph.verify(hashed, plain)
|
| 27 |
+
except (VerifyMismatchError, VerificationError, InvalidHashError):
|
| 28 |
+
return False
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# ──────────────────────────────────────────────
|
| 32 |
+
# JWT — HS256 access + refresh tokens
|
| 33 |
+
# ──────────────────────────────────────────────
|
| 34 |
+
|
| 35 |
+
ALGORITHM = "HS256"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _create_token(data: dict, secret: str, expires_delta: timedelta) -> str:
|
| 39 |
+
payload = data.copy()
|
| 40 |
+
payload["exp"] = datetime.now(timezone.utc) + expires_delta
|
| 41 |
+
return jwt.encode(payload, secret, algorithm=ALGORITHM)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def create_access_token(user_id: str, role: str) -> str:
|
| 45 |
+
return _create_token(
|
| 46 |
+
{"sub": user_id, "role": role, "type": "access"},
|
| 47 |
+
settings.jwt_access_secret,
|
| 48 |
+
timedelta(minutes=settings.jwt_access_expires_minutes),
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def create_refresh_token(user_id: str, role: str) -> str:
|
| 53 |
+
return _create_token(
|
| 54 |
+
{"sub": user_id, "role": role, "type": "refresh"},
|
| 55 |
+
settings.jwt_refresh_secret,
|
| 56 |
+
timedelta(days=settings.jwt_refresh_expires_days),
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def decode_access_token(token: str) -> Optional[dict]:
|
| 61 |
+
"""Returns payload dict or None on failure."""
|
| 62 |
+
try:
|
| 63 |
+
payload = jwt.decode(token, settings.jwt_access_secret, algorithms=[ALGORITHM])
|
| 64 |
+
if payload.get("type") != "access":
|
| 65 |
+
return None
|
| 66 |
+
return payload
|
| 67 |
+
except JWTError:
|
| 68 |
+
return None
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def decode_refresh_token(token: str) -> Optional[dict]:
|
| 72 |
+
"""Returns payload dict or None on failure."""
|
| 73 |
+
try:
|
| 74 |
+
payload = jwt.decode(token, settings.jwt_refresh_secret, algorithms=[ALGORITHM])
|
| 75 |
+
if payload.get("type") != "refresh":
|
| 76 |
+
return None
|
| 77 |
+
return payload
|
| 78 |
+
except JWTError:
|
| 79 |
+
return None
|
app/db/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# app/db/__init__.py
|
app/db/connection.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/db/connection.py
|
| 2 |
+
# Beanie ODM initialization — reuses the existing mongo_service Motor client
|
| 3 |
+
|
| 4 |
+
from beanie import init_beanie
|
| 5 |
+
from app.observability.logging import get_logger
|
| 6 |
+
|
| 7 |
+
logger = get_logger(__name__)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
async def connect_db() -> None:
|
| 11 |
+
"""Initialize Beanie ODM using the already-connected mongo_service client."""
|
| 12 |
+
from app.services.mongo_service import mongo_service
|
| 13 |
+
from app.db.models.user import UserDocument
|
| 14 |
+
from app.db.models.scan_result import ScanResultDocument
|
| 15 |
+
from app.db.models.alert import AlertDocument
|
| 16 |
+
|
| 17 |
+
if not mongo_service._connected or mongo_service.client is None:
|
| 18 |
+
logger.error("beanie_init_skipped", reason="mongo_service not connected")
|
| 19 |
+
return
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
db = mongo_service.client[mongo_service.settings.mongodb_db_name]
|
| 23 |
+
await init_beanie(
|
| 24 |
+
database=db,
|
| 25 |
+
document_models=[UserDocument, ScanResultDocument, AlertDocument],
|
| 26 |
+
allow_index_dropping=False,
|
| 27 |
+
)
|
| 28 |
+
logger.info("beanie_initialized", db=mongo_service.settings.mongodb_db_name)
|
| 29 |
+
except Exception as e:
|
| 30 |
+
logger.error("beanie_init_failed", error=str(e))
|
| 31 |
+
raise
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
async def close_db() -> None:
|
| 35 |
+
"""No-op — Motor client is closed by mongo_service.disconnect()."""
|
| 36 |
+
pass
|
app/db/models/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/db/models/__init__.py
|
| 2 |
+
from app.db.models.user import UserDocument, UserRole
|
| 3 |
+
from app.db.models.scan_result import ScanResultDocument, RiskLevel
|
| 4 |
+
from app.db.models.alert import AlertDocument, AlertSeverity, AlertStatus
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"UserDocument", "UserRole",
|
| 8 |
+
"ScanResultDocument", "RiskLevel",
|
| 9 |
+
"AlertDocument", "AlertSeverity", "AlertStatus",
|
| 10 |
+
]
|
app/db/models/alert.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/db/models/alert.py
|
| 2 |
+
# Alert Beanie document
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from enum import Enum
|
| 7 |
+
from typing import Optional
|
| 8 |
+
from beanie import Document
|
| 9 |
+
from pydantic import Field
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class AlertSeverity(str, Enum):
|
| 13 |
+
LOW = "low"
|
| 14 |
+
MEDIUM = "medium"
|
| 15 |
+
HIGH = "high"
|
| 16 |
+
CRITICAL = "critical"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class AlertStatus(str, Enum):
|
| 20 |
+
PENDING = "pending"
|
| 21 |
+
ACKNOWLEDGED = "acknowledged"
|
| 22 |
+
RESOLVED = "resolved"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class AlertDocument(Document):
|
| 26 |
+
"""Parent alert generated when flagged content is detected for a child."""
|
| 27 |
+
|
| 28 |
+
child_id: str
|
| 29 |
+
parent_id: str
|
| 30 |
+
scan_result_id: str
|
| 31 |
+
|
| 32 |
+
title: str
|
| 33 |
+
message: str
|
| 34 |
+
guidance: str = ""
|
| 35 |
+
|
| 36 |
+
severity: AlertSeverity = AlertSeverity.LOW
|
| 37 |
+
categories: list[str] = Field(default_factory=list)
|
| 38 |
+
severity_score: float = 0.0
|
| 39 |
+
|
| 40 |
+
status: AlertStatus = AlertStatus.PENDING
|
| 41 |
+
parent_notified: bool = False
|
| 42 |
+
child_notified: bool = False
|
| 43 |
+
|
| 44 |
+
acknowledged_at: Optional[datetime] = None
|
| 45 |
+
resolved_at: Optional[datetime] = None
|
| 46 |
+
resolved_by: Optional[str] = None
|
| 47 |
+
resolution_notes: Optional[str] = None
|
| 48 |
+
|
| 49 |
+
created_at: datetime = Field(default_factory=datetime.utcnow)
|
| 50 |
+
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
| 51 |
+
|
| 52 |
+
class Settings:
|
| 53 |
+
name = "alerts"
|
| 54 |
+
|
| 55 |
+
def generate_content(self) -> None:
|
| 56 |
+
"""Populate title, message, guidance based on severity and categories."""
|
| 57 |
+
category_text = ", ".join(self.categories) if self.categories else "potentially harmful content"
|
| 58 |
+
score = int(self.severity_score * 100)
|
| 59 |
+
|
| 60 |
+
if self.severity == AlertSeverity.LOW:
|
| 61 |
+
self.title = "Mild Concern Detected"
|
| 62 |
+
self.message = f"Content flagged for: {category_text}. Score: {score}/100."
|
| 63 |
+
self.guidance = "This content contains some concerning elements. Consider talking about online safety."
|
| 64 |
+
elif self.severity == AlertSeverity.MEDIUM:
|
| 65 |
+
self.title = "Moderate Concern Detected"
|
| 66 |
+
self.message = f"Concerning content detected: {category_text}. Score: {score}/100."
|
| 67 |
+
self.guidance = "This content shows signs of potential cyberbullying. We recommend discussing this with your child."
|
| 68 |
+
elif self.severity == AlertSeverity.HIGH:
|
| 69 |
+
self.title = "⚠️ High Severity Alert"
|
| 70 |
+
self.message = f"Serious concern detected: {category_text}. Score: {score}/100."
|
| 71 |
+
self.guidance = "Immediate discussion with your child is recommended. Consider reaching out to school counselors."
|
| 72 |
+
else: # CRITICAL
|
| 73 |
+
self.title = "🚨 CRITICAL ALERT - Immediate Action Required"
|
| 74 |
+
self.message = f"Critical content detected: {category_text}. Score: {score}/100."
|
| 75 |
+
self.guidance = "If there are threats of violence or self-harm, please contact emergency services immediately."
|
app/db/models/scan_result.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/db/models/scan_result.py
|
| 2 |
+
# ScanResult Beanie document
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from enum import Enum
|
| 7 |
+
from typing import Optional
|
| 8 |
+
from beanie import Document
|
| 9 |
+
from pydantic import Field
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class RiskLevel(str, Enum):
|
| 13 |
+
LOW = "LOW"
|
| 14 |
+
MEDIUM = "MEDIUM"
|
| 15 |
+
HIGH = "HIGH"
|
| 16 |
+
CRITICAL = "CRITICAL"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class ScanResultDocument(Document):
|
| 20 |
+
"""Persisted result of a content scan through the AI pipeline."""
|
| 21 |
+
|
| 22 |
+
user_id: str
|
| 23 |
+
input_type: str # text | image | video
|
| 24 |
+
content_preview: Optional[str] = None # first 200 chars of text, or filename
|
| 25 |
+
|
| 26 |
+
# Risk
|
| 27 |
+
risk_level: RiskLevel = RiskLevel.LOW
|
| 28 |
+
risk_score: float = 0.0
|
| 29 |
+
categories: list[str] = Field(default_factory=list)
|
| 30 |
+
is_flagged: bool = False
|
| 31 |
+
|
| 32 |
+
# Decision
|
| 33 |
+
action: str = "ALLOWED" # ALLOWED | WARNED | BLOCKED | ESCALATED
|
| 34 |
+
reasoning: Optional[str] = None
|
| 35 |
+
|
| 36 |
+
# Pipeline metadata
|
| 37 |
+
processing_time_ms: int = 0
|
| 38 |
+
deep_analysis_used: bool = False
|
| 39 |
+
|
| 40 |
+
created_at: datetime = Field(default_factory=datetime.utcnow)
|
| 41 |
+
|
| 42 |
+
class Settings:
|
| 43 |
+
name = "scan_results"
|
app/db/models/user.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/db/models/user.py
|
| 2 |
+
# Beanie User document — mirrors the Mongoose User model from Backend/
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from enum import Enum
|
| 7 |
+
from typing import Optional
|
| 8 |
+
from beanie import Document, Indexed
|
| 9 |
+
from pydantic import EmailStr, Field
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class UserRole(str, Enum):
|
| 13 |
+
PARENT = "parent"
|
| 14 |
+
CHILD = "child"
|
| 15 |
+
ADMIN = "admin"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class UserDocument(Document):
|
| 19 |
+
"""
|
| 20 |
+
User document stored in MongoDB.
|
| 21 |
+
Supports Parent → Child relationship for monitoring.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
email: Optional[EmailStr] = None
|
| 25 |
+
username: Indexed(str, unique=True) # type: ignore[valid-type]
|
| 26 |
+
password_hash: str # bcrypt hash — never returned in API responses
|
| 27 |
+
role: UserRole = UserRole.PARENT
|
| 28 |
+
first_name: str
|
| 29 |
+
last_name: str
|
| 30 |
+
phone: Optional[str] = None
|
| 31 |
+
date_of_birth: Optional[datetime] = None
|
| 32 |
+
|
| 33 |
+
# Parent-Child relationship
|
| 34 |
+
parent_id: Optional[str] = None # set for child accounts
|
| 35 |
+
children: list[str] = Field(default_factory=list) # set for parent accounts
|
| 36 |
+
|
| 37 |
+
# Consent & Privacy
|
| 38 |
+
consent_given: bool = False
|
| 39 |
+
consent_date: Optional[datetime] = None
|
| 40 |
+
parental_consent: Optional[bool] = None # for child accounts
|
| 41 |
+
|
| 42 |
+
# Account status
|
| 43 |
+
is_active: bool = True
|
| 44 |
+
is_verified: bool = False
|
| 45 |
+
last_login_at: Optional[datetime] = None
|
| 46 |
+
|
| 47 |
+
# Security — refresh tokens stored for rotation/revocation
|
| 48 |
+
refresh_tokens: list[str] = Field(default_factory=list)
|
| 49 |
+
password_changed_at: Optional[datetime] = None
|
| 50 |
+
|
| 51 |
+
created_at: datetime = Field(default_factory=datetime.utcnow)
|
| 52 |
+
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
| 53 |
+
|
| 54 |
+
class Settings:
|
| 55 |
+
name = "users"
|
| 56 |
+
# Atlas already has the correct indexes from the Mongoose model.
|
| 57 |
+
# Beanie will use the field-level Indexed() annotations and won't
|
| 58 |
+
# try to re-create conflicting ones when allow_index_dropping=False.
|
| 59 |
+
|
| 60 |
+
def to_public(self) -> dict:
|
| 61 |
+
"""Return safe user dict without sensitive fields."""
|
| 62 |
+
return {
|
| 63 |
+
"id": str(self.id),
|
| 64 |
+
"email": self.email,
|
| 65 |
+
"username": self.username,
|
| 66 |
+
"role": self.role.value,
|
| 67 |
+
"first_name": self.first_name,
|
| 68 |
+
"last_name": self.last_name,
|
| 69 |
+
"is_active": self.is_active,
|
| 70 |
+
"is_verified": self.is_verified,
|
| 71 |
+
"parent_id": self.parent_id,
|
| 72 |
+
"children": self.children,
|
| 73 |
+
"created_at": self.created_at.isoformat(),
|
| 74 |
+
}
|
app/dependencies.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/dependencies.py
|
| 2 |
+
# FastAPI dependency injection
|
| 3 |
+
|
| 4 |
+
from app.models.model_registry import model_registry
|
| 5 |
+
from app.services.redis_service import redis_service
|
| 6 |
+
from app.services.mongo_service import mongo_service
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
async def require_models():
|
| 10 |
+
"""Dependency: ensure models are loaded."""
|
| 11 |
+
status = model_registry.get_status()
|
| 12 |
+
if not status.get("text_model") and not status.get("image_model"):
|
| 13 |
+
from fastapi import HTTPException
|
| 14 |
+
raise HTTPException(
|
| 15 |
+
status_code=503,
|
| 16 |
+
detail="AI models not loaded. Server is starting up.",
|
| 17 |
+
)
|
| 18 |
+
return model_registry
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
async def require_mongo():
|
| 22 |
+
"""Dependency: ensure MongoDB is connected."""
|
| 23 |
+
if not mongo_service.is_connected:
|
| 24 |
+
from fastapi import HTTPException
|
| 25 |
+
raise HTTPException(
|
| 26 |
+
status_code=503,
|
| 27 |
+
detail="MongoDB not available.",
|
| 28 |
+
)
|
| 29 |
+
return mongo_service
|
app/main.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/main.py
|
| 2 |
+
# FastAPI application factory — entry point for the Hubble AI Engine
|
| 3 |
+
|
| 4 |
+
from contextlib import asynccontextmanager
|
| 5 |
+
from fastapi import FastAPI
|
| 6 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 7 |
+
from fastapi.responses import JSONResponse
|
| 8 |
+
|
| 9 |
+
from app.config import get_settings
|
| 10 |
+
from app.observability.logging import setup_logging, get_logger
|
| 11 |
+
from app.observability.langsmith import setup_langsmith
|
| 12 |
+
from app.models.model_registry import model_registry
|
| 13 |
+
from app.services.redis_service import redis_service
|
| 14 |
+
from app.services.mongo_service import mongo_service
|
| 15 |
+
from app.services.gemini_service import gemini_service
|
| 16 |
+
from app.pipeline.workflow import get_workflow
|
| 17 |
+
from app.api.router import api_router
|
| 18 |
+
from app.db.connection import connect_db, close_db
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@asynccontextmanager
|
| 22 |
+
async def lifespan(app: FastAPI):
|
| 23 |
+
"""
|
| 24 |
+
Application lifespan manager.
|
| 25 |
+
|
| 26 |
+
Startup:
|
| 27 |
+
1. Configure logging
|
| 28 |
+
2. Setup LangSmith tracing
|
| 29 |
+
3. Connect MongoDB & Redis
|
| 30 |
+
4. Initialize Gemini service
|
| 31 |
+
5. Load all ML models
|
| 32 |
+
6. Compile LangGraph workflow
|
| 33 |
+
|
| 34 |
+
Shutdown:
|
| 35 |
+
1. Disconnect MongoDB & Redis
|
| 36 |
+
"""
|
| 37 |
+
logger = get_logger(__name__)
|
| 38 |
+
settings = get_settings()
|
| 39 |
+
|
| 40 |
+
# ── Startup ──
|
| 41 |
+
logger.info("=" * 60)
|
| 42 |
+
logger.info("[STARTUP] HUBBLE AI ENGINE — Starting up...")
|
| 43 |
+
logger.info("=" * 60)
|
| 44 |
+
|
| 45 |
+
# 1. LangSmith
|
| 46 |
+
langsmith_ok = setup_langsmith()
|
| 47 |
+
logger.info("langsmith", enabled=langsmith_ok)
|
| 48 |
+
|
| 49 |
+
# 2. MongoDB (legacy motor service + Beanie ODM)
|
| 50 |
+
logger.info("connecting_mongodb")
|
| 51 |
+
await mongo_service.connect()
|
| 52 |
+
await connect_db() # initializes Beanie document models
|
| 53 |
+
|
| 54 |
+
# 3. Redis
|
| 55 |
+
logger.info("connecting_redis")
|
| 56 |
+
await redis_service.connect()
|
| 57 |
+
|
| 58 |
+
# 4. Gemini
|
| 59 |
+
logger.info("initializing_gemini")
|
| 60 |
+
gemini_service.initialize()
|
| 61 |
+
|
| 62 |
+
# 5. ML Models
|
| 63 |
+
logger.info("loading_models")
|
| 64 |
+
model_status = await model_registry.load_all()
|
| 65 |
+
logger.info("models_loaded", status=model_status)
|
| 66 |
+
|
| 67 |
+
# 6. LangGraph Workflow
|
| 68 |
+
logger.info("compiling_workflow")
|
| 69 |
+
get_workflow()
|
| 70 |
+
|
| 71 |
+
logger.info("=" * 60)
|
| 72 |
+
logger.info("[READY] HUBBLE AI ENGINE — Ready!")
|
| 73 |
+
logger.info(f" Environment: {settings.env}")
|
| 74 |
+
logger.info(f" Port: {settings.port}")
|
| 75 |
+
logger.info(f" Docs: http://localhost:{settings.port}/docs")
|
| 76 |
+
logger.info("=" * 60)
|
| 77 |
+
|
| 78 |
+
yield # ── Application runs here ──
|
| 79 |
+
|
| 80 |
+
# ── Shutdown ──
|
| 81 |
+
logger.info("[SHUTDOWN] HUBBLE — Shutting down...")
|
| 82 |
+
await redis_service.disconnect()
|
| 83 |
+
await mongo_service.disconnect()
|
| 84 |
+
await close_db()
|
| 85 |
+
logger.info("Shutdown complete")
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def create_app() -> FastAPI:
|
| 89 |
+
"""Create and configure the FastAPI application."""
|
| 90 |
+
settings = get_settings()
|
| 91 |
+
|
| 92 |
+
# Configure logging first
|
| 93 |
+
setup_logging()
|
| 94 |
+
|
| 95 |
+
app = FastAPI(
|
| 96 |
+
title="Hubble AI Engine — Cyberbullying Detection API",
|
| 97 |
+
description=(
|
| 98 |
+
"Production-grade content moderation pipeline with layered AI analysis. "
|
| 99 |
+
"Supports text, image, and video inputs with risk-based routing."
|
| 100 |
+
),
|
| 101 |
+
version="4.0.0",
|
| 102 |
+
docs_url="/docs",
|
| 103 |
+
redoc_url="/redoc",
|
| 104 |
+
lifespan=lifespan,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# CORS — use origins from config
|
| 108 |
+
app.add_middleware(
|
| 109 |
+
CORSMiddleware,
|
| 110 |
+
allow_origins=settings.cors_origins_list if settings.is_production else ["*"],
|
| 111 |
+
allow_credentials=True,
|
| 112 |
+
allow_methods=["*"],
|
| 113 |
+
allow_headers=["*"],
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# Mount routes
|
| 117 |
+
app.include_router(api_router)
|
| 118 |
+
|
| 119 |
+
@app.get("/", include_in_schema=False)
|
| 120 |
+
async def root():
|
| 121 |
+
return JSONResponse({
|
| 122 |
+
"name": "Hubble Unified API",
|
| 123 |
+
"version": "5.0.0",
|
| 124 |
+
"docs": "/docs",
|
| 125 |
+
"health": "/health",
|
| 126 |
+
"endpoints": {
|
| 127 |
+
"auth": "POST /api/v1/auth/{register|login|refresh|logout}",
|
| 128 |
+
"users": "GET /api/v1/users/me",
|
| 129 |
+
"scan_text": "POST /api/v1/scan/text",
|
| 130 |
+
"scan_image": "POST /api/v1/scan/image",
|
| 131 |
+
"scan_history": "GET /api/v1/scan/history",
|
| 132 |
+
"alerts": "GET /api/v1/alerts",
|
| 133 |
+
"analyze_text": "POST /api/v1/analyze/text (raw, no auth)",
|
| 134 |
+
"analyze_image": "POST /api/v1/analyze/image (raw, no auth)",
|
| 135 |
+
},
|
| 136 |
+
})
|
| 137 |
+
|
| 138 |
+
return app
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# Create the app instance
|
| 142 |
+
app = create_app()
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
if __name__ == "__main__":
|
| 146 |
+
import uvicorn
|
| 147 |
+
|
| 148 |
+
settings = get_settings()
|
| 149 |
+
uvicorn.run(
|
| 150 |
+
"app.main:app",
|
| 151 |
+
host=settings.host,
|
| 152 |
+
port=settings.port,
|
| 153 |
+
reload=not settings.is_production,
|
| 154 |
+
)
|
app/models/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/models/__init__.py
|
| 2 |
+
"""Model loading, ONNX optimization, and inference."""
|
app/models/clip_model.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/models/clip_model.py
|
| 2 |
+
# CLIP model for multimodal text-image alignment (deep analysis only)
|
| 3 |
+
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import numpy as np
|
| 6 |
+
from app.config import get_settings
|
| 7 |
+
from app.observability.logging import get_logger
|
| 8 |
+
|
| 9 |
+
logger = get_logger(__name__)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class CLIPModel:
|
| 13 |
+
"""
|
| 14 |
+
CLIP (Contrastive Language-Image Pre-Training) model.
|
| 15 |
+
|
| 16 |
+
Used in the deep analysis path to compute semantic alignment
|
| 17 |
+
between text descriptions and image content. This helps detect
|
| 18 |
+
subtle multimodal threats (e.g., threatening text overlaid on images).
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self):
|
| 22 |
+
self.settings = get_settings()
|
| 23 |
+
self.model = None
|
| 24 |
+
self.preprocess = None
|
| 25 |
+
self.tokenizer = None
|
| 26 |
+
self._loaded = False
|
| 27 |
+
self.device = None
|
| 28 |
+
|
| 29 |
+
def load(self) -> None:
|
| 30 |
+
"""Load the CLIP model and preprocessor."""
|
| 31 |
+
import torch
|
| 32 |
+
try:
|
| 33 |
+
import open_clip
|
| 34 |
+
|
| 35 |
+
model_name = self.settings.clip_model_name
|
| 36 |
+
cache_dir = self.settings.model_cache_path / "clip"
|
| 37 |
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
| 38 |
+
|
| 39 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 40 |
+
|
| 41 |
+
logger.info("loading_clip_model", model=model_name)
|
| 42 |
+
|
| 43 |
+
# Use OpenCLIP for flexibility
|
| 44 |
+
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
|
| 45 |
+
"ViT-B-32",
|
| 46 |
+
pretrained="laion2b_s34b_b79k",
|
| 47 |
+
)
|
| 48 |
+
self.model = self.model.to(self.device)
|
| 49 |
+
self.model.eval()
|
| 50 |
+
|
| 51 |
+
self.tokenizer = open_clip.get_tokenizer("ViT-B-32")
|
| 52 |
+
|
| 53 |
+
self._loaded = True
|
| 54 |
+
logger.info("clip_model_loaded")
|
| 55 |
+
|
| 56 |
+
except ImportError:
|
| 57 |
+
logger.warning("clip_not_available", reason="open_clip not installed")
|
| 58 |
+
self._loaded = False
|
| 59 |
+
except Exception as e:
|
| 60 |
+
logger.error("clip_load_failed", error=str(e))
|
| 61 |
+
self._loaded = False
|
| 62 |
+
|
| 63 |
+
def compute_similarity(self, image: Image.Image, texts: list[str]) -> dict:
|
| 64 |
+
"""
|
| 65 |
+
Compute cosine similarity between an image and a list of text descriptions.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
image: PIL Image.
|
| 69 |
+
texts: List of text descriptions to compare against.
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
Dict with similarities, best_match, and best_score.
|
| 73 |
+
"""
|
| 74 |
+
if not self._loaded:
|
| 75 |
+
return {"error": "CLIP model not loaded", "similarities": []}
|
| 76 |
+
|
| 77 |
+
import torch
|
| 78 |
+
|
| 79 |
+
# Preprocess image
|
| 80 |
+
image_input = self.preprocess(image).unsqueeze(0).to(self.device)
|
| 81 |
+
|
| 82 |
+
# Tokenize texts
|
| 83 |
+
text_tokens = self.tokenizer(texts).to(self.device)
|
| 84 |
+
|
| 85 |
+
with torch.no_grad():
|
| 86 |
+
image_features = self.model.encode_image(image_input)
|
| 87 |
+
text_features = self.model.encode_text(text_tokens)
|
| 88 |
+
|
| 89 |
+
# Normalize
|
| 90 |
+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
| 91 |
+
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
| 92 |
+
|
| 93 |
+
# Cosine similarity
|
| 94 |
+
similarities = (image_features @ text_features.T).squeeze(0).cpu().numpy()
|
| 95 |
+
|
| 96 |
+
sim_list = similarities.tolist()
|
| 97 |
+
best_idx = int(np.argmax(sim_list))
|
| 98 |
+
|
| 99 |
+
return {
|
| 100 |
+
"similarities": dict(zip(texts, sim_list)),
|
| 101 |
+
"best_match": texts[best_idx],
|
| 102 |
+
"best_score": sim_list[best_idx],
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
def align_content(self, image: Image.Image, context_text: str | None = None) -> dict:
|
| 106 |
+
"""
|
| 107 |
+
Analyze image alignment with harmful content categories.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
image: Image to analyze.
|
| 111 |
+
context_text: Optional surrounding text context.
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
Dict with category alignment scores.
|
| 115 |
+
"""
|
| 116 |
+
harmful_descriptions = [
|
| 117 |
+
"a photo containing violence, fighting, or physical harm",
|
| 118 |
+
"a photo containing nudity or sexual content",
|
| 119 |
+
"a photo containing self-harm or suicide imagery",
|
| 120 |
+
"a photo containing hate symbols or extremist content",
|
| 121 |
+
"a photo containing drugs or substance abuse",
|
| 122 |
+
"a safe and appropriate photo for children",
|
| 123 |
+
]
|
| 124 |
+
|
| 125 |
+
result = self.compute_similarity(image, harmful_descriptions)
|
| 126 |
+
|
| 127 |
+
if "error" in result:
|
| 128 |
+
return result
|
| 129 |
+
|
| 130 |
+
# Also check text-image alignment if context provided
|
| 131 |
+
text_alignment = None
|
| 132 |
+
if context_text:
|
| 133 |
+
text_result = self.compute_similarity(image, [context_text, "unrelated content"])
|
| 134 |
+
text_alignment = text_result["similarities"].get(context_text, 0.0)
|
| 135 |
+
|
| 136 |
+
return {
|
| 137 |
+
"category_scores": result["similarities"],
|
| 138 |
+
"most_aligned": result["best_match"],
|
| 139 |
+
"alignment_score": result["best_score"],
|
| 140 |
+
"text_image_alignment": text_alignment,
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
@property
|
| 144 |
+
def is_loaded(self) -> bool:
|
| 145 |
+
return self._loaded
|
app/models/image_model.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/models/image_model.py
|
| 2 |
+
# EfficientNet-based image classification model with ONNX optimization
|
| 3 |
+
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import numpy as np
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from app.config import get_settings
|
| 8 |
+
from app.observability.logging import get_logger
|
| 9 |
+
|
| 10 |
+
logger = get_logger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ImageClassificationModel:
|
| 14 |
+
"""
|
| 15 |
+
Image content classifier using EfficientNet.
|
| 16 |
+
|
| 17 |
+
Detects violence, NSFW content, and other harmful imagery.
|
| 18 |
+
Supports ONNX (fast) and PyTorch (fallback) inference.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
LABELS = ["safe", "violence", "nsfw", "self_harm", "hate_symbol"]
|
| 22 |
+
|
| 23 |
+
def __init__(self):
|
| 24 |
+
self.settings = get_settings()
|
| 25 |
+
self.processor = None
|
| 26 |
+
self.onnx_session = None
|
| 27 |
+
self.pt_model = None
|
| 28 |
+
self.device = None
|
| 29 |
+
self._loaded = False
|
| 30 |
+
self._num_labels = len(self.LABELS)
|
| 31 |
+
|
| 32 |
+
def load(self) -> None:
|
| 33 |
+
"""Load the image processor and model."""
|
| 34 |
+
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
| 35 |
+
|
| 36 |
+
model_name = self.settings.image_model_name
|
| 37 |
+
cache_dir = self.settings.model_cache_path / "efficientnet"
|
| 38 |
+
onnx_path = cache_dir / "image_classifier.onnx"
|
| 39 |
+
|
| 40 |
+
logger.info("loading_image_model", model=model_name)
|
| 41 |
+
|
| 42 |
+
# Load image processor
|
| 43 |
+
try:
|
| 44 |
+
self.processor = AutoImageProcessor.from_pretrained(
|
| 45 |
+
model_name, cache_dir=cache_dir
|
| 46 |
+
)
|
| 47 |
+
except Exception:
|
| 48 |
+
# Fallback: use a generic processor
|
| 49 |
+
from transformers import AutoImageProcessor
|
| 50 |
+
self.processor = AutoImageProcessor.from_pretrained(
|
| 51 |
+
"google/efficientnet-b0", cache_dir=cache_dir
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
if self.settings.onnx_enabled and onnx_path.exists():
|
| 55 |
+
from app.models.onnx_utils import load_onnx_session
|
| 56 |
+
self.onnx_session = load_onnx_session(onnx_path)
|
| 57 |
+
logger.info("image_model_loaded", backend="onnx")
|
| 58 |
+
else:
|
| 59 |
+
self._load_pytorch(model_name, cache_dir)
|
| 60 |
+
if self.settings.onnx_enabled:
|
| 61 |
+
try:
|
| 62 |
+
self._export_onnx(onnx_path)
|
| 63 |
+
from app.models.onnx_utils import load_onnx_session
|
| 64 |
+
self.onnx_session = load_onnx_session(onnx_path)
|
| 65 |
+
self.pt_model = None
|
| 66 |
+
logger.info("image_model_loaded", backend="onnx", note="exported")
|
| 67 |
+
except Exception as e:
|
| 68 |
+
logger.warning("onnx_export_failed", error=str(e), fallback="pytorch")
|
| 69 |
+
else:
|
| 70 |
+
logger.info("image_model_loaded", backend="pytorch")
|
| 71 |
+
|
| 72 |
+
self._loaded = True
|
| 73 |
+
|
| 74 |
+
def _load_pytorch(self, model_name: str, cache_dir: Path) -> None:
|
| 75 |
+
"""Load PyTorch model."""
|
| 76 |
+
import torch
|
| 77 |
+
from transformers import AutoModelForImageClassification
|
| 78 |
+
|
| 79 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 80 |
+
try:
|
| 81 |
+
self.pt_model = AutoModelForImageClassification.from_pretrained(
|
| 82 |
+
model_name, cache_dir=cache_dir
|
| 83 |
+
)
|
| 84 |
+
except Exception:
|
| 85 |
+
# If the model doesn't exist as a pretrained classifier, load base EfficientNet
|
| 86 |
+
self.pt_model = AutoModelForImageClassification.from_pretrained(
|
| 87 |
+
"google/efficientnet-b0", cache_dir=cache_dir
|
| 88 |
+
)
|
| 89 |
+
self.pt_model.to(self.device)
|
| 90 |
+
self.pt_model.eval()
|
| 91 |
+
|
| 92 |
+
# Update labels from model config if available
|
| 93 |
+
if hasattr(self.pt_model.config, "id2label"):
|
| 94 |
+
model_labels = list(self.pt_model.config.id2label.values())
|
| 95 |
+
if model_labels:
|
| 96 |
+
self._num_labels = len(model_labels)
|
| 97 |
+
|
| 98 |
+
def _export_onnx(self, onnx_path: Path) -> None:
|
| 99 |
+
"""Export to ONNX."""
|
| 100 |
+
import torch
|
| 101 |
+
from app.models.onnx_utils import export_to_onnx
|
| 102 |
+
|
| 103 |
+
dummy_input = torch.randn(1, 3, 224, 224).to(self.device)
|
| 104 |
+
export_to_onnx(
|
| 105 |
+
model=self.pt_model,
|
| 106 |
+
sample_input={"pixel_values": dummy_input},
|
| 107 |
+
output_path=onnx_path,
|
| 108 |
+
input_names=["pixel_values"],
|
| 109 |
+
output_names=["logits"],
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
def predict(self, image: Image.Image) -> dict:
|
| 113 |
+
"""
|
| 114 |
+
Classify an image for harmful content.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
image: PIL Image (RGB).
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
Dict with labels, scores, is_harmful, max_score, max_label.
|
| 121 |
+
"""
|
| 122 |
+
if not self._loaded:
|
| 123 |
+
raise RuntimeError("Image model not loaded. Call load() first.")
|
| 124 |
+
|
| 125 |
+
# Preprocess with the model's processor
|
| 126 |
+
inputs = self.processor(images=image, return_tensors="np" if self.onnx_session else "pt")
|
| 127 |
+
|
| 128 |
+
if self.onnx_session:
|
| 129 |
+
return self._predict_onnx(inputs)
|
| 130 |
+
else:
|
| 131 |
+
return self._predict_pytorch(inputs)
|
| 132 |
+
|
| 133 |
+
def _predict_onnx(self, inputs) -> dict:
|
| 134 |
+
"""ONNX inference."""
|
| 135 |
+
from app.models.onnx_utils import onnx_inference
|
| 136 |
+
|
| 137 |
+
pixel_values = inputs["pixel_values"].astype(np.float32)
|
| 138 |
+
outputs = onnx_inference(self.onnx_session, {"pixel_values": pixel_values})
|
| 139 |
+
logits = outputs[0][0]
|
| 140 |
+
return self._format_output(logits)
|
| 141 |
+
|
| 142 |
+
def _predict_pytorch(self, inputs) -> dict:
|
| 143 |
+
"""PyTorch inference."""
|
| 144 |
+
import torch
|
| 145 |
+
|
| 146 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 147 |
+
with torch.no_grad():
|
| 148 |
+
outputs = self.pt_model(**inputs)
|
| 149 |
+
logits = outputs.logits[0].cpu().numpy()
|
| 150 |
+
return self._format_output(logits)
|
| 151 |
+
|
| 152 |
+
def _format_output(self, logits: np.ndarray) -> dict:
|
| 153 |
+
"""Convert logits to prediction dict."""
|
| 154 |
+
# Softmax for single-label classification
|
| 155 |
+
exp_logits = np.exp(logits - np.max(logits))
|
| 156 |
+
scores = (exp_logits / exp_logits.sum()).tolist()
|
| 157 |
+
|
| 158 |
+
# Map to our labels (or use model's own labels)
|
| 159 |
+
if self.pt_model and hasattr(self.pt_model.config, "id2label"):
|
| 160 |
+
labels = [self.pt_model.config.id2label.get(i, f"class_{i}") for i in range(len(scores))]
|
| 161 |
+
else:
|
| 162 |
+
labels = [f"class_{i}" for i in range(len(scores))]
|
| 163 |
+
|
| 164 |
+
max_idx = int(np.argmax(scores))
|
| 165 |
+
|
| 166 |
+
# Determine if harmful (anything not classified as safe/non-violent)
|
| 167 |
+
safe_keywords = {"safe", "non-violence", "non_violence", "normal", "neutral"}
|
| 168 |
+
is_harmful = labels[max_idx].lower().replace("-", "_").replace(" ", "_") not in safe_keywords
|
| 169 |
+
|
| 170 |
+
return {
|
| 171 |
+
"labels": labels,
|
| 172 |
+
"scores": scores,
|
| 173 |
+
"is_harmful": is_harmful,
|
| 174 |
+
"max_score": scores[max_idx],
|
| 175 |
+
"max_label": labels[max_idx],
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
@property
|
| 179 |
+
def is_loaded(self) -> bool:
|
| 180 |
+
return self._loaded
|
app/models/model_registry.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/models/model_registry.py
|
| 2 |
+
# Singleton model registry — loads and manages all ML models
|
| 3 |
+
|
| 4 |
+
from app.models.text_model import TextToxicityModel
|
| 5 |
+
from app.models.image_model import ImageClassificationModel
|
| 6 |
+
from app.models.clip_model import CLIPModel
|
| 7 |
+
from app.observability.logging import get_logger
|
| 8 |
+
|
| 9 |
+
logger = get_logger(__name__)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ModelRegistry:
|
| 13 |
+
"""
|
| 14 |
+
Central registry for all ML models.
|
| 15 |
+
|
| 16 |
+
Provides lazy-loading and lifecycle management.
|
| 17 |
+
Models are loaded once and reused across requests.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self):
|
| 21 |
+
self._text_model: TextToxicityModel | None = None
|
| 22 |
+
self._image_model: ImageClassificationModel | None = None
|
| 23 |
+
self._clip_model: CLIPModel | None = None
|
| 24 |
+
|
| 25 |
+
async def load_all(self) -> dict[str, bool]:
|
| 26 |
+
"""
|
| 27 |
+
Load all models. Called during application startup.
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
Dict of model name → loaded status.
|
| 31 |
+
"""
|
| 32 |
+
results = {}
|
| 33 |
+
|
| 34 |
+
# Text model (required)
|
| 35 |
+
logger.info("registry_loading", model="text_toxicity")
|
| 36 |
+
try:
|
| 37 |
+
self._text_model = TextToxicityModel()
|
| 38 |
+
self._text_model.load()
|
| 39 |
+
results["text"] = True
|
| 40 |
+
except Exception as e:
|
| 41 |
+
logger.error("text_model_load_failed", error=str(e))
|
| 42 |
+
results["text"] = False
|
| 43 |
+
|
| 44 |
+
# Image model (required)
|
| 45 |
+
logger.info("registry_loading", model="image_classifier")
|
| 46 |
+
try:
|
| 47 |
+
self._image_model = ImageClassificationModel()
|
| 48 |
+
self._image_model.load()
|
| 49 |
+
results["image"] = True
|
| 50 |
+
except Exception as e:
|
| 51 |
+
logger.error("image_model_load_failed", error=str(e))
|
| 52 |
+
results["image"] = False
|
| 53 |
+
|
| 54 |
+
# CLIP model (optional — only for deep analysis)
|
| 55 |
+
logger.info("registry_loading", model="clip")
|
| 56 |
+
try:
|
| 57 |
+
self._clip_model = CLIPModel()
|
| 58 |
+
self._clip_model.load()
|
| 59 |
+
results["clip"] = self._clip_model.is_loaded
|
| 60 |
+
except Exception as e:
|
| 61 |
+
logger.warning("clip_model_load_failed", error=str(e))
|
| 62 |
+
results["clip"] = False
|
| 63 |
+
|
| 64 |
+
logger.info("registry_loaded", results=results)
|
| 65 |
+
return results
|
| 66 |
+
|
| 67 |
+
@property
|
| 68 |
+
def text_model(self) -> TextToxicityModel:
|
| 69 |
+
if self._text_model is None or not self._text_model.is_loaded:
|
| 70 |
+
raise RuntimeError("Text model not available")
|
| 71 |
+
return self._text_model
|
| 72 |
+
|
| 73 |
+
@property
|
| 74 |
+
def image_model(self) -> ImageClassificationModel:
|
| 75 |
+
if self._image_model is None or not self._image_model.is_loaded:
|
| 76 |
+
raise RuntimeError("Image model not available")
|
| 77 |
+
return self._image_model
|
| 78 |
+
|
| 79 |
+
@property
|
| 80 |
+
def clip_model(self) -> CLIPModel:
|
| 81 |
+
if self._clip_model is None:
|
| 82 |
+
raise RuntimeError("CLIP model not available")
|
| 83 |
+
return self._clip_model
|
| 84 |
+
|
| 85 |
+
@property
|
| 86 |
+
def clip_available(self) -> bool:
|
| 87 |
+
return self._clip_model is not None and self._clip_model.is_loaded
|
| 88 |
+
|
| 89 |
+
def get_status(self) -> dict:
|
| 90 |
+
"""Get health status of all models."""
|
| 91 |
+
return {
|
| 92 |
+
"text_model": self._text_model.is_loaded if self._text_model else False,
|
| 93 |
+
"image_model": self._image_model.is_loaded if self._image_model else False,
|
| 94 |
+
"clip_model": self._clip_model.is_loaded if self._clip_model else False,
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# Global singleton
|
| 99 |
+
model_registry = ModelRegistry()
|
app/models/onnx_utils.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/models/onnx_utils.py
|
| 2 |
+
# ONNX export and inference utilities
|
| 3 |
+
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import numpy as np
|
| 6 |
+
from app.observability.logging import get_logger
|
| 7 |
+
|
| 8 |
+
logger = get_logger(__name__)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def export_to_onnx(
|
| 12 |
+
model,
|
| 13 |
+
sample_input: dict,
|
| 14 |
+
output_path: Path,
|
| 15 |
+
input_names: list[str] | None = None,
|
| 16 |
+
output_names: list[str] | None = None,
|
| 17 |
+
dynamic_axes: dict | None = None,
|
| 18 |
+
opset_version: int = 14,
|
| 19 |
+
) -> Path:
|
| 20 |
+
"""
|
| 21 |
+
Export a PyTorch model to ONNX format.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
model: PyTorch model (eval mode).
|
| 25 |
+
sample_input: Dict of tensor inputs for tracing.
|
| 26 |
+
output_path: Where to save the .onnx file.
|
| 27 |
+
input_names: Names for input tensors.
|
| 28 |
+
output_names: Names for output tensors.
|
| 29 |
+
dynamic_axes: Dynamic axes specification.
|
| 30 |
+
opset_version: ONNX opset version.
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
Path to the exported ONNX model.
|
| 34 |
+
"""
|
| 35 |
+
import torch
|
| 36 |
+
|
| 37 |
+
output_path = Path(output_path)
|
| 38 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 39 |
+
|
| 40 |
+
if input_names is None:
|
| 41 |
+
input_names = list(sample_input.keys())
|
| 42 |
+
if output_names is None:
|
| 43 |
+
output_names = ["logits"]
|
| 44 |
+
if dynamic_axes is None:
|
| 45 |
+
dynamic_axes = {name: {0: "batch_size"} for name in input_names + output_names}
|
| 46 |
+
|
| 47 |
+
# Prepare ordered tuple of inputs
|
| 48 |
+
input_tuple = tuple(sample_input[name] for name in input_names)
|
| 49 |
+
|
| 50 |
+
model.eval()
|
| 51 |
+
with torch.no_grad():
|
| 52 |
+
torch.onnx.export(
|
| 53 |
+
model,
|
| 54 |
+
input_tuple,
|
| 55 |
+
str(output_path),
|
| 56 |
+
input_names=input_names,
|
| 57 |
+
output_names=output_names,
|
| 58 |
+
dynamic_axes=dynamic_axes,
|
| 59 |
+
opset_version=opset_version,
|
| 60 |
+
do_constant_folding=True,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
logger.info("onnx_export_complete", path=str(output_path), size_mb=round(output_path.stat().st_size / 1e6, 1))
|
| 64 |
+
return output_path
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def load_onnx_session(model_path: Path, providers: list[str] | None = None):
|
| 68 |
+
"""
|
| 69 |
+
Load an ONNX model as an InferenceSession.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
model_path: Path to .onnx file.
|
| 73 |
+
providers: ONNX Runtime execution providers (defaults to CPU).
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
ort.InferenceSession instance.
|
| 77 |
+
"""
|
| 78 |
+
import onnxruntime as ort
|
| 79 |
+
|
| 80 |
+
if providers is None:
|
| 81 |
+
available = ort.get_available_providers()
|
| 82 |
+
# Prefer CUDA if available, else CPU
|
| 83 |
+
if "CUDAExecutionProvider" in available:
|
| 84 |
+
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
| 85 |
+
else:
|
| 86 |
+
providers = ["CPUExecutionProvider"]
|
| 87 |
+
|
| 88 |
+
session = ort.InferenceSession(str(model_path), providers=providers)
|
| 89 |
+
logger.info(
|
| 90 |
+
"onnx_session_loaded",
|
| 91 |
+
path=str(model_path),
|
| 92 |
+
providers=providers,
|
| 93 |
+
)
|
| 94 |
+
return session
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def onnx_inference(session, inputs: dict[str, np.ndarray]) -> list[np.ndarray]:
|
| 98 |
+
"""
|
| 99 |
+
Run inference on an ONNX session.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
session: ONNX InferenceSession.
|
| 103 |
+
inputs: Dict mapping input names to numpy arrays.
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
List of output numpy arrays.
|
| 107 |
+
"""
|
| 108 |
+
# Ensure proper dtypes
|
| 109 |
+
feed = {}
|
| 110 |
+
for inp in session.get_inputs():
|
| 111 |
+
if inp.name in inputs:
|
| 112 |
+
arr = inputs[inp.name]
|
| 113 |
+
# Match expected dtype
|
| 114 |
+
if "int" in inp.type:
|
| 115 |
+
arr = arr.astype(np.int64)
|
| 116 |
+
else:
|
| 117 |
+
arr = arr.astype(np.float32)
|
| 118 |
+
feed[inp.name] = arr
|
| 119 |
+
|
| 120 |
+
return session.run(None, feed)
|
app/models/text_model.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/models/text_model.py
|
| 2 |
+
# RoBERTa-based text toxicity model with ONNX optimization
|
| 3 |
+
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import numpy as np
|
| 6 |
+
from app.config import get_settings
|
| 7 |
+
from app.observability.logging import get_logger
|
| 8 |
+
|
| 9 |
+
logger = get_logger(__name__)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class TextToxicityModel:
|
| 13 |
+
"""
|
| 14 |
+
Text toxicity classifier using a RoBERTa-based model.
|
| 15 |
+
|
| 16 |
+
Supports both ONNX (fast) and PyTorch (fallback) inference.
|
| 17 |
+
Model: unitary/toxic-bert (multi-label toxicity detection).
|
| 18 |
+
|
| 19 |
+
Labels: toxic, severe_toxic, obscene, threat, insult, identity_hate
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
LABELS = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]
|
| 23 |
+
|
| 24 |
+
def __init__(self):
|
| 25 |
+
self.settings = get_settings()
|
| 26 |
+
self.tokenizer = None
|
| 27 |
+
self.onnx_session = None
|
| 28 |
+
self.pt_model = None
|
| 29 |
+
self.device = None
|
| 30 |
+
self._loaded = False
|
| 31 |
+
|
| 32 |
+
def load(self) -> None:
|
| 33 |
+
"""Load the tokenizer and model (ONNX preferred, PyTorch fallback)."""
|
| 34 |
+
from transformers import AutoTokenizer
|
| 35 |
+
|
| 36 |
+
model_name = self.settings.text_model_name
|
| 37 |
+
cache_dir = self.settings.model_cache_path / "roberta"
|
| 38 |
+
onnx_path = cache_dir / "text_toxicity.onnx"
|
| 39 |
+
|
| 40 |
+
logger.info("loading_text_model", model=model_name)
|
| 41 |
+
|
| 42 |
+
# Load tokenizer
|
| 43 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
|
| 44 |
+
|
| 45 |
+
if self.settings.onnx_enabled and onnx_path.exists():
|
| 46 |
+
# Use existing ONNX model
|
| 47 |
+
from app.models.onnx_utils import load_onnx_session
|
| 48 |
+
self.onnx_session = load_onnx_session(onnx_path)
|
| 49 |
+
logger.info("text_model_loaded", backend="onnx")
|
| 50 |
+
elif self.settings.onnx_enabled:
|
| 51 |
+
# Load PyTorch, export to ONNX, then use ONNX
|
| 52 |
+
self._load_pytorch(model_name, cache_dir)
|
| 53 |
+
self._export_onnx(onnx_path)
|
| 54 |
+
# Switch to ONNX session
|
| 55 |
+
from app.models.onnx_utils import load_onnx_session
|
| 56 |
+
self.onnx_session = load_onnx_session(onnx_path)
|
| 57 |
+
self.pt_model = None # Free PyTorch memory
|
| 58 |
+
logger.info("text_model_loaded", backend="onnx", note="exported_from_pytorch")
|
| 59 |
+
else:
|
| 60 |
+
# PyTorch only
|
| 61 |
+
self._load_pytorch(model_name, cache_dir)
|
| 62 |
+
logger.info("text_model_loaded", backend="pytorch")
|
| 63 |
+
|
| 64 |
+
self._loaded = True
|
| 65 |
+
|
| 66 |
+
def _load_pytorch(self, model_name: str, cache_dir: Path) -> None:
|
| 67 |
+
"""Load the PyTorch model."""
|
| 68 |
+
import torch
|
| 69 |
+
from transformers import AutoModelForSequenceClassification
|
| 70 |
+
|
| 71 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 72 |
+
self.pt_model = AutoModelForSequenceClassification.from_pretrained(
|
| 73 |
+
model_name, cache_dir=cache_dir
|
| 74 |
+
)
|
| 75 |
+
self.pt_model.to(self.device)
|
| 76 |
+
self.pt_model.eval()
|
| 77 |
+
|
| 78 |
+
def _export_onnx(self, onnx_path: Path) -> None:
|
| 79 |
+
"""Export current PyTorch model to ONNX."""
|
| 80 |
+
import torch
|
| 81 |
+
from app.models.onnx_utils import export_to_onnx
|
| 82 |
+
|
| 83 |
+
sample = self.tokenizer(
|
| 84 |
+
"test input for export",
|
| 85 |
+
return_tensors="pt",
|
| 86 |
+
padding="max_length",
|
| 87 |
+
truncation=True,
|
| 88 |
+
max_length=128,
|
| 89 |
+
)
|
| 90 |
+
sample = {k: v.to(self.device) for k, v in sample.items()}
|
| 91 |
+
|
| 92 |
+
export_to_onnx(
|
| 93 |
+
model=self.pt_model,
|
| 94 |
+
sample_input=sample,
|
| 95 |
+
output_path=onnx_path,
|
| 96 |
+
input_names=["input_ids", "attention_mask"],
|
| 97 |
+
output_names=["logits"],
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
def predict(self, text: str) -> dict:
|
| 101 |
+
"""
|
| 102 |
+
Predict toxicity scores for input text.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
text: Input text to classify.
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
Dict with:
|
| 109 |
+
- labels: list of label names
|
| 110 |
+
- scores: list of per-label probabilities
|
| 111 |
+
- is_toxic: bool (any label > 0.5)
|
| 112 |
+
- max_score: float (highest toxicity probability)
|
| 113 |
+
- max_label: str (label with highest probability)
|
| 114 |
+
"""
|
| 115 |
+
if not self._loaded:
|
| 116 |
+
raise RuntimeError("Text model not loaded. Call load() first.")
|
| 117 |
+
|
| 118 |
+
# Tokenize
|
| 119 |
+
encoding = self.tokenizer(
|
| 120 |
+
text,
|
| 121 |
+
return_tensors="np" if self.onnx_session else "pt",
|
| 122 |
+
padding="max_length",
|
| 123 |
+
truncation=True,
|
| 124 |
+
max_length=128,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
if self.onnx_session:
|
| 128 |
+
return self._predict_onnx(encoding)
|
| 129 |
+
else:
|
| 130 |
+
return self._predict_pytorch(encoding)
|
| 131 |
+
|
| 132 |
+
def _predict_onnx(self, encoding: dict) -> dict:
|
| 133 |
+
"""Run ONNX inference."""
|
| 134 |
+
from app.models.onnx_utils import onnx_inference
|
| 135 |
+
|
| 136 |
+
inputs = {
|
| 137 |
+
"input_ids": encoding["input_ids"].astype(np.int64),
|
| 138 |
+
"attention_mask": encoding["attention_mask"].astype(np.int64),
|
| 139 |
+
}
|
| 140 |
+
outputs = onnx_inference(self.onnx_session, inputs)
|
| 141 |
+
logits = outputs[0][0] # (num_labels,)
|
| 142 |
+
return self._format_output(logits)
|
| 143 |
+
|
| 144 |
+
def _predict_pytorch(self, encoding: dict) -> dict:
|
| 145 |
+
"""Run PyTorch inference."""
|
| 146 |
+
import torch
|
| 147 |
+
|
| 148 |
+
inputs = {k: v.to(self.device) for k, v in encoding.items()}
|
| 149 |
+
with torch.no_grad():
|
| 150 |
+
outputs = self.pt_model(**inputs)
|
| 151 |
+
logits = outputs.logits[0].cpu().numpy()
|
| 152 |
+
return self._format_output(logits)
|
| 153 |
+
|
| 154 |
+
def _format_output(self, logits: np.ndarray) -> dict:
|
| 155 |
+
"""Convert raw logits to formatted prediction dict."""
|
| 156 |
+
# Sigmoid for multi-label classification
|
| 157 |
+
scores = 1 / (1 + np.exp(-logits))
|
| 158 |
+
scores = scores.tolist()
|
| 159 |
+
|
| 160 |
+
# Handle case where model has fewer outputs than expected labels
|
| 161 |
+
labels = self.LABELS[: len(scores)]
|
| 162 |
+
|
| 163 |
+
label_scores = dict(zip(labels, scores))
|
| 164 |
+
max_idx = int(np.argmax(scores))
|
| 165 |
+
is_toxic = any(s > 0.5 for s in scores)
|
| 166 |
+
|
| 167 |
+
return {
|
| 168 |
+
"labels": labels,
|
| 169 |
+
"scores": scores,
|
| 170 |
+
"label_scores": label_scores,
|
| 171 |
+
"is_toxic": is_toxic,
|
| 172 |
+
"max_score": scores[max_idx],
|
| 173 |
+
"max_label": labels[max_idx],
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
@property
|
| 177 |
+
def is_loaded(self) -> bool:
|
| 178 |
+
return self._loaded
|
app/observability/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/observability/__init__.py
|
| 2 |
+
"""Observability: structured logging and LangSmith tracing."""
|
app/observability/langsmith.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/observability/langsmith.py
|
| 2 |
+
# LangSmith tracing integration for pipeline observability
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
from app.config import get_settings
|
| 6 |
+
from app.observability.logging import get_logger
|
| 7 |
+
|
| 8 |
+
logger = get_logger(__name__)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def setup_langsmith() -> bool:
|
| 12 |
+
"""
|
| 13 |
+
Configure LangSmith tracing via environment variables.
|
| 14 |
+
LangChain/LangGraph automatically pick up these env vars.
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
True if LangSmith is configured, False otherwise.
|
| 18 |
+
"""
|
| 19 |
+
settings = get_settings()
|
| 20 |
+
|
| 21 |
+
if not settings.langsmith_api_key:
|
| 22 |
+
logger.info("langsmith_disabled", reason="No API key provided")
|
| 23 |
+
return False
|
| 24 |
+
|
| 25 |
+
os.environ["LANGCHAIN_TRACING_V2"] = str(settings.langsmith_tracing_v2).lower()
|
| 26 |
+
os.environ["LANGCHAIN_API_KEY"] = settings.langsmith_api_key
|
| 27 |
+
os.environ["LANGCHAIN_PROJECT"] = settings.langsmith_project
|
| 28 |
+
|
| 29 |
+
logger.info(
|
| 30 |
+
"langsmith_enabled",
|
| 31 |
+
project=settings.langsmith_project,
|
| 32 |
+
)
|
| 33 |
+
return True
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def get_trace_config(
|
| 37 |
+
run_name: str,
|
| 38 |
+
input_type: str,
|
| 39 |
+
metadata: dict | None = None,
|
| 40 |
+
) -> dict:
|
| 41 |
+
"""
|
| 42 |
+
Build configuration dict for LangGraph invoke calls.
|
| 43 |
+
This attaches metadata and tags to the LangSmith trace.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
run_name: Human-readable name for this trace run.
|
| 47 |
+
input_type: The content type being analyzed (text/image/video).
|
| 48 |
+
metadata: Additional key-value metadata.
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
Config dict to pass to workflow.invoke().
|
| 52 |
+
"""
|
| 53 |
+
tags = [f"input:{input_type}", "hubble-moderation"]
|
| 54 |
+
config = {
|
| 55 |
+
"run_name": run_name,
|
| 56 |
+
"tags": tags,
|
| 57 |
+
"metadata": metadata or {},
|
| 58 |
+
}
|
| 59 |
+
return {"configurable": config}
|
app/observability/logging.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/observability/logging.py
|
| 2 |
+
# Structured logging with structlog
|
| 3 |
+
|
| 4 |
+
import sys
|
| 5 |
+
import logging
|
| 6 |
+
import structlog
|
| 7 |
+
from app.config import get_settings
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def setup_logging() -> None:
|
| 11 |
+
"""Configure structured logging for the application."""
|
| 12 |
+
settings = get_settings()
|
| 13 |
+
|
| 14 |
+
# Configure structlog
|
| 15 |
+
structlog.configure(
|
| 16 |
+
processors=[
|
| 17 |
+
structlog.contextvars.merge_contextvars,
|
| 18 |
+
structlog.processors.add_log_level,
|
| 19 |
+
structlog.processors.StackInfoRenderer(),
|
| 20 |
+
structlog.dev.set_exc_info,
|
| 21 |
+
structlog.processors.TimeStamper(fmt="iso"),
|
| 22 |
+
structlog.processors.JSONRenderer()
|
| 23 |
+
if settings.is_production
|
| 24 |
+
else structlog.dev.ConsoleRenderer(colors=True),
|
| 25 |
+
],
|
| 26 |
+
wrapper_class=structlog.make_filtering_bound_logger(
|
| 27 |
+
logging.getLevelName(settings.log_level)
|
| 28 |
+
),
|
| 29 |
+
context_class=dict,
|
| 30 |
+
logger_factory=structlog.PrintLoggerFactory(),
|
| 31 |
+
cache_logger_on_first_use=True,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# Silence noisy third-party loggers
|
| 35 |
+
for logger_name in ["uvicorn.access", "httpx", "httpcore"]:
|
| 36 |
+
logging.getLogger(logger_name).setLevel(logging.WARNING)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def get_logger(name: str | None = None) -> structlog.BoundLogger:
|
| 40 |
+
"""Get a structured logger instance."""
|
| 41 |
+
return structlog.get_logger(name or __name__)
|
app/pipeline/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/pipeline/__init__.py
|
| 2 |
+
"""Core moderation pipeline: preprocess → filter → score → route → decide."""
|
app/pipeline/decision_engine.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/pipeline/decision_engine.py
|
| 2 |
+
# Rule-based decision engine: final moderation verdict
|
| 3 |
+
|
| 4 |
+
from dataclasses import dataclass, field
|
| 5 |
+
from app.pipeline.risk_scorer import RiskScore
|
| 6 |
+
from app.pipeline.deep_analyzer import DeepAnalysisResult
|
| 7 |
+
from app.observability.logging import get_logger
|
| 8 |
+
|
| 9 |
+
logger = get_logger(__name__)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class Decision:
|
| 14 |
+
"""Final moderation decision."""
|
| 15 |
+
action: str # ALLOWED, WARNING, BLOCKED
|
| 16 |
+
reason: str
|
| 17 |
+
severity: str # low, medium, high, critical
|
| 18 |
+
categories: list[str] = field(default_factory=list)
|
| 19 |
+
should_alert_parent: bool = False
|
| 20 |
+
should_log: bool = True
|
| 21 |
+
escalation_notes: str | None = None
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class DecisionEngine:
|
| 25 |
+
"""
|
| 26 |
+
Rule-based final decision engine.
|
| 27 |
+
|
| 28 |
+
Takes risk score + optional deep analysis and produces a final verdict.
|
| 29 |
+
|
| 30 |
+
Rules:
|
| 31 |
+
- LOW risk → ALLOWED (no action)
|
| 32 |
+
- MEDIUM risk → WARNING (increment user warning count, log)
|
| 33 |
+
- HIGH risk + deep_confirmed → BLOCKED (alert parent, log, escalate if critical)
|
| 34 |
+
- HIGH risk + deep_not_confirmed → WARNING (false positive recovery)
|
| 35 |
+
- Repeat offender with MEDIUM → BLOCKED (escalate)
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def decide(
|
| 39 |
+
self,
|
| 40 |
+
risk: RiskScore,
|
| 41 |
+
deep_result: DeepAnalysisResult | None = None,
|
| 42 |
+
user_history: dict | None = None,
|
| 43 |
+
) -> Decision:
|
| 44 |
+
"""
|
| 45 |
+
Produce final moderation decision.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
risk: Composite risk score from the scoring engine.
|
| 49 |
+
deep_result: Optional deep analysis result (only for HIGH risk).
|
| 50 |
+
user_history: Optional user moderation history.
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
Decision with action, reason, and metadata.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
# === LOW RISK ===
|
| 57 |
+
if risk.level == "LOW":
|
| 58 |
+
decision = Decision(
|
| 59 |
+
action="ALLOWED",
|
| 60 |
+
reason="Content passed all safety checks",
|
| 61 |
+
severity="low",
|
| 62 |
+
should_log=False, # Don't clutter logs with safe content
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# === MEDIUM RISK ===
|
| 66 |
+
elif risk.level == "MEDIUM":
|
| 67 |
+
# Check for repeat offender escalation
|
| 68 |
+
if risk.repeat_offender:
|
| 69 |
+
decision = Decision(
|
| 70 |
+
action="BLOCKED",
|
| 71 |
+
reason="Repeat offender with moderately harmful content — escalated to block",
|
| 72 |
+
severity="high",
|
| 73 |
+
should_alert_parent=True,
|
| 74 |
+
escalation_notes="User has repeated violation history. Medium-risk content escalated.",
|
| 75 |
+
)
|
| 76 |
+
else:
|
| 77 |
+
decision = Decision(
|
| 78 |
+
action="WARNING",
|
| 79 |
+
reason=f"Content flagged as potentially harmful (risk score: {risk.score})",
|
| 80 |
+
severity="medium",
|
| 81 |
+
should_alert_parent=False,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# === HIGH RISK ===
|
| 85 |
+
elif risk.level == "HIGH":
|
| 86 |
+
if deep_result and deep_result.is_confirmed:
|
| 87 |
+
# Deep analysis confirms the threat
|
| 88 |
+
severity = deep_result.severity
|
| 89 |
+
should_escalate = severity == "critical"
|
| 90 |
+
|
| 91 |
+
decision = Decision(
|
| 92 |
+
action="BLOCKED",
|
| 93 |
+
reason=deep_result.reasoning,
|
| 94 |
+
severity=severity,
|
| 95 |
+
categories=deep_result.categories,
|
| 96 |
+
should_alert_parent=True,
|
| 97 |
+
escalation_notes=(
|
| 98 |
+
"CRITICAL: Immediate review required. "
|
| 99 |
+
f"Recommended action: {deep_result.recommended_action}"
|
| 100 |
+
if should_escalate
|
| 101 |
+
else None
|
| 102 |
+
),
|
| 103 |
+
)
|
| 104 |
+
elif deep_result and not deep_result.is_confirmed:
|
| 105 |
+
# Deep analysis says it's a false positive
|
| 106 |
+
decision = Decision(
|
| 107 |
+
action="WARNING",
|
| 108 |
+
reason=(
|
| 109 |
+
f"Content initially flagged as high-risk (score: {risk.score}) "
|
| 110 |
+
f"but deep analysis did not confirm threat. "
|
| 111 |
+
f"Reasoning: {deep_result.reasoning}"
|
| 112 |
+
),
|
| 113 |
+
severity="low",
|
| 114 |
+
should_alert_parent=False,
|
| 115 |
+
)
|
| 116 |
+
else:
|
| 117 |
+
# No deep analysis available — err on caution
|
| 118 |
+
decision = Decision(
|
| 119 |
+
action="BLOCKED",
|
| 120 |
+
reason=f"High-risk content detected (score: {risk.score}). Deep analysis unavailable.",
|
| 121 |
+
severity="high",
|
| 122 |
+
should_alert_parent=True,
|
| 123 |
+
escalation_notes="Deep analysis was not performed. Manual review recommended.",
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
else:
|
| 127 |
+
# Fallback
|
| 128 |
+
decision = Decision(
|
| 129 |
+
action="WARNING",
|
| 130 |
+
reason="Unclassified risk level",
|
| 131 |
+
severity="medium",
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
logger.info(
|
| 135 |
+
"decision_made",
|
| 136 |
+
action=decision.action,
|
| 137 |
+
severity=decision.severity,
|
| 138 |
+
alert_parent=decision.should_alert_parent,
|
| 139 |
+
risk_score=risk.score,
|
| 140 |
+
risk_level=risk.level,
|
| 141 |
+
)
|
| 142 |
+
return decision
|
app/pipeline/deep_analyzer.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/pipeline/deep_analyzer.py
|
| 2 |
+
# Deep analysis layer: CLIP + Gemini reasoning (HIGH risk only)
|
| 3 |
+
|
| 4 |
+
from dataclasses import dataclass, field
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from app.models.model_registry import model_registry
|
| 7 |
+
from app.services.gemini_service import gemini_service
|
| 8 |
+
from app.pipeline.fast_filter import FilterResult
|
| 9 |
+
from app.utils.image_utils import image_to_base64
|
| 10 |
+
from app.observability.logging import get_logger
|
| 11 |
+
|
| 12 |
+
logger = get_logger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class DeepAnalysisResult:
|
| 17 |
+
"""Output from the deep analysis stage."""
|
| 18 |
+
is_confirmed: bool # Does deep analysis confirm the threat?
|
| 19 |
+
severity: str # low, medium, high, critical
|
| 20 |
+
reasoning: str # Explanation from Gemini
|
| 21 |
+
categories: list[str] = field(default_factory=list)
|
| 22 |
+
recommended_action: str = "warn" # allow, warn, block, escalate
|
| 23 |
+
confidence: float = 0.0
|
| 24 |
+
clip_scores: dict = field(default_factory=dict)
|
| 25 |
+
gemini_raw: dict = field(default_factory=dict)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class DeepAnalyzer:
|
| 29 |
+
"""
|
| 30 |
+
Deep analysis layer invoked only for HIGH-risk content.
|
| 31 |
+
|
| 32 |
+
Pipeline:
|
| 33 |
+
1. CLIP multimodal alignment (if image present)
|
| 34 |
+
2. Gemini reasoning via LangChain
|
| 35 |
+
3. Combine signals into final assessment
|
| 36 |
+
|
| 37 |
+
This layer trades speed for accuracy — expected latency: 1-3 seconds.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
async def analyze_text(
|
| 41 |
+
self,
|
| 42 |
+
text: str,
|
| 43 |
+
filter_result: FilterResult,
|
| 44 |
+
) -> DeepAnalysisResult:
|
| 45 |
+
"""
|
| 46 |
+
Deep analysis for flagged text content.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
text: Original text content.
|
| 50 |
+
filter_result: Results from the fast filter stage.
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
DeepAnalysisResult with Gemini reasoning.
|
| 54 |
+
"""
|
| 55 |
+
logger.info("deep_analysis_text_started")
|
| 56 |
+
|
| 57 |
+
# Prepare context from fast filter
|
| 58 |
+
context = {
|
| 59 |
+
"flagged_categories": filter_result.categories,
|
| 60 |
+
"max_score": filter_result.max_score,
|
| 61 |
+
"max_label": filter_result.max_label,
|
| 62 |
+
"all_scores": filter_result.scores,
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
# Invoke Gemini for contextual reasoning
|
| 66 |
+
gemini_result = await gemini_service.analyze_text(text, context)
|
| 67 |
+
|
| 68 |
+
result = self._build_result(gemini_result)
|
| 69 |
+
|
| 70 |
+
logger.info(
|
| 71 |
+
"deep_analysis_text_complete",
|
| 72 |
+
confirmed=result.is_confirmed,
|
| 73 |
+
severity=result.severity,
|
| 74 |
+
action=result.recommended_action,
|
| 75 |
+
)
|
| 76 |
+
return result
|
| 77 |
+
|
| 78 |
+
async def analyze_image(
|
| 79 |
+
self,
|
| 80 |
+
image: Image.Image,
|
| 81 |
+
filter_result: FilterResult,
|
| 82 |
+
context_text: str | None = None,
|
| 83 |
+
) -> DeepAnalysisResult:
|
| 84 |
+
"""
|
| 85 |
+
Deep analysis for flagged image content.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
image: PIL Image.
|
| 89 |
+
filter_result: Results from the fast filter.
|
| 90 |
+
context_text: Optional text accompanying the image.
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
DeepAnalysisResult with CLIP alignment + Gemini reasoning.
|
| 94 |
+
"""
|
| 95 |
+
logger.info("deep_analysis_image_started")
|
| 96 |
+
|
| 97 |
+
# Step 1: CLIP multimodal alignment
|
| 98 |
+
clip_scores = {}
|
| 99 |
+
if model_registry.clip_available:
|
| 100 |
+
try:
|
| 101 |
+
clip_result = model_registry.clip_model.align_content(image, context_text)
|
| 102 |
+
clip_scores = clip_result
|
| 103 |
+
logger.info("clip_alignment_complete", most_aligned=clip_result.get("most_aligned"))
|
| 104 |
+
except Exception as e:
|
| 105 |
+
logger.warning("clip_alignment_failed", error=str(e))
|
| 106 |
+
|
| 107 |
+
# Step 2: Gemini image reasoning
|
| 108 |
+
context = {
|
| 109 |
+
"flagged_categories": filter_result.categories,
|
| 110 |
+
"max_score": filter_result.max_score,
|
| 111 |
+
"clip_alignment": clip_scores.get("most_aligned", "unknown"),
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
image_b64 = image_to_base64(image)
|
| 115 |
+
gemini_result = await gemini_service.analyze_image(image_b64, context)
|
| 116 |
+
|
| 117 |
+
result = self._build_result(gemini_result, clip_scores)
|
| 118 |
+
|
| 119 |
+
logger.info(
|
| 120 |
+
"deep_analysis_image_complete",
|
| 121 |
+
confirmed=result.is_confirmed,
|
| 122 |
+
severity=result.severity,
|
| 123 |
+
)
|
| 124 |
+
return result
|
| 125 |
+
|
| 126 |
+
def _build_result(
|
| 127 |
+
self,
|
| 128 |
+
gemini_result: dict,
|
| 129 |
+
clip_scores: dict | None = None,
|
| 130 |
+
) -> DeepAnalysisResult:
|
| 131 |
+
"""Build DeepAnalysisResult from Gemini response."""
|
| 132 |
+
if "error" in gemini_result:
|
| 133 |
+
# Gemini failed — err on the side of caution
|
| 134 |
+
return DeepAnalysisResult(
|
| 135 |
+
is_confirmed=True, # Assume harmful if we can't verify
|
| 136 |
+
severity="medium",
|
| 137 |
+
reasoning=f"Deep analysis unavailable: {gemini_result['error']}. Defaulting to caution.",
|
| 138 |
+
recommended_action="warn",
|
| 139 |
+
confidence=0.3,
|
| 140 |
+
clip_scores=clip_scores or {},
|
| 141 |
+
gemini_raw=gemini_result,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
return DeepAnalysisResult(
|
| 145 |
+
is_confirmed=gemini_result.get("is_confirmed", False),
|
| 146 |
+
severity=gemini_result.get("severity", "medium"),
|
| 147 |
+
reasoning=gemini_result.get("reasoning", "No reasoning provided"),
|
| 148 |
+
categories=gemini_result.get("categories", []),
|
| 149 |
+
recommended_action=gemini_result.get("recommended_action", "warn"),
|
| 150 |
+
confidence=gemini_result.get("confidence", 0.5),
|
| 151 |
+
clip_scores=clip_scores or {},
|
| 152 |
+
gemini_raw=gemini_result,
|
| 153 |
+
)
|
app/pipeline/fast_filter.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/pipeline/fast_filter.py
|
| 2 |
+
# First-pass AI classification using ONNX-optimized models
|
| 3 |
+
|
| 4 |
+
from dataclasses import dataclass, field
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from app.models.model_registry import model_registry
|
| 7 |
+
from app.pipeline.preprocessor import ProcessedText, ProcessedImage
|
| 8 |
+
from app.observability.logging import get_logger
|
| 9 |
+
|
| 10 |
+
logger = get_logger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class FilterResult:
|
| 15 |
+
"""Output from the fast filter stage."""
|
| 16 |
+
input_type: str # "text", "image"
|
| 17 |
+
is_flagged: bool
|
| 18 |
+
scores: dict[str, float] = field(default_factory=dict)
|
| 19 |
+
max_score: float = 0.0
|
| 20 |
+
max_label: str = ""
|
| 21 |
+
categories: list[str] = field(default_factory=list)
|
| 22 |
+
confidence: float = 0.0
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class FastFilter:
|
| 26 |
+
"""
|
| 27 |
+
Fast AI filter using ONNX-optimized models.
|
| 28 |
+
|
| 29 |
+
- RoBERTa for text toxicity (multi-label)
|
| 30 |
+
- EfficientNet for image classification
|
| 31 |
+
|
| 32 |
+
This is the first gate in the pipeline. Designed for speed (<200ms).
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
# Toxicity threshold for flagging
|
| 36 |
+
TEXT_FLAG_THRESHOLD = 0.4
|
| 37 |
+
IMAGE_FLAG_THRESHOLD = 0.5
|
| 38 |
+
|
| 39 |
+
def filter_text(self, processed: ProcessedText) -> FilterResult:
|
| 40 |
+
"""
|
| 41 |
+
Run RoBERTa text toxicity inference.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
processed: Preprocessed text input.
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
FilterResult with per-category toxicity scores.
|
| 48 |
+
"""
|
| 49 |
+
text_model = model_registry.text_model
|
| 50 |
+
prediction = text_model.predict(processed.cleaned)
|
| 51 |
+
|
| 52 |
+
# Determine which categories are flagged
|
| 53 |
+
flagged_categories = []
|
| 54 |
+
label_scores = prediction.get("label_scores", {})
|
| 55 |
+
for label, score in label_scores.items():
|
| 56 |
+
if score > self.TEXT_FLAG_THRESHOLD:
|
| 57 |
+
flagged_categories.append(label)
|
| 58 |
+
|
| 59 |
+
is_flagged = len(flagged_categories) > 0
|
| 60 |
+
|
| 61 |
+
result = FilterResult(
|
| 62 |
+
input_type="text",
|
| 63 |
+
is_flagged=is_flagged,
|
| 64 |
+
scores=label_scores,
|
| 65 |
+
max_score=prediction["max_score"],
|
| 66 |
+
max_label=prediction["max_label"],
|
| 67 |
+
categories=flagged_categories,
|
| 68 |
+
confidence=prediction["max_score"],
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
logger.info(
|
| 72 |
+
"fast_filter_text",
|
| 73 |
+
flagged=is_flagged,
|
| 74 |
+
max_label=result.max_label,
|
| 75 |
+
max_score=round(result.max_score, 3),
|
| 76 |
+
categories=flagged_categories,
|
| 77 |
+
)
|
| 78 |
+
return result
|
| 79 |
+
|
| 80 |
+
def filter_image(self, processed: ProcessedImage) -> FilterResult:
|
| 81 |
+
"""
|
| 82 |
+
Run EfficientNet image classification inference.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
processed: Preprocessed image input.
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
FilterResult with image classification scores.
|
| 89 |
+
"""
|
| 90 |
+
image_model = model_registry.image_model
|
| 91 |
+
prediction = image_model.predict(processed.image)
|
| 92 |
+
|
| 93 |
+
# Map model output to categories
|
| 94 |
+
scores = {}
|
| 95 |
+
for label, score in zip(prediction["labels"], prediction["scores"]):
|
| 96 |
+
scores[label] = score
|
| 97 |
+
|
| 98 |
+
flagged_categories = []
|
| 99 |
+
for label, score in scores.items():
|
| 100 |
+
if score > self.IMAGE_FLAG_THRESHOLD:
|
| 101 |
+
# Check if this is a harmful category
|
| 102 |
+
safe_labels = {"safe", "non-violence", "non_violence", "normal", "neutral"}
|
| 103 |
+
if label.lower().replace("-", "_").replace(" ", "_") not in safe_labels:
|
| 104 |
+
flagged_categories.append(label)
|
| 105 |
+
|
| 106 |
+
is_flagged = prediction["is_harmful"]
|
| 107 |
+
|
| 108 |
+
result = FilterResult(
|
| 109 |
+
input_type="image",
|
| 110 |
+
is_flagged=is_flagged,
|
| 111 |
+
scores=scores,
|
| 112 |
+
max_score=prediction["max_score"],
|
| 113 |
+
max_label=prediction["max_label"],
|
| 114 |
+
categories=flagged_categories,
|
| 115 |
+
confidence=prediction["max_score"],
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
logger.info(
|
| 119 |
+
"fast_filter_image",
|
| 120 |
+
flagged=is_flagged,
|
| 121 |
+
max_label=result.max_label,
|
| 122 |
+
max_score=round(result.max_score, 3),
|
| 123 |
+
)
|
| 124 |
+
return result
|
app/pipeline/preprocessor.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/pipeline/preprocessor.py
|
| 2 |
+
# Input preprocessing: normalization, frame extraction, cleaning
|
| 3 |
+
|
| 4 |
+
import re
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from app.config import get_settings
|
| 8 |
+
from app.observability.logging import get_logger
|
| 9 |
+
|
| 10 |
+
logger = get_logger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class ProcessedText:
|
| 15 |
+
"""Preprocessed text content."""
|
| 16 |
+
original: str
|
| 17 |
+
cleaned: str
|
| 18 |
+
word_count: int
|
| 19 |
+
char_count: int
|
| 20 |
+
language: str = "en" # placeholder for language detection
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class ProcessedImage:
|
| 25 |
+
"""Preprocessed image content."""
|
| 26 |
+
image: Image.Image
|
| 27 |
+
width: int
|
| 28 |
+
height: int
|
| 29 |
+
format: str = "RGB"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@dataclass
|
| 33 |
+
class ProcessedVideo:
|
| 34 |
+
"""Preprocessed video — a list of extracted frames."""
|
| 35 |
+
frames: list[ProcessedImage] = field(default_factory=list)
|
| 36 |
+
frame_count: int = 0
|
| 37 |
+
duration_seconds: float = 0.0
|
| 38 |
+
metadata: dict = field(default_factory=dict)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class Preprocessor:
|
| 42 |
+
"""
|
| 43 |
+
Input preprocessing for all content types.
|
| 44 |
+
|
| 45 |
+
- Text: cleaning, normalization
|
| 46 |
+
- Image: resize, format conversion
|
| 47 |
+
- Video: frame extraction + per-frame preprocessing
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def __init__(self):
|
| 51 |
+
self.settings = get_settings()
|
| 52 |
+
|
| 53 |
+
def process_text(self, text: str) -> ProcessedText:
|
| 54 |
+
"""
|
| 55 |
+
Clean and normalize input text.
|
| 56 |
+
|
| 57 |
+
- Strip excessive whitespace
|
| 58 |
+
- Remove zero-width characters
|
| 59 |
+
- Normalize unicode
|
| 60 |
+
"""
|
| 61 |
+
import unicodedata
|
| 62 |
+
|
| 63 |
+
# Remove zero-width characters often used for obfuscation
|
| 64 |
+
cleaned = re.sub(r"[\u200b\u200c\u200d\ufeff]", "", text)
|
| 65 |
+
|
| 66 |
+
# Normalize unicode
|
| 67 |
+
cleaned = unicodedata.normalize("NFKC", cleaned)
|
| 68 |
+
|
| 69 |
+
# Collapse excessive whitespace
|
| 70 |
+
cleaned = re.sub(r"\s+", " ", cleaned).strip()
|
| 71 |
+
|
| 72 |
+
result = ProcessedText(
|
| 73 |
+
original=text,
|
| 74 |
+
cleaned=cleaned,
|
| 75 |
+
word_count=len(cleaned.split()),
|
| 76 |
+
char_count=len(cleaned),
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
logger.debug(
|
| 80 |
+
"text_preprocessed",
|
| 81 |
+
word_count=result.word_count,
|
| 82 |
+
char_count=result.char_count,
|
| 83 |
+
)
|
| 84 |
+
return result
|
| 85 |
+
|
| 86 |
+
def process_image(self, image_bytes: bytes) -> ProcessedImage:
|
| 87 |
+
"""
|
| 88 |
+
Load and preprocess image from bytes.
|
| 89 |
+
|
| 90 |
+
- Convert to RGB
|
| 91 |
+
- Record dimensions
|
| 92 |
+
"""
|
| 93 |
+
from app.utils.image_utils import load_image_from_bytes
|
| 94 |
+
|
| 95 |
+
image = load_image_from_bytes(image_bytes)
|
| 96 |
+
width, height = image.size
|
| 97 |
+
|
| 98 |
+
result = ProcessedImage(
|
| 99 |
+
image=image,
|
| 100 |
+
width=width,
|
| 101 |
+
height=height,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
logger.debug("image_preprocessed", width=width, height=height)
|
| 105 |
+
return result
|
| 106 |
+
|
| 107 |
+
def process_video(self, video_bytes: bytes) -> ProcessedVideo:
|
| 108 |
+
"""
|
| 109 |
+
Extract key frames from video.
|
| 110 |
+
|
| 111 |
+
Uses OpenCV to sample frames at configured intervals.
|
| 112 |
+
"""
|
| 113 |
+
from app.utils.video_utils import extract_frames, get_video_metadata
|
| 114 |
+
|
| 115 |
+
metadata = get_video_metadata(video_bytes)
|
| 116 |
+
frames_pil = extract_frames(
|
| 117 |
+
video_bytes,
|
| 118 |
+
max_frames=self.settings.video_max_frames,
|
| 119 |
+
fps_sample=self.settings.video_fps_sample,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
processed_frames = []
|
| 123 |
+
for frame in frames_pil:
|
| 124 |
+
w, h = frame.size
|
| 125 |
+
processed_frames.append(
|
| 126 |
+
ProcessedImage(image=frame, width=w, height=h)
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
result = ProcessedVideo(
|
| 130 |
+
frames=processed_frames,
|
| 131 |
+
frame_count=len(processed_frames),
|
| 132 |
+
duration_seconds=metadata.get("duration_seconds", 0.0),
|
| 133 |
+
metadata=metadata,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
logger.debug(
|
| 137 |
+
"video_preprocessed",
|
| 138 |
+
frames_extracted=result.frame_count,
|
| 139 |
+
duration=result.duration_seconds,
|
| 140 |
+
)
|
| 141 |
+
return result
|
app/pipeline/risk_scorer.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/pipeline/risk_scorer.py
|
| 2 |
+
# Composite risk scoring engine
|
| 3 |
+
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from app.config import get_settings
|
| 6 |
+
from app.pipeline.fast_filter import FilterResult
|
| 7 |
+
from app.observability.logging import get_logger
|
| 8 |
+
|
| 9 |
+
logger = get_logger(__name__)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class RiskScore:
|
| 14 |
+
"""Composite risk assessment."""
|
| 15 |
+
score: float # 0-100
|
| 16 |
+
level: str # LOW, MEDIUM, HIGH
|
| 17 |
+
components: dict # breakdown of scoring factors
|
| 18 |
+
repeat_offender: bool = False
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class RiskScorer:
|
| 22 |
+
"""
|
| 23 |
+
Computes a composite risk score (0-100) from multiple signals.
|
| 24 |
+
|
| 25 |
+
Scoring formula:
|
| 26 |
+
- Base score from model confidence (weighted by category severity)
|
| 27 |
+
- Repeat offender boost (user history)
|
| 28 |
+
- Multi-category penalty (multiple harmful categories = higher risk)
|
| 29 |
+
|
| 30 |
+
Thresholds (configurable via env):
|
| 31 |
+
- 0-30: LOW → Allow
|
| 32 |
+
- 31-65: MEDIUM → Warning
|
| 33 |
+
- 66-100: HIGH → Deep Analysis
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
# Category severity weights (how dangerous each type is)
|
| 37 |
+
CATEGORY_WEIGHTS = {
|
| 38 |
+
# Text categories (from RoBERTa toxic-bert)
|
| 39 |
+
"toxic": 0.6,
|
| 40 |
+
"severe_toxic": 1.0,
|
| 41 |
+
"obscene": 0.5,
|
| 42 |
+
"threat": 1.0,
|
| 43 |
+
"insult": 0.5,
|
| 44 |
+
"identity_hate": 0.9,
|
| 45 |
+
# Image categories
|
| 46 |
+
"violence": 0.9,
|
| 47 |
+
"nsfw": 0.8,
|
| 48 |
+
"self_harm": 1.0,
|
| 49 |
+
"hate_symbol": 0.9,
|
| 50 |
+
# Generic fallback
|
| 51 |
+
"harassment": 0.7,
|
| 52 |
+
"bullying": 0.7,
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
# Repeat offender thresholds
|
| 56 |
+
REPEAT_OFFENDER_VIOLATIONS = 3
|
| 57 |
+
REPEAT_OFFENDER_BOOST = 15 # points added
|
| 58 |
+
|
| 59 |
+
def __init__(self):
|
| 60 |
+
self.settings = get_settings()
|
| 61 |
+
|
| 62 |
+
def score(
|
| 63 |
+
self,
|
| 64 |
+
filter_result: FilterResult,
|
| 65 |
+
user_history: dict | None = None,
|
| 66 |
+
) -> RiskScore:
|
| 67 |
+
"""
|
| 68 |
+
Compute composite risk score.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
filter_result: Output from fast filter stage.
|
| 72 |
+
user_history: Optional user moderation history.
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
RiskScore with level classification.
|
| 76 |
+
"""
|
| 77 |
+
# 1. Base score from model confidence
|
| 78 |
+
base_score = self._compute_base_score(filter_result)
|
| 79 |
+
|
| 80 |
+
# 2. Multi-category penalty
|
| 81 |
+
multi_cat_penalty = self._multi_category_penalty(filter_result)
|
| 82 |
+
|
| 83 |
+
# 3. Repeat offender boost
|
| 84 |
+
repeat_boost, is_repeat = self._repeat_offender_boost(user_history)
|
| 85 |
+
|
| 86 |
+
# 4. Combine
|
| 87 |
+
raw_score = base_score + multi_cat_penalty + repeat_boost
|
| 88 |
+
final_score = min(100.0, max(0.0, raw_score))
|
| 89 |
+
|
| 90 |
+
# 5. Classify level
|
| 91 |
+
level = self._classify_level(final_score)
|
| 92 |
+
|
| 93 |
+
result = RiskScore(
|
| 94 |
+
score=round(final_score, 1),
|
| 95 |
+
level=level,
|
| 96 |
+
components={
|
| 97 |
+
"base_score": round(base_score, 1),
|
| 98 |
+
"multi_category_penalty": round(multi_cat_penalty, 1),
|
| 99 |
+
"repeat_offender_boost": round(repeat_boost, 1),
|
| 100 |
+
},
|
| 101 |
+
repeat_offender=is_repeat,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
logger.info(
|
| 105 |
+
"risk_scored",
|
| 106 |
+
score=result.score,
|
| 107 |
+
level=result.level,
|
| 108 |
+
components=result.components,
|
| 109 |
+
repeat_offender=is_repeat,
|
| 110 |
+
)
|
| 111 |
+
return result
|
| 112 |
+
|
| 113 |
+
def _compute_base_score(self, result: FilterResult) -> float:
|
| 114 |
+
"""
|
| 115 |
+
Compute base score from model predictions.
|
| 116 |
+
|
| 117 |
+
Uses weighted sum of flagged category scores.
|
| 118 |
+
"""
|
| 119 |
+
if not result.is_flagged:
|
| 120 |
+
# Even unflagged content gets a small score based on max prediction
|
| 121 |
+
return result.max_score * 20 # Scale 0-1 → 0-20
|
| 122 |
+
|
| 123 |
+
# Weighted sum of flagged category scores
|
| 124 |
+
weighted_sum = 0.0
|
| 125 |
+
weight_total = 0.0
|
| 126 |
+
|
| 127 |
+
for category, score in result.scores.items():
|
| 128 |
+
weight = self.CATEGORY_WEIGHTS.get(category.lower(), 0.5)
|
| 129 |
+
weighted_sum += score * weight * 100
|
| 130 |
+
weight_total += weight
|
| 131 |
+
|
| 132 |
+
if weight_total > 0:
|
| 133 |
+
return weighted_sum / weight_total
|
| 134 |
+
return result.max_score * 60
|
| 135 |
+
|
| 136 |
+
def _multi_category_penalty(self, result: FilterResult) -> float:
|
| 137 |
+
"""Add penalty when multiple harmful categories are detected."""
|
| 138 |
+
num_categories = len(result.categories)
|
| 139 |
+
if num_categories <= 1:
|
| 140 |
+
return 0.0
|
| 141 |
+
# Each additional category adds 5 points
|
| 142 |
+
return (num_categories - 1) * 5.0
|
| 143 |
+
|
| 144 |
+
def _repeat_offender_boost(self, user_history: dict | None) -> tuple[float, bool]:
|
| 145 |
+
"""Boost score for users with violation history."""
|
| 146 |
+
if not user_history:
|
| 147 |
+
return 0.0, False
|
| 148 |
+
|
| 149 |
+
total_violations = user_history.get("total_violations", 0)
|
| 150 |
+
is_repeat = total_violations >= self.REPEAT_OFFENDER_VIOLATIONS
|
| 151 |
+
|
| 152 |
+
if is_repeat:
|
| 153 |
+
return self.REPEAT_OFFENDER_BOOST, True
|
| 154 |
+
elif total_violations > 0:
|
| 155 |
+
# Smaller boost for users with some history
|
| 156 |
+
return total_violations * 3.0, False
|
| 157 |
+
return 0.0, False
|
| 158 |
+
|
| 159 |
+
def _classify_level(self, score: float) -> str:
|
| 160 |
+
"""Map numeric score to risk level."""
|
| 161 |
+
if score <= self.settings.risk_low_max:
|
| 162 |
+
return "LOW"
|
| 163 |
+
elif score <= self.settings.risk_medium_max:
|
| 164 |
+
return "MEDIUM"
|
| 165 |
+
else:
|
| 166 |
+
return "HIGH"
|
app/pipeline/workflow.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/pipeline/workflow.py
|
| 2 |
+
# LangGraph state machine — orchestrates the full moderation pipeline
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
from typing import Any, TypedDict, Literal
|
| 6 |
+
from dataclasses import asdict
|
| 7 |
+
|
| 8 |
+
from langgraph.graph import StateGraph, END
|
| 9 |
+
|
| 10 |
+
from app.pipeline.preprocessor import (
|
| 11 |
+
Preprocessor,
|
| 12 |
+
ProcessedText,
|
| 13 |
+
ProcessedImage,
|
| 14 |
+
ProcessedVideo,
|
| 15 |
+
)
|
| 16 |
+
from app.pipeline.fast_filter import FastFilter, FilterResult
|
| 17 |
+
from app.pipeline.risk_scorer import RiskScorer, RiskScore
|
| 18 |
+
from app.pipeline.deep_analyzer import DeepAnalyzer, DeepAnalysisResult
|
| 19 |
+
from app.pipeline.decision_engine import DecisionEngine, Decision
|
| 20 |
+
from app.services.mongo_service import mongo_service
|
| 21 |
+
from app.services.redis_service import redis_service
|
| 22 |
+
from app.observability.logging import get_logger
|
| 23 |
+
|
| 24 |
+
logger = get_logger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# ──────────────────────────────────────────────
|
| 28 |
+
# Pipeline State Schema
|
| 29 |
+
# ──────────────────────────────────────────────
|
| 30 |
+
|
| 31 |
+
class PipelineState(TypedDict, total=False):
|
| 32 |
+
"""State that flows through the LangGraph pipeline."""
|
| 33 |
+
# Input
|
| 34 |
+
input_type: str # "text", "image", "video"
|
| 35 |
+
raw_content: Any # str for text, bytes for image/video
|
| 36 |
+
user_id: str | None
|
| 37 |
+
|
| 38 |
+
# Preprocessed
|
| 39 |
+
processed_text: ProcessedText | None
|
| 40 |
+
processed_image: ProcessedImage | None
|
| 41 |
+
processed_video: ProcessedVideo | None
|
| 42 |
+
|
| 43 |
+
# Pipeline stages
|
| 44 |
+
filter_result: FilterResult | None
|
| 45 |
+
filter_results: list[FilterResult] # For video (multiple frames)
|
| 46 |
+
risk_score: RiskScore | None
|
| 47 |
+
deep_result: DeepAnalysisResult | None
|
| 48 |
+
decision: Decision | None
|
| 49 |
+
|
| 50 |
+
# Context
|
| 51 |
+
user_history: dict | None
|
| 52 |
+
|
| 53 |
+
# Metadata
|
| 54 |
+
error: str | None
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# ──────────────────────────────────────────────
|
| 58 |
+
# Pipeline Node Functions
|
| 59 |
+
# ──────────────────────────────────────────────
|
| 60 |
+
|
| 61 |
+
preprocessor = Preprocessor()
|
| 62 |
+
fast_filter = FastFilter()
|
| 63 |
+
risk_scorer = RiskScorer()
|
| 64 |
+
deep_analyzer = DeepAnalyzer()
|
| 65 |
+
decision_engine = DecisionEngine()
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
async def preprocess_node(state: PipelineState) -> dict:
|
| 69 |
+
"""Node 1: Preprocess the raw input."""
|
| 70 |
+
input_type = state["input_type"]
|
| 71 |
+
raw = state["raw_content"]
|
| 72 |
+
|
| 73 |
+
try:
|
| 74 |
+
if input_type == "text":
|
| 75 |
+
processed = preprocessor.process_text(raw)
|
| 76 |
+
return {"processed_text": processed}
|
| 77 |
+
|
| 78 |
+
elif input_type == "image":
|
| 79 |
+
processed = preprocessor.process_image(raw)
|
| 80 |
+
return {"processed_image": processed}
|
| 81 |
+
|
| 82 |
+
elif input_type == "video":
|
| 83 |
+
processed = preprocessor.process_video(raw)
|
| 84 |
+
return {"processed_video": processed}
|
| 85 |
+
|
| 86 |
+
else:
|
| 87 |
+
return {"error": f"Unknown input type: {input_type}"}
|
| 88 |
+
|
| 89 |
+
except Exception as e:
|
| 90 |
+
logger.error("preprocess_failed", error=str(e))
|
| 91 |
+
return {"error": f"Preprocessing failed: {str(e)}"}
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
async def fetch_user_history_node(state: PipelineState) -> dict:
|
| 95 |
+
"""Node 1b: Fetch user moderation history (parallel with preprocess)."""
|
| 96 |
+
user_id = state.get("user_id")
|
| 97 |
+
if not user_id:
|
| 98 |
+
return {"user_history": None}
|
| 99 |
+
|
| 100 |
+
# Try Redis cache first
|
| 101 |
+
cached = await redis_service.get_user_history(user_id)
|
| 102 |
+
if cached:
|
| 103 |
+
return {"user_history": cached}
|
| 104 |
+
|
| 105 |
+
# Fall back to MongoDB
|
| 106 |
+
history = await mongo_service.get_user_history(user_id)
|
| 107 |
+
if history:
|
| 108 |
+
await redis_service.cache_user_history(user_id, history)
|
| 109 |
+
return {"user_history": history}
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
async def fast_filter_node(state: PipelineState) -> dict:
|
| 113 |
+
"""Node 2: Run fast AI filter."""
|
| 114 |
+
input_type = state["input_type"]
|
| 115 |
+
|
| 116 |
+
try:
|
| 117 |
+
if input_type == "text" and state.get("processed_text"):
|
| 118 |
+
result = fast_filter.filter_text(state["processed_text"])
|
| 119 |
+
return {"filter_result": result}
|
| 120 |
+
|
| 121 |
+
elif input_type == "image" and state.get("processed_image"):
|
| 122 |
+
result = fast_filter.filter_image(state["processed_image"])
|
| 123 |
+
return {"filter_result": result}
|
| 124 |
+
|
| 125 |
+
elif input_type == "video" and state.get("processed_video"):
|
| 126 |
+
# Analyze each frame, take the worst result
|
| 127 |
+
video = state["processed_video"]
|
| 128 |
+
frame_results = []
|
| 129 |
+
for frame in video.frames:
|
| 130 |
+
result = fast_filter.filter_image(frame)
|
| 131 |
+
frame_results.append(result)
|
| 132 |
+
|
| 133 |
+
# Use the highest-risk frame as the representative result
|
| 134 |
+
if frame_results:
|
| 135 |
+
worst = max(frame_results, key=lambda r: r.max_score)
|
| 136 |
+
return {
|
| 137 |
+
"filter_result": worst,
|
| 138 |
+
"filter_results": frame_results,
|
| 139 |
+
}
|
| 140 |
+
else:
|
| 141 |
+
return {
|
| 142 |
+
"filter_result": FilterResult(
|
| 143 |
+
input_type="video",
|
| 144 |
+
is_flagged=False,
|
| 145 |
+
max_score=0.0,
|
| 146 |
+
)
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
return {"error": "No processed content available for filtering"}
|
| 150 |
+
|
| 151 |
+
except Exception as e:
|
| 152 |
+
logger.error("fast_filter_failed", error=str(e))
|
| 153 |
+
return {"error": f"Fast filter failed: {str(e)}"}
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
async def risk_score_node(state: PipelineState) -> dict:
|
| 157 |
+
"""Node 3: Compute composite risk score."""
|
| 158 |
+
filter_result = state.get("filter_result")
|
| 159 |
+
if not filter_result:
|
| 160 |
+
return {"error": "No filter result to score"}
|
| 161 |
+
|
| 162 |
+
try:
|
| 163 |
+
user_history = state.get("user_history")
|
| 164 |
+
score = risk_scorer.score(filter_result, user_history)
|
| 165 |
+
return {"risk_score": score}
|
| 166 |
+
except Exception as e:
|
| 167 |
+
logger.error("risk_score_failed", error=str(e))
|
| 168 |
+
return {"error": f"Risk scoring failed: {str(e)}"}
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def route_by_risk(state: PipelineState) -> str:
|
| 172 |
+
"""
|
| 173 |
+
Conditional router: decides whether to do deep analysis or skip to decision.
|
| 174 |
+
|
| 175 |
+
- LOW / MEDIUM → skip directly to decision
|
| 176 |
+
- HIGH → go to deep analysis
|
| 177 |
+
"""
|
| 178 |
+
risk = state.get("risk_score")
|
| 179 |
+
if risk and risk.level == "HIGH":
|
| 180 |
+
return "deep_analysis"
|
| 181 |
+
return "decide"
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
async def deep_analysis_node(state: PipelineState) -> dict:
|
| 185 |
+
"""Node 4 (conditional): Deep analysis with CLIP + Gemini."""
|
| 186 |
+
input_type = state["input_type"]
|
| 187 |
+
filter_result = state.get("filter_result")
|
| 188 |
+
|
| 189 |
+
try:
|
| 190 |
+
if input_type == "text" and state.get("processed_text"):
|
| 191 |
+
result = await deep_analyzer.analyze_text(
|
| 192 |
+
state["processed_text"].cleaned,
|
| 193 |
+
filter_result,
|
| 194 |
+
)
|
| 195 |
+
return {"deep_result": result}
|
| 196 |
+
|
| 197 |
+
elif input_type in ("image", "video") and state.get("processed_image"):
|
| 198 |
+
result = await deep_analyzer.analyze_image(
|
| 199 |
+
state["processed_image"].image,
|
| 200 |
+
filter_result,
|
| 201 |
+
)
|
| 202 |
+
return {"deep_result": result}
|
| 203 |
+
|
| 204 |
+
elif input_type == "video" and state.get("processed_video"):
|
| 205 |
+
# Use the worst frame for deep analysis
|
| 206 |
+
video = state["processed_video"]
|
| 207 |
+
if video.frames:
|
| 208 |
+
# Find the worst frame based on filter_results
|
| 209 |
+
worst_frame = video.frames[0]
|
| 210 |
+
filter_results = state.get("filter_results", [])
|
| 211 |
+
if filter_results:
|
| 212 |
+
worst_idx = max(
|
| 213 |
+
range(len(filter_results)),
|
| 214 |
+
key=lambda i: filter_results[i].max_score,
|
| 215 |
+
)
|
| 216 |
+
if worst_idx < len(video.frames):
|
| 217 |
+
worst_frame = video.frames[worst_idx]
|
| 218 |
+
|
| 219 |
+
result = await deep_analyzer.analyze_image(
|
| 220 |
+
worst_frame.image,
|
| 221 |
+
filter_result,
|
| 222 |
+
)
|
| 223 |
+
return {"deep_result": result}
|
| 224 |
+
|
| 225 |
+
return {"deep_result": None}
|
| 226 |
+
|
| 227 |
+
except Exception as e:
|
| 228 |
+
logger.error("deep_analysis_failed", error=str(e))
|
| 229 |
+
return {"deep_result": None}
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
async def decision_node(state: PipelineState) -> dict:
|
| 233 |
+
"""Node 5: Final decision."""
|
| 234 |
+
risk = state.get("risk_score")
|
| 235 |
+
if not risk:
|
| 236 |
+
# Emergency fallback
|
| 237 |
+
return {
|
| 238 |
+
"decision": Decision(
|
| 239 |
+
action="WARNING",
|
| 240 |
+
reason="Pipeline error: no risk score available",
|
| 241 |
+
severity="medium",
|
| 242 |
+
)
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
try:
|
| 246 |
+
deep_result = state.get("deep_result")
|
| 247 |
+
user_history = state.get("user_history")
|
| 248 |
+
decision = decision_engine.decide(risk, deep_result, user_history)
|
| 249 |
+
return {"decision": decision}
|
| 250 |
+
except Exception as e:
|
| 251 |
+
logger.error("decision_failed", error=str(e))
|
| 252 |
+
return {
|
| 253 |
+
"decision": Decision(
|
| 254 |
+
action="WARNING",
|
| 255 |
+
reason=f"Decision engine error: {str(e)}",
|
| 256 |
+
severity="medium",
|
| 257 |
+
)
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
# ──────────────────────────────────────────────
|
| 262 |
+
# Build the LangGraph Workflow
|
| 263 |
+
# ──────────────────────────────────────────────
|
| 264 |
+
|
| 265 |
+
def build_moderation_workflow():
|
| 266 |
+
"""
|
| 267 |
+
Construct and compile the LangGraph moderation pipeline.
|
| 268 |
+
|
| 269 |
+
Flow:
|
| 270 |
+
preprocess → fast_filter → risk_score
|
| 271 |
+
├─ LOW/MEDIUM → decide
|
| 272 |
+
└─ HIGH → deep_analysis → decide
|
| 273 |
+
|
| 274 |
+
Returns:
|
| 275 |
+
Compiled LangGraph workflow.
|
| 276 |
+
"""
|
| 277 |
+
graph = StateGraph(PipelineState)
|
| 278 |
+
|
| 279 |
+
# Add nodes
|
| 280 |
+
graph.add_node("preprocess", preprocess_node)
|
| 281 |
+
graph.add_node("fetch_history", fetch_user_history_node)
|
| 282 |
+
graph.add_node("fast_filter", fast_filter_node)
|
| 283 |
+
graph.add_node("risk_score", risk_score_node)
|
| 284 |
+
graph.add_node("deep_analysis", deep_analysis_node)
|
| 285 |
+
graph.add_node("decide", decision_node)
|
| 286 |
+
|
| 287 |
+
# Define edges
|
| 288 |
+
graph.set_entry_point("preprocess")
|
| 289 |
+
|
| 290 |
+
# After preprocess, run fast filter
|
| 291 |
+
graph.add_edge("preprocess", "fast_filter")
|
| 292 |
+
|
| 293 |
+
# After fast filter, compute risk score
|
| 294 |
+
graph.add_edge("fast_filter", "risk_score")
|
| 295 |
+
|
| 296 |
+
# Conditional routing based on risk level
|
| 297 |
+
graph.add_conditional_edges(
|
| 298 |
+
"risk_score",
|
| 299 |
+
route_by_risk,
|
| 300 |
+
{
|
| 301 |
+
"deep_analysis": "deep_analysis",
|
| 302 |
+
"decide": "decide",
|
| 303 |
+
},
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
# Deep analysis flows to decision
|
| 307 |
+
graph.add_edge("deep_analysis", "decide")
|
| 308 |
+
|
| 309 |
+
# Decision is the terminal node
|
| 310 |
+
graph.add_edge("decide", END)
|
| 311 |
+
|
| 312 |
+
# Compile
|
| 313 |
+
workflow = graph.compile()
|
| 314 |
+
logger.info("moderation_workflow_compiled")
|
| 315 |
+
return workflow
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
# Global compiled workflow (initialized at startup)
|
| 319 |
+
moderation_workflow = None
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def get_workflow():
|
| 323 |
+
"""Get or create the compiled moderation workflow."""
|
| 324 |
+
global moderation_workflow
|
| 325 |
+
if moderation_workflow is None:
|
| 326 |
+
moderation_workflow = build_moderation_workflow()
|
| 327 |
+
return moderation_workflow
|
app/services/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/services/__init__.py
|
| 2 |
+
"""External service integrations: MongoDB, Redis, Gemini."""
|
app/services/gemini_service.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/services/gemini_service.py
|
| 2 |
+
# Gemini API integration via LangChain for deep analysis reasoning
|
| 3 |
+
|
| 4 |
+
from app.config import get_settings
|
| 5 |
+
from app.observability.logging import get_logger
|
| 6 |
+
|
| 7 |
+
logger = get_logger(__name__)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class GeminiService:
|
| 11 |
+
"""
|
| 12 |
+
Google Gemini API client powered by LangChain.
|
| 13 |
+
|
| 14 |
+
Used in the deep analysis path for contextual reasoning
|
| 15 |
+
about flagged content. Replaces the old raw REST calls
|
| 16 |
+
with structured LangChain invocations for:
|
| 17 |
+
- Reliable structured output (JSON)
|
| 18 |
+
- Automatic retry and fallback
|
| 19 |
+
- LangSmith trace integration
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self):
|
| 23 |
+
self.settings = get_settings()
|
| 24 |
+
self.llm = None
|
| 25 |
+
self._current_key_idx = 0
|
| 26 |
+
self._initialized = False
|
| 27 |
+
|
| 28 |
+
def initialize(self) -> None:
|
| 29 |
+
"""Initialize the LangChain Gemini client."""
|
| 30 |
+
keys = self.settings.gemini_keys_list
|
| 31 |
+
if not keys:
|
| 32 |
+
logger.warning("gemini_no_keys", reason="No API keys configured")
|
| 33 |
+
return
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 37 |
+
|
| 38 |
+
self.llm = ChatGoogleGenerativeAI(
|
| 39 |
+
model=self.settings.gemini_model,
|
| 40 |
+
google_api_key=keys[self._current_key_idx],
|
| 41 |
+
temperature=0.1,
|
| 42 |
+
max_output_tokens=1024,
|
| 43 |
+
convert_system_message_to_human=True,
|
| 44 |
+
)
|
| 45 |
+
self._initialized = True
|
| 46 |
+
logger.info("gemini_initialized", model=self.settings.gemini_model)
|
| 47 |
+
except Exception as e:
|
| 48 |
+
logger.error("gemini_init_failed", error=str(e))
|
| 49 |
+
|
| 50 |
+
def _rotate_key(self) -> bool:
|
| 51 |
+
"""Rotate to the next API key. Returns False if all keys exhausted."""
|
| 52 |
+
keys = self.settings.gemini_keys_list
|
| 53 |
+
if not keys:
|
| 54 |
+
return False
|
| 55 |
+
|
| 56 |
+
self._current_key_idx = (self._current_key_idx + 1) % len(keys)
|
| 57 |
+
try:
|
| 58 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 59 |
+
|
| 60 |
+
self.llm = ChatGoogleGenerativeAI(
|
| 61 |
+
model=self.settings.gemini_model,
|
| 62 |
+
google_api_key=keys[self._current_key_idx],
|
| 63 |
+
temperature=0.1,
|
| 64 |
+
max_output_tokens=1024,
|
| 65 |
+
convert_system_message_to_human=True,
|
| 66 |
+
)
|
| 67 |
+
logger.info("gemini_key_rotated", key_index=self._current_key_idx)
|
| 68 |
+
return True
|
| 69 |
+
except Exception:
|
| 70 |
+
return False
|
| 71 |
+
|
| 72 |
+
async def analyze_text(self, text: str, context: dict | None = None) -> dict:
|
| 73 |
+
"""
|
| 74 |
+
Perform deep contextual analysis of flagged text.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
text: The flagged text content.
|
| 78 |
+
context: Additional context (fast filter results, categories, etc.).
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
Structured analysis with verdict, reasoning, and severity.
|
| 82 |
+
"""
|
| 83 |
+
if not self._initialized:
|
| 84 |
+
return {"error": "Gemini not initialized", "is_confirmed": False}
|
| 85 |
+
|
| 86 |
+
from langchain_core.messages import SystemMessage, HumanMessage
|
| 87 |
+
|
| 88 |
+
system_prompt = """You are an expert content moderator specializing in cyberbullying detection.
|
| 89 |
+
You are analyzing content that has been flagged as potentially harmful by automated filters.
|
| 90 |
+
|
| 91 |
+
Your task: Provide a detailed, contextual analysis. Consider:
|
| 92 |
+
- Intent and context (sarcasm, jokes, genuine threats)
|
| 93 |
+
- Severity level (mild rudeness vs. serious threats)
|
| 94 |
+
- Whether this constitutes cyberbullying
|
| 95 |
+
- Impact on minors (under 18)
|
| 96 |
+
|
| 97 |
+
Respond in this exact JSON format:
|
| 98 |
+
{
|
| 99 |
+
"is_confirmed": true/false,
|
| 100 |
+
"severity": "low" | "medium" | "high" | "critical",
|
| 101 |
+
"categories": ["category1", "category2"],
|
| 102 |
+
"reasoning": "Detailed explanation of your analysis",
|
| 103 |
+
"recommended_action": "allow" | "warn" | "block" | "escalate",
|
| 104 |
+
"confidence": 0.0-1.0
|
| 105 |
+
}"""
|
| 106 |
+
|
| 107 |
+
context_str = ""
|
| 108 |
+
if context:
|
| 109 |
+
context_str = f"\n\nPre-filter context: {context}"
|
| 110 |
+
|
| 111 |
+
human_message = f"""Analyze this flagged content:{context_str}
|
| 112 |
+
|
| 113 |
+
Content: "{text}"
|
| 114 |
+
|
| 115 |
+
Provide your JSON analysis:"""
|
| 116 |
+
|
| 117 |
+
return await self._invoke(system_prompt, human_message)
|
| 118 |
+
|
| 119 |
+
async def analyze_image(self, image_base64: str, context: dict | None = None) -> dict:
|
| 120 |
+
"""
|
| 121 |
+
Perform deep analysis of a flagged image.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
image_base64: Base64-encoded image.
|
| 125 |
+
context: Additional context from fast filter.
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
Structured analysis dict.
|
| 129 |
+
"""
|
| 130 |
+
if not self._initialized:
|
| 131 |
+
return {"error": "Gemini not initialized", "is_confirmed": False}
|
| 132 |
+
|
| 133 |
+
from langchain_core.messages import SystemMessage, HumanMessage
|
| 134 |
+
|
| 135 |
+
system_prompt = """You are an expert content moderator analyzing images for content harmful to minors.
|
| 136 |
+
|
| 137 |
+
Analyze the image for:
|
| 138 |
+
- Violence, gore, weapons
|
| 139 |
+
- Nudity, sexual content
|
| 140 |
+
- Drug/alcohol imagery
|
| 141 |
+
- Self-harm or suicide content
|
| 142 |
+
- Hate symbols, extremist content
|
| 143 |
+
- Cyberbullying imagery (humiliating photos, etc.)
|
| 144 |
+
|
| 145 |
+
Respond in this exact JSON format:
|
| 146 |
+
{
|
| 147 |
+
"is_confirmed": true/false,
|
| 148 |
+
"severity": "low" | "medium" | "high" | "critical",
|
| 149 |
+
"categories": ["category1", "category2"],
|
| 150 |
+
"reasoning": "Description of what was found",
|
| 151 |
+
"recommended_action": "allow" | "warn" | "block" | "escalate",
|
| 152 |
+
"confidence": 0.0-1.0
|
| 153 |
+
}"""
|
| 154 |
+
|
| 155 |
+
context_str = f"\nPre-filter flags: {context}" if context else ""
|
| 156 |
+
|
| 157 |
+
human_content = [
|
| 158 |
+
{"type": "text", "text": f"Analyze this flagged image:{context_str}\n\nProvide your JSON analysis:"},
|
| 159 |
+
{
|
| 160 |
+
"type": "image_url",
|
| 161 |
+
"image_url": {"url": f"data:image/jpeg;base64,{image_base64}"},
|
| 162 |
+
},
|
| 163 |
+
]
|
| 164 |
+
|
| 165 |
+
return await self._invoke_multimodal(system_prompt, human_content)
|
| 166 |
+
|
| 167 |
+
async def _invoke(self, system_prompt: str, human_message: str) -> dict:
|
| 168 |
+
"""Invoke Gemini with retry on rate limits."""
|
| 169 |
+
from langchain_core.messages import SystemMessage, HumanMessage
|
| 170 |
+
import json
|
| 171 |
+
|
| 172 |
+
keys = self.settings.gemini_keys_list
|
| 173 |
+
attempts = max(len(keys), 1)
|
| 174 |
+
|
| 175 |
+
for attempt in range(attempts):
|
| 176 |
+
try:
|
| 177 |
+
messages = [
|
| 178 |
+
SystemMessage(content=system_prompt),
|
| 179 |
+
HumanMessage(content=human_message),
|
| 180 |
+
]
|
| 181 |
+
response = await self.llm.ainvoke(messages)
|
| 182 |
+
return self._parse_response(response.content)
|
| 183 |
+
|
| 184 |
+
except Exception as e:
|
| 185 |
+
error_str = str(e)
|
| 186 |
+
if "429" in error_str or "quota" in error_str.lower():
|
| 187 |
+
logger.warning("gemini_rate_limited", attempt=attempt)
|
| 188 |
+
if not self._rotate_key():
|
| 189 |
+
break
|
| 190 |
+
else:
|
| 191 |
+
logger.error("gemini_invoke_failed", error=error_str)
|
| 192 |
+
return {"error": error_str, "is_confirmed": False}
|
| 193 |
+
|
| 194 |
+
return {"error": "All Gemini API keys exhausted", "is_confirmed": False}
|
| 195 |
+
|
| 196 |
+
async def _invoke_multimodal(self, system_prompt: str, human_content: list) -> dict:
|
| 197 |
+
"""Invoke Gemini with multimodal content."""
|
| 198 |
+
from langchain_core.messages import SystemMessage, HumanMessage
|
| 199 |
+
|
| 200 |
+
keys = self.settings.gemini_keys_list
|
| 201 |
+
attempts = max(len(keys), 1)
|
| 202 |
+
|
| 203 |
+
for attempt in range(attempts):
|
| 204 |
+
try:
|
| 205 |
+
messages = [
|
| 206 |
+
SystemMessage(content=system_prompt),
|
| 207 |
+
HumanMessage(content=human_content),
|
| 208 |
+
]
|
| 209 |
+
response = await self.llm.ainvoke(messages)
|
| 210 |
+
return self._parse_response(response.content)
|
| 211 |
+
|
| 212 |
+
except Exception as e:
|
| 213 |
+
error_str = str(e)
|
| 214 |
+
if "429" in error_str or "quota" in error_str.lower():
|
| 215 |
+
if not self._rotate_key():
|
| 216 |
+
break
|
| 217 |
+
else:
|
| 218 |
+
return {"error": error_str, "is_confirmed": False}
|
| 219 |
+
|
| 220 |
+
return {"error": "All Gemini API keys exhausted", "is_confirmed": False}
|
| 221 |
+
|
| 222 |
+
def _parse_response(self, text: str) -> dict:
|
| 223 |
+
"""Parse JSON from Gemini response text."""
|
| 224 |
+
import json, re
|
| 225 |
+
|
| 226 |
+
try:
|
| 227 |
+
# Extract JSON block (handle markdown code fences)
|
| 228 |
+
json_match = re.search(r"\{[\s\S]*\}", text)
|
| 229 |
+
if json_match:
|
| 230 |
+
return json.loads(json_match.group())
|
| 231 |
+
except json.JSONDecodeError:
|
| 232 |
+
pass
|
| 233 |
+
|
| 234 |
+
logger.warning("gemini_parse_failed", raw_text=text[:200])
|
| 235 |
+
return {
|
| 236 |
+
"error": "Failed to parse Gemini response",
|
| 237 |
+
"is_confirmed": False,
|
| 238 |
+
"raw_response": text[:500],
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
@property
|
| 242 |
+
def is_initialized(self) -> bool:
|
| 243 |
+
return self._initialized
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
# Global singleton
|
| 247 |
+
gemini_service = GeminiService()
|