sajith-0701 commited on
Commit
71c1ad2
·
0 Parent(s):

initial deployment for HF Spaces

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .dockerignore +54 -0
  2. .env +47 -0
  3. .env.example +48 -0
  4. .gitattributes +36 -0
  5. Dockerfile +34 -0
  6. README.md +11 -0
  7. app/__init__.py +2 -0
  8. app/api/__init__.py +2 -0
  9. app/api/router.py +27 -0
  10. app/api/schemas/__init__.py +2 -0
  11. app/api/schemas/requests.py +43 -0
  12. app/api/schemas/responses.py +136 -0
  13. app/api/v1/__init__.py +2 -0
  14. app/api/v1/alerts.py +151 -0
  15. app/api/v1/analyze.py +330 -0
  16. app/api/v1/auth.py +217 -0
  17. app/api/v1/health.py +51 -0
  18. app/api/v1/history.py +68 -0
  19. app/api/v1/scan.py +166 -0
  20. app/api/v1/users.py +61 -0
  21. app/config.py +99 -0
  22. app/core/__init__.py +1 -0
  23. app/core/dependencies.py +48 -0
  24. app/core/security.py +79 -0
  25. app/db/__init__.py +1 -0
  26. app/db/connection.py +36 -0
  27. app/db/models/__init__.py +10 -0
  28. app/db/models/alert.py +75 -0
  29. app/db/models/scan_result.py +43 -0
  30. app/db/models/user.py +74 -0
  31. app/dependencies.py +29 -0
  32. app/main.py +154 -0
  33. app/models/__init__.py +2 -0
  34. app/models/clip_model.py +145 -0
  35. app/models/image_model.py +180 -0
  36. app/models/model_registry.py +99 -0
  37. app/models/onnx_utils.py +120 -0
  38. app/models/text_model.py +178 -0
  39. app/observability/__init__.py +2 -0
  40. app/observability/langsmith.py +59 -0
  41. app/observability/logging.py +41 -0
  42. app/pipeline/__init__.py +2 -0
  43. app/pipeline/decision_engine.py +142 -0
  44. app/pipeline/deep_analyzer.py +153 -0
  45. app/pipeline/fast_filter.py +124 -0
  46. app/pipeline/preprocessor.py +141 -0
  47. app/pipeline/risk_scorer.py +166 -0
  48. app/pipeline/workflow.py +327 -0
  49. app/services/__init__.py +2 -0
  50. app/services/gemini_service.py +247 -0
