muthuk1's picture
Add GraphRAG novelties engine: PPR scoring, token budget, path pruning, spreading activation, hybrid router, incremental updates
4dd67cd verified
"""
GraphRAG Novelties Engine
==========================
Six cutting-edge techniques from 2024-2025 GraphRAG literature,
implemented as modular components that plug into the orchestration layer.
1. PPR Confidence-Weighted Retrieval (CatRAG, 2602.01965)
2. Graph Token Budget Controller (TERAG, 2509.18667)
3. Flow-Pruned Path Serializer (PathRAG, 2502.14902)
4. Spreading Activation Context Scorer (SA-RAG, 2512.15922)
5. PolyG Hybrid Retrieval Router (RAGRouter-Bench, 2602.00296)
6. Incremental Graph Updater (TG-RAG, 2510.13590)
"""
import logging
import math
import re
from collections import defaultdict
from typing import Any, Dict, List, Optional, Set, Tuple
logger = logging.getLogger(__name__)
# ═══════════════════════════════════════════════════════════
# 1. PPR CONFIDENCE-WEIGHTED RETRIEVAL
# Paper: CatRAG (2602.01965), HippoRAG 2 (2502.14802)
# Key idea: Personalized PageRank from seed entities with
# query-aware dynamic edge weights. PPR score = confidence.
# ═══════════════════════════════════════════════════════════
class PPRConfidenceScorer:
"""
Runs Personalized PageRank from query-matched seed entities.
Each node's PPR score becomes its confidence weight for context ranking.
Formula (TG-RAG, 2510.13590):
s(edge) = s(v1) + s(v2) # sum of incident node PPR scores
s(chunk) = w(c) Γ— Ξ£ s(edges_in_c) # weighted by query-chunk similarity
"""
def __init__(self, damping: float = 0.85, max_iterations: int = 20,
convergence: float = 1e-6):
self.damping = damping
self.max_iter = max_iterations
self.convergence = convergence
def compute_ppr(
self,
adjacency: Dict[str, List[Tuple[str, float]]], # node β†’ [(neighbor, weight)]
seed_nodes: List[str],
seed_weights: Optional[Dict[str, float]] = None,
) -> Dict[str, float]:
"""
Compute Personalized PageRank scores.
Returns: {node_id: ppr_score}
"""
all_nodes = set(adjacency.keys())
for neighbors in adjacency.values():
for n, _ in neighbors:
all_nodes.add(n)
n = len(all_nodes)
if n == 0:
return {}
# Personalization vector (seed distribution)
personalization: Dict[str, float] = {}
if seed_weights:
total = sum(seed_weights.values())
personalization = {k: v / total for k, v in seed_weights.items()} if total > 0 else {}
else:
for s in seed_nodes:
personalization[s] = 1.0 / len(seed_nodes) if seed_nodes else 0
# Initialize scores
scores = {node: 1.0 / n for node in all_nodes}
# Power iteration
for iteration in range(self.max_iter):
new_scores: Dict[str, float] = {}
for node in all_nodes:
# Teleport component
teleport = personalization.get(node, 0.0)
# Random walk component
walk_score = 0.0
for source, neighbors in adjacency.items():
out_weight = sum(w for _, w in neighbors)
if out_weight > 0:
for target, weight in neighbors:
if target == node:
walk_score += scores[source] * (weight / out_weight)
new_scores[node] = (1 - self.damping) * teleport + self.damping * walk_score
# Check convergence
diff = sum(abs(new_scores[n] - scores[n]) for n in all_nodes)
scores = new_scores
if diff < self.convergence:
logger.debug(f"PPR converged at iteration {iteration + 1}")
break
return scores
def score_contexts(
self,
ppr_scores: Dict[str, float],
entity_to_chunks: Dict[str, List[str]], # entity_id β†’ [chunk_ids]
chunk_texts: Dict[str, str],
query_similarity: Optional[Dict[str, float]] = None, # chunk_id β†’ sim score
) -> List[Tuple[str, str, float]]:
"""
Score and rank chunks using PPR scores of their entities.
Returns: [(chunk_id, text, confidence_score)] sorted by score desc.
"""
chunk_scores: Dict[str, float] = defaultdict(float)
for entity_id, chunk_ids in entity_to_chunks.items():
entity_ppr = ppr_scores.get(entity_id, 0.0)
for cid in chunk_ids:
chunk_scores[cid] += entity_ppr
# Multiply by query-chunk similarity if available
if query_similarity:
for cid in chunk_scores:
sim = query_similarity.get(cid, 0.5)
chunk_scores[cid] *= (1 + sim)
results = [
(cid, chunk_texts.get(cid, ""), score)
for cid, score in chunk_scores.items()
if cid in chunk_texts
]
results.sort(key=lambda x: x[2], reverse=True)
return results
# ═══════════════════════════════════════════════════════════
# 2. GRAPH TOKEN BUDGET CONTROLLER
# Paper: TERAG (2509.18667)
# Key idea: Cap context by token budget. Prioritize by
# concept_frequency Γ— semantic_relevance. 97% reduction possible.
# ═══════════════════════════════════════════════════════════
class TokenBudgetController:
"""
Controls the token budget for graph-retrieved context.
Prioritizes high-value content within a fixed token limit.
TERAG insight: 3-11% of LightRAG's token cost retains 80%+ accuracy.
"""
def __init__(self, max_tokens: int = 2000, chars_per_token: float = 4.0):
self.max_tokens = max_tokens
self.chars_per_token = chars_per_token
def estimate_tokens(self, text: str) -> int:
return max(1, int(len(text) / self.chars_per_token))
def prune_context(
self,
scored_items: List[Tuple[str, float]], # [(text, score)]
budget: Optional[int] = None,
) -> Tuple[List[str], Dict[str, Any]]:
"""
Select highest-scored items within token budget.
Returns: (selected_texts, stats)
"""
limit = budget or self.max_tokens
selected: List[str] = []
total_tokens = 0
total_available = sum(self.estimate_tokens(t) for t, _ in scored_items)
items_considered = 0
# Sort by score descending
sorted_items = sorted(scored_items, key=lambda x: x[1], reverse=True)
for text, score in sorted_items:
tokens = self.estimate_tokens(text)
items_considered += 1
if total_tokens + tokens <= limit:
selected.append(text)
total_tokens += tokens
elif total_tokens == 0:
# At least include one item (truncated)
truncated = text[:int(limit * self.chars_per_token)]
selected.append(truncated)
total_tokens = limit
break
stats = {
"budget_tokens": limit,
"used_tokens": total_tokens,
"utilization_pct": round(total_tokens / limit * 100, 1) if limit > 0 else 0,
"items_selected": len(selected),
"items_available": len(scored_items),
"tokens_saved": total_available - total_tokens,
"reduction_pct": round((1 - total_tokens / max(total_available, 1)) * 100, 1),
}
return selected, stats
# ═══════════════════════════════════════════════════════════
# 3. FLOW-PRUNED PATH SERIALIZER
# Paper: PathRAG (2502.14902)
# Key idea: Retrieve key relational paths between entities,
# prune low-flow paths, serialize for LLM consumption.
# Exploits "lost-in-the-middle" by placing best paths first.
# ═══════════════════════════════════════════════════════════
class PathPruner:
"""
Extracts and prunes reasoning paths between entities.
High-reliability paths placed FIRST in context (recency bias exploit).
PathRAG result: 62-65% win rate vs LightRAG in comprehensiveness.
"""
def find_paths(
self,
adjacency: Dict[str, List[Tuple[str, str, float]]], # node β†’ [(neighbor, relation, weight)]
source: str,
target: str,
max_depth: int = 3,
max_paths: int = 5,
) -> List[List[Tuple[str, str, str]]]:
"""
Find top paths between source and target using DFS.
Returns: [[(entity, relation, next_entity), ...], ...]
"""
paths: List[List[Tuple[str, str, str]]] = []
def dfs(current: str, target: str, path: List[Tuple[str, str, str]],
visited: Set[str], depth: int):
if depth > max_depth or len(paths) >= max_paths * 3:
return
if current == target and path:
paths.append(list(path))
return
visited.add(current)
for neighbor, relation, weight in adjacency.get(current, []):
if neighbor not in visited:
path.append((current, relation, neighbor))
dfs(neighbor, target, path, visited, depth + 1)
path.pop()
visited.discard(current)
dfs(source, target, [], set(), 0)
return paths
def score_and_prune(
self,
paths: List[List[Tuple[str, str, str]]],
edge_weights: Dict[Tuple[str, str], float],
threshold: float = 0.1,
) -> List[Tuple[List[Tuple[str, str, str]], float]]:
"""
Score paths by accumulated edge weight, prune below threshold.
Returns: [(path, score)] sorted by score desc.
"""
scored = []
for path in paths:
score = 1.0
for src, rel, tgt in path:
w = edge_weights.get((src, tgt), edge_weights.get((tgt, src), 0.5))
score *= w
if score >= threshold:
scored.append((path, score))
scored.sort(key=lambda x: x[1], reverse=True)
return scored
def serialize_paths(
self,
scored_paths: List[Tuple[List[Tuple[str, str, str]], float]],
max_paths: int = 5,
) -> str:
"""
Serialize paths into LLM-friendly text.
HIGH-reliability paths placed FIRST (exploits lost-in-the-middle bias).
"""
lines = ["### Reasoning Paths (ranked by reliability):"]
for i, (path, score) in enumerate(scored_paths[:max_paths]):
chain = " β†’ ".join(
[path[0][0]] + [f"--[{rel}]--> {tgt}" for _, rel, tgt in path]
)
lines.append(f" Path {i+1} (confidence: {score:.3f}): {chain}")
return "\n".join(lines)
# ═══════════════════════════════════════════════════════════
# 4. SPREADING ACTIVATION CONTEXT SCORER
# Paper: SA-RAG (2512.15922)
# Key idea: Activate seed nodes, propagate activation through
# graph edges with decay. Activation score = retrieval priority.
# Result: +39% answer correctness on MuSiQue.
# ═══════════════════════════════════════════════════════════
class SpreadingActivation:
"""
Spreading Activation from seed entities through the knowledge graph.
Nodes with high activation are most relevant to the query.
"""
def __init__(self, decay_factor: float = 0.7, threshold: float = 0.01,
max_steps: int = 3):
self.decay = decay_factor
self.threshold = threshold
self.max_steps = max_steps
def activate(
self,
adjacency: Dict[str, List[Tuple[str, float]]],
seed_activations: Dict[str, float],
) -> Dict[str, float]:
"""
Spread activation from seeds through the graph.
Returns: {node_id: activation_level}
"""
activations = dict(seed_activations)
frontier = set(seed_activations.keys())
for step in range(self.max_steps):
new_activations: Dict[str, float] = {}
next_frontier: Set[str] = set()
for node in frontier:
current_activation = activations.get(node, 0.0)
if current_activation < self.threshold:
continue
for neighbor, weight in adjacency.get(node, []):
spread = current_activation * weight * self.decay
if spread >= self.threshold:
existing = new_activations.get(neighbor, 0.0)
new_activations[neighbor] = max(existing, spread)
next_frontier.add(neighbor)
# Merge new activations (keep max)
for node, act in new_activations.items():
activations[node] = max(activations.get(node, 0.0), act)
frontier = next_frontier
if not frontier:
break
return activations
def rank_contexts(
self,
activations: Dict[str, float],
entity_to_chunks: Dict[str, List[str]],
chunk_texts: Dict[str, str],
) -> List[Tuple[str, str, float]]:
"""Rank chunks by sum of their entities' activation levels."""
chunk_scores: Dict[str, float] = defaultdict(float)
for entity_id, chunk_ids in entity_to_chunks.items():
act = activations.get(entity_id, 0.0)
for cid in chunk_ids:
chunk_scores[cid] += act
results = [
(cid, chunk_texts.get(cid, ""), score)
for cid, score in chunk_scores.items()
if score > 0 and cid in chunk_texts
]
results.sort(key=lambda x: x[2], reverse=True)
return results
# ═══════════════════════════════════════════════════════════
# 5. POLYG HYBRID RETRIEVAL ROUTER
# Papers: RAGRouter-Bench (2602.00296), PolyG (2504.02112)
# Key idea: 4-class query taxonomy determines retrieval strategy.
# No single paradigm wins everywhere β€” route adaptively.
# ═══════════════════════════════════════════════════════════
class PolyGRouter:
"""
Enhanced hybrid router using PolyG's 4-class query taxonomy.
Routes queries to optimal retrieval strategy:
- entity_centric β†’ Graph 1-hop lookup
- relation_lookup β†’ Vector semantic search
- multi_hop β†’ Graph traversal (PPR + paths)
- summarization β†’ Community summaries
- hybrid β†’ Both vector + graph (dual channel)
"""
# Regex patterns for query classification
ENTITY_PATTERNS = [
r"^(what|who|where) (is|are|was|were) ",
r"^tell me about ",
r"^describe ",
r"^define ",
]
RELATION_PATTERNS = [
r"(what|which) .* (did|does|do) .* (do|make|create|write|direct)",
r"(what|which) .* (position|role|job|title)",
r"how (did|does|do) .* (relate|connect)",
]
MULTI_HOP_PATTERNS = [
r"(same|both|compare|difference|which.*first|who.*born.*first)",
r"(what|who) .* (the|a) .* (that|which|who) ",
r"(capital|director|author|founder) .* (of|for) .* (the|a) .* (that|which)",
]
SUMMARIZATION_PATTERNS = [
r"^(summarize|overview|main themes|what are the)",
r"(overall|in general|broadly)",
]
def classify_query(self, query: str) -> Dict[str, Any]:
"""
Classify query into retrieval strategy.
Returns: {strategy, confidence, query_type, reasoning}
"""
q = query.lower().strip()
# Score each category
scores = {
"entity_centric": 0.0,
"relation_lookup": 0.0,
"multi_hop": 0.0,
"summarization": 0.0,
}
for pattern in self.ENTITY_PATTERNS:
if re.search(pattern, q):
scores["entity_centric"] += 0.4
for pattern in self.RELATION_PATTERNS:
if re.search(pattern, q):
scores["relation_lookup"] += 0.4
for pattern in self.MULTI_HOP_PATTERNS:
if re.search(pattern, q):
scores["multi_hop"] += 0.5
for pattern in self.SUMMARIZATION_PATTERNS:
if re.search(pattern, q):
scores["summarization"] += 0.4
# Structural signals
question_marks = q.count("?")
word_count = len(q.split())
has_comparison = any(w in q for w in ["same", "both", "compare", "difference", "versus", "vs"])
has_chain = any(w in q for w in ["that", "which", "who", "where", "whose"])
entity_count = sum(1 for word in q.split() if word[0:1].isupper()) if q else 0
if has_comparison:
scores["multi_hop"] += 0.3
if has_chain:
scores["multi_hop"] += 0.2
if word_count > 15:
scores["multi_hop"] += 0.1
if word_count < 8 and entity_count <= 1:
scores["entity_centric"] += 0.2
if entity_count >= 2:
scores["multi_hop"] += 0.15
# Determine winner
best_type = max(scores, key=scores.get) # type: ignore
best_score = scores[best_type]
# Map to strategy
strategy_map = {
"entity_centric": "graph_lookup",
"relation_lookup": "vector_search",
"multi_hop": "graph_traversal",
"summarization": "community_summary",
}
# If no strong signal, use hybrid
if best_score < 0.2:
strategy = "hybrid"
best_type = "ambiguous"
else:
strategy = strategy_map[best_type]
return {
"strategy": strategy,
"query_type": best_type,
"confidence": round(min(best_score, 1.0), 3),
"scores": {k: round(v, 3) for k, v in scores.items()},
"use_graph": strategy in ["graph_lookup", "graph_traversal", "hybrid"],
"use_vector": strategy in ["vector_search", "hybrid"],
"use_community": strategy == "community_summary",
"reasoning": f"Classified as '{best_type}' (score={best_score:.2f}) β†’ {strategy}",
}
# ═══════════════════════════════════════════════════════════
# 6. INCREMENTAL GRAPH UPDATER
# Papers: TG-RAG (2510.13590), LightRAG (2410.05779)
# Key idea: Add new documents without rebuilding the entire graph.
# Only recompute communities for affected subgraph.
# ═══════════════════════════════════════════════════════════
class IncrementalGraphUpdater:
"""
Supports incremental document ingestion without full graph rebuild.
New entities merge with existing by embedding similarity.
Community re-detection scoped to affected neighborhoods only.
"""
def __init__(self, merge_threshold: float = 0.85):
self.merge_threshold = merge_threshold
def find_merge_candidates(
self,
new_entity: Dict[str, Any], # {name, type, embedding}
existing_entities: List[Dict[str, Any]],
similarity_fn=None,
) -> Optional[str]:
"""
Find existing entity to merge with (deduplication).
Returns existing entity_id if merge candidate found, else None.
"""
if not similarity_fn:
from .graph_layer import cosine_similarity
similarity_fn = cosine_similarity
new_emb = new_entity.get("embedding", [])
if not new_emb:
return None
best_sim = 0.0
best_id = None
for existing in existing_entities:
existing_emb = existing.get("embedding", [])
if not existing_emb:
continue
sim = similarity_fn(new_emb, existing_emb)
if sim > best_sim:
best_sim = sim
best_id = existing.get("entity_id")
if best_sim >= self.merge_threshold and best_id:
logger.info(f"Merge: '{new_entity.get('name')}' β†’ existing '{best_id}' (sim={best_sim:.3f})")
return best_id
return None
def compute_affected_scope(
self,
new_entity_ids: Set[str],
adjacency: Dict[str, List[str]],
scope_hops: int = 2,
) -> Set[str]:
"""
Find entities affected by new additions (for scoped community re-detection).
Returns set of entity_ids within scope_hops of new entities.
"""
affected = set(new_entity_ids)
frontier = set(new_entity_ids)
for _ in range(scope_hops):
next_frontier: Set[str] = set()
for node in frontier:
for neighbor in adjacency.get(node, []):
if neighbor not in affected:
affected.add(neighbor)
next_frontier.add(neighbor)
frontier = next_frontier
return affected
def plan_update(
self,
new_entities: List[Dict[str, Any]],
new_relations: List[Dict[str, Any]],
existing_entity_count: int,
) -> Dict[str, Any]:
"""
Plan the incremental update (for logging/dashboard display).
Returns update plan with estimated savings.
"""
return {
"new_entities": len(new_entities),
"new_relations": len(new_relations),
"existing_entities": existing_entity_count,
"merge_candidates_to_check": min(len(new_entities) * 10, existing_entity_count),
"community_redetection_scope": "affected_subgraph_only",
"estimated_llm_calls_saved": max(0, existing_entity_count - len(new_entities) * 2),
"vs_full_rebuild_savings_pct": round(
(1 - len(new_entities) / max(existing_entity_count, 1)) * 100, 1
) if existing_entity_count > 0 else 0,
}
# ═══════════════════════════════════════════════════════════
# NOVELTY ORCHESTRATOR β€” Combines all 6 techniques
# ═══════════════════════════════════════════════════════════
class NoveltyEngine:
"""
Orchestrates all 6 novelty techniques into a single pipeline.
Used by the main orchestration layer to enhance GraphRAG retrieval.
"""
def __init__(self, token_budget: int = 2000):
self.ppr = PPRConfidenceScorer()
self.budget = TokenBudgetController(max_tokens=token_budget)
self.paths = PathPruner()
self.activation = SpreadingActivation()
self.router = PolyGRouter()
self.updater = IncrementalGraphUpdater()
def enhanced_retrieve(
self,
query: str,
adjacency: Dict[str, List[Tuple[str, float]]],
seed_entities: List[str],
entity_to_chunks: Dict[str, List[str]],
chunk_texts: Dict[str, str],
seed_weights: Optional[Dict[str, float]] = None,
) -> Dict[str, Any]:
"""
Full novelty-enhanced retrieval pipeline:
1. Route query β†’ determine strategy
2. PPR scoring from seeds
3. Spreading activation for expanded context
4. Token budget pruning
5. Return ranked, pruned context with metadata
"""
# Step 1: Route
routing = self.router.classify_query(query)
# Step 2: PPR
ppr_scores = self.ppr.compute_ppr(adjacency, seed_entities, seed_weights)
# Step 3: Spreading Activation
seed_acts = {s: 1.0 for s in seed_entities}
activations = self.activation.activate(adjacency, seed_acts)
# Step 4: Combined scoring (PPR + activation)
combined_chunks: Dict[str, float] = defaultdict(float)
for entity_id, chunk_ids in entity_to_chunks.items():
ppr_s = ppr_scores.get(entity_id, 0.0)
act_s = activations.get(entity_id, 0.0)
combined = 0.6 * ppr_s + 0.4 * act_s # weighted blend
for cid in chunk_ids:
combined_chunks[cid] = max(combined_chunks[cid], combined)
# Step 5: Token budget pruning
scored_items = [
(chunk_texts.get(cid, ""), score)
for cid, score in combined_chunks.items()
if cid in chunk_texts
]
selected_texts, budget_stats = self.budget.prune_context(scored_items)
return {
"contexts": selected_texts,
"routing": routing,
"budget_stats": budget_stats,
"ppr_top_entities": sorted(ppr_scores.items(), key=lambda x: x[1], reverse=True)[:10],
"activation_spread": len([v for v in activations.values() if v > 0.01]),
"technique_chain": [
f"PolyG Router β†’ {routing['strategy']}",
f"PPR Scoring (damping={self.ppr.damping})",
f"Spreading Activation (decay={self.activation.decay})",
f"Token Budget ({budget_stats['used_tokens']}/{budget_stats['budget_tokens']} tokens)",
f"Reduction: {budget_stats['reduction_pct']}%",
],
}