File size: 11,010 Bytes
71c1ad2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
# app/pipeline/workflow.py
# LangGraph state machine β€” orchestrates the full moderation pipeline

from __future__ import annotations
from typing import Any, TypedDict, Literal
from dataclasses import asdict

from langgraph.graph import StateGraph, END

from app.pipeline.preprocessor import (
    Preprocessor,
    ProcessedText,
    ProcessedImage,
    ProcessedVideo,
)
from app.pipeline.fast_filter import FastFilter, FilterResult
from app.pipeline.risk_scorer import RiskScorer, RiskScore
from app.pipeline.deep_analyzer import DeepAnalyzer, DeepAnalysisResult
from app.pipeline.decision_engine import DecisionEngine, Decision
from app.services.mongo_service import mongo_service
from app.services.redis_service import redis_service
from app.observability.logging import get_logger

logger = get_logger(__name__)


# ──────────────────────────────────────────────
# Pipeline State Schema
# ──────────────────────────────────────────────

class PipelineState(TypedDict, total=False):
    """State that flows through the LangGraph pipeline."""
    # Input
    input_type: str  # "text", "image", "video"
    raw_content: Any  # str for text, bytes for image/video
    user_id: str | None

    # Preprocessed
    processed_text: ProcessedText | None
    processed_image: ProcessedImage | None
    processed_video: ProcessedVideo | None

    # Pipeline stages
    filter_result: FilterResult | None
    filter_results: list[FilterResult]  # For video (multiple frames)
    risk_score: RiskScore | None
    deep_result: DeepAnalysisResult | None
    decision: Decision | None

    # Context
    user_history: dict | None

    # Metadata
    error: str | None


# ──────────────────────────────────────────────
# Pipeline Node Functions
# ──────────────────────────────────────────────

preprocessor = Preprocessor()
fast_filter = FastFilter()
risk_scorer = RiskScorer()
deep_analyzer = DeepAnalyzer()
decision_engine = DecisionEngine()


async def preprocess_node(state: PipelineState) -> dict:
    """Node 1: Preprocess the raw input."""
    input_type = state["input_type"]
    raw = state["raw_content"]

    try:
        if input_type == "text":
            processed = preprocessor.process_text(raw)
            return {"processed_text": processed}

        elif input_type == "image":
            processed = preprocessor.process_image(raw)
            return {"processed_image": processed}

        elif input_type == "video":
            processed = preprocessor.process_video(raw)
            return {"processed_video": processed}

        else:
            return {"error": f"Unknown input type: {input_type}"}

    except Exception as e:
        logger.error("preprocess_failed", error=str(e))
        return {"error": f"Preprocessing failed: {str(e)}"}


async def fetch_user_history_node(state: PipelineState) -> dict:
    """Node 1b: Fetch user moderation history (parallel with preprocess)."""
    user_id = state.get("user_id")
    if not user_id:
        return {"user_history": None}

    # Try Redis cache first
    cached = await redis_service.get_user_history(user_id)
    if cached:
        return {"user_history": cached}

    # Fall back to MongoDB
    history = await mongo_service.get_user_history(user_id)
    if history:
        await redis_service.cache_user_history(user_id, history)
    return {"user_history": history}


async def fast_filter_node(state: PipelineState) -> dict:
    """Node 2: Run fast AI filter."""
    input_type = state["input_type"]

    try:
        if input_type == "text" and state.get("processed_text"):
            result = fast_filter.filter_text(state["processed_text"])
            return {"filter_result": result}

        elif input_type == "image" and state.get("processed_image"):
            result = fast_filter.filter_image(state["processed_image"])
            return {"filter_result": result}

        elif input_type == "video" and state.get("processed_video"):
            # Analyze each frame, take the worst result
            video = state["processed_video"]
            frame_results = []
            for frame in video.frames:
                result = fast_filter.filter_image(frame)
                frame_results.append(result)

            # Use the highest-risk frame as the representative result
            if frame_results:
                worst = max(frame_results, key=lambda r: r.max_score)
                return {
                    "filter_result": worst,
                    "filter_results": frame_results,
                }
            else:
                return {
                    "filter_result": FilterResult(
                        input_type="video",
                        is_flagged=False,
                        max_score=0.0,
                    )
                }

        return {"error": "No processed content available for filtering"}

    except Exception as e:
        logger.error("fast_filter_failed", error=str(e))
        return {"error": f"Fast filter failed: {str(e)}"}