.dockerignore ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Environments
2
+ venv/
3
+ .venv/
4
+ env/
5
+ .env
6
+
7
+ # Python
8
+ __pycache__/
9
+ *.pyc
10
+ *.pyo
11
+ *.pyd
12
+ .Python
13
+ build/
14
+ develop-eggs/
15
+ dist/
16
+ downloads/
17
+ eggs/
18
+ .eggs/
19
+ lib/
20
+ lib64/
21
+ parts/
22
+ sdist/
23
+ var/
24
+ wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+
29
+ # Data & Models (not baked into image — loaded at runtime)
30
+ model_cache/
31
+ tests/
32
+ .ipynb_checkpoints/
33
+
34
+ # Version Control
35
+ .git/
36
+ .github/
37
+ .gitignore
38
+
39
+ # IDEs
40
+ .vscode/
41
+ .idea/
42
+
43
+ # OS files
44
+ .DS_Store
45
+ Thumbs.db
46
+
47
+ # HF Spaces — README is required at repo root, not in image
48
+ # README.md is NOT excluded so HF can read the frontmatter
49
+
50
+
51
+ model_cache/
52
+ __pycache__/
53
+ *.pyc
54
+ .env
.env ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Server
2
+ HOST=0.0.0.0
3
+ PORT=8000
4
+ ENV=development
5
+ LOG_LEVEL=INFO
6
+
7
+ # JWT Authentication
8
+ JWT_ACCESS_SECRET=hubble-access-secret-change-in-production-32chars
9
+ JWT_REFRESH_SECRET=hubble-refresh-secret-change-in-production-32chars
10
+ JWT_ACCESS_EXPIRES_MINUTES=15
11
+ JWT_REFRESH_EXPIRES_DAYS=7
12
+
13
+ # Security
14
+ BCRYPT_ROUNDS=12
15
+ CORS_ORIGINS=http://localhost:3000,http://localhost:3001
16
+
17
+ # MongoDB (Atlas Cloud)
18
+ MONGODB_URI=mongodb+srv://sajithjaganathan7_db_user:Winter_bear_07@cluster0.jxdvukx.mongodb.net/hubble?appName=Cluster0
19
+ MONGODB_DB_NAME=hubble
20
+
21
+ # Redis (Redis Cloud)
22
+ REDIS_URL=redis://default:dongH74t41QfBN0TO0e5ylWAVThXZoLR@redis-13470.crce281.ap-south-1-3.ec2.cloud.redislabs.com:13470
23
+ REDIS_CACHE_TTL=300
24
+
25
+ # Gemini API Keys (User should provide real keys for this)
26
+ GEMINI_API_KEYS=AIzaSyBL6onsP6Z-wG32nFgy8Bi7uGDGIopbRDE
27
+ GEMINI_MODEL=gemini-2.5-flash
28
+
29
+ # LangSmith Tracing
30
+ LANGSMITH_API_KEY=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
31
+ LANGSMITH_PROJECT=hubble-moderation
32
+ LANGSMITH_TRACING_V2=true
33
+
34
+ # Model Configuration
35
+ MODEL_CACHE_DIR=./model_cache
36
+ ONNX_ENABLED=false
37
+ TEXT_MODEL_NAME=unitary/toxic-bert
38
+ IMAGE_MODEL_NAME=google/efficientnet-b0
39
+ CLIP_MODEL_NAME=openai/clip-vit-base-patch32
40
+
41
+ # Risk Thresholds
42
+ RISK_LOW_MAX=30
43
+ RISK_MEDIUM_MAX=65
44
+
45
+ # Video Processing
46
+ VIDEO_MAX_FRAMES=10
47
+ VIDEO_FPS_SAMPLE=1
.env.example ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================
2
+ # Hubble AI Engine - Environment Configuration
3
+ # ============================================
4
+ # On Hugging Face Spaces, set these as Secrets
5
+ # in the Space settings (Settings > Secrets).
6
+ # They are injected as environment variables at runtime.
7
+
8
+ # Server
9
+ HOST=0.0.0.0
10
+ PORT=7860
11
+ ENV=production
12
+ LOG_LEVEL=INFO
13
+
14
+ # MongoDB (async via motor)
15
+ MONGODB_URI=mongodb+srv://<user>:<password>@<cluster>.mongodb.net/hubble
16
+ MONGODB_DB_NAME=hubble
17
+
18
+ # Redis
19
+ REDIS_URL=redis://default:<password>@<host>:<port>
20
+ REDIS_CACHE_TTL=300
21
+
22
+ # Gemini API Keys (comma-separated for rotation)
23
+ GEMINI_API_KEYS=your-key-1,your-key-2,your-key-3
24
+ GEMINI_MODEL=gemini-2.5-flash
25
+
26
+ # LangSmith Observability (optional)
27
+ LANGSMITH_API_KEY=your-langsmith-api-key
28
+ LANGSMITH_PROJECT=hubble-moderation
29
+ LANGSMITH_TRACING_V2=true
30
+
31
+ # JWT Secrets (use long random strings)
32
+ JWT_ACCESS_SECRET=your-access-secret-at-least-32-chars
33
+ JWT_REFRESH_SECRET=your-refresh-secret-at-least-32-chars
34
+
35
+ # Model Configuration
36
+ MODEL_CACHE_DIR=/tmp/model_cache
37
+ ONNX_ENABLED=false
38
+ TEXT_MODEL_NAME=unitary/toxic-bert
39
+ IMAGE_MODEL_NAME=google/efficientnet-b0
40
+ CLIP_MODEL_NAME=openai/clip-vit-base-patch32
41
+
42
+ # Risk Thresholds
43
+ RISK_LOW_MAX=30
44
+ RISK_MEDIUM_MAX=65
45
+
46
+ # Video Processing
47
+ VIDEO_MAX_FRAMES=10
48
+ VIDEO_FPS_SAMPLE=1
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.onnx.data filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hugging Face Spaces — Hubble AI Engine
2
+ FROM python:3.12-slim
3
+
4
+ # HF Spaces requires port 7860
5
+ ENV PYTHONDONTWRITEBYTECODE=1
6
+ ENV PYTHONUNBUFFERED=1
7
+ ENV HOST=0.0.0.0
8
+ ENV PORT=7860
9
+
10
+ WORKDIR /app
11
+
12
+ # System dependencies for OpenCV and native builds
13
+ RUN apt-get update && apt-get install -y --no-install-recommends \
14
+ build-essential \
15
+ libgl1 \
16
+ libglib2.0-0 \
17
+ && rm -rf /var/lib/apt/lists/*
18
+
19
+ # Install Python dependencies
20
+ COPY requirements.txt .
21
+ RUN pip install --no-cache-dir --upgrade pip && \
22
+ pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu && \
23
+ pip install --no-cache-dir -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cpu
24
+
25
+ COPY . .
26
+
27
+ # HF Spaces runs containers as a non-root user (UID 1000)
28
+ RUN useradd -m -u 1000 user
29
+ RUN mkdir -p /tmp/model_cache && chown -R user:user /tmp/model_cache && chown -R user:user /app
30
+ USER user
31
+
32
+ EXPOSE 7860
33
+
34
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: SentinelAI
3
+ emoji: 📈
4
+ colorFrom: green
5
+ colorTo: blue
6
+ sdk: docker
7
+ pinned: false
8
+ license: mit
9
+ ---
10
+
11
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # app/__init__.py
2
+ """Hubble AI Engine — Production-grade cyberbullying detection pipeline."""
app/api/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # app/api/__init__.py
2
+ """API layer: routes, schemas, and versioning."""
app/api/router.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/api/router.py
2
+ # Root router — aggregates all versioned API routes
3
+
4
+ from fastapi import APIRouter
5
+ from app.api.v1.health import router as health_router
6
+ from app.api.v1.analyze import router as analyze_router
7
+ from app.api.v1.history import router as history_router
8
+ from app.api.v1.auth import router as auth_router
9
+ from app.api.v1.users import router as users_router
10
+ from app.api.v1.scan import router as scan_router
11
+ from app.api.v1.alerts import router as alerts_router
12
+
13
+ # Main API router
14
+ api_router = APIRouter()
15
+
16
+ # ── Unauthenticated: health + raw AI pipeline ──
17
+ api_router.include_router(health_router, prefix="")
18
+ api_router.include_router(analyze_router, prefix="/api/v1")
19
+ api_router.include_router(history_router, prefix="/api/v1")
20
+
21
+ # ── Auth ──
22
+ api_router.include_router(auth_router, prefix="/api/v1")
23
+
24
+ # ── Authenticated business logic ──
25
+ api_router.include_router(users_router, prefix="/api/v1")
26
+ api_router.include_router(scan_router, prefix="/api/v1")
27
+ api_router.include_router(alerts_router, prefix="/api/v1")
app/api/schemas/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # app/api/schemas/__init__.py
2
+ """Pydantic request/response schemas."""
app/api/schemas/requests.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/api/schemas/requests.py
2
+ # Pydantic request models for API endpoints
3
+
4
+ from pydantic import BaseModel, Field
5
+ from typing import Optional
6
+
7
+
8
+ class TextAnalysisRequest(BaseModel):
9
+ """Request body for text analysis."""
10
+ text: str = Field(..., min_length=1, max_length=10000, description="Text content to analyze")
11
+ user_id: Optional[str] = Field(None, description="Optional user ID for history tracking")
12
+ source_app: Optional[str] = Field(None, description="Source application (e.g., 'instagram', 'whatsapp')")
13
+ metadata: Optional[dict] = Field(None, description="Additional metadata")
14
+
15
+ model_config = {
16
+ "json_schema_extra": {
17
+ "examples": [
18
+ {
19
+ "text": "You are so ugly nobody likes you",
20
+ "user_id": "user_123",
21
+ "source_app": "instagram",
22
+ }
23
+ ]
24
+ }
25
+ }
26
+
27
+
28
+ class ImageAnalysisRequest(BaseModel):
29
+ """Metadata for image analysis (file sent as multipart)."""
30
+ user_id: Optional[str] = Field(None, description="Optional user ID for history tracking")
31
+ source_app: Optional[str] = Field(None, description="Source application")
32
+
33
+
34
+ class VideoAnalysisRequest(BaseModel):
35
+ """Metadata for video analysis (file sent as multipart)."""
36
+ user_id: Optional[str] = Field(None, description="Optional user ID for history tracking")
37
+ source_app: Optional[str] = Field(None, description="Source application")
38
+
39
+
40
+ class HistoryRequest(BaseModel):
41
+ """Parameters for history queries."""
42
+ limit: int = Field(20, ge=1, le=100, description="Maximum results to return")
43
+ skip: int = Field(0, ge=0, description="Number of results to skip")
app/api/schemas/responses.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/api/schemas/responses.py
2
+ # Pydantic response models for API endpoints
3
+
4
+ from pydantic import BaseModel, Field
5
+ from typing import Optional, Literal, Any
6
+ from datetime import datetime
7
+
8
+
9
+ class DecisionDetail(BaseModel):
10
+ """Details of the moderation decision."""
11
+ action: Literal["ALLOWED", "WARNING", "BLOCKED"]
12
+ reason: str
13
+ severity: str
14
+ should_alert_parent: bool = False
15
+ escalation_notes: Optional[str] = None
16
+
17
+
18
+ class DeepAnalysisDetail(BaseModel):
19
+ """Details from the deep analysis stage (only present for HIGH risk)."""
20
+ is_confirmed: bool
21
+ severity: str
22
+ reasoning: str
23
+ categories: list[str] = []
24
+ recommended_action: str
25
+ confidence: float
26
+ clip_scores: dict = {}
27
+
28
+
29
+ class RiskDetail(BaseModel):
30
+ """Breakdown of risk scoring."""
31
+ score: float
32
+ level: Literal["LOW", "MEDIUM", "HIGH"]
33
+ components: dict = {}
34
+ repeat_offender: bool = False
35
+
36
+
37
+ class FilterDetail(BaseModel):
38
+ """Fast filter stage output."""
39
+ is_flagged: bool
40
+ scores: dict[str, float] = {}
41
+ max_score: float
42
+ max_label: str
43
+ categories: list[str] = []
44
+
45
+
46
+ class AnalysisResponse(BaseModel):
47
+ """
48
+ Unified response for all /analyze/* endpoints.
49
+
50
+ This is the primary contract between the AI engine
51
+ and the Node.js backend (and any external consumers).
52
+ """
53
+ request_id: str = Field(..., description="Unique request identifier")
54
+ input_type: Literal["text", "image", "video"]
55
+ status: Literal["ALLOWED", "WARNING", "BLOCKED"]
56
+ risk_level: Literal["LOW", "MEDIUM", "HIGH"]
57
+ risk_score: float = Field(..., ge=0, le=100, description="Composite risk score 0-100")
58
+ categories: list[str] = Field(default_factory=list, description="Detected abuse categories")
59
+ confidence: float = Field(..., ge=0, le=1, description="Overall confidence 0-1")
60
+
61
+ # Detailed breakdowns
62
+ decision: DecisionDetail
63
+ risk_detail: RiskDetail
64
+ filter_detail: FilterDetail
65
+ deep_analysis: Optional[DeepAnalysisDetail] = None
66
+
67
+ # Metadata
68
+ processing_time_ms: int = Field(..., description="Total processing time in milliseconds")
69
+ trace_id: Optional[str] = Field(None, description="LangSmith trace ID for observability")
70
+ cached: bool = Field(False, description="Whether this result was served from cache")
71
+
72
+ model_config = {
73
+ "json_schema_extra": {
74
+ "examples": [
75
+ {
76
+ "request_id": "req_abc123",
77
+ "input_type": "text",
78
+ "status": "WARNING",
79
+ "risk_level": "MEDIUM",
80
+ "risk_score": 45.2,
81
+ "categories": ["insult", "toxic"],
82
+ "confidence": 0.82,
83
+ "decision": {
84
+ "action": "WARNING",
85
+ "reason": "Content flagged as potentially harmful",
86
+ "severity": "medium",
87
+ "should_alert_parent": False,
88
+ },
89
+ "risk_detail": {
90
+ "score": 45.2,
91
+ "level": "MEDIUM",
92
+ "components": {
93
+ "base_score": 42.0,
94
+ "multi_category_penalty": 3.2,
95
+ "repeat_offender_boost": 0.0,
96
+ },
97
+ "repeat_offender": False,
98
+ },
99
+ "filter_detail": {
100
+ "is_flagged": True,
101
+ "scores": {"toxic": 0.78, "insult": 0.65},
102
+ "max_score": 0.78,
103
+ "max_label": "toxic",
104
+ "categories": ["toxic", "insult"],
105
+ },
106
+ "deep_analysis": None,
107
+ "processing_time_ms": 156,
108
+ "trace_id": None,
109
+ "cached": False,
110
+ }
111
+ ]
112
+ }
113
+ }
114
+
115
+
116
+ class HealthResponse(BaseModel):
117
+ """Health check response."""
118
+ status: str
119
+ version: str
120
+ models: dict[str, bool]
121
+ services: dict[str, bool]
122
+ uptime_seconds: float
123
+
124
+
125
+ class HistoryResponse(BaseModel):
126
+ """Moderation history response."""
127
+ user_id: str
128
+ total: int
129
+ results: list[dict[str, Any]]
130
+
131
+
132
+ class ErrorResponse(BaseModel):
133
+ """Standard error response."""
134
+ error: str
135
+ detail: str
136
+ request_id: Optional[str] = None
app/api/v1/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # app/api/v1/__init__.py
2
+ """API v1 endpoints."""
app/api/v1/alerts.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/api/v1/alerts.py
2
+ # Alert endpoints for parents and children
3
+
4
+ from fastapi import APIRouter, HTTPException, Depends
5
+ from pydantic import BaseModel
6
+ from typing import Optional
7
+ from app.core.dependencies import get_current_user
8
+ from app.db.models.user import UserDocument, UserRole
9
+ from app.db.models.alert import AlertDocument, AlertStatus
10
+ from app.observability.logging import get_logger
11
+ from datetime import datetime
12
+
13
+ logger = get_logger(__name__)
14
+ router = APIRouter(prefix="/alerts", tags=["Alerts"])
15
+
16
+
17
+ @router.get("")
18
+ async def list_alerts(
19
+ page: int = 1,
20
+ limit: int = 20,
21
+ status: Optional[str] = None,
22
+ severity: Optional[str] = None,
23
+ user: UserDocument = Depends(get_current_user),
24
+ ):
25
+ """List alerts. Parents see all their children's alerts; children see their own."""
26
+ skip = (page - 1) * limit
27
+
28
+ if user.role == UserRole.PARENT:
29
+ query = AlertDocument.find(AlertDocument.parent_id == str(user.id))
30
+ else:
31
+ query = AlertDocument.find(AlertDocument.child_id == str(user.id))
32
+
33
+ if status:
34
+ query = query.find(AlertDocument.status == AlertStatus(status))
35
+ if severity:
36
+ from app.db.models.alert import AlertSeverity
37
+ query = query.find(AlertDocument.severity == AlertSeverity(severity))
38
+
39
+ alerts = await query.sort(-AlertDocument.created_at).skip(skip).limit(limit).to_list()
40
+ return {
41
+ "success": True,
42
+ "page": page,
43
+ "limit": limit,
44
+ "alerts": [_fmt(a) for a in alerts],
45
+ }
46
+
47
+
48
+ @router.get("/{alert_id}")
49
+ async def get_alert(
50
+ alert_id: str,
51
+ user: UserDocument = Depends(get_current_user),
52
+ ):
53
+ alert = await AlertDocument.get(alert_id)
54
+ if not alert:
55
+ raise HTTPException(status_code=404, detail="Alert not found")
56
+ _check_access(alert, user)
57
+ return {"success": True, "alert": _fmt(alert)}
58
+
59
+
60
+ class AcknowledgeRequest(BaseModel):
61
+ resolution_notes: Optional[str] = None
62
+
63
+
64
+ @router.post("/{alert_id}/acknowledge")
65
+ async def acknowledge_alert(
66
+ alert_id: str,
67
+ user: UserDocument = Depends(get_current_user),
68
+ ):
69
+ alert = await AlertDocument.get(alert_id)
70
+ if not alert:
71
+ raise HTTPException(status_code=404, detail="Alert not found")
72
+ _check_access(alert, user)
73
+ alert.status = AlertStatus.ACKNOWLEDGED
74
+ alert.acknowledged_at = datetime.utcnow()
75
+ alert.updated_at = datetime.utcnow()
76
+ await alert.save()
77
+ return {"success": True, "alert": _fmt(alert)}
78
+
79
+
80
+ @router.post("/{alert_id}/resolve")
81
+ async def resolve_alert(
82
+ alert_id: str,
83
+ body: AcknowledgeRequest,
84
+ user: UserDocument = Depends(get_current_user),
85
+ ):
86
+ if user.role != UserRole.PARENT:
87
+ raise HTTPException(status_code=403, detail="Only parents can resolve alerts")
88
+ alert = await AlertDocument.get(alert_id)
89
+ if not alert:
90
+ raise HTTPException(status_code=404, detail="Alert not found")
91
+ if alert.parent_id != str(user.id):
92
+ raise HTTPException(status_code=403, detail="Cannot resolve this alert")
93
+ alert.status = AlertStatus.RESOLVED
94
+ alert.resolved_at = datetime.utcnow()
95
+ alert.resolved_by = str(user.id)
96
+ alert.resolution_notes = body.resolution_notes
97
+ alert.updated_at = datetime.utcnow()
98
+ await alert.save()
99
+ return {"success": True, "alert": _fmt(alert)}
100
+
101
+
102
+ @router.get("/stats/summary")
103
+ async def alert_stats(user: UserDocument = Depends(get_current_user)):
104
+ if user.role != UserRole.PARENT:
105
+ raise HTTPException(status_code=403, detail="Only parents can view stats")
106
+ alerts = await AlertDocument.find(
107
+ AlertDocument.parent_id == str(user.id)
108
+ ).to_list()
109
+
110
+ by_severity: dict = {}
111
+ by_status: dict = {}
112
+ by_category: dict = {}
113
+ for a in alerts:
114
+ by_severity[a.severity.value] = by_severity.get(a.severity.value, 0) + 1
115
+ by_status[a.status.value] = by_status.get(a.status.value, 0) + 1
116
+ for c in a.categories:
117
+ by_category[c] = by_category.get(c, 0) + 1
118
+
119
+ return {
120
+ "success": True,
121
+ "total": len(alerts),
122
+ "by_severity": by_severity,
123
+ "by_status": by_status,
124
+ "by_category": by_category,
125
+ }
126
+
127
+
128
+ # ──────────────────────────────────────────────
129
+ # Helpers
130
+ # ──────────────────────────────────────────────
131
+
132
+ def _check_access(alert: AlertDocument, user: UserDocument):
133
+ if alert.parent_id != str(user.id) and alert.child_id != str(user.id):
134
+ raise HTTPException(status_code=403, detail="Cannot access this alert")
135
+
136
+
137
+ def _fmt(a: AlertDocument) -> dict:
138
+ return {
139
+ "id": str(a.id),
140
+ "child_id": a.child_id,
141
+ "parent_id": a.parent_id,
142
+ "title": a.title,
143
+ "message": a.message,
144
+ "guidance": a.guidance,
145
+ "severity": a.severity.value,
146
+ "categories": a.categories,
147
+ "status": a.status.value,
148
+ "created_at": a.created_at.isoformat(),
149
+ "acknowledged_at": a.acknowledged_at.isoformat() if a.acknowledged_at else None,
150
+ "resolved_at": a.resolved_at.isoformat() if a.resolved_at else None,
151
+ }
app/api/v1/analyze.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/api/v1/analyze.py
2
+ # Core analysis endpoints — text, image, video
3
+
4
+ import time
5
+ import uuid
6
+ from dataclasses import asdict
7
+ from fastapi import APIRouter, UploadFile, File, Form, HTTPException
8
+
9
+ from app.api.schemas.requests import TextAnalysisRequest
10
+ from app.api.schemas.responses import (
11
+ AnalysisResponse,
12
+ DecisionDetail,
13
+ RiskDetail,
14
+ FilterDetail,
15
+ DeepAnalysisDetail,
16
+ ErrorResponse,
17
+ )
18
+ from app.services.mongo_service import mongo_service
19
+ from app.services.redis_service import redis_service
20
+ from app.pipeline.workflow import get_workflow, PipelineState
21
+ from app.observability.logging import get_logger
22
+ from app.config import get_settings
23
+
24
+ logger = get_logger(__name__)
25
+ settings = get_settings()
26
+
27
+ router = APIRouter(tags=["Analysis"])
28
+
29
+
30
+ # ──────────────────────────────────────────────
31
+ # Helper: Convert pipeline state to API response
32
+ # ──────────────────────────────────────────────
33
+
34
+ def _build_response(
35
+ state: dict,
36
+ input_type: str,
37
+ request_id: str,
38
+ start_time: float,
39
+ cached: bool = False,
40
+ ) -> AnalysisResponse:
41
+ """Convert pipeline output state into the unified AnalysisResponse."""
42
+
43
+ risk = state.get("risk_score")
44
+ decision = state.get("decision")
45
+ filter_result = state.get("filter_result")
46
+ deep_result = state.get("deep_result")
47
+
48
+ # Build nested models
49
+ decision_detail = DecisionDetail(
50
+ action=decision.action if decision else "WARNING",
51
+ reason=decision.reason if decision else "Pipeline incomplete",
52
+ severity=decision.severity if decision else "medium",
53
+ should_alert_parent=decision.should_alert_parent if decision else False,
54
+ escalation_notes=decision.escalation_notes if decision else None,
55
+ )
56
+
57
+ risk_detail = RiskDetail(
58
+ score=risk.score if risk else 0.0,
59
+ level=risk.level if risk else "LOW",
60
+ components=risk.components if risk else {},
61
+ repeat_offender=risk.repeat_offender if risk else False,
62
+ )
63
+
64
+ filter_detail = FilterDetail(
65
+ is_flagged=filter_result.is_flagged if filter_result else False,
66
+ scores=filter_result.scores if filter_result else {},
67
+ max_score=filter_result.max_score if filter_result else 0.0,
68
+ max_label=filter_result.max_label if filter_result else "",
69
+ categories=filter_result.categories if filter_result else [],
70
+ )
71
+
72
+ deep_analysis = None
73
+ if deep_result:
74
+ deep_analysis = DeepAnalysisDetail(
75
+ is_confirmed=deep_result.is_confirmed,
76
+ severity=deep_result.severity,
77
+ reasoning=deep_result.reasoning,
78
+ categories=deep_result.categories,
79
+ recommended_action=deep_result.recommended_action,
80
+ confidence=deep_result.confidence,
81
+ clip_scores=deep_result.clip_scores,
82
+ )
83
+
84
+ processing_time = int((time.time() - start_time) * 1000)
85
+
86
+ return AnalysisResponse(
87
+ request_id=request_id,
88
+ input_type=input_type,
89
+ status=decision_detail.action,
90
+ risk_level=risk_detail.level,
91
+ risk_score=risk_detail.score,
92
+ categories=filter_detail.categories,
93
+ confidence=filter_detail.max_score,
94
+ decision=decision_detail,
95
+ risk_detail=risk_detail,
96
+ filter_detail=filter_detail,
97
+ deep_analysis=deep_analysis,
98
+ processing_time_ms=processing_time,
99
+ trace_id=None, # TODO: capture from LangSmith
100
+ cached=cached,
101
+ )
102
+
103
+
104
+ # ──────────────────────────────────────────────
105
+ # POST /analyze/text
106
+ # ──────────────────────────────────────────────
107
+
108
+ @router.post(
109
+ "/analyze/text",
110
+ response_model=AnalysisResponse,
111
+ responses={400: {"model": ErrorResponse}, 503: {"model": ErrorResponse}},
112
+ summary="Analyze text content for cyberbullying",
113
+ )
114
+ async def analyze_text(request: TextAnalysisRequest):
115
+ """
116
+ Full moderation pipeline for text content.
117
+
118
+ Pipeline: Preprocess → RoBERTa Filter → Risk Score → [Deep Analysis] → Decision
119
+ """
120
+ request_id = f"req_{uuid.uuid4().hex[:12]}"
121
+ start_time = time.time()
122
+
123
+ logger.info("analyze_text_started", request_id=request_id, text_length=len(request.text))
124
+
125
+ # Validate text length
126
+ if len(request.text) > settings.max_text_length:
127
+ raise HTTPException(status_code=400, detail=f"Text too long (max {settings.max_text_length} chars)")
128
+
129
+ # Check cache
130
+ cached_result = await redis_service.get_cached_result(request.text, "text")
131
+ if cached_result:
132
+ logger.info("analyze_text_cached", request_id=request_id)
133
+ cached_result["request_id"] = request_id
134
+ cached_result["cached"] = True
135
+ cached_result["processing_time_ms"] = int((time.time() - start_time) * 1000)
136
+ return AnalysisResponse(**cached_result)
137
+
138
+ # Run pipeline
139
+ try:
140
+ workflow = get_workflow()
141
+ initial_state: PipelineState = {
142
+ "input_type": "text",
143
+ "raw_content": request.text,
144
+ "user_id": request.user_id,
145
+ }
146
+
147
+ result_state = await workflow.ainvoke(initial_state)
148
+
149
+ # Check for pipeline errors
150
+ if result_state.get("error"):
151
+ raise HTTPException(status_code=500, detail=result_state["error"])
152
+
153
+ response = _build_response(result_state, "text", request_id, start_time)
154
+
155
+ # Cache the result
156
+ await redis_service.cache_result(
157
+ request.text, "text", response.model_dump()
158
+ )
159
+
160
+ # Log to MongoDB
161
+ await _log_moderation(request_id, "text", request.user_id, response)
162
+
163
+ return response
164
+
165
+ except HTTPException:
166
+ raise
167
+ except Exception as e:
168
+ logger.error("analyze_text_failed", request_id=request_id, error=str(e))
169
+ raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}")
170
+
171
+
172
+ # ──────────────────────────────────────────────
173
+ # POST /analyze/image
174
+ # ──────────────────────────────────────────────
175
+
176
+ @router.post(
177
+ "/analyze/image",
178
+ response_model=AnalysisResponse,
179
+ responses={400: {"model": ErrorResponse}, 503: {"model": ErrorResponse}},
180
+ summary="Analyze image content for harmful material",
181
+ )
182
+ async def analyze_image(
183
+ file: UploadFile = File(..., description="Image file to analyze"),
184
+ user_id: str | None = Form(None, description="Optional user ID"),
185
+ source_app: str | None = Form(None, description="Source application"),
186
+ ):
187
+ """
188
+ Full moderation pipeline for image content.
189
+
190
+ Pipeline: Preprocess → EfficientNet Filter → Risk Score → [CLIP + Gemini] → Decision
191
+ """
192
+ request_id = f"req_{uuid.uuid4().hex[:12]}"
193
+ start_time = time.time()
194
+
195
+ # Validate file type
196
+ if not file.content_type or not file.content_type.startswith("image/"):
197
+ raise HTTPException(status_code=400, detail="File must be an image (JPEG, PNG, etc.)")
198
+
199
+ logger.info("analyze_image_started", request_id=request_id, filename=file.filename)
200
+
201
+ # Validate file size
202
+ if file.size and file.size > settings.max_image_size:
203
+ raise HTTPException(status_code=400, detail=f"Image too large (max {settings.max_image_size / 1024 / 1024}MB)")
204
+
205
+ try:
206
+ image_bytes = await file.read()
207
+
208
+ workflow = get_workflow()
209
+ initial_state: PipelineState = {
210
+ "input_type": "image",
211
+ "raw_content": image_bytes,
212
+ "user_id": user_id,
213
+ }
214
+
215
+ result_state = await workflow.ainvoke(initial_state)
216
+
217
+ if result_state.get("error"):
218
+ raise HTTPException(status_code=500, detail=result_state["error"])
219
+
220
+ response = _build_response(result_state, "image", request_id, start_time)
221
+
222
+ # Log to MongoDB
223
+ await _log_moderation(request_id, "image", user_id, response)
224
+
225
+ return response
226
+
227
+ except HTTPException:
228
+ raise
229
+ except Exception as e:
230
+ logger.error("analyze_image_failed", request_id=request_id, error=str(e))
231
+ raise HTTPException(status_code=500, detail=f"Image analysis failed: {str(e)}")
232
+
233
+
234
+ # ──────────────────────────────────────────────
235
+ # POST /analyze/video
236
+ # ──────────────────────────────────────────────
237
+
238
+ @router.post(
239
+ "/analyze/video",
240
+ response_model=AnalysisResponse,
241
+ responses={400: {"model": ErrorResponse}, 503: {"model": ErrorResponse}},
242
+ summary="Analyze video content for harmful material",
243
+ )
244
+ async def analyze_video(
245
+ file: UploadFile = File(..., description="Video file to analyze"),
246
+ user_id: str | None = Form(None, description="Optional user ID"),
247
+ source_app: str | None = Form(None, description="Source application"),
248
+ ):
249
+ """
250
+ Full moderation pipeline for video content.
251
+
252
+ Pipeline: Extract frames → Per-frame EfficientNet → Aggregate risk → [Deep Analysis] → Decision
253
+ """
254
+ request_id = f"req_{uuid.uuid4().hex[:12]}"
255
+ start_time = time.time()
256
+
257
+ # Validate file type
258
+ if not file.content_type or not file.content_type.startswith("video/"):
259
+ raise HTTPException(status_code=400, detail="File must be a video (MP4, AVI, etc.)")
260
+
261
+ logger.info("analyze_video_started", request_id=request_id, filename=file.filename)
262
+
263
+ # Validate file size
264
+ if file.size and file.size > settings.max_video_size:
265
+ raise HTTPException(status_code=400, detail=f"Video too large (max {settings.max_video_size / 1024 / 1024}MB)")
266
+
267
+ try:
268
+ video_bytes = await file.read()
269
+
270
+ workflow = get_workflow()
271
+ initial_state: PipelineState = {
272
+ "input_type": "video",
273
+ "raw_content": video_bytes,
274
+ "user_id": user_id,
275
+ }
276
+
277
+ result_state = await workflow.ainvoke(initial_state)
278
+
279
+ if result_state.get("error"):
280
+ raise HTTPException(status_code=500, detail=result_state["error"])
281
+
282
+ response = _build_response(result_state, "video", request_id, start_time)
283
+
284
+ # Log to MongoDB
285
+ await _log_moderation(request_id, "video", user_id, response)
286
+
287
+ return response
288
+
289
+ except HTTPException:
290
+ raise
291
+ except Exception as e:
292
+ logger.error("analyze_video_failed", request_id=request_id, error=str(e))
293
+ raise HTTPException(status_code=500, detail=f"Video analysis failed: {str(e)}")
294
+
295
+
296
+ # ──────────────────────────────────────────────
297
+ # Helper: Log moderation result to MongoDB
298
+ # ──────────────────────────────────────────────
299
+
300
+ async def _log_moderation(
301
+ request_id: str,
302
+ input_type: str,
303
+ user_id: str | None,
304
+ response: AnalysisResponse,
305
+ ) -> None:
306
+ """Log the moderation result and update user history."""
307
+ try:
308
+ log_entry = {
309
+ "request_id": request_id,
310
+ "input_type": input_type,
311
+ "user_id": user_id,
312
+ "status": response.status,
313
+ "risk_level": response.risk_level,
314
+ "risk_score": response.risk_score,
315
+ "categories": response.categories,
316
+ "processing_time_ms": response.processing_time_ms,
317
+ }
318
+ await mongo_service.log_moderation(log_entry)
319
+
320
+ # Update user history
321
+ if user_id:
322
+ await mongo_service.update_user_history(user_id, {
323
+ "risk_level": response.risk_level,
324
+ "categories": response.categories,
325
+ })
326
+ # Invalidate cached history
327
+ await redis_service.invalidate_user_history(user_id)
328
+
329
+ except Exception as e:
330
+ logger.warning("moderation_logging_failed", error=str(e))
app/api/v1/auth.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/api/v1/auth.py
2
+ # Auth endpoints: register (parent), create-child, login, refresh, logout
3
+
4
+ from datetime import datetime
5
+ from fastapi import APIRouter, HTTPException, status, Depends
6
+ from pydantic import BaseModel, EmailStr, field_validator
7
+ from app.db.models.user import UserDocument, UserRole
8
+ from app.core.security import hash_password, verify_password, create_access_token, create_refresh_token, decode_refresh_token
9
+ from app.core.dependencies import get_current_user
10
+ from app.observability.logging import get_logger
11
+
12
+ logger = get_logger(__name__)
13
+ router = APIRouter(prefix="/auth", tags=["Auth"])
14
+
15
+
16
+ # ──────────────────────────────────────────────
17
+ # Request / Response Schemas
18
+ # ──────────────────────────────────────────────
19
+
20
+ class RegisterParentRequest(BaseModel):
21
+ email: EmailStr
22
+ username: str
23
+ password: str
24
+ first_name: str
25
+ last_name: str
26
+ phone: str | None = None
27
+ consent_given: bool = True
28
+
29
+ @field_validator("username")
30
+ @classmethod
31
+ def username_length(cls, v: str) -> str:
32
+ if len(v) < 3 or len(v) > 30:
33
+ raise ValueError("Username must be 3–30 characters")
34
+ return v.strip()
35
+
36
+ @field_validator("password")
37
+ @classmethod
38
+ def password_length(cls, v: str) -> str:
39
+ if len(v) < 8:
40
+ raise ValueError("Password must be at least 8 characters")
41
+ return v
42
+
43
+
44
+ class CreateChildRequest(BaseModel):
45
+ username: str
46
+ password: str
47
+ first_name: str
48
+ last_name: str
49
+
50
+
51
+ class LoginRequest(BaseModel):
52
+ login: str # email or username
53
+ password: str
54
+
55
+
56
+ class RefreshRequest(BaseModel):
57
+ refresh_token: str
58
+
59
+
60
+ class TokenResponse(BaseModel):
61
+ access_token: str
62
+ refresh_token: str
63
+ token_type: str = "bearer"
64
+ expires_in: int # seconds
65
+
66
+
67
+ def _token_response(user: UserDocument) -> dict:
68
+ access = create_access_token(str(user.id), user.role.value)
69
+ refresh = create_refresh_token(str(user.id), user.role.value)
70
+ return {
71
+ "access_token": access,
72
+ "refresh_token": refresh,
73
+ "token_type": "bearer",
74
+ "expires_in": 15 * 60,
75
+ "user": user.to_public(),
76
+ }
77
+
78
+
79
+ # ──────────────────────────────────────────────
80
+ # Endpoints
81
+ # ──────────────────────────────────────────────
82
+
83
+ @router.post("/register", status_code=status.HTTP_201_CREATED)
84
+ async def register_parent(body: RegisterParentRequest):
85
+ """Register a new parent account."""
86
+ # Email uniqueness
87
+ existing = await UserDocument.find_one(UserDocument.email == body.email)
88
+ if existing:
89
+ raise HTTPException(status_code=409, detail="Email already registered")
90
+
91
+ # Username uniqueness
92
+ existing_u = await UserDocument.find_one(UserDocument.username == body.username)
93
+ if existing_u:
94
+ raise HTTPException(status_code=409, detail="Username already taken")
95
+
96
+ user = UserDocument(
97
+ email=body.email,
98
+ username=body.username,
99
+ password_hash=hash_password(body.password),
100
+ role=UserRole.PARENT,
101
+ first_name=body.first_name,
102
+ last_name=body.last_name,
103
+ phone=body.phone,
104
+ consent_given=body.consent_given,
105
+ consent_date=datetime.utcnow() if body.consent_given else None,
106
+ )
107
+ await user.insert()
108
+
109
+ resp = _token_response(user)
110
+ # Save refresh token
111
+ user.refresh_tokens.append(resp["refresh_token"])
112
+ await user.save()
113
+
114
+ logger.info("parent_registered", user_id=str(user.id))
115
+ return {"success": True, **resp}
116
+
117
+
118
+ @router.post("/create-child", status_code=status.HTTP_201_CREATED)
119
+ async def create_child(
120
+ body: CreateChildRequest,
121
+ current_user: UserDocument = Depends(get_current_user),
122
+ ):
123
+ """Parent creates a child account linked to their account."""
124
+ if current_user.role != UserRole.PARENT:
125
+ raise HTTPException(status_code=403, detail="Only parents can create child accounts")
126
+
127
+ existing_u = await UserDocument.find_one(UserDocument.username == body.username)
128
+ if existing_u:
129
+ raise HTTPException(status_code=409, detail="Username already taken")
130
+
131
+ child = UserDocument(
132
+ username=body.username,
133
+ password_hash=hash_password(body.password),
134
+ role=UserRole.CHILD,
135
+ first_name=body.first_name,
136
+ last_name=body.last_name,
137
+ parent_id=str(current_user.id),
138
+ parental_consent=True,
139
+ consent_given=True,
140
+ )
141
+ await child.insert()
142
+
143
+ # Link child to parent
144
+ current_user.children.append(str(child.id))
145
+ await current_user.save()
146
+
147
+ logger.info("child_created", child_id=str(child.id), parent_id=str(current_user.id))
148
+ return {"success": True, "user": child.to_public()}
149
+
150
+
151
+ @router.post("/login")
152
+ async def login(body: LoginRequest):
153
+ """Login with email or username + password."""
154
+ login_val = body.login.strip().lower()
155
+
156
+ # Try email first, then username
157
+ if "@" in login_val:
158
+ user = await UserDocument.find_one(UserDocument.email == login_val)
159
+ else:
160
+ user = await UserDocument.find_one(UserDocument.username == login_val)
161
+
162
+ if not user or not verify_password(body.password, user.password_hash):
163
+ raise HTTPException(status_code=401, detail="Invalid credentials")
164
+
165
+ if not user.is_active:
166
+ raise HTTPException(status_code=403, detail="Account is deactivated")
167
+
168
+ resp = _token_response(user)
169
+ user.refresh_tokens.append(resp["refresh_token"])
170
+ user.last_login_at = datetime.utcnow()
171
+ await user.save()
172
+
173
+ logger.info("user_logged_in", user_id=str(user.id))
174
+ return {"success": True, **resp}
175
+
176
+
177
+ @router.post("/refresh")
178
+ async def refresh_token(body: RefreshRequest):
179
+ """Exchange a valid refresh token for a new token pair."""
180
+ payload = decode_refresh_token(body.refresh_token)
181
+ if not payload:
182
+ raise HTTPException(status_code=401, detail="Invalid or expired refresh token")
183
+
184
+ user = await UserDocument.get(payload["sub"])
185
+ if not user or not user.is_active:
186
+ raise HTTPException(status_code=401, detail="User not found")
187
+
188
+ if body.refresh_token not in user.refresh_tokens:
189
+ raise HTTPException(status_code=401, detail="Refresh token revoked")
190
+
191
+ # Rotate: remove old, add new
192
+ user.refresh_tokens.remove(body.refresh_token)
193
+ resp = _token_response(user)
194
+ user.refresh_tokens.append(resp["refresh_token"])
195
+ await user.save()
196
+
197
+ return {"success": True, **resp}
198
+
199
+
200
+ @router.post("/logout")
201
+ async def logout(
202
+ body: RefreshRequest | None = None,
203
+ current_user: UserDocument = Depends(get_current_user),
204
+ ):
205
+ """Revoke refresh token(s). Omit body to logout all devices."""
206
+ if body and body.refresh_token in current_user.refresh_tokens:
207
+ current_user.refresh_tokens.remove(body.refresh_token)
208
+ else:
209
+ current_user.refresh_tokens.clear()
210
+ await current_user.save()
211
+ return {"success": True, "message": "Logged out"}
212
+
213
+
214
+ @router.get("/me")
215
+ async def get_me(current_user: UserDocument = Depends(get_current_user)):
216
+ """Return current authenticated user's profile."""
217
+ return {"success": True, "user": current_user.to_public()}
app/api/v1/health.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/api/v1/health.py
2
+ # Health check endpoints
3
+
4
+ import time
5
+ from fastapi import APIRouter
6
+ from app.api.schemas.responses import HealthResponse
7
+ from app.models.model_registry import model_registry
8
+ from app.services.redis_service import redis_service
9
+ from app.services.mongo_service import mongo_service
10
+ from app.services.gemini_service import gemini_service
11
+
12
+ router = APIRouter(tags=["Health"])
13
+
14
+ _startup_time = time.time()
15
+
16
+
17
+ @router.get("/health", response_model=HealthResponse)
18
+ async def health_check():
19
+ """
20
+ Full health check — reports status of all models and services.
21
+ """
22
+ model_status = model_registry.get_status()
23
+ all_models_ok = all(model_status.values())
24
+
25
+ service_status = {
26
+ "mongodb": mongo_service.is_connected,
27
+ "redis": redis_service.is_connected,
28
+ "gemini": gemini_service.is_initialized,
29
+ }
30
+
31
+ overall = "healthy" if all_models_ok else "degraded"
32
+
33
+ return HealthResponse(
34
+ status=overall,
35
+ version="4.0.0",
36
+ models=model_status,
37
+ services=service_status,
38
+ uptime_seconds=round(time.time() - _startup_time, 1),
39
+ )
40
+
41
+
42
+ @router.get("/health/models")
43
+ async def model_health():
44
+ """Detailed model status."""
45
+ return model_registry.get_status()
46
+
47
+
48
+ @router.get("/health/ping")
49
+ async def ping():
50
+ """Lightweight liveness probe."""
51
+ return {"status": "ok"}
app/api/v1/history.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/api/v1/history.py
2
+ # Moderation history endpoints
3
+
4
+ from fastapi import APIRouter, HTTPException, Query
5
+ from app.api.schemas.responses import HistoryResponse
6
+ from app.services.mongo_service import mongo_service
7
+ from app.observability.logging import get_logger
8
+
9
+ logger = get_logger(__name__)
10
+
11
+ router = APIRouter(tags=["History"])
12
+
13
+
14
+ @router.get(
15
+ "/history/{user_id}",
16
+ response_model=HistoryResponse,
17
+ summary="Get moderation history for a user",
18
+ )
19
+ async def get_user_history(
20
+ user_id: str,
21
+ limit: int = Query(20, ge=1, le=100, description="Max results"),
22
+ skip: int = Query(0, ge=0, description="Results to skip"),
23
+ ):
24
+ """
25
+ Retrieve moderation history for a specific user.
26
+
27
+ Returns past moderation decisions with timestamps,
28
+ risk scores, and categories.
29
+ """
30
+ if not mongo_service.is_connected:
31
+ raise HTTPException(
32
+ status_code=503,
33
+ detail="MongoDB not available — history querying disabled",
34
+ )
35
+
36
+ results = await mongo_service.get_moderation_history(user_id, limit=limit, skip=skip)
37
+
38
+ return HistoryResponse(
39
+ user_id=user_id,
40
+ total=len(results),
41
+ results=results,
42
+ )
43
+
44
+
45
+ @router.get(
46
+ "/history/{user_id}/summary",
47
+ summary="Get aggregated user stats",
48
+ )
49
+ async def get_user_summary(user_id: str):
50
+ """
51
+ Get aggregated moderation statistics for a user.
52
+
53
+ Includes total scans, violations, warnings, and category breakdown.
54
+ """
55
+ if not mongo_service.is_connected:
56
+ raise HTTPException(status_code=503, detail="MongoDB not available")
57
+
58
+ history = await mongo_service.get_user_history(user_id)
59
+ if not history:
60
+ return {
61
+ "user_id": user_id,
62
+ "total_scans": 0,
63
+ "total_violations": 0,
64
+ "total_warnings": 0,
65
+ "violation_categories": {},
66
+ }
67
+
68
+ return history
app/api/v1/scan.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/api/v1/scan.py
2
+ # Authenticated scan endpoints — wraps the AI pipeline and persists results
3
+
4
+ import time
5
+ import uuid
6
+ from fastapi import APIRouter, UploadFile, File, Form, HTTPException, Depends
7
+ from pydantic import BaseModel
8
+ from app.core.dependencies import get_current_user, require_role
9
+ from app.db.models.user import UserDocument, UserRole
10
+ from app.db.models.scan_result import ScanResultDocument, RiskLevel
11
+ from app.db.models.alert import AlertDocument, AlertSeverity, AlertStatus
12
+ from app.pipeline.workflow import get_workflow, PipelineState
13
+ from app.observability.logging import get_logger
14
+
15
+ logger = get_logger(__name__)
16
+ router = APIRouter(prefix="/scan", tags=["Scan"])
17
+
18
+ _SEVERITY_MAP = {
19
+ "LOW": AlertSeverity.LOW,
20
+ "MEDIUM": AlertSeverity.MEDIUM,
21
+ "HIGH": AlertSeverity.HIGH,
22
+ "CRITICAL": AlertSeverity.CRITICAL,
23
+ }
24
+
25
+
26
+ async def _run_pipeline_and_persist(
27
+ input_type: str,
28
+ raw_content,
29
+ user: UserDocument,
30
+ ) -> dict:
31
+ """Run the AI pipeline, persist ScanResult and optional Alert, return result dict."""
32
+ start = time.time()
33
+
34
+ workflow = get_workflow()
35
+ state: PipelineState = {
36
+ "input_type": input_type,
37
+ "raw_content": raw_content,
38
+ "user_id": str(user.id),
39
+ }
40
+ result = await workflow.ainvoke(state)
41
+
42
+ if result.get("error"):
43
+ raise HTTPException(status_code=500, detail=result["error"])
44
+
45
+ risk = result.get("risk_score")
46
+ decision = result.get("decision")
47
+ filter_result = result.get("filter_result")
48
+ deep_result = result.get("deep_result")
49
+
50
+ risk_level_str = risk.level if risk else "LOW"
51
+ risk_score = risk.score if risk else 0.0
52
+ action = decision.action if decision else "ALLOWED"
53
+ categories = filter_result.categories if filter_result else []
54
+ is_flagged = filter_result.is_flagged if filter_result else False
55
+ reasoning = deep_result.reasoning if deep_result else None
56
+ processing_ms = int((time.time() - start) * 1000)
57
+
58
+ # Persist ScanResult
59
+ scan_doc = ScanResultDocument(
60
+ user_id=str(user.id),
61
+ input_type=input_type,
62
+ content_preview=(raw_content[:200] if isinstance(raw_content, str) else None),
63
+ risk_level=RiskLevel(risk_level_str),
64
+ risk_score=risk_score,
65
+ categories=categories,
66
+ is_flagged=is_flagged,
67
+ action=action,
68
+ reasoning=reasoning,
69
+ processing_time_ms=processing_ms,
70
+ deep_analysis_used=deep_result is not None,
71
+ )
72
+ await scan_doc.insert()
73
+
74
+ # Create Alert if child account and content is flagged
75
+ if is_flagged and user.role == UserRole.CHILD and user.parent_id:
76
+ severity = _SEVERITY_MAP.get(risk_level_str, AlertSeverity.LOW)
77
+ alert = AlertDocument(
78
+ child_id=str(user.id),
79
+ parent_id=user.parent_id,
80
+ scan_result_id=str(scan_doc.id),
81
+ severity=severity,
82
+ categories=categories,
83
+ severity_score=risk_score,
84
+ )
85
+ alert.generate_content()
86
+ await alert.insert()
87
+ logger.info("alert_created", alert_id=str(alert.id), child_id=str(user.id))
88
+
89
+ return {
90
+ "request_id": f"req_{uuid.uuid4().hex[:12]}",
91
+ "input_type": input_type,
92
+ "scan_id": str(scan_doc.id),
93
+ "status": action,
94
+ "risk_level": risk_level_str,
95
+ "risk_score": risk_score,
96
+ "categories": categories,
97
+ "is_flagged": is_flagged,
98
+ "reasoning": reasoning,
99
+ "processing_time_ms": processing_ms,
100
+ }
101
+
102
+
103
+ # ──────────────────────────────────────────────
104
+ # Endpoints
105
+ # ──────────────────────────────────────────────
106
+
107
+ class ScanTextRequest(BaseModel):
108
+ text: str
109
+
110
+
111
+ @router.post("/text")
112
+ async def scan_text(
113
+ body: ScanTextRequest,
114
+ user: UserDocument = Depends(get_current_user),
115
+ ):
116
+ """Scan text content. Requires authentication."""
117
+ if not body.text.strip():
118
+ raise HTTPException(status_code=400, detail="Text cannot be empty")
119
+ result = await _run_pipeline_and_persist("text", body.text, user)
120
+ return {"success": True, **result}
121
+
122
+
123
+ @router.post("/image")
124
+ async def scan_image(
125
+ file: UploadFile = File(...),
126
+ user: UserDocument = Depends(get_current_user),
127
+ ):
128
+ """Scan image content. Requires authentication."""
129
+ if not file.content_type or not file.content_type.startswith("image/"):
130
+ raise HTTPException(status_code=400, detail="File must be an image")
131
+
132
+ image_bytes = await file.read()
133
+ result = await _run_pipeline_and_persist("image", image_bytes, user)
134
+ return {"success": True, **result}
135
+
136
+
137
+ @router.get("/history")
138
+ async def get_scan_history(
139
+ page: int = 1,
140
+ limit: int = 20,
141
+ user: UserDocument = Depends(get_current_user),
142
+ ):
143
+ """Get scan history for the current user."""
144
+ skip = (page - 1) * limit
145
+ scans = await ScanResultDocument.find(
146
+ ScanResultDocument.user_id == str(user.id)
147
+ ).sort(-ScanResultDocument.created_at).skip(skip).limit(limit).to_list()
148
+
149
+ return {
150
+ "success": True,
151
+ "page": page,
152
+ "limit": limit,
153
+ "results": [
154
+ {
155
+ "id": str(s.id),
156
+ "input_type": s.input_type,
157
+ "risk_level": s.risk_level.value,
158
+ "risk_score": s.risk_score,
159
+ "action": s.action,
160
+ "is_flagged": s.is_flagged,
161
+ "categories": s.categories,
162
+ "created_at": s.created_at.isoformat(),
163
+ }
164
+ for s in scans
165
+ ],
166
+ }
app/api/v1/users.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/api/v1/users.py
2
+ # User profile and parent-child management endpoints
3
+
4
+ from fastapi import APIRouter, HTTPException, Depends
5
+ from pydantic import BaseModel
6
+ from typing import Optional
7
+ from app.core.dependencies import get_current_user
8
+ from app.db.models.user import UserDocument, UserRole
9
+
10
+ router = APIRouter(prefix="/users", tags=["Users"])
11
+
12
+
13
+ @router.get("/me")
14
+ async def get_profile(user: UserDocument = Depends(get_current_user)):
15
+ return {"success": True, "user": user.to_public()}
16
+
17
+
18
+ class UpdateProfileRequest(BaseModel):
19
+ first_name: Optional[str] = None
20
+ last_name: Optional[str] = None
21
+ phone: Optional[str] = None
22
+
23
+
24
+ @router.patch("/me")
25
+ async def update_profile(
26
+ body: UpdateProfileRequest,
27
+ user: UserDocument = Depends(get_current_user),
28
+ ):
29
+ if body.first_name:
30
+ user.first_name = body.first_name
31
+ if body.last_name:
32
+ user.last_name = body.last_name
33
+ if body.phone is not None:
34
+ user.phone = body.phone
35
+ await user.save()
36
+ return {"success": True, "user": user.to_public()}
37
+
38
+
39
+ @router.get("/children")
40
+ async def list_children(user: UserDocument = Depends(get_current_user)):
41
+ """Parent: list all linked child accounts."""
42
+ if user.role != UserRole.PARENT:
43
+ raise HTTPException(status_code=403, detail="Only parents can list children")
44
+ children = await UserDocument.find(
45
+ UserDocument.parent_id == str(user.id)
46
+ ).to_list()
47
+ return {"success": True, "children": [c.to_public() for c in children]}
48
+
49
+
50
+ @router.get("/children/{child_id}")
51
+ async def get_child(
52
+ child_id: str,
53
+ user: UserDocument = Depends(get_current_user),
54
+ ):
55
+ """Parent: get a specific child's profile."""
56
+ if user.role != UserRole.PARENT:
57
+ raise HTTPException(status_code=403, detail="Only parents can view child profiles")
58
+ child = await UserDocument.get(child_id)
59
+ if not child or child.parent_id != str(user.id):
60
+ raise HTTPException(status_code=404, detail="Child not found")
61
+ return {"success": True, "child": child.to_public()}
app/config.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/config.py
2
+ # Centralized configuration via Pydantic Settings
3
+
4
+ from pathlib import Path
5
+ from functools import lru_cache
6
+ from pydantic_settings import BaseSettings
7
+ from pydantic import Field
8
+
9
+
10
+ class Settings(BaseSettings):
11
+ """Application settings loaded from environment variables."""
12
+
13
+ # --- Server ---
14
+ host: str = "0.0.0.0"
15
+ port: int = 7860
16
+ env: str = "production"
17
+ log_level: str = "INFO"
18
+
19
+ # --- JWT ---
20
+ jwt_access_secret: str = Field(default="change-me-in-production-at-least-32-chars!!")
21
+ jwt_refresh_secret: str = Field(default="change-me-in-production-at-least-32-chars!!")
22
+ jwt_access_expires_minutes: int = 15
23
+ jwt_refresh_expires_days: int = 7
24
+
25
+ # --- Security ---
26
+ bcrypt_rounds: int = 12
27
+ cors_origins: str = "*"
28
+
29
+ # --- MongoDB ---
30
+ mongodb_uri: str = Field(default="")
31
+ mongodb_db_name: str = "hubble"
32
+
33
+ # --- Redis ---
34
+ redis_url: str = Field(default="")
35
+ redis_cache_ttl: int = 300 # seconds
36
+
37
+ # --- Gemini ---
38
+ gemini_api_keys: str = "" # comma-separated
39
+ gemini_model: str = "gemini-2.5-flash"
40
+
41
+ # --- LangSmith ---
42
+ langsmith_api_key: str = ""
43
+ langsmith_project: str = "hubble-moderation"
44
+ langsmith_tracing_v2: bool = True
45
+
46
+ # --- Models ---
47
+ model_cache_dir: str = "/tmp/model_cache"
48
+ onnx_enabled: bool = False
49
+ text_model_name: str = "unitary/toxic-bert"
50
+ image_model_name: str = "google/efficientnet-b0"
51
+ clip_model_name: str = "openai/clip-vit-base-patch32"
52
+
53
+ # --- Risk Thresholds ---
54
+ risk_low_max: int = 30
55
+ risk_medium_max: int = 65
56
+
57
+ # --- Content Limits ---
58
+ max_text_length: int = 10000
59
+ max_image_size: int = 10 * 1024 * 1024 # 10MB
60
+ max_video_size: int = 50 * 1024 * 1024 # 50MB
61
+
62
+ # --- Video Processing ---
63
+ video_max_frames: int = 10
64
+ video_fps_sample: int = 1
65
+
66
+ model_config = {
67
+ "env_file": ".env",
68
+ "env_file_encoding": "utf-8",
69
+ "extra": "ignore",
70
+ }
71
+
72
+ @property
73
+ def gemini_keys_list(self) -> list[str]:
74
+ """Parse comma-separated Gemini API keys."""
75
+ if not self.gemini_api_keys:
76
+ return []
77
+ return [k.strip() for k in self.gemini_api_keys.split(",") if k.strip()]
78
+
79
+ @property
80
+ def cors_origins_list(self) -> list[str]:
81
+ """Parse comma-separated CORS origins."""
82
+ return [o.strip() for o in self.cors_origins.split(",") if o.strip()]
83
+
84
+ @property
85
+ def model_cache_path(self) -> Path:
86
+ """Resolved path for model cache directory."""
87
+ path = Path(self.model_cache_dir)
88
+ path.mkdir(parents=True, exist_ok=True)
89
+ return path
90
+
91
+ @property
92
+ def is_production(self) -> bool:
93
+ return self.env == "production"
94
+
95
+
96
+ @lru_cache()
97
+ def get_settings() -> Settings:
98
+ """Cached settings singleton."""
99
+ return Settings()
app/core/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # app/core/__init__.py
app/core/dependencies.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/core/dependencies.py
2
+ # FastAPI dependency injection — auth guards, role checks, shared helpers
3
+
4
+ from fastapi import Depends, HTTPException, status
5
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
6
+ from app.core.security import decode_access_token
7
+ from app.db.models.user import UserDocument, UserRole
8
+
9
+ _bearer = HTTPBearer(auto_error=False)
10
+
11
+
12
+ async def get_current_user(
13
+ credentials: HTTPAuthorizationCredentials | None = Depends(_bearer),
14
+ ) -> UserDocument:
15
+ """Extract and validate JWT; return the authenticated UserDocument."""
16
+ if not credentials:
17
+ raise HTTPException(
18
+ status_code=status.HTTP_401_UNAUTHORIZED,
19
+ detail="Not authenticated",
20
+ headers={"WWW-Authenticate": "Bearer"},
21
+ )
22
+
23
+ payload = decode_access_token(credentials.credentials)
24
+ if not payload:
25
+ raise HTTPException(
26
+ status_code=status.HTTP_401_UNAUTHORIZED,
27
+ detail="Invalid or expired token",
28
+ )
29
+
30
+ user = await UserDocument.get(payload["sub"])
31
+ if not user or not user.is_active:
32
+ raise HTTPException(
33
+ status_code=status.HTTP_401_UNAUTHORIZED,
34
+ detail="User not found or deactivated",
35
+ )
36
+ return user
37
+
38
+
39
+ def require_role(*roles: UserRole):
40
+ """Factory that returns a dependency requiring one of the given roles."""
41
+ async def _check(user: UserDocument = Depends(get_current_user)) -> UserDocument:
42
+ if user.role not in roles:
43
+ raise HTTPException(
44
+ status_code=status.HTTP_403_FORBIDDEN,
45
+ detail=f"Role '{user.role}' is not permitted for this action",
46
+ )
47
+ return user
48
+ return _check
app/core/security.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/core/security.py
2
+ # Password hashing (argon2-cffi) + JWT creation/verification (python-jose)
3
+
4
+ from datetime import datetime, timedelta, timezone
5
+ from typing import Optional
6
+ from argon2 import PasswordHasher
7
+ from argon2.exceptions import VerifyMismatchError, VerificationError, InvalidHashError
8
+ from jose import JWTError, jwt
9
+ from app.config import get_settings
10
+
11
+ settings = get_settings()
12
+
13
+ # ──────────────────────────────────────────────
14
+ # Password hashing — Argon2id (modern, OWASP recommended)
15
+ # ──────────────────────────────────────────────
16
+
17
+ _ph = PasswordHasher()
18
+
19
+
20
+ def hash_password(plain: str) -> str:
21
+ return _ph.hash(plain)
22
+
23
+
24
+ def verify_password(plain: str, hashed: str) -> bool:
25
+ try:
26
+ return _ph.verify(hashed, plain)
27
+ except (VerifyMismatchError, VerificationError, InvalidHashError):
28
+ return False
29
+
30
+
31
+ # ──────────────────────────────────────────────
32
+ # JWT — HS256 access + refresh tokens
33
+ # ──────────────────────────────────────────────
34
+
35
+ ALGORITHM = "HS256"
36
+
37
+
38
+ def _create_token(data: dict, secret: str, expires_delta: timedelta) -> str:
39
+ payload = data.copy()
40
+ payload["exp"] = datetime.now(timezone.utc) + expires_delta
41
+ return jwt.encode(payload, secret, algorithm=ALGORITHM)
42
+
43
+
44
+ def create_access_token(user_id: str, role: str) -> str:
45
+ return _create_token(
46
+ {"sub": user_id, "role": role, "type": "access"},
47
+ settings.jwt_access_secret,
48
+ timedelta(minutes=settings.jwt_access_expires_minutes),
49
+ )
50
+
51
+
52
+ def create_refresh_token(user_id: str, role: str) -> str:
53
+ return _create_token(
54
+ {"sub": user_id, "role": role, "type": "refresh"},
55
+ settings.jwt_refresh_secret,
56
+ timedelta(days=settings.jwt_refresh_expires_days),
57
+ )
58
+
59
+
60
+ def decode_access_token(token: str) -> Optional[dict]:
61
+ """Returns payload dict or None on failure."""
62
+ try:
63
+ payload = jwt.decode(token, settings.jwt_access_secret, algorithms=[ALGORITHM])
64
+ if payload.get("type") != "access":
65
+ return None
66
+ return payload
67
+ except JWTError:
68
+ return None
69
+
70
+
71
+ def decode_refresh_token(token: str) -> Optional[dict]:
72
+ """Returns payload dict or None on failure."""
73
+ try:
74
+ payload = jwt.decode(token, settings.jwt_refresh_secret, algorithms=[ALGORITHM])
75
+ if payload.get("type") != "refresh":
76
+ return None
77
+ return payload
78
+ except JWTError:
79
+ return None
app/db/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # app/db/__init__.py
app/db/connection.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/db/connection.py
2
+ # Beanie ODM initialization — reuses the existing mongo_service Motor client
3
+
4
+ from beanie import init_beanie
5
+ from app.observability.logging import get_logger
6
+
7
+ logger = get_logger(__name__)
8
+
9
+
10
+ async def connect_db() -> None:
11
+ """Initialize Beanie ODM using the already-connected mongo_service client."""
12
+ from app.services.mongo_service import mongo_service
13
+ from app.db.models.user import UserDocument
14
+ from app.db.models.scan_result import ScanResultDocument
15
+ from app.db.models.alert import AlertDocument
16
+
17
+ if not mongo_service._connected or mongo_service.client is None:
18
+ logger.error("beanie_init_skipped", reason="mongo_service not connected")
19
+ return
20
+
21
+ try:
22
+ db = mongo_service.client[mongo_service.settings.mongodb_db_name]
23
+ await init_beanie(
24
+ database=db,
25
+ document_models=[UserDocument, ScanResultDocument, AlertDocument],
26
+ allow_index_dropping=False,
27
+ )
28
+ logger.info("beanie_initialized", db=mongo_service.settings.mongodb_db_name)
29
+ except Exception as e:
30
+ logger.error("beanie_init_failed", error=str(e))
31
+ raise
32
+
33
+
34
+ async def close_db() -> None:
35
+ """No-op — Motor client is closed by mongo_service.disconnect()."""
36
+ pass
app/db/models/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/db/models/__init__.py
2
+ from app.db.models.user import UserDocument, UserRole
3
+ from app.db.models.scan_result import ScanResultDocument, RiskLevel
4
+ from app.db.models.alert import AlertDocument, AlertSeverity, AlertStatus
5
+
6
+ __all__ = [
7
+ "UserDocument", "UserRole",
8
+ "ScanResultDocument", "RiskLevel",
9
+ "AlertDocument", "AlertSeverity", "AlertStatus",
10
+ ]
app/db/models/alert.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/db/models/alert.py
2
+ # Alert Beanie document
3
+
4
+ from __future__ import annotations
5
+ from datetime import datetime
6
+ from enum import Enum
7
+ from typing import Optional
8
+ from beanie import Document
9
+ from pydantic import Field
10
+
11
+
12
+ class AlertSeverity(str, Enum):
13
+ LOW = "low"
14
+ MEDIUM = "medium"
15
+ HIGH = "high"
16
+ CRITICAL = "critical"
17
+
18
+
19
+ class AlertStatus(str, Enum):
20
+ PENDING = "pending"
21
+ ACKNOWLEDGED = "acknowledged"
22
+ RESOLVED = "resolved"
23
+
24
+
25
+ class AlertDocument(Document):
26
+ """Parent alert generated when flagged content is detected for a child."""
27
+
28
+ child_id: str
29
+ parent_id: str
30
+ scan_result_id: str
31
+
32
+ title: str
33
+ message: str
34
+ guidance: str = ""
35
+
36
+ severity: AlertSeverity = AlertSeverity.LOW
37
+ categories: list[str] = Field(default_factory=list)
38
+ severity_score: float = 0.0
39
+
40
+ status: AlertStatus = AlertStatus.PENDING
41
+ parent_notified: bool = False
42
+ child_notified: bool = False
43
+
44
+ acknowledged_at: Optional[datetime] = None
45
+ resolved_at: Optional[datetime] = None
46
+ resolved_by: Optional[str] = None
47
+ resolution_notes: Optional[str] = None
48
+
49
+ created_at: datetime = Field(default_factory=datetime.utcnow)
50
+ updated_at: datetime = Field(default_factory=datetime.utcnow)
51
+
52
+ class Settings:
53
+ name = "alerts"
54
+
55
+ def generate_content(self) -> None:
56
+ """Populate title, message, guidance based on severity and categories."""
57
+ category_text = ", ".join(self.categories) if self.categories else "potentially harmful content"
58
+ score = int(self.severity_score * 100)
59
+
60
+ if self.severity == AlertSeverity.LOW:
61
+ self.title = "Mild Concern Detected"
62
+ self.message = f"Content flagged for: {category_text}. Score: {score}/100."
63
+ self.guidance = "This content contains some concerning elements. Consider talking about online safety."
64
+ elif self.severity == AlertSeverity.MEDIUM:
65
+ self.title = "Moderate Concern Detected"
66
+ self.message = f"Concerning content detected: {category_text}. Score: {score}/100."
67
+ self.guidance = "This content shows signs of potential cyberbullying. We recommend discussing this with your child."
68
+ elif self.severity == AlertSeverity.HIGH:
69
+ self.title = "⚠️ High Severity Alert"
70
+ self.message = f"Serious concern detected: {category_text}. Score: {score}/100."
71
+ self.guidance = "Immediate discussion with your child is recommended. Consider reaching out to school counselors."
72
+ else: # CRITICAL
73
+ self.title = "🚨 CRITICAL ALERT - Immediate Action Required"
74
+ self.message = f"Critical content detected: {category_text}. Score: {score}/100."
75
+ self.guidance = "If there are threats of violence or self-harm, please contact emergency services immediately."
app/db/models/scan_result.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/db/models/scan_result.py
2
+ # ScanResult Beanie document
3
+
4
+ from __future__ import annotations
5
+ from datetime import datetime
6
+ from enum import Enum
7
+ from typing import Optional
8
+ from beanie import Document
9
+ from pydantic import Field
10
+
11
+
12
+ class RiskLevel(str, Enum):
13
+ LOW = "LOW"
14
+ MEDIUM = "MEDIUM"
15
+ HIGH = "HIGH"
16
+ CRITICAL = "CRITICAL"
17
+
18
+
19
+ class ScanResultDocument(Document):
20
+ """Persisted result of a content scan through the AI pipeline."""
21
+
22
+ user_id: str
23
+ input_type: str # text | image | video
24
+ content_preview: Optional[str] = None # first 200 chars of text, or filename
25
+
26
+ # Risk
27
+ risk_level: RiskLevel = RiskLevel.LOW
28
+ risk_score: float = 0.0
29
+ categories: list[str] = Field(default_factory=list)
30
+ is_flagged: bool = False
31
+
32
+ # Decision
33
+ action: str = "ALLOWED" # ALLOWED | WARNED | BLOCKED | ESCALATED
34
+ reasoning: Optional[str] = None
35
+
36
+ # Pipeline metadata
37
+ processing_time_ms: int = 0
38
+ deep_analysis_used: bool = False
39
+
40
+ created_at: datetime = Field(default_factory=datetime.utcnow)
41
+
42
+ class Settings:
43
+ name = "scan_results"
app/db/models/user.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/db/models/user.py
2
+ # Beanie User document — mirrors the Mongoose User model from Backend/
3
+
4
+ from __future__ import annotations
5
+ from datetime import datetime
6
+ from enum import Enum
7
+ from typing import Optional
8
+ from beanie import Document, Indexed
9
+ from pydantic import EmailStr, Field
10
+
11
+
12
+ class UserRole(str, Enum):
13
+ PARENT = "parent"
14
+ CHILD = "child"
15
+ ADMIN = "admin"
16
+
17
+
18
+ class UserDocument(Document):
19
+ """
20
+ User document stored in MongoDB.
21
+ Supports Parent → Child relationship for monitoring.
22
+ """
23
+
24
+ email: Optional[EmailStr] = None
25
+ username: Indexed(str, unique=True) # type: ignore[valid-type]
26
+ password_hash: str # bcrypt hash — never returned in API responses
27
+ role: UserRole = UserRole.PARENT
28
+ first_name: str
29
+ last_name: str
30
+ phone: Optional[str] = None
31
+ date_of_birth: Optional[datetime] = None
32
+
33
+ # Parent-Child relationship
34
+ parent_id: Optional[str] = None # set for child accounts
35
+ children: list[str] = Field(default_factory=list) # set for parent accounts
36
+
37
+ # Consent & Privacy
38
+ consent_given: bool = False
39
+ consent_date: Optional[datetime] = None
40
+ parental_consent: Optional[bool] = None # for child accounts
41
+
42
+ # Account status
43
+ is_active: bool = True
44
+ is_verified: bool = False
45
+ last_login_at: Optional[datetime] = None
46
+
47
+ # Security — refresh tokens stored for rotation/revocation
48
+ refresh_tokens: list[str] = Field(default_factory=list)
49
+ password_changed_at: Optional[datetime] = None
50
+
51
+ created_at: datetime = Field(default_factory=datetime.utcnow)
52
+ updated_at: datetime = Field(default_factory=datetime.utcnow)
53
+
54
+ class Settings:
55
+ name = "users"
56
+ # Atlas already has the correct indexes from the Mongoose model.
57
+ # Beanie will use the field-level Indexed() annotations and won't
58
+ # try to re-create conflicting ones when allow_index_dropping=False.
59
+
60
+ def to_public(self) -> dict:
61
+ """Return safe user dict without sensitive fields."""
62
+ return {
63
+ "id": str(self.id),
64
+ "email": self.email,
65
+ "username": self.username,
66
+ "role": self.role.value,
67
+ "first_name": self.first_name,
68
+ "last_name": self.last_name,
69
+ "is_active": self.is_active,
70
+ "is_verified": self.is_verified,
71
+ "parent_id": self.parent_id,
72
+ "children": self.children,
73
+ "created_at": self.created_at.isoformat(),
74
+ }
app/dependencies.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/dependencies.py
2
+ # FastAPI dependency injection
3
+
4
+ from app.models.model_registry import model_registry
5
+ from app.services.redis_service import redis_service
6
+ from app.services.mongo_service import mongo_service
7
+
8
+
9
+ async def require_models():
10
+ """Dependency: ensure models are loaded."""
11
+ status = model_registry.get_status()
12
+ if not status.get("text_model") and not status.get("image_model"):
13
+ from fastapi import HTTPException
14
+ raise HTTPException(
15
+ status_code=503,
16
+ detail="AI models not loaded. Server is starting up.",
17
+ )
18
+ return model_registry
19
+
20
+
21
+ async def require_mongo():
22
+ """Dependency: ensure MongoDB is connected."""
23
+ if not mongo_service.is_connected:
24
+ from fastapi import HTTPException
25
+ raise HTTPException(
26
+ status_code=503,
27
+ detail="MongoDB not available.",
28
+ )
29
+ return mongo_service
app/main.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/main.py
2
+ # FastAPI application factory — entry point for the Hubble AI Engine
3
+
4
+ from contextlib import asynccontextmanager
5
+ from fastapi import FastAPI
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from fastapi.responses import JSONResponse
8
+
9
+ from app.config import get_settings
10
+ from app.observability.logging import setup_logging, get_logger
11
+ from app.observability.langsmith import setup_langsmith
12
+ from app.models.model_registry import model_registry
13
+ from app.services.redis_service import redis_service
14
+ from app.services.mongo_service import mongo_service
15
+ from app.services.gemini_service import gemini_service
16
+ from app.pipeline.workflow import get_workflow
17
+ from app.api.router import api_router
18
+ from app.db.connection import connect_db, close_db
19
+
20
+
21
+ @asynccontextmanager
22
+ async def lifespan(app: FastAPI):
23
+ """
24
+ Application lifespan manager.
25
+
26
+ Startup:
27
+ 1. Configure logging
28
+ 2. Setup LangSmith tracing
29
+ 3. Connect MongoDB & Redis
30
+ 4. Initialize Gemini service
31
+ 5. Load all ML models
32
+ 6. Compile LangGraph workflow
33
+
34
+ Shutdown:
35
+ 1. Disconnect MongoDB & Redis
36
+ """
37
+ logger = get_logger(__name__)
38
+ settings = get_settings()
39
+
40
+ # ── Startup ──
41
+ logger.info("=" * 60)
42
+ logger.info("[STARTUP] HUBBLE AI ENGINE — Starting up...")
43
+ logger.info("=" * 60)
44
+
45
+ # 1. LangSmith
46
+ langsmith_ok = setup_langsmith()
47
+ logger.info("langsmith", enabled=langsmith_ok)
48
+
49
+ # 2. MongoDB (legacy motor service + Beanie ODM)
50
+ logger.info("connecting_mongodb")
51
+ await mongo_service.connect()
52
+ await connect_db() # initializes Beanie document models
53
+
54
+ # 3. Redis
55
+ logger.info("connecting_redis")
56
+ await redis_service.connect()
57
+
58
+ # 4. Gemini
59
+ logger.info("initializing_gemini")
60
+ gemini_service.initialize()
61
+
62
+ # 5. ML Models
63
+ logger.info("loading_models")
64
+ model_status = await model_registry.load_all()
65
+ logger.info("models_loaded", status=model_status)
66
+
67
+ # 6. LangGraph Workflow
68
+ logger.info("compiling_workflow")
69
+ get_workflow()
70
+
71
+ logger.info("=" * 60)
72
+ logger.info("[READY] HUBBLE AI ENGINE — Ready!")
73
+ logger.info(f" Environment: {settings.env}")
74
+ logger.info(f" Port: {settings.port}")
75
+ logger.info(f" Docs: http://localhost:{settings.port}/docs")
76
+ logger.info("=" * 60)
77
+
78
+ yield # ── Application runs here ──
79
+
80
+ # ── Shutdown ──
81
+ logger.info("[SHUTDOWN] HUBBLE — Shutting down...")
82
+ await redis_service.disconnect()
83
+ await mongo_service.disconnect()
84
+ await close_db()
85
+ logger.info("Shutdown complete")
86
+
87
+
88
+ def create_app() -> FastAPI:
89
+ """Create and configure the FastAPI application."""
90
+ settings = get_settings()
91
+
92
+ # Configure logging first
93
+ setup_logging()
94
+
95
+ app = FastAPI(
96
+ title="Hubble AI Engine — Cyberbullying Detection API",
97
+ description=(
98
+ "Production-grade content moderation pipeline with layered AI analysis. "
99
+ "Supports text, image, and video inputs with risk-based routing."
100
+ ),
101
+ version="4.0.0",
102
+ docs_url="/docs",
103
+ redoc_url="/redoc",
104
+ lifespan=lifespan,
105
+ )
106
+
107
+ # CORS — use origins from config
108
+ app.add_middleware(
109
+ CORSMiddleware,
110
+ allow_origins=settings.cors_origins_list if settings.is_production else ["*"],
111
+ allow_credentials=True,
112
+ allow_methods=["*"],
113
+ allow_headers=["*"],
114
+ )
115
+
116
+ # Mount routes
117
+ app.include_router(api_router)
118
+
119
+ @app.get("/", include_in_schema=False)
120
+ async def root():
121
+ return JSONResponse({
122
+ "name": "Hubble Unified API",
123
+ "version": "5.0.0",
124
+ "docs": "/docs",
125
+ "health": "/health",
126
+ "endpoints": {
127
+ "auth": "POST /api/v1/auth/{register|login|refresh|logout}",
128
+ "users": "GET /api/v1/users/me",
129
+ "scan_text": "POST /api/v1/scan/text",
130
+ "scan_image": "POST /api/v1/scan/image",
131
+ "scan_history": "GET /api/v1/scan/history",
132
+ "alerts": "GET /api/v1/alerts",
133
+ "analyze_text": "POST /api/v1/analyze/text (raw, no auth)",
134
+ "analyze_image": "POST /api/v1/analyze/image (raw, no auth)",
135
+ },
136
+ })
137
+
138
+ return app
139
+
140
+
141
+ # Create the app instance
142
+ app = create_app()
143
+
144
+
145
+ if __name__ == "__main__":
146
+ import uvicorn
147
+
148
+ settings = get_settings()
149
+ uvicorn.run(
150
+ "app.main:app",
151
+ host=settings.host,
152
+ port=settings.port,
153
+ reload=not settings.is_production,
154
+ )
app/models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # app/models/__init__.py
2
+ """Model loading, ONNX optimization, and inference."""
app/models/clip_model.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/models/clip_model.py
2
+ # CLIP model for multimodal text-image alignment (deep analysis only)
3
+
4
+ from PIL import Image
5
+ import numpy as np
6
+ from app.config import get_settings
7
+ from app.observability.logging import get_logger
8
+
9
+ logger = get_logger(__name__)
10
+
11
+
12
+ class CLIPModel:
13
+ """
14
+ CLIP (Contrastive Language-Image Pre-Training) model.
15
+
16
+ Used in the deep analysis path to compute semantic alignment
17
+ between text descriptions and image content. This helps detect
18
+ subtle multimodal threats (e.g., threatening text overlaid on images).
19
+ """
20
+
21
+ def __init__(self):
22
+ self.settings = get_settings()
23
+ self.model = None
24
+ self.preprocess = None
25
+ self.tokenizer = None
26
+ self._loaded = False
27
+ self.device = None
28
+
29
+ def load(self) -> None:
30
+ """Load the CLIP model and preprocessor."""
31
+ import torch
32
+ try:
33
+ import open_clip
34
+
35
+ model_name = self.settings.clip_model_name
36
+ cache_dir = self.settings.model_cache_path / "clip"
37
+ cache_dir.mkdir(parents=True, exist_ok=True)
38
+
39
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
+
41
+ logger.info("loading_clip_model", model=model_name)
42
+
43
+ # Use OpenCLIP for flexibility
44
+ self.model, _, self.preprocess = open_clip.create_model_and_transforms(
45
+ "ViT-B-32",
46
+ pretrained="laion2b_s34b_b79k",
47
+ )
48
+ self.model = self.model.to(self.device)
49
+ self.model.eval()
50
+
51
+ self.tokenizer = open_clip.get_tokenizer("ViT-B-32")
52
+
53
+ self._loaded = True
54
+ logger.info("clip_model_loaded")
55
+
56
+ except ImportError:
57
+ logger.warning("clip_not_available", reason="open_clip not installed")
58
+ self._loaded = False
59
+ except Exception as e:
60
+ logger.error("clip_load_failed", error=str(e))
61
+ self._loaded = False
62
+
63
+ def compute_similarity(self, image: Image.Image, texts: list[str]) -> dict:
64
+ """
65
+ Compute cosine similarity between an image and a list of text descriptions.
66
+
67
+ Args:
68
+ image: PIL Image.
69
+ texts: List of text descriptions to compare against.
70
+
71
+ Returns:
72
+ Dict with similarities, best_match, and best_score.
73
+ """
74
+ if not self._loaded:
75
+ return {"error": "CLIP model not loaded", "similarities": []}
76
+
77
+ import torch
78
+
79
+ # Preprocess image
80
+ image_input = self.preprocess(image).unsqueeze(0).to(self.device)
81
+
82
+ # Tokenize texts
83
+ text_tokens = self.tokenizer(texts).to(self.device)
84
+
85
+ with torch.no_grad():
86
+ image_features = self.model.encode_image(image_input)
87
+ text_features = self.model.encode_text(text_tokens)
88
+
89
+ # Normalize
90
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
91
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
92
+
93
+ # Cosine similarity
94
+ similarities = (image_features @ text_features.T).squeeze(0).cpu().numpy()
95
+
96
+ sim_list = similarities.tolist()
97
+ best_idx = int(np.argmax(sim_list))
98
+
99
+ return {
100
+ "similarities": dict(zip(texts, sim_list)),
101
+ "best_match": texts[best_idx],
102
+ "best_score": sim_list[best_idx],
103
+ }
104
+
105
+ def align_content(self, image: Image.Image, context_text: str | None = None) -> dict:
106
+ """
107
+ Analyze image alignment with harmful content categories.
108
+
109
+ Args:
110
+ image: Image to analyze.
111
+ context_text: Optional surrounding text context.
112
+
113
+ Returns:
114
+ Dict with category alignment scores.
115
+ """
116
+ harmful_descriptions = [
117
+ "a photo containing violence, fighting, or physical harm",
118
+ "a photo containing nudity or sexual content",
119
+ "a photo containing self-harm or suicide imagery",
120
+ "a photo containing hate symbols or extremist content",
121
+ "a photo containing drugs or substance abuse",
122
+ "a safe and appropriate photo for children",
123
+ ]
124
+
125
+ result = self.compute_similarity(image, harmful_descriptions)
126
+
127
+ if "error" in result:
128
+ return result
129
+
130
+ # Also check text-image alignment if context provided
131
+ text_alignment = None
132
+ if context_text:
133
+ text_result = self.compute_similarity(image, [context_text, "unrelated content"])
134
+ text_alignment = text_result["similarities"].get(context_text, 0.0)
135
+
136
+ return {
137
+ "category_scores": result["similarities"],
138
+ "most_aligned": result["best_match"],
139
+ "alignment_score": result["best_score"],
140
+ "text_image_alignment": text_alignment,
141
+ }
142
+
143
+ @property
144
+ def is_loaded(self) -> bool:
145
+ return self._loaded
app/models/image_model.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/models/image_model.py
2
+ # EfficientNet-based image classification model with ONNX optimization
3
+
4
+ from pathlib import Path
5
+ import numpy as np
6
+ from PIL import Image
7
+ from app.config import get_settings
8
+ from app.observability.logging import get_logger
9
+
10
+ logger = get_logger(__name__)
11
+
12
+
13
+ class ImageClassificationModel:
14
+ """
15
+ Image content classifier using EfficientNet.
16
+
17
+ Detects violence, NSFW content, and other harmful imagery.
18
+ Supports ONNX (fast) and PyTorch (fallback) inference.
19
+ """
20
+
21
+ LABELS = ["safe", "violence", "nsfw", "self_harm", "hate_symbol"]
22
+
23
+ def __init__(self):
24
+ self.settings = get_settings()
25
+ self.processor = None
26
+ self.onnx_session = None
27
+ self.pt_model = None
28
+ self.device = None
29
+ self._loaded = False
30
+ self._num_labels = len(self.LABELS)
31
+
32
+ def load(self) -> None:
33
+ """Load the image processor and model."""
34
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
35
+
36
+ model_name = self.settings.image_model_name
37
+ cache_dir = self.settings.model_cache_path / "efficientnet"
38
+ onnx_path = cache_dir / "image_classifier.onnx"
39
+
40
+ logger.info("loading_image_model", model=model_name)
41
+
42
+ # Load image processor
43
+ try:
44
+ self.processor = AutoImageProcessor.from_pretrained(
45
+ model_name, cache_dir=cache_dir
46
+ )
47
+ except Exception:
48
+ # Fallback: use a generic processor
49
+ from transformers import AutoImageProcessor
50
+ self.processor = AutoImageProcessor.from_pretrained(
51
+ "google/efficientnet-b0", cache_dir=cache_dir
52
+ )
53
+
54
+ if self.settings.onnx_enabled and onnx_path.exists():
55
+ from app.models.onnx_utils import load_onnx_session
56
+ self.onnx_session = load_onnx_session(onnx_path)
57
+ logger.info("image_model_loaded", backend="onnx")
58
+ else:
59
+ self._load_pytorch(model_name, cache_dir)
60
+ if self.settings.onnx_enabled:
61
+ try:
62
+ self._export_onnx(onnx_path)
63
+ from app.models.onnx_utils import load_onnx_session
64
+ self.onnx_session = load_onnx_session(onnx_path)
65
+ self.pt_model = None
66
+ logger.info("image_model_loaded", backend="onnx", note="exported")
67
+ except Exception as e:
68
+ logger.warning("onnx_export_failed", error=str(e), fallback="pytorch")
69
+ else:
70
+ logger.info("image_model_loaded", backend="pytorch")
71
+
72
+ self._loaded = True
73
+
74
+ def _load_pytorch(self, model_name: str, cache_dir: Path) -> None:
75
+ """Load PyTorch model."""
76
+ import torch
77
+ from transformers import AutoModelForImageClassification
78
+
79
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
80
+ try:
81
+ self.pt_model = AutoModelForImageClassification.from_pretrained(
82
+ model_name, cache_dir=cache_dir
83
+ )
84
+ except Exception:
85
+ # If the model doesn't exist as a pretrained classifier, load base EfficientNet
86
+ self.pt_model = AutoModelForImageClassification.from_pretrained(
87
+ "google/efficientnet-b0", cache_dir=cache_dir
88
+ )
89
+ self.pt_model.to(self.device)
90
+ self.pt_model.eval()
91
+
92
+ # Update labels from model config if available
93
+ if hasattr(self.pt_model.config, "id2label"):
94
+ model_labels = list(self.pt_model.config.id2label.values())
95
+ if model_labels:
96
+ self._num_labels = len(model_labels)
97
+
98
+ def _export_onnx(self, onnx_path: Path) -> None:
99
+ """Export to ONNX."""
100
+ import torch
101
+ from app.models.onnx_utils import export_to_onnx
102
+
103
+ dummy_input = torch.randn(1, 3, 224, 224).to(self.device)
104
+ export_to_onnx(
105
+ model=self.pt_model,
106
+ sample_input={"pixel_values": dummy_input},
107
+ output_path=onnx_path,
108
+ input_names=["pixel_values"],
109
+ output_names=["logits"],
110
+ )
111
+
112
+ def predict(self, image: Image.Image) -> dict:
113
+ """
114
+ Classify an image for harmful content.
115
+
116
+ Args:
117
+ image: PIL Image (RGB).
118
+
119
+ Returns:
120
+ Dict with labels, scores, is_harmful, max_score, max_label.
121
+ """
122
+ if not self._loaded:
123
+ raise RuntimeError("Image model not loaded. Call load() first.")
124
+
125
+ # Preprocess with the model's processor
126
+ inputs = self.processor(images=image, return_tensors="np" if self.onnx_session else "pt")
127
+
128
+ if self.onnx_session:
129
+ return self._predict_onnx(inputs)
130
+ else:
131
+ return self._predict_pytorch(inputs)
132
+
133
+ def _predict_onnx(self, inputs) -> dict:
134
+ """ONNX inference."""
135
+ from app.models.onnx_utils import onnx_inference
136
+
137
+ pixel_values = inputs["pixel_values"].astype(np.float32)
138
+ outputs = onnx_inference(self.onnx_session, {"pixel_values": pixel_values})
139
+ logits = outputs[0][0]
140
+ return self._format_output(logits)
141
+
142
+ def _predict_pytorch(self, inputs) -> dict:
143
+ """PyTorch inference."""
144
+ import torch
145
+
146
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
147
+ with torch.no_grad():
148
+ outputs = self.pt_model(**inputs)
149
+ logits = outputs.logits[0].cpu().numpy()
150
+ return self._format_output(logits)
151
+
152
+ def _format_output(self, logits: np.ndarray) -> dict:
153
+ """Convert logits to prediction dict."""
154
+ # Softmax for single-label classification
155
+ exp_logits = np.exp(logits - np.max(logits))
156
+ scores = (exp_logits / exp_logits.sum()).tolist()
157
+
158
+ # Map to our labels (or use model's own labels)
159
+ if self.pt_model and hasattr(self.pt_model.config, "id2label"):
160
+ labels = [self.pt_model.config.id2label.get(i, f"class_{i}") for i in range(len(scores))]
161
+ else:
162
+ labels = [f"class_{i}" for i in range(len(scores))]
163
+
164
+ max_idx = int(np.argmax(scores))
165
+
166
+ # Determine if harmful (anything not classified as safe/non-violent)
167
+ safe_keywords = {"safe", "non-violence", "non_violence", "normal", "neutral"}
168
+ is_harmful = labels[max_idx].lower().replace("-", "_").replace(" ", "_") not in safe_keywords
169
+
170
+ return {
171
+ "labels": labels,
172
+ "scores": scores,
173
+ "is_harmful": is_harmful,
174
+ "max_score": scores[max_idx],
175
+ "max_label": labels[max_idx],
176
+ }
177
+
178
+ @property
179
+ def is_loaded(self) -> bool:
180
+ return self._loaded
app/models/model_registry.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/models/model_registry.py
2
+ # Singleton model registry — loads and manages all ML models
3
+
4
+ from app.models.text_model import TextToxicityModel
5
+ from app.models.image_model import ImageClassificationModel
6
+ from app.models.clip_model import CLIPModel
7
+ from app.observability.logging import get_logger
8
+
9
+ logger = get_logger(__name__)
10
+
11
+
12
+ class ModelRegistry:
13
+ """
14
+ Central registry for all ML models.
15
+
16
+ Provides lazy-loading and lifecycle management.
17
+ Models are loaded once and reused across requests.
18
+ """
19
+
20
+ def __init__(self):
21
+ self._text_model: TextToxicityModel | None = None
22
+ self._image_model: ImageClassificationModel | None = None
23
+ self._clip_model: CLIPModel | None = None
24
+
25
+ async def load_all(self) -> dict[str, bool]:
26
+ """
27
+ Load all models. Called during application startup.
28
+
29
+ Returns:
30
+ Dict of model name → loaded status.
31
+ """
32
+ results = {}
33
+
34
+ # Text model (required)
35
+ logger.info("registry_loading", model="text_toxicity")
36
+ try:
37
+ self._text_model = TextToxicityModel()
38
+ self._text_model.load()
39
+ results["text"] = True
40
+ except Exception as e:
41
+ logger.error("text_model_load_failed", error=str(e))
42
+ results["text"] = False
43
+
44
+ # Image model (required)
45
+ logger.info("registry_loading", model="image_classifier")
46
+ try:
47
+ self._image_model = ImageClassificationModel()
48
+ self._image_model.load()
49
+ results["image"] = True
50
+ except Exception as e:
51
+ logger.error("image_model_load_failed", error=str(e))
52
+ results["image"] = False
53
+
54
+ # CLIP model (optional — only for deep analysis)
55
+ logger.info("registry_loading", model="clip")
56
+ try:
57
+ self._clip_model = CLIPModel()
58
+ self._clip_model.load()
59
+ results["clip"] = self._clip_model.is_loaded
60
+ except Exception as e:
61
+ logger.warning("clip_model_load_failed", error=str(e))
62
+ results["clip"] = False
63
+
64
+ logger.info("registry_loaded", results=results)
65
+ return results
66
+
67
+ @property
68
+ def text_model(self) -> TextToxicityModel:
69
+ if self._text_model is None or not self._text_model.is_loaded:
70
+ raise RuntimeError("Text model not available")
71
+ return self._text_model
72
+
73
+ @property
74
+ def image_model(self) -> ImageClassificationModel:
75
+ if self._image_model is None or not self._image_model.is_loaded:
76
+ raise RuntimeError("Image model not available")
77
+ return self._image_model
78
+
79
+ @property
80
+ def clip_model(self) -> CLIPModel:
81
+ if self._clip_model is None:
82
+ raise RuntimeError("CLIP model not available")
83
+ return self._clip_model
84
+
85
+ @property
86
+ def clip_available(self) -> bool:
87
+ return self._clip_model is not None and self._clip_model.is_loaded
88
+
89
+ def get_status(self) -> dict:
90
+ """Get health status of all models."""
91
+ return {
92
+ "text_model": self._text_model.is_loaded if self._text_model else False,
93
+ "image_model": self._image_model.is_loaded if self._image_model else False,
94
+ "clip_model": self._clip_model.is_loaded if self._clip_model else False,
95
+ }
96
+
97
+
98
+ # Global singleton
99
+ model_registry = ModelRegistry()
app/models/onnx_utils.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/models/onnx_utils.py
2
+ # ONNX export and inference utilities
3
+
4
+ from pathlib import Path
5
+ import numpy as np
6
+ from app.observability.logging import get_logger
7
+
8
+ logger = get_logger(__name__)
9
+
10
+
11
+ def export_to_onnx(
12
+ model,
13
+ sample_input: dict,
14
+ output_path: Path,
15
+ input_names: list[str] | None = None,
16
+ output_names: list[str] | None = None,
17
+ dynamic_axes: dict | None = None,
18
+ opset_version: int = 14,
19
+ ) -> Path:
20
+ """
21
+ Export a PyTorch model to ONNX format.
22
+
23
+ Args:
24
+ model: PyTorch model (eval mode).
25
+ sample_input: Dict of tensor inputs for tracing.
26
+ output_path: Where to save the .onnx file.
27
+ input_names: Names for input tensors.
28
+ output_names: Names for output tensors.
29
+ dynamic_axes: Dynamic axes specification.
30
+ opset_version: ONNX opset version.
31
+
32
+ Returns:
33
+ Path to the exported ONNX model.
34
+ """
35
+ import torch
36
+
37
+ output_path = Path(output_path)
38
+ output_path.parent.mkdir(parents=True, exist_ok=True)
39
+
40
+ if input_names is None:
41
+ input_names = list(sample_input.keys())
42
+ if output_names is None:
43
+ output_names = ["logits"]
44
+ if dynamic_axes is None:
45
+ dynamic_axes = {name: {0: "batch_size"} for name in input_names + output_names}
46
+
47
+ # Prepare ordered tuple of inputs
48
+ input_tuple = tuple(sample_input[name] for name in input_names)
49
+
50
+ model.eval()
51
+ with torch.no_grad():
52
+ torch.onnx.export(
53
+ model,
54
+ input_tuple,
55
+ str(output_path),
56
+ input_names=input_names,
57
+ output_names=output_names,
58
+ dynamic_axes=dynamic_axes,
59
+ opset_version=opset_version,
60
+ do_constant_folding=True,
61
+ )
62
+
63
+ logger.info("onnx_export_complete", path=str(output_path), size_mb=round(output_path.stat().st_size / 1e6, 1))
64
+ return output_path
65
+
66
+
67
+ def load_onnx_session(model_path: Path, providers: list[str] | None = None):
68
+ """
69
+ Load an ONNX model as an InferenceSession.
70
+
71
+ Args:
72
+ model_path: Path to .onnx file.
73
+ providers: ONNX Runtime execution providers (defaults to CPU).
74
+
75
+ Returns:
76
+ ort.InferenceSession instance.
77
+ """
78
+ import onnxruntime as ort
79
+
80
+ if providers is None:
81
+ available = ort.get_available_providers()
82
+ # Prefer CUDA if available, else CPU
83
+ if "CUDAExecutionProvider" in available:
84
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
85
+ else:
86
+ providers = ["CPUExecutionProvider"]
87
+
88
+ session = ort.InferenceSession(str(model_path), providers=providers)
89
+ logger.info(
90
+ "onnx_session_loaded",
91
+ path=str(model_path),
92
+ providers=providers,
93
+ )
94
+ return session
95
+
96
+
97
+ def onnx_inference(session, inputs: dict[str, np.ndarray]) -> list[np.ndarray]:
98
+ """
99
+ Run inference on an ONNX session.
100
+
101
+ Args:
102
+ session: ONNX InferenceSession.
103
+ inputs: Dict mapping input names to numpy arrays.
104
+
105
+ Returns:
106
+ List of output numpy arrays.
107
+ """
108
+ # Ensure proper dtypes
109
+ feed = {}
110
+ for inp in session.get_inputs():
111
+ if inp.name in inputs:
112
+ arr = inputs[inp.name]
113
+ # Match expected dtype
114
+ if "int" in inp.type:
115
+ arr = arr.astype(np.int64)
116
+ else:
117
+ arr = arr.astype(np.float32)
118
+ feed[inp.name] = arr
119
+
120
+ return session.run(None, feed)
app/models/text_model.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/models/text_model.py
2
+ # RoBERTa-based text toxicity model with ONNX optimization
3
+
4
+ from pathlib import Path
5
+ import numpy as np
6
+ from app.config import get_settings
7
+ from app.observability.logging import get_logger
8
+
9
+ logger = get_logger(__name__)
10
+
11
+
12
+ class TextToxicityModel:
13
+ """
14
+ Text toxicity classifier using a RoBERTa-based model.
15
+
16
+ Supports both ONNX (fast) and PyTorch (fallback) inference.
17
+ Model: unitary/toxic-bert (multi-label toxicity detection).
18
+
19
+ Labels: toxic, severe_toxic, obscene, threat, insult, identity_hate
20
+ """
21
+
22
+ LABELS = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]
23
+
24
+ def __init__(self):
25
+ self.settings = get_settings()
26
+ self.tokenizer = None
27
+ self.onnx_session = None
28
+ self.pt_model = None
29
+ self.device = None
30
+ self._loaded = False
31
+
32
+ def load(self) -> None:
33
+ """Load the tokenizer and model (ONNX preferred, PyTorch fallback)."""
34
+ from transformers import AutoTokenizer
35
+
36
+ model_name = self.settings.text_model_name
37
+ cache_dir = self.settings.model_cache_path / "roberta"
38
+ onnx_path = cache_dir / "text_toxicity.onnx"
39
+
40
+ logger.info("loading_text_model", model=model_name)
41
+
42
+ # Load tokenizer
43
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
44
+
45
+ if self.settings.onnx_enabled and onnx_path.exists():
46
+ # Use existing ONNX model
47
+ from app.models.onnx_utils import load_onnx_session
48
+ self.onnx_session = load_onnx_session(onnx_path)
49
+ logger.info("text_model_loaded", backend="onnx")
50
+ elif self.settings.onnx_enabled:
51
+ # Load PyTorch, export to ONNX, then use ONNX
52
+ self._load_pytorch(model_name, cache_dir)
53
+ self._export_onnx(onnx_path)
54
+ # Switch to ONNX session
55
+ from app.models.onnx_utils import load_onnx_session
56
+ self.onnx_session = load_onnx_session(onnx_path)
57
+ self.pt_model = None # Free PyTorch memory
58
+ logger.info("text_model_loaded", backend="onnx", note="exported_from_pytorch")
59
+ else:
60
+ # PyTorch only
61
+ self._load_pytorch(model_name, cache_dir)
62
+ logger.info("text_model_loaded", backend="pytorch")
63
+
64
+ self._loaded = True
65
+
66
+ def _load_pytorch(self, model_name: str, cache_dir: Path) -> None:
67
+ """Load the PyTorch model."""
68
+ import torch
69
+ from transformers import AutoModelForSequenceClassification
70
+
71
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
72
+ self.pt_model = AutoModelForSequenceClassification.from_pretrained(
73
+ model_name, cache_dir=cache_dir
74
+ )
75
+ self.pt_model.to(self.device)
76
+ self.pt_model.eval()
77
+
78
+ def _export_onnx(self, onnx_path: Path) -> None:
79
+ """Export current PyTorch model to ONNX."""
80
+ import torch
81
+ from app.models.onnx_utils import export_to_onnx
82
+
83
+ sample = self.tokenizer(
84
+ "test input for export",
85
+ return_tensors="pt",
86
+ padding="max_length",
87
+ truncation=True,
88
+ max_length=128,
89
+ )
90
+ sample = {k: v.to(self.device) for k, v in sample.items()}
91
+
92
+ export_to_onnx(
93
+ model=self.pt_model,
94
+ sample_input=sample,
95
+ output_path=onnx_path,
96
+ input_names=["input_ids", "attention_mask"],
97
+ output_names=["logits"],
98
+ )
99
+
100
+ def predict(self, text: str) -> dict:
101
+ """
102
+ Predict toxicity scores for input text.
103
+
104
+ Args:
105
+ text: Input text to classify.
106
+
107
+ Returns:
108
+ Dict with:
109
+ - labels: list of label names
110
+ - scores: list of per-label probabilities
111
+ - is_toxic: bool (any label > 0.5)
112
+ - max_score: float (highest toxicity probability)
113
+ - max_label: str (label with highest probability)
114
+ """
115
+ if not self._loaded:
116
+ raise RuntimeError("Text model not loaded. Call load() first.")
117
+
118
+ # Tokenize
119
+ encoding = self.tokenizer(
120
+ text,
121
+ return_tensors="np" if self.onnx_session else "pt",
122
+ padding="max_length",
123
+ truncation=True,
124
+ max_length=128,
125
+ )
126
+
127
+ if self.onnx_session:
128
+ return self._predict_onnx(encoding)
129
+ else:
130
+ return self._predict_pytorch(encoding)
131
+
132
+ def _predict_onnx(self, encoding: dict) -> dict:
133
+ """Run ONNX inference."""
134
+ from app.models.onnx_utils import onnx_inference
135
+
136
+ inputs = {
137
+ "input_ids": encoding["input_ids"].astype(np.int64),
138
+ "attention_mask": encoding["attention_mask"].astype(np.int64),
139
+ }
140
+ outputs = onnx_inference(self.onnx_session, inputs)
141
+ logits = outputs[0][0] # (num_labels,)
142
+ return self._format_output(logits)
143
+
144
+ def _predict_pytorch(self, encoding: dict) -> dict:
145
+ """Run PyTorch inference."""
146
+ import torch
147
+
148
+ inputs = {k: v.to(self.device) for k, v in encoding.items()}
149
+ with torch.no_grad():
150
+ outputs = self.pt_model(**inputs)
151
+ logits = outputs.logits[0].cpu().numpy()
152
+ return self._format_output(logits)
153
+
154
+ def _format_output(self, logits: np.ndarray) -> dict:
155
+ """Convert raw logits to formatted prediction dict."""
156
+ # Sigmoid for multi-label classification
157
+ scores = 1 / (1 + np.exp(-logits))
158
+ scores = scores.tolist()
159
+
160
+ # Handle case where model has fewer outputs than expected labels
161
+ labels = self.LABELS[: len(scores)]
162
+
163
+ label_scores = dict(zip(labels, scores))
164
+ max_idx = int(np.argmax(scores))
165
+ is_toxic = any(s > 0.5 for s in scores)
166
+
167
+ return {
168
+ "labels": labels,
169
+ "scores": scores,
170
+ "label_scores": label_scores,
171
+ "is_toxic": is_toxic,
172
+ "max_score": scores[max_idx],
173
+ "max_label": labels[max_idx],
174
+ }
175
+
176
+ @property
177
+ def is_loaded(self) -> bool:
178
+ return self._loaded
app/observability/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # app/observability/__init__.py
2
+ """Observability: structured logging and LangSmith tracing."""
app/observability/langsmith.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/observability/langsmith.py
2
+ # LangSmith tracing integration for pipeline observability
3
+
4
+ import os
5
+ from app.config import get_settings
6
+ from app.observability.logging import get_logger
7
+
8
+ logger = get_logger(__name__)
9
+
10
+
11
+ def setup_langsmith() -> bool:
12
+ """
13
+ Configure LangSmith tracing via environment variables.
14
+ LangChain/LangGraph automatically pick up these env vars.
15
+
16
+ Returns:
17
+ True if LangSmith is configured, False otherwise.
18
+ """
19
+ settings = get_settings()
20
+
21
+ if not settings.langsmith_api_key:
22
+ logger.info("langsmith_disabled", reason="No API key provided")
23
+ return False
24
+
25
+ os.environ["LANGCHAIN_TRACING_V2"] = str(settings.langsmith_tracing_v2).lower()
26
+ os.environ["LANGCHAIN_API_KEY"] = settings.langsmith_api_key
27
+ os.environ["LANGCHAIN_PROJECT"] = settings.langsmith_project
28
+
29
+ logger.info(
30
+ "langsmith_enabled",
31
+ project=settings.langsmith_project,
32
+ )
33
+ return True
34
+
35
+
36
+ def get_trace_config(
37
+ run_name: str,
38
+ input_type: str,
39
+ metadata: dict | None = None,
40
+ ) -> dict:
41
+ """
42
+ Build configuration dict for LangGraph invoke calls.
43
+ This attaches metadata and tags to the LangSmith trace.
44
+
45
+ Args:
46
+ run_name: Human-readable name for this trace run.
47
+ input_type: The content type being analyzed (text/image/video).
48
+ metadata: Additional key-value metadata.
49
+
50
+ Returns:
51
+ Config dict to pass to workflow.invoke().
52
+ """
53
+ tags = [f"input:{input_type}", "hubble-moderation"]
54
+ config = {
55
+ "run_name": run_name,
56
+ "tags": tags,
57
+ "metadata": metadata or {},
58
+ }
59
+ return {"configurable": config}
app/observability/logging.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/observability/logging.py
2
+ # Structured logging with structlog
3
+
4
+ import sys
5
+ import logging
6
+ import structlog
7
+ from app.config import get_settings
8
+
9
+
10
+ def setup_logging() -> None:
11
+ """Configure structured logging for the application."""
12
+ settings = get_settings()
13
+
14
+ # Configure structlog
15
+ structlog.configure(
16
+ processors=[
17
+ structlog.contextvars.merge_contextvars,
18
+ structlog.processors.add_log_level,
19
+ structlog.processors.StackInfoRenderer(),
20
+ structlog.dev.set_exc_info,
21
+ structlog.processors.TimeStamper(fmt="iso"),
22
+ structlog.processors.JSONRenderer()
23
+ if settings.is_production
24
+ else structlog.dev.ConsoleRenderer(colors=True),
25
+ ],
26
+ wrapper_class=structlog.make_filtering_bound_logger(
27
+ logging.getLevelName(settings.log_level)
28
+ ),
29
+ context_class=dict,
30
+ logger_factory=structlog.PrintLoggerFactory(),
31
+ cache_logger_on_first_use=True,
32
+ )
33
+
34
+ # Silence noisy third-party loggers
35
+ for logger_name in ["uvicorn.access", "httpx", "httpcore"]:
36
+ logging.getLogger(logger_name).setLevel(logging.WARNING)
37
+
38
+
39
+ def get_logger(name: str | None = None) -> structlog.BoundLogger:
40
+ """Get a structured logger instance."""
41
+ return structlog.get_logger(name or __name__)
app/pipeline/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # app/pipeline/__init__.py
2
+ """Core moderation pipeline: preprocess → filter → score → route → decide."""
app/pipeline/decision_engine.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/pipeline/decision_engine.py
2
+ # Rule-based decision engine: final moderation verdict
3
+
4
+ from dataclasses import dataclass, field
5
+ from app.pipeline.risk_scorer import RiskScore
6
+ from app.pipeline.deep_analyzer import DeepAnalysisResult
7
+ from app.observability.logging import get_logger
8
+
9
+ logger = get_logger(__name__)
10
+
11
+
12
+ @dataclass
13
+ class Decision:
14
+ """Final moderation decision."""
15
+ action: str # ALLOWED, WARNING, BLOCKED
16
+ reason: str
17
+ severity: str # low, medium, high, critical
18
+ categories: list[str] = field(default_factory=list)
19
+ should_alert_parent: bool = False
20
+ should_log: bool = True
21
+ escalation_notes: str | None = None
22
+
23
+
24
+ class DecisionEngine:
25
+ """
26
+ Rule-based final decision engine.
27
+
28
+ Takes risk score + optional deep analysis and produces a final verdict.
29
+
30
+ Rules:
31
+ - LOW risk → ALLOWED (no action)
32
+ - MEDIUM risk → WARNING (increment user warning count, log)
33
+ - HIGH risk + deep_confirmed → BLOCKED (alert parent, log, escalate if critical)
34
+ - HIGH risk + deep_not_confirmed → WARNING (false positive recovery)
35
+ - Repeat offender with MEDIUM → BLOCKED (escalate)
36
+ """
37
+
38
+ def decide(
39
+ self,
40
+ risk: RiskScore,
41
+ deep_result: DeepAnalysisResult | None = None,
42
+ user_history: dict | None = None,
43
+ ) -> Decision:
44
+ """
45
+ Produce final moderation decision.
46
+
47
+ Args:
48
+ risk: Composite risk score from the scoring engine.
49
+ deep_result: Optional deep analysis result (only for HIGH risk).
50
+ user_history: Optional user moderation history.
51
+
52
+ Returns:
53
+ Decision with action, reason, and metadata.
54
+ """
55
+
56
+ # === LOW RISK ===
57
+ if risk.level == "LOW":
58
+ decision = Decision(
59
+ action="ALLOWED",
60
+ reason="Content passed all safety checks",
61
+ severity="low",
62
+ should_log=False, # Don't clutter logs with safe content
63
+ )
64
+
65
+ # === MEDIUM RISK ===
66
+ elif risk.level == "MEDIUM":
67
+ # Check for repeat offender escalation
68
+ if risk.repeat_offender:
69
+ decision = Decision(
70
+ action="BLOCKED",
71
+ reason="Repeat offender with moderately harmful content — escalated to block",
72
+ severity="high",
73
+ should_alert_parent=True,
74
+ escalation_notes="User has repeated violation history. Medium-risk content escalated.",
75
+ )
76
+ else:
77
+ decision = Decision(
78
+ action="WARNING",
79
+ reason=f"Content flagged as potentially harmful (risk score: {risk.score})",
80
+ severity="medium",
81
+ should_alert_parent=False,
82
+ )
83
+
84
+ # === HIGH RISK ===
85
+ elif risk.level == "HIGH":
86
+ if deep_result and deep_result.is_confirmed:
87
+ # Deep analysis confirms the threat
88
+ severity = deep_result.severity
89
+ should_escalate = severity == "critical"
90
+
91
+ decision = Decision(
92
+ action="BLOCKED",
93
+ reason=deep_result.reasoning,
94
+ severity=severity,
95
+ categories=deep_result.categories,
96
+ should_alert_parent=True,
97
+ escalation_notes=(
98
+ "CRITICAL: Immediate review required. "
99
+ f"Recommended action: {deep_result.recommended_action}"
100
+ if should_escalate
101
+ else None
102
+ ),
103
+ )
104
+ elif deep_result and not deep_result.is_confirmed:
105
+ # Deep analysis says it's a false positive
106
+ decision = Decision(
107
+ action="WARNING",
108
+ reason=(
109
+ f"Content initially flagged as high-risk (score: {risk.score}) "
110
+ f"but deep analysis did not confirm threat. "
111
+ f"Reasoning: {deep_result.reasoning}"
112
+ ),
113
+ severity="low",
114
+ should_alert_parent=False,
115
+ )
116
+ else:
117
+ # No deep analysis available — err on caution
118
+ decision = Decision(
119
+ action="BLOCKED",
120
+ reason=f"High-risk content detected (score: {risk.score}). Deep analysis unavailable.",
121
+ severity="high",
122
+ should_alert_parent=True,
123
+ escalation_notes="Deep analysis was not performed. Manual review recommended.",
124
+ )
125
+
126
+ else:
127
+ # Fallback
128
+ decision = Decision(
129
+ action="WARNING",
130
+ reason="Unclassified risk level",
131
+ severity="medium",
132
+ )
133
+
134
+ logger.info(
135
+ "decision_made",
136
+ action=decision.action,
137
+ severity=decision.severity,
138
+ alert_parent=decision.should_alert_parent,
139
+ risk_score=risk.score,
140
+ risk_level=risk.level,
141
+ )
142
+ return decision
app/pipeline/deep_analyzer.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/pipeline/deep_analyzer.py
2
+ # Deep analysis layer: CLIP + Gemini reasoning (HIGH risk only)
3
+
4
+ from dataclasses import dataclass, field
5
+ from PIL import Image
6
+ from app.models.model_registry import model_registry
7
+ from app.services.gemini_service import gemini_service
8
+ from app.pipeline.fast_filter import FilterResult
9
+ from app.utils.image_utils import image_to_base64
10
+ from app.observability.logging import get_logger
11
+
12
+ logger = get_logger(__name__)
13
+
14
+
15
+ @dataclass
16
+ class DeepAnalysisResult:
17
+ """Output from the deep analysis stage."""
18
+ is_confirmed: bool # Does deep analysis confirm the threat?
19
+ severity: str # low, medium, high, critical
20
+ reasoning: str # Explanation from Gemini
21
+ categories: list[str] = field(default_factory=list)
22
+ recommended_action: str = "warn" # allow, warn, block, escalate
23
+ confidence: float = 0.0
24
+ clip_scores: dict = field(default_factory=dict)
25
+ gemini_raw: dict = field(default_factory=dict)
26
+
27
+
28
+ class DeepAnalyzer:
29
+ """
30
+ Deep analysis layer invoked only for HIGH-risk content.
31
+
32
+ Pipeline:
33
+ 1. CLIP multimodal alignment (if image present)
34
+ 2. Gemini reasoning via LangChain
35
+ 3. Combine signals into final assessment
36
+
37
+ This layer trades speed for accuracy — expected latency: 1-3 seconds.
38
+ """
39
+
40
+ async def analyze_text(
41
+ self,
42
+ text: str,
43
+ filter_result: FilterResult,
44
+ ) -> DeepAnalysisResult:
45
+ """
46
+ Deep analysis for flagged text content.
47
+
48
+ Args:
49
+ text: Original text content.
50
+ filter_result: Results from the fast filter stage.
51
+
52
+ Returns:
53
+ DeepAnalysisResult with Gemini reasoning.
54
+ """
55
+ logger.info("deep_analysis_text_started")
56
+
57
+ # Prepare context from fast filter
58
+ context = {
59
+ "flagged_categories": filter_result.categories,
60
+ "max_score": filter_result.max_score,
61
+ "max_label": filter_result.max_label,
62
+ "all_scores": filter_result.scores,
63
+ }
64
+
65
+ # Invoke Gemini for contextual reasoning
66
+ gemini_result = await gemini_service.analyze_text(text, context)
67
+
68
+ result = self._build_result(gemini_result)
69
+
70
+ logger.info(
71
+ "deep_analysis_text_complete",
72
+ confirmed=result.is_confirmed,
73
+ severity=result.severity,
74
+ action=result.recommended_action,
75
+ )
76
+ return result
77
+
78
+ async def analyze_image(
79
+ self,
80
+ image: Image.Image,
81
+ filter_result: FilterResult,
82
+ context_text: str | None = None,
83
+ ) -> DeepAnalysisResult:
84
+ """
85
+ Deep analysis for flagged image content.
86
+
87
+ Args:
88
+ image: PIL Image.
89
+ filter_result: Results from the fast filter.
90
+ context_text: Optional text accompanying the image.
91
+
92
+ Returns:
93
+ DeepAnalysisResult with CLIP alignment + Gemini reasoning.
94
+ """
95
+ logger.info("deep_analysis_image_started")
96
+
97
+ # Step 1: CLIP multimodal alignment
98
+ clip_scores = {}
99
+ if model_registry.clip_available:
100
+ try:
101
+ clip_result = model_registry.clip_model.align_content(image, context_text)
102
+ clip_scores = clip_result
103
+ logger.info("clip_alignment_complete", most_aligned=clip_result.get("most_aligned"))
104
+ except Exception as e:
105
+ logger.warning("clip_alignment_failed", error=str(e))
106
+
107
+ # Step 2: Gemini image reasoning
108
+ context = {
109
+ "flagged_categories": filter_result.categories,
110
+ "max_score": filter_result.max_score,
111
+ "clip_alignment": clip_scores.get("most_aligned", "unknown"),
112
+ }
113
+
114
+ image_b64 = image_to_base64(image)
115
+ gemini_result = await gemini_service.analyze_image(image_b64, context)
116
+
117
+ result = self._build_result(gemini_result, clip_scores)
118
+
119
+ logger.info(
120
+ "deep_analysis_image_complete",
121
+ confirmed=result.is_confirmed,
122
+ severity=result.severity,
123
+ )
124
+ return result
125
+
126
+ def _build_result(
127
+ self,
128
+ gemini_result: dict,
129
+ clip_scores: dict | None = None,
130
+ ) -> DeepAnalysisResult:
131
+ """Build DeepAnalysisResult from Gemini response."""
132
+ if "error" in gemini_result:
133
+ # Gemini failed — err on the side of caution
134
+ return DeepAnalysisResult(
135
+ is_confirmed=True, # Assume harmful if we can't verify
136
+ severity="medium",
137
+ reasoning=f"Deep analysis unavailable: {gemini_result['error']}. Defaulting to caution.",
138
+ recommended_action="warn",
139
+ confidence=0.3,
140
+ clip_scores=clip_scores or {},
141
+ gemini_raw=gemini_result,
142
+ )
143
+
144
+ return DeepAnalysisResult(
145
+ is_confirmed=gemini_result.get("is_confirmed", False),
146
+ severity=gemini_result.get("severity", "medium"),
147
+ reasoning=gemini_result.get("reasoning", "No reasoning provided"),
148
+ categories=gemini_result.get("categories", []),
149
+ recommended_action=gemini_result.get("recommended_action", "warn"),
150
+ confidence=gemini_result.get("confidence", 0.5),
151
+ clip_scores=clip_scores or {},
152
+ gemini_raw=gemini_result,
153
+ )
app/pipeline/fast_filter.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/pipeline/fast_filter.py
2
+ # First-pass AI classification using ONNX-optimized models
3
+
4
+ from dataclasses import dataclass, field
5
+ from PIL import Image
6
+ from app.models.model_registry import model_registry
7
+ from app.pipeline.preprocessor import ProcessedText, ProcessedImage
8
+ from app.observability.logging import get_logger
9
+
10
+ logger = get_logger(__name__)
11
+
12
+
13
+ @dataclass
14
+ class FilterResult:
15
+ """Output from the fast filter stage."""
16
+ input_type: str # "text", "image"
17
+ is_flagged: bool
18
+ scores: dict[str, float] = field(default_factory=dict)
19
+ max_score: float = 0.0
20
+ max_label: str = ""
21
+ categories: list[str] = field(default_factory=list)
22
+ confidence: float = 0.0
23
+
24
+
25
+ class FastFilter:
26
+ """
27
+ Fast AI filter using ONNX-optimized models.
28
+
29
+ - RoBERTa for text toxicity (multi-label)
30
+ - EfficientNet for image classification
31
+
32
+ This is the first gate in the pipeline. Designed for speed (<200ms).
33
+ """
34
+
35
+ # Toxicity threshold for flagging
36
+ TEXT_FLAG_THRESHOLD = 0.4
37
+ IMAGE_FLAG_THRESHOLD = 0.5
38
+
39
+ def filter_text(self, processed: ProcessedText) -> FilterResult:
40
+ """
41
+ Run RoBERTa text toxicity inference.
42
+
43
+ Args:
44
+ processed: Preprocessed text input.
45
+
46
+ Returns:
47
+ FilterResult with per-category toxicity scores.
48
+ """
49
+ text_model = model_registry.text_model
50
+ prediction = text_model.predict(processed.cleaned)
51
+
52
+ # Determine which categories are flagged
53
+ flagged_categories = []
54
+ label_scores = prediction.get("label_scores", {})
55
+ for label, score in label_scores.items():
56
+ if score > self.TEXT_FLAG_THRESHOLD:
57
+ flagged_categories.append(label)
58
+
59
+ is_flagged = len(flagged_categories) > 0
60
+
61
+ result = FilterResult(
62
+ input_type="text",
63
+ is_flagged=is_flagged,
64
+ scores=label_scores,
65
+ max_score=prediction["max_score"],
66
+ max_label=prediction["max_label"],
67
+ categories=flagged_categories,
68
+ confidence=prediction["max_score"],
69
+ )
70
+
71
+ logger.info(
72
+ "fast_filter_text",
73
+ flagged=is_flagged,
74
+ max_label=result.max_label,
75
+ max_score=round(result.max_score, 3),
76
+ categories=flagged_categories,
77
+ )
78
+ return result
79
+
80
+ def filter_image(self, processed: ProcessedImage) -> FilterResult:
81
+ """
82
+ Run EfficientNet image classification inference.
83
+
84
+ Args:
85
+ processed: Preprocessed image input.
86
+
87
+ Returns:
88
+ FilterResult with image classification scores.
89
+ """
90
+ image_model = model_registry.image_model
91
+ prediction = image_model.predict(processed.image)
92
+
93
+ # Map model output to categories
94
+ scores = {}
95
+ for label, score in zip(prediction["labels"], prediction["scores"]):
96
+ scores[label] = score
97
+
98
+ flagged_categories = []
99
+ for label, score in scores.items():
100
+ if score > self.IMAGE_FLAG_THRESHOLD:
101
+ # Check if this is a harmful category
102
+ safe_labels = {"safe", "non-violence", "non_violence", "normal", "neutral"}
103
+ if label.lower().replace("-", "_").replace(" ", "_") not in safe_labels:
104
+ flagged_categories.append(label)
105
+
106
+ is_flagged = prediction["is_harmful"]
107
+
108
+ result = FilterResult(
109
+ input_type="image",
110
+ is_flagged=is_flagged,
111
+ scores=scores,
112
+ max_score=prediction["max_score"],
113
+ max_label=prediction["max_label"],
114
+ categories=flagged_categories,
115
+ confidence=prediction["max_score"],
116
+ )
117
+
118
+ logger.info(
119
+ "fast_filter_image",
120
+ flagged=is_flagged,
121
+ max_label=result.max_label,
122
+ max_score=round(result.max_score, 3),
123
+ )
124
+ return result
app/pipeline/preprocessor.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/pipeline/preprocessor.py
2
+ # Input preprocessing: normalization, frame extraction, cleaning
3
+
4
+ import re
5
+ from dataclasses import dataclass, field
6
+ from PIL import Image
7
+ from app.config import get_settings
8
+ from app.observability.logging import get_logger
9
+
10
+ logger = get_logger(__name__)
11
+
12
+
13
+ @dataclass
14
+ class ProcessedText:
15
+ """Preprocessed text content."""
16
+ original: str
17
+ cleaned: str
18
+ word_count: int
19
+ char_count: int
20
+ language: str = "en" # placeholder for language detection
21
+
22
+
23
+ @dataclass
24
+ class ProcessedImage:
25
+ """Preprocessed image content."""
26
+ image: Image.Image
27
+ width: int
28
+ height: int
29
+ format: str = "RGB"
30
+
31
+
32
+ @dataclass
33
+ class ProcessedVideo:
34
+ """Preprocessed video — a list of extracted frames."""
35
+ frames: list[ProcessedImage] = field(default_factory=list)
36
+ frame_count: int = 0
37
+ duration_seconds: float = 0.0
38
+ metadata: dict = field(default_factory=dict)
39
+
40
+
41
+ class Preprocessor:
42
+ """
43
+ Input preprocessing for all content types.
44
+
45
+ - Text: cleaning, normalization
46
+ - Image: resize, format conversion
47
+ - Video: frame extraction + per-frame preprocessing
48
+ """
49
+
50
+ def __init__(self):
51
+ self.settings = get_settings()
52
+
53
+ def process_text(self, text: str) -> ProcessedText:
54
+ """
55
+ Clean and normalize input text.
56
+
57
+ - Strip excessive whitespace
58
+ - Remove zero-width characters
59
+ - Normalize unicode
60
+ """
61
+ import unicodedata
62
+
63
+ # Remove zero-width characters often used for obfuscation
64
+ cleaned = re.sub(r"[\u200b\u200c\u200d\ufeff]", "", text)
65
+
66
+ # Normalize unicode
67
+ cleaned = unicodedata.normalize("NFKC", cleaned)
68
+
69
+ # Collapse excessive whitespace
70
+ cleaned = re.sub(r"\s+", " ", cleaned).strip()
71
+
72
+ result = ProcessedText(
73
+ original=text,
74
+ cleaned=cleaned,
75
+ word_count=len(cleaned.split()),
76
+ char_count=len(cleaned),
77
+ )
78
+
79
+ logger.debug(
80
+ "text_preprocessed",
81
+ word_count=result.word_count,
82
+ char_count=result.char_count,
83
+ )
84
+ return result
85
+
86
+ def process_image(self, image_bytes: bytes) -> ProcessedImage:
87
+ """
88
+ Load and preprocess image from bytes.
89
+
90
+ - Convert to RGB
91
+ - Record dimensions
92
+ """
93
+ from app.utils.image_utils import load_image_from_bytes
94
+
95
+ image = load_image_from_bytes(image_bytes)
96
+ width, height = image.size
97
+
98
+ result = ProcessedImage(
99
+ image=image,
100
+ width=width,
101
+ height=height,
102
+ )
103
+
104
+ logger.debug("image_preprocessed", width=width, height=height)
105
+ return result
106
+
107
+ def process_video(self, video_bytes: bytes) -> ProcessedVideo:
108
+ """
109
+ Extract key frames from video.
110
+
111
+ Uses OpenCV to sample frames at configured intervals.
112
+ """
113
+ from app.utils.video_utils import extract_frames, get_video_metadata
114
+
115
+ metadata = get_video_metadata(video_bytes)
116
+ frames_pil = extract_frames(
117
+ video_bytes,
118
+ max_frames=self.settings.video_max_frames,
119
+ fps_sample=self.settings.video_fps_sample,
120
+ )
121
+
122
+ processed_frames = []
123
+ for frame in frames_pil:
124
+ w, h = frame.size
125
+ processed_frames.append(
126
+ ProcessedImage(image=frame, width=w, height=h)
127
+ )
128
+
129
+ result = ProcessedVideo(
130
+ frames=processed_frames,
131
+ frame_count=len(processed_frames),
132
+ duration_seconds=metadata.get("duration_seconds", 0.0),
133
+ metadata=metadata,
134
+ )
135
+
136
+ logger.debug(
137
+ "video_preprocessed",
138
+ frames_extracted=result.frame_count,
139
+ duration=result.duration_seconds,
140
+ )
141
+ return result
app/pipeline/risk_scorer.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/pipeline/risk_scorer.py
2
+ # Composite risk scoring engine
3
+
4
+ from dataclasses import dataclass
5
+ from app.config import get_settings
6
+ from app.pipeline.fast_filter import FilterResult
7
+ from app.observability.logging import get_logger
8
+
9
+ logger = get_logger(__name__)
10
+
11
+
12
+ @dataclass
13
+ class RiskScore:
14
+ """Composite risk assessment."""
15
+ score: float # 0-100
16
+ level: str # LOW, MEDIUM, HIGH
17
+ components: dict # breakdown of scoring factors
18
+ repeat_offender: bool = False
19
+
20
+
21
+ class RiskScorer:
22
+ """
23
+ Computes a composite risk score (0-100) from multiple signals.
24
+
25
+ Scoring formula:
26
+ - Base score from model confidence (weighted by category severity)
27
+ - Repeat offender boost (user history)
28
+ - Multi-category penalty (multiple harmful categories = higher risk)
29
+
30
+ Thresholds (configurable via env):
31
+ - 0-30: LOW → Allow
32
+ - 31-65: MEDIUM → Warning
33
+ - 66-100: HIGH → Deep Analysis
34
+ """
35
+
36
+ # Category severity weights (how dangerous each type is)
37
+ CATEGORY_WEIGHTS = {
38
+ # Text categories (from RoBERTa toxic-bert)
39
+ "toxic": 0.6,
40
+ "severe_toxic": 1.0,
41
+ "obscene": 0.5,
42
+ "threat": 1.0,
43
+ "insult": 0.5,
44
+ "identity_hate": 0.9,
45
+ # Image categories
46
+ "violence": 0.9,
47
+ "nsfw": 0.8,
48
+ "self_harm": 1.0,
49
+ "hate_symbol": 0.9,
50
+ # Generic fallback
51
+ "harassment": 0.7,
52
+ "bullying": 0.7,
53
+ }
54
+
55
+ # Repeat offender thresholds
56
+ REPEAT_OFFENDER_VIOLATIONS = 3
57
+ REPEAT_OFFENDER_BOOST = 15 # points added
58
+
59
+ def __init__(self):
60
+ self.settings = get_settings()
61
+
62
+ def score(
63
+ self,
64
+ filter_result: FilterResult,
65
+ user_history: dict | None = None,
66
+ ) -> RiskScore:
67
+ """
68
+ Compute composite risk score.
69
+
70
+ Args:
71
+ filter_result: Output from fast filter stage.
72
+ user_history: Optional user moderation history.
73
+
74
+ Returns:
75
+ RiskScore with level classification.
76
+ """
77
+ # 1. Base score from model confidence
78
+ base_score = self._compute_base_score(filter_result)
79
+
80
+ # 2. Multi-category penalty
81
+ multi_cat_penalty = self._multi_category_penalty(filter_result)
82
+
83
+ # 3. Repeat offender boost
84
+ repeat_boost, is_repeat = self._repeat_offender_boost(user_history)
85
+
86
+ # 4. Combine
87
+ raw_score = base_score + multi_cat_penalty + repeat_boost
88
+ final_score = min(100.0, max(0.0, raw_score))
89
+
90
+ # 5. Classify level
91
+ level = self._classify_level(final_score)
92
+
93
+ result = RiskScore(
94
+ score=round(final_score, 1),
95
+ level=level,
96
+ components={
97
+ "base_score": round(base_score, 1),
98
+ "multi_category_penalty": round(multi_cat_penalty, 1),
99
+ "repeat_offender_boost": round(repeat_boost, 1),
100
+ },
101
+ repeat_offender=is_repeat,
102
+ )
103
+
104
+ logger.info(
105
+ "risk_scored",
106
+ score=result.score,
107
+ level=result.level,
108
+ components=result.components,
109
+ repeat_offender=is_repeat,
110
+ )
111
+ return result
112
+
113
+ def _compute_base_score(self, result: FilterResult) -> float:
114
+ """
115
+ Compute base score from model predictions.
116
+
117
+ Uses weighted sum of flagged category scores.
118
+ """
119
+ if not result.is_flagged:
120
+ # Even unflagged content gets a small score based on max prediction
121
+ return result.max_score * 20 # Scale 0-1 → 0-20
122
+
123
+ # Weighted sum of flagged category scores
124
+ weighted_sum = 0.0
125
+ weight_total = 0.0
126
+
127
+ for category, score in result.scores.items():
128
+ weight = self.CATEGORY_WEIGHTS.get(category.lower(), 0.5)
129
+ weighted_sum += score * weight * 100
130
+ weight_total += weight
131
+
132
+ if weight_total > 0:
133
+ return weighted_sum / weight_total
134
+ return result.max_score * 60
135
+
136
+ def _multi_category_penalty(self, result: FilterResult) -> float:
137
+ """Add penalty when multiple harmful categories are detected."""
138
+ num_categories = len(result.categories)
139
+ if num_categories <= 1:
140
+ return 0.0
141
+ # Each additional category adds 5 points
142
+ return (num_categories - 1) * 5.0
143
+
144
+ def _repeat_offender_boost(self, user_history: dict | None) -> tuple[float, bool]:
145
+ """Boost score for users with violation history."""
146
+ if not user_history:
147
+ return 0.0, False
148
+
149
+ total_violations = user_history.get("total_violations", 0)
150
+ is_repeat = total_violations >= self.REPEAT_OFFENDER_VIOLATIONS
151
+
152
+ if is_repeat:
153
+ return self.REPEAT_OFFENDER_BOOST, True
154
+ elif total_violations > 0:
155
+ # Smaller boost for users with some history
156
+ return total_violations * 3.0, False
157
+ return 0.0, False
158
+
159
+ def _classify_level(self, score: float) -> str:
160
+ """Map numeric score to risk level."""
161
+ if score <= self.settings.risk_low_max:
162
+ return "LOW"
163
+ elif score <= self.settings.risk_medium_max:
164
+ return "MEDIUM"
165
+ else:
166
+ return "HIGH"
app/pipeline/workflow.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/pipeline/workflow.py
2
+ # LangGraph state machine — orchestrates the full moderation pipeline
3
+
4
+ from __future__ import annotations
5
+ from typing import Any, TypedDict, Literal
6
+ from dataclasses import asdict
7
+
8
+ from langgraph.graph import StateGraph, END
9
+
10
+ from app.pipeline.preprocessor import (
11
+ Preprocessor,
12
+ ProcessedText,
13
+ ProcessedImage,
14
+ ProcessedVideo,
15
+ )
16
+ from app.pipeline.fast_filter import FastFilter, FilterResult
17
+ from app.pipeline.risk_scorer import RiskScorer, RiskScore
18
+ from app.pipeline.deep_analyzer import DeepAnalyzer, DeepAnalysisResult
19
+ from app.pipeline.decision_engine import DecisionEngine, Decision
20
+ from app.services.mongo_service import mongo_service
21
+ from app.services.redis_service import redis_service
22
+ from app.observability.logging import get_logger
23
+
24
+ logger = get_logger(__name__)
25
+
26
+
27
+ # ──────────────────────────────────────────────
28
+ # Pipeline State Schema
29
+ # ──────────────────────────────────────────────
30
+
31
+ class PipelineState(TypedDict, total=False):
32
+ """State that flows through the LangGraph pipeline."""
33
+ # Input
34
+ input_type: str # "text", "image", "video"
35
+ raw_content: Any # str for text, bytes for image/video
36
+ user_id: str | None
37
+
38
+ # Preprocessed
39
+ processed_text: ProcessedText | None
40
+ processed_image: ProcessedImage | None
41
+ processed_video: ProcessedVideo | None
42
+
43
+ # Pipeline stages
44
+ filter_result: FilterResult | None
45
+ filter_results: list[FilterResult] # For video (multiple frames)
46
+ risk_score: RiskScore | None
47
+ deep_result: DeepAnalysisResult | None
48
+ decision: Decision | None
49
+
50
+ # Context
51
+ user_history: dict | None
52
+
53
+ # Metadata
54
+ error: str | None
55
+
56
+
57
+ # ──────────────────────────────────────────────
58
+ # Pipeline Node Functions
59
+ # ──────────────────────────────────────────────
60
+
61
+ preprocessor = Preprocessor()
62
+ fast_filter = FastFilter()
63
+ risk_scorer = RiskScorer()
64
+ deep_analyzer = DeepAnalyzer()
65
+ decision_engine = DecisionEngine()
66
+
67
+
68
+ async def preprocess_node(state: PipelineState) -> dict:
69
+ """Node 1: Preprocess the raw input."""
70
+ input_type = state["input_type"]
71
+ raw = state["raw_content"]
72
+
73
+ try:
74
+ if input_type == "text":
75
+ processed = preprocessor.process_text(raw)
76
+ return {"processed_text": processed}
77
+
78
+ elif input_type == "image":
79
+ processed = preprocessor.process_image(raw)
80
+ return {"processed_image": processed}
81
+
82
+ elif input_type == "video":
83
+ processed = preprocessor.process_video(raw)
84
+ return {"processed_video": processed}
85
+
86
+ else:
87
+ return {"error": f"Unknown input type: {input_type}"}
88
+
89
+ except Exception as e:
90
+ logger.error("preprocess_failed", error=str(e))
91
+ return {"error": f"Preprocessing failed: {str(e)}"}
92
+
93
+
94
+ async def fetch_user_history_node(state: PipelineState) -> dict:
95
+ """Node 1b: Fetch user moderation history (parallel with preprocess)."""
96
+ user_id = state.get("user_id")
97
+ if not user_id:
98
+ return {"user_history": None}
99
+
100
+ # Try Redis cache first
101
+ cached = await redis_service.get_user_history(user_id)
102
+ if cached:
103
+ return {"user_history": cached}
104
+
105
+ # Fall back to MongoDB
106
+ history = await mongo_service.get_user_history(user_id)
107
+ if history:
108
+ await redis_service.cache_user_history(user_id, history)
109
+ return {"user_history": history}
110
+
111
+
112
+ async def fast_filter_node(state: PipelineState) -> dict:
113
+ """Node 2: Run fast AI filter."""
114
+ input_type = state["input_type"]
115
+
116
+ try:
117
+ if input_type == "text" and state.get("processed_text"):
118
+ result = fast_filter.filter_text(state["processed_text"])
119
+ return {"filter_result": result}
120
+
121
+ elif input_type == "image" and state.get("processed_image"):
122
+ result = fast_filter.filter_image(state["processed_image"])
123
+ return {"filter_result": result}
124
+
125
+ elif input_type == "video" and state.get("processed_video"):
126
+ # Analyze each frame, take the worst result
127
+ video = state["processed_video"]
128
+ frame_results = []
129
+ for frame in video.frames:
130
+ result = fast_filter.filter_image(frame)
131
+ frame_results.append(result)
132
+
133
+ # Use the highest-risk frame as the representative result
134
+ if frame_results:
135
+ worst = max(frame_results, key=lambda r: r.max_score)
136
+ return {
137
+ "filter_result": worst,
138
+ "filter_results": frame_results,
139
+ }
140
+ else:
141
+ return {
142
+ "filter_result": FilterResult(
143
+ input_type="video",
144
+ is_flagged=False,
145
+ max_score=0.0,
146
+ )
147
+ }
148
+
149
+ return {"error": "No processed content available for filtering"}
150
+
151
+ except Exception as e:
152
+ logger.error("fast_filter_failed", error=str(e))
153
+ return {"error": f"Fast filter failed: {str(e)}"}
154
+
155
+
156
+ async def risk_score_node(state: PipelineState) -> dict:
157
+ """Node 3: Compute composite risk score."""
158
+ filter_result = state.get("filter_result")
159
+ if not filter_result:
160
+ return {"error": "No filter result to score"}
161
+
162
+ try:
163
+ user_history = state.get("user_history")
164
+ score = risk_scorer.score(filter_result, user_history)
165
+ return {"risk_score": score}
166
+ except Exception as e:
167
+ logger.error("risk_score_failed", error=str(e))
168
+ return {"error": f"Risk scoring failed: {str(e)}"}
169
+
170
+
171
+ def route_by_risk(state: PipelineState) -> str:
172
+ """
173
+ Conditional router: decides whether to do deep analysis or skip to decision.
174
+
175
+ - LOW / MEDIUM → skip directly to decision
176
+ - HIGH → go to deep analysis
177
+ """
178
+ risk = state.get("risk_score")
179
+ if risk and risk.level == "HIGH":
180
+ return "deep_analysis"
181
+ return "decide"
182
+
183
+
184
+ async def deep_analysis_node(state: PipelineState) -> dict:
185
+ """Node 4 (conditional): Deep analysis with CLIP + Gemini."""
186
+ input_type = state["input_type"]
187
+ filter_result = state.get("filter_result")
188
+
189
+ try:
190
+ if input_type == "text" and state.get("processed_text"):
191
+ result = await deep_analyzer.analyze_text(
192
+ state["processed_text"].cleaned,
193
+ filter_result,
194
+ )
195
+ return {"deep_result": result}
196
+
197
+ elif input_type in ("image", "video") and state.get("processed_image"):
198
+ result = await deep_analyzer.analyze_image(
199
+ state["processed_image"].image,
200
+ filter_result,
201
+ )
202
+ return {"deep_result": result}
203
+
204
+ elif input_type == "video" and state.get("processed_video"):
205
+ # Use the worst frame for deep analysis
206
+ video = state["processed_video"]
207
+ if video.frames:
208
+ # Find the worst frame based on filter_results
209
+ worst_frame = video.frames[0]
210
+ filter_results = state.get("filter_results", [])
211
+ if filter_results:
212
+ worst_idx = max(
213
+ range(len(filter_results)),
214
+ key=lambda i: filter_results[i].max_score,
215
+ )
216
+ if worst_idx < len(video.frames):
217
+ worst_frame = video.frames[worst_idx]
218
+
219
+ result = await deep_analyzer.analyze_image(
220
+ worst_frame.image,
221
+ filter_result,
222
+ )
223
+ return {"deep_result": result}
224
+
225
+ return {"deep_result": None}
226
+
227
+ except Exception as e:
228
+ logger.error("deep_analysis_failed", error=str(e))
229
+ return {"deep_result": None}
230
+
231
+
232
+ async def decision_node(state: PipelineState) -> dict:
233
+ """Node 5: Final decision."""
234
+ risk = state.get("risk_score")
235
+ if not risk:
236
+ # Emergency fallback
237
+ return {
238
+ "decision": Decision(
239
+ action="WARNING",
240
+ reason="Pipeline error: no risk score available",
241
+ severity="medium",
242
+ )
243
+ }
244
+
245
+ try:
246
+ deep_result = state.get("deep_result")
247
+ user_history = state.get("user_history")
248
+ decision = decision_engine.decide(risk, deep_result, user_history)
249
+ return {"decision": decision}
250
+ except Exception as e:
251
+ logger.error("decision_failed", error=str(e))
252
+ return {
253
+ "decision": Decision(
254
+ action="WARNING",
255
+ reason=f"Decision engine error: {str(e)}",
256
+ severity="medium",
257
+ )
258
+ }
259
+
260
+
261
+ # ──────────────────────────────────────────────
262
+ # Build the LangGraph Workflow
263
+ # ──────────────────────────────────────────────
264
+
265
+ def build_moderation_workflow():
266
+ """
267
+ Construct and compile the LangGraph moderation pipeline.
268
+
269
+ Flow:
270
+ preprocess → fast_filter → risk_score
271
+ ├─ LOW/MEDIUM → decide
272
+ └─ HIGH → deep_analysis → decide
273
+
274
+ Returns:
275
+ Compiled LangGraph workflow.
276
+ """
277
+ graph = StateGraph(PipelineState)
278
+
279
+ # Add nodes
280
+ graph.add_node("preprocess", preprocess_node)
281
+ graph.add_node("fetch_history", fetch_user_history_node)
282
+ graph.add_node("fast_filter", fast_filter_node)
283
+ graph.add_node("risk_score", risk_score_node)
284
+ graph.add_node("deep_analysis", deep_analysis_node)
285
+ graph.add_node("decide", decision_node)
286
+
287
+ # Define edges
288
+ graph.set_entry_point("preprocess")
289
+
290
+ # After preprocess, run fast filter
291
+ graph.add_edge("preprocess", "fast_filter")
292
+
293
+ # After fast filter, compute risk score
294
+ graph.add_edge("fast_filter", "risk_score")
295
+
296
+ # Conditional routing based on risk level
297
+ graph.add_conditional_edges(
298
+ "risk_score",
299
+ route_by_risk,
300
+ {
301
+ "deep_analysis": "deep_analysis",
302
+ "decide": "decide",
303
+ },
304
+ )
305
+
306
+ # Deep analysis flows to decision
307
+ graph.add_edge("deep_analysis", "decide")
308
+
309
+ # Decision is the terminal node
310
+ graph.add_edge("decide", END)
311
+
312
+ # Compile
313
+ workflow = graph.compile()
314
+ logger.info("moderation_workflow_compiled")
315
+ return workflow
316
+
317
+
318
+ # Global compiled workflow (initialized at startup)
319
+ moderation_workflow = None
320
+
321
+
322
+ def get_workflow():
323
+ """Get or create the compiled moderation workflow."""
324
+ global moderation_workflow
325
+ if moderation_workflow is None:
326
+ moderation_workflow = build_moderation_workflow()
327
+ return moderation_workflow
app/services/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # app/services/__init__.py
2
+ """External service integrations: MongoDB, Redis, Gemini."""
app/services/gemini_service.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/services/gemini_service.py
2
+ # Gemini API integration via LangChain for deep analysis reasoning
3
+
4
+ from app.config import get_settings
5
+ from app.observability.logging import get_logger
6
+
7
+ logger = get_logger(__name__)
8
+
9
+
10
+ class GeminiService:
11
+ """
12
+ Google Gemini API client powered by LangChain.
13
+
14
+ Used in the deep analysis path for contextual reasoning
15
+ about flagged content. Replaces the old raw REST calls
16
+ with structured LangChain invocations for:
17
+ - Reliable structured output (JSON)
18
+ - Automatic retry and fallback
19
+ - LangSmith trace integration
20
+ """
21
+
22
+ def __init__(self):
23
+ self.settings = get_settings()
24
+ self.llm = None
25
+ self._current_key_idx = 0
26
+ self._initialized = False
27
+
28
+ def initialize(self) -> None:
29
+ """Initialize the LangChain Gemini client."""
30
+ keys = self.settings.gemini_keys_list
31
+ if not keys:
32
+ logger.warning("gemini_no_keys", reason="No API keys configured")
33
+ return
34
+
35
+ try:
36
+ from langchain_google_genai import ChatGoogleGenerativeAI
37
+
38
+ self.llm = ChatGoogleGenerativeAI(
39
+ model=self.settings.gemini_model,
40
+ google_api_key=keys[self._current_key_idx],
41
+ temperature=0.1,
42
+ max_output_tokens=1024,
43
+ convert_system_message_to_human=True,
44
+ )
45
+ self._initialized = True
46
+ logger.info("gemini_initialized", model=self.settings.gemini_model)
47
+ except Exception as e:
48
+ logger.error("gemini_init_failed", error=str(e))
49
+
50
+ def _rotate_key(self) -> bool:
51
+ """Rotate to the next API key. Returns False if all keys exhausted."""
52
+ keys = self.settings.gemini_keys_list
53
+ if not keys:
54
+ return False
55
+
56
+ self._current_key_idx = (self._current_key_idx + 1) % len(keys)
57
+ try:
58
+ from langchain_google_genai import ChatGoogleGenerativeAI
59
+
60
+ self.llm = ChatGoogleGenerativeAI(
61
+ model=self.settings.gemini_model,
62
+ google_api_key=keys[self._current_key_idx],
63
+ temperature=0.1,
64
+ max_output_tokens=1024,
65
+ convert_system_message_to_human=True,
66
+ )
67
+ logger.info("gemini_key_rotated", key_index=self._current_key_idx)
68
+ return True
69
+ except Exception:
70
+ return False
71
+
72
+ async def analyze_text(self, text: str, context: dict | None = None) -> dict:
73
+ """
74
+ Perform deep contextual analysis of flagged text.
75
+
76
+ Args:
77
+ text: The flagged text content.
78
+ context: Additional context (fast filter results, categories, etc.).
79
+
80
+ Returns:
81
+ Structured analysis with verdict, reasoning, and severity.
82
+ """
83
+ if not self._initialized:
84
+ return {"error": "Gemini not initialized", "is_confirmed": False}
85
+
86
+ from langchain_core.messages import SystemMessage, HumanMessage
87
+
88
+ system_prompt = """You are an expert content moderator specializing in cyberbullying detection.
89
+ You are analyzing content that has been flagged as potentially harmful by automated filters.
90
+
91
+ Your task: Provide a detailed, contextual analysis. Consider:
92
+ - Intent and context (sarcasm, jokes, genuine threats)
93
+ - Severity level (mild rudeness vs. serious threats)
94
+ - Whether this constitutes cyberbullying
95
+ - Impact on minors (under 18)
96
+
97
+ Respond in this exact JSON format:
98
+ {
99
+ "is_confirmed": true/false,
100
+ "severity": "low" | "medium" | "high" | "critical",
101
+ "categories": ["category1", "category2"],
102
+ "reasoning": "Detailed explanation of your analysis",
103
+ "recommended_action": "allow" | "warn" | "block" | "escalate",
104
+ "confidence": 0.0-1.0
105
+ }"""
106
+
107
+ context_str = ""
108
+ if context:
109
+ context_str = f"\n\nPre-filter context: {context}"
110
+
111
+ human_message = f"""Analyze this flagged content:{context_str}
112
+
113
+ Content: "{text}"
114
+
115
+ Provide your JSON analysis:"""
116
+
117
+ return await self._invoke(system_prompt, human_message)
118
+
119
+ async def analyze_image(self, image_base64: str, context: dict | None = None) -> dict:
120
+ """
121
+ Perform deep analysis of a flagged image.
122
+
123
+ Args:
124
+ image_base64: Base64-encoded image.
125
+ context: Additional context from fast filter.
126
+
127
+ Returns:
128
+ Structured analysis dict.
129
+ """
130
+ if not self._initialized:
131
+ return {"error": "Gemini not initialized", "is_confirmed": False}
132
+
133
+ from langchain_core.messages import SystemMessage, HumanMessage
134
+
135
+ system_prompt = """You are an expert content moderator analyzing images for content harmful to minors.
136
+
137
+ Analyze the image for:
138
+ - Violence, gore, weapons
139
+ - Nudity, sexual content
140
+ - Drug/alcohol imagery
141
+ - Self-harm or suicide content
142
+ - Hate symbols, extremist content
143
+ - Cyberbullying imagery (humiliating photos, etc.)
144
+
145
+ Respond in this exact JSON format:
146
+ {
147
+ "is_confirmed": true/false,
148
+ "severity": "low" | "medium" | "high" | "critical",
149
+ "categories": ["category1", "category2"],
150
+ "reasoning": "Description of what was found",
151
+ "recommended_action": "allow" | "warn" | "block" | "escalate",
152
+ "confidence": 0.0-1.0
153
+ }"""
154
+
155
+ context_str = f"\nPre-filter flags: {context}" if context else ""
156
+
157
+ human_content = [
158
+ {"type": "text", "text": f"Analyze this flagged image:{context_str}\n\nProvide your JSON analysis:"},
159
+ {
160
+ "type": "image_url",
161
+ "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"},
162
+ },
163
+ ]
164
+
165
+ return await self._invoke_multimodal(system_prompt, human_content)
166
+
167
+ async def _invoke(self, system_prompt: str, human_message: str) -> dict:
168
+ """Invoke Gemini with retry on rate limits."""
169
+ from langchain_core.messages import SystemMessage, HumanMessage
170
+ import json
171
+
172
+ keys = self.settings.gemini_keys_list
173
+ attempts = max(len(keys), 1)
174
+
175
+ for attempt in range(attempts):
176
+ try:
177
+ messages = [
178
+ SystemMessage(content=system_prompt),
179
+ HumanMessage(content=human_message),
180
+ ]
181
+ response = await self.llm.ainvoke(messages)
182
+ return self._parse_response(response.content)
183
+
184
+ except Exception as e:
185
+ error_str = str(e)
186
+ if "429" in error_str or "quota" in error_str.lower():
187
+ logger.warning("gemini_rate_limited", attempt=attempt)
188
+ if not self._rotate_key():
189
+ break
190
+ else:
191
+ logger.error("gemini_invoke_failed", error=error_str)
192
+ return {"error": error_str, "is_confirmed": False}
193
+
194
+ return {"error": "All Gemini API keys exhausted", "is_confirmed": False}
195
+
196
+ async def _invoke_multimodal(self, system_prompt: str, human_content: list) -> dict:
197
+ """Invoke Gemini with multimodal content."""
198
+ from langchain_core.messages import SystemMessage, HumanMessage
199
+
200
+ keys = self.settings.gemini_keys_list
201
+ attempts = max(len(keys), 1)
202
+
203
+ for attempt in range(attempts):
204
+ try:
205
+ messages = [
206
+ SystemMessage(content=system_prompt),
207
+ HumanMessage(content=human_content),
208
+ ]
209
+ response = await self.llm.ainvoke(messages)
210
+ return self._parse_response(response.content)
211
+
212
+ except Exception as e:
213
+ error_str = str(e)
214
+ if "429" in error_str or "quota" in error_str.lower():
215
+ if not self._rotate_key():
216
+ break
217
+ else:
218
+ return {"error": error_str, "is_confirmed": False}
219
+
220
+ return {"error": "All Gemini API keys exhausted", "is_confirmed": False}
221
+
222
+ def _parse_response(self, text: str) -> dict:
223
+ """Parse JSON from Gemini response text."""
224
+ import json, re
225
+
226
+ try:
227
+ # Extract JSON block (handle markdown code fences)
228
+ json_match = re.search(r"\{[\s\S]*\}", text)
229
+ if json_match:
230
+ return json.loads(json_match.group())
231
+ except json.JSONDecodeError:
232
+ pass
233
+
234
+ logger.warning("gemini_parse_failed", raw_text=text[:200])
235
+ return {
236
+ "error": "Failed to parse Gemini response",
237
+ "is_confirmed": False,
238
+ "raw_response": text[:500],
239
+ }
240
+
241
+ @property
242
+ def is_initialized(self) -> bool:
243
+ return self._initialized
244
+
245
+
246
+ # Global singleton
247
+ gemini_service = GeminiService()