""" rag_pipeline.py — BGE-M3 embedding → Qdrant ANN → Memgraph trust scoring. Pipeline stages: 1. Embed incoming claim with BGE-M3 (BAAI/bge-m3) via FastEmbed 2. Query Qdrant HNSW index (ef=128, top-8, recency filter 72h) 3. Traverse Memgraph trust graph via Bolt to compute trust score 4. Return RagContext dataclass consumed by agents.py Why BGE-M3: - 1024-dimensional dense embeddings, multilingual (100+ languages) - Better factual recall on news content vs. OpenAI text-embedding-3 - Runs on CPU, completely free — no API calls - Supports late interaction (ColBERT) scoring in Qdrant v1.9+ Why in-memory Memgraph over Neo4j: - Pure in-memory graph store → ~100x faster Cypher for real-time scoring - Same Bolt protocol driver compatibility - Single Docker image, no disk I/O for hot queries """ from __future__ import annotations import asyncio import os import time from concurrent.futures import ProcessPoolExecutor from dataclasses import dataclass, field from datetime import datetime, timezone from typing import Any import structlog log = structlog.get_logger(__name__) # ── Lazy singletons (initialized on first use) ──────────────────────────────── _embed_model = None _qdrant_client = None _memgraph_driver = None _executor = ProcessPoolExecutor(max_workers=2) def _get_embed_model(): global _embed_model if _embed_model is None: try: from fastembed import TextEmbedding _embed_model = TextEmbedding( model_name="BAAI/bge-m3", max_length=512, # cache_dir ensures the model is downloaded once cache_dir=os.getenv("EMBED_CACHE_DIR", "/tmp/fastembed_cache"), ) log.info("embed_model.loaded", model="BAAI/bge-m3") except Exception as exc: log.warning("embed_model.unavailable", error=str(exc)) return _embed_model def _get_qdrant(): global _qdrant_client if _qdrant_client is None: try: from qdrant_client import QdrantClient from qdrant_client.models import ( Distance, VectorParams, HnswConfigDiff, OptimizersConfigDiff ) url = os.getenv("QDRANT_URL", "http://localhost:6333") _qdrant_client = QdrantClient(url=url, timeout=5) # Ensure collection exists collections = [c.name for c in _qdrant_client.get_collections().collections] if "claims" not in collections: _qdrant_client.create_collection( collection_name="claims", vectors_config=VectorParams(size=1024, distance=Distance.COSINE), hnsw_config=HnswConfigDiff(ef_construct=128, m=16), optimizers_config=OptimizersConfigDiff(indexing_threshold=1000), ) log.info("qdrant.collection_created", name="claims") log.info("qdrant.connected", url=url) except Exception as exc: log.warning("qdrant.unavailable", error=str(exc)) return _qdrant_client def _get_memgraph(): global _memgraph_driver if _memgraph_driver is None: try: import neo4j # Bolt-compatible with Memgraph host = os.getenv("MEMGRAPH_HOST", "localhost") port = int(os.getenv("MEMGRAPH_PORT", "7687")) _memgraph_driver = neo4j.GraphDatabase.driver( f"bolt://{host}:{port}", auth=( os.getenv("MEMGRAPH_USER", ""), os.getenv("MEMGRAPH_PASS", ""), ), connection_timeout=3, ) log.info("memgraph.connected", host=host) except Exception as exc: log.warning("memgraph.unavailable", error=str(exc)) return _memgraph_driver # ── Data models ─────────────────────────────────────────────────────────────── @dataclass class RetrievedDoc: text: str score: float source_url: str domain: str ingested_at: float author_verified: bool = False @dataclass class RagContext: claim_text: str claim_hash: str retrieved_docs: list[RetrievedDoc] = field(default_factory=list) trust_score: float = 0.5 community_note: bool = False corroboration_count: int = 0 has_verified_source: bool = False # ── Embedding (CPU-bound, runs in ProcessPoolExecutor) ──────────────────────── def _embed_sync(texts: list[str]) -> list[list[float]]: """Synchronous embedding — called from ProcessPoolExecutor.""" model = _get_embed_model() if model is None: # Fallback: zero vector of correct dimensionality return [[0.0] * 1024 for _ in texts] return [list(v) for v in model.embed(texts)] async def embed_texts(texts: list[str]) -> list[list[float]]: """Async wrapper: offloads CPU-bound embedding to process pool.""" loop = asyncio.get_running_loop() return await loop.run_in_executor(_executor, _embed_sync, texts) # ── Qdrant retrieval ────────────────────────────────────────────────────────── async def retrieve_from_qdrant( query_vector: list[float], top_k: int = 8, recency_hours: int = 72, ) -> list[RetrievedDoc]: """ ANN search with: - ef=128 for high recall at query time - Payload filter: ingested_at > now - 72h (keeps results recent) - Returns top_k nearest neighbors """ client = _get_qdrant() if client is None: return _mock_retrieved_docs() try: from qdrant_client.models import Filter, FieldCondition, Range cutoff_ts = time.time() - (recency_hours * 3600) results = client.search( collection_name="claims", query_vector=query_vector, limit=top_k, with_payload=True, search_params={"hnsw_ef": 128}, query_filter=Filter( must=[ FieldCondition( key="ingested_at", range=Range(gte=cutoff_ts), ) ] ), ) return [ RetrievedDoc( text=r.payload.get("text", ""), score=r.score, source_url=r.payload.get("source_url", ""), domain=r.payload.get("domain", "unknown"), ingested_at=r.payload.get("ingested_at", 0.0), author_verified=r.payload.get("author_verified", False), ) for r in results ] except Exception as exc: log.warning("qdrant.search_error", error=str(exc)) return _mock_retrieved_docs() # ── Memgraph trust scoring ──────────────────────────────────────────────────── TRUST_SCORE_CYPHER = """ MATCH (c:Claim {hash: $hash}) OPTIONAL MATCH (a:Author)-[:REPORTED]->(c) OPTIONAL MATCH (c)<-[:CORROBORATED_BY]-(s:Source) OPTIONAL MATCH (c)-[:HAS_NOTE]->(n:CommunityNote {active: true}) RETURN c.hash AS hash, collect(DISTINCT a.verified) AS author_verified_flags, collect(DISTINCT a.account_type) AS author_types, count(DISTINCT s) AS corroboration_count, count(DISTINCT n) AS active_notes """ def _compute_trust_score( author_verified_flags: list[bool], author_types: list[str], corroboration_count: int, active_notes: int, ) -> float: """ Trust score algorithm (deterministic, no LLM needed): Base: 0.50 Verified gov/news official: +0.30 Per corroborating source: +0.05 (max +0.25) Active Community Note: -0.40 Clamped to [0.0, 1.0]. """ score = 0.50 official_types = {"government", "official_news"} if any(v for v in author_verified_flags) and any( t in official_types for t in author_types ): score += 0.30 corroborations_boost = min(corroboration_count * 0.05, 0.25) score += corroborations_boost if active_notes > 0: score -= 0.40 return max(0.0, min(1.0, score)) async def get_trust_score(claim_hash: str) -> tuple[float, bool, int]: """ Query Memgraph for trust metadata. Returns: (trust_score, has_community_note, corroboration_count) """ driver = _get_memgraph() if driver is None: return 0.5, False, 0 try: loop = asyncio.get_running_loop() def _query(tx): result = tx.run(TRUST_SCORE_CYPHER, hash=claim_hash) record = result.single() if record is None: return None return dict(record) def _run_sync(): with driver.session() as session: return session.execute_read(_query) record = await loop.run_in_executor(None, _run_sync) if record is None: return 0.5, False, 0 trust = _compute_trust_score( author_verified_flags=record["author_verified_flags"] or [], author_types=record["author_types"] or [], corroboration_count=record["corroboration_count"] or 0, active_notes=record["active_notes"] or 0, ) return trust, bool(record["active_notes"]), record["corroboration_count"] or 0 except Exception as exc: log.warning("memgraph.query_error", error=str(exc)) return 0.5, False, 0 # ── Main entry point ────────────────────────────────────────────────────────── async def build_rag_context(claim_text: str, claim_hash: str) -> RagContext: """ Full RAG context assembly: 1. Embed claim → query Qdrant (concurrent with trust score fetch) 2. Retrieve Memgraph trust data 3. Assemble RagContext """ ctx = RagContext(claim_text=claim_text, claim_hash=claim_hash) # Embed + retrieve concurrently with trust score lookup embed_task = asyncio.create_task(embed_texts([claim_text])) trust_task = asyncio.create_task(get_trust_score(claim_hash)) vectors, (trust_score, has_note, corroborations) = await asyncio.gather( embed_task, trust_task ) query_vector = vectors[0] docs = await retrieve_from_qdrant(query_vector, top_k=8) ctx.retrieved_docs = docs ctx.trust_score = trust_score ctx.community_note = has_note ctx.corroboration_count = corroborations ctx.has_verified_source = any(d.author_verified for d in docs) log.debug( "rag_context.built", claim_hash=claim_hash[:8], docs=len(docs), trust_score=round(trust_score, 3), community_note=has_note, ) return ctx # ── Mock data for offline development ──────────────────────────────────────── def _mock_retrieved_docs() -> list[RetrievedDoc]: """Realistic mock documents returned when Qdrant is unavailable.""" return [ RetrievedDoc( text="Scientists publish peer-reviewed study confirming the phenomenon with 95% confidence.", score=0.87, source_url="https://reuters.com/science/study-2024", domain="reuters.com", ingested_at=time.time() - 3600, author_verified=True, ), RetrievedDoc( text="Multiple independent sources corroborate the original report.", score=0.75, source_url="https://apnews.com/article/corroboration-2024", domain="apnews.com", ingested_at=time.time() - 7200, author_verified=True, ), RetrievedDoc( text="Context and background on the related historical precedent.", score=0.61, source_url="https://bbc.com/news/context", domain="bbc.com", ingested_at=time.time() - 14400, author_verified=False, ), ]