async def risk_score_node(state: PipelineState) -> dict:
    """Node 3: Compute composite risk score."""
    filter_result = state.get("filter_result")
    if not filter_result:
        return {"error": "No filter result to score"}

    try:
        user_history = state.get("user_history")
        score = risk_scorer.score(filter_result, user_history)
        return {"risk_score": score}
    except Exception as e:
        logger.error("risk_score_failed", error=str(e))
        return {"error": f"Risk scoring failed: {str(e)}"}


def route_by_risk(state: PipelineState) -> str:
    """
    Conditional router: decides whether to do deep analysis or skip to decision.

    - LOW / MEDIUM β†’ skip directly to decision
    - HIGH β†’ go to deep analysis
    """
    risk = state.get("risk_score")
    if risk and risk.level == "HIGH":
        return "deep_analysis"
    return "decide"


async def deep_analysis_node(state: PipelineState) -> dict:
    """Node 4 (conditional): Deep analysis with CLIP + Gemini."""
    input_type = state["input_type"]
    filter_result = state.get("filter_result")

    try:
        if input_type == "text" and state.get("processed_text"):
            result = await deep_analyzer.analyze_text(
                state["processed_text"].cleaned,
                filter_result,
            )
            return {"deep_result": result}

        elif input_type in ("image", "video") and state.get("processed_image"):
            result = await deep_analyzer.analyze_image(
                state["processed_image"].image,
                filter_result,
            )
            return {"deep_result": result}

        elif input_type == "video" and state.get("processed_video"):
            # Use the worst frame for deep analysis
            video = state["processed_video"]
            if video.frames:
                # Find the worst frame based on filter_results
                worst_frame = video.frames[0]
                filter_results = state.get("filter_results", [])
                if filter_results:
                    worst_idx = max(
                        range(len(filter_results)),
                        key=lambda i: filter_results[i].max_score,
                    )
                    if worst_idx < len(video.frames):
                        worst_frame = video.frames[worst_idx]

                result = await deep_analyzer.analyze_image(
                    worst_frame.image,
                    filter_result,
                )
                return {"deep_result": result}

        return {"deep_result": None}

    except Exception as e:
        logger.error("deep_analysis_failed", error=str(e))
        return {"deep_result": None}


async def decision_node(state: PipelineState) -> dict:
    """Node 5: Final decision."""
    risk = state.get("risk_score")
    if not risk:
        # Emergency fallback
        return {
            "decision": Decision(
                action="WARNING",
                reason="Pipeline error: no risk score available",
                severity="medium",
            )
        }

    try:
        deep_result = state.get("deep_result")
        user_history = state.get("user_history")
        decision = decision_engine.decide(risk, deep_result, user_history)
        return {"decision": decision}
    except Exception as e:
        logger.error("decision_failed", error=str(e))
        return {
            "decision": Decision(
                action="WARNING",
                reason=f"Decision engine error: {str(e)}",
                severity="medium",
            )
        }


# ──────────────────────────────────────────────
# Build the LangGraph Workflow
# ──────────────────────────────────────────────

def build_moderation_workflow():
    """
    Construct and compile the LangGraph moderation pipeline.

    Flow:
        preprocess β†’ fast_filter β†’ risk_score
            β”œβ”€ LOW/MEDIUM β†’ decide
            └─ HIGH β†’ deep_analysis β†’ decide

    Returns:
        Compiled LangGraph workflow.
    """
    graph = StateGraph(PipelineState)

    # Add nodes
    graph.add_node("preprocess", preprocess_node)
    graph.add_node("fetch_history", fetch_user_history_node)
    graph.add_node("fast_filter", fast_filter_node)
    graph.add_node("risk_score", risk_score_node)
    graph.add_node("deep_analysis", deep_analysis_node)
    graph.add_node("decide", decision_node)

    # Define edges
    graph.set_entry_point("preprocess")

    # After preprocess, run fast filter
    graph.add_edge("preprocess", "fast_filter")

    # After fast filter, compute risk score
    graph.add_edge("fast_filter", "risk_score")

    # Conditional routing based on risk level
    graph.add_conditional_edges(
        "risk_score",
        route_by_risk,
        {
            "deep_analysis": "deep_analysis",
            "decide": "decide",
        },
    )

    # Deep analysis flows to decision
    graph.add_edge("deep_analysis", "decide")

    # Decision is the terminal node
    graph.add_edge("decide", END)

    # Compile
    workflow = graph.compile()
    logger.info("moderation_workflow_compiled")
    return workflow


# Global compiled workflow (initialized at startup)
moderation_workflow = None


def get_workflow():
    """Get or create the compiled moderation workflow."""
    global moderation_workflow
    if moderation_workflow is None:
        moderation_workflow = build_moderation_workflow()
    return moderation_workflow