""" AETHER Knowledge Graph Engine. Integrates PyTorch Geometric patterns for relational reasoning: - RGCN for node classification on knowledge graphs - ComplEx for link prediction - Neuro-symbolic bridge: learned attention over symbolic rules """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Dict, List, Any, Optional, Tuple import networkx as nx import numpy as np import logging logger = logging.getLogger("AETHER.Knowledge") class RGCNLayer(nn.Module): """Simplified RGCN layer for knowledge graph reasoning.""" def __init__(self, in_dim: int, out_dim: int, num_relations: int, num_bases: int = 4): super().__init__() self.in_dim = in_dim self.out_dim = out_dim self.num_relations = num_relations self.num_bases = num_bases self.bases = nn.Parameter(torch.Tensor(num_bases, in_dim, out_dim)) self.comp = nn.Parameter(torch.Tensor(num_relations, num_bases)) self.self_loop = nn.Parameter(torch.Tensor(in_dim, out_dim)) self.bias = nn.Parameter(torch.Tensor(out_dim)) self.reset_parameters() def reset_parameters(self): nn.init.xavier_uniform_(self.bases) nn.init.xavier_uniform_(self.comp) nn.init.xavier_uniform_(self.self_loop) nn.init.zeros_(self.bias) def forward(self, x: Optional[torch.Tensor], edge_index: torch.Tensor, edge_type: torch.Tensor) -> torch.Tensor: num_nodes = int(edge_index.max().item()) + 1 if x is None else x.size(0) if x is None: x = torch.eye(num_nodes, self.in_dim, device=edge_index.device) weight = torch.einsum('rb, bio -> rio', self.comp, self.bases) out = torch.zeros(num_nodes, self.out_dim, device=x.device) for rel_id in range(self.num_relations): mask = edge_type == rel_id if mask.sum() == 0: continue rel_edges = edge_index[:, mask] source = rel_edges[0] target = rel_edges[1] messages = torch.mm(x[source], weight[rel_id]) out.index_add_(0, target, messages) out = out + torch.mm(x, self.self_loop) out = out + self.bias return out class KnowledgeGraphEncoder(nn.Module): """Multi-layer RGCN encoder for knowledge graph embeddings.""" def __init__(self, num_nodes: int, hidden_dim: int, num_relations: int, num_layers: int = 2, num_bases: int = 4): super().__init__() self.num_nodes = num_nodes self.hidden_dim = hidden_dim self.num_relations = num_relations self.node_embeddings = nn.Embedding(num_nodes, hidden_dim) self.layers = nn.ModuleList([ RGCNLayer( in_dim=hidden_dim if i == 0 else hidden_dim, out_dim=hidden_dim, num_relations=num_relations, num_bases=num_bases, ) for i in range(num_layers) ]) self.norms = nn.ModuleList([ nn.LayerNorm(hidden_dim) for _ in range(num_layers) ]) def forward(self, edge_index: torch.Tensor, edge_type: torch.Tensor) -> torch.Tensor: num_nodes = int(edge_index.max().item()) + 1 x = self.node_embeddings(torch.arange(num_nodes, device=edge_index.device)) for layer, norm in zip(self.layers, self.norms): x_new = layer(x, edge_index, edge_type) x_new = F.relu(norm(x_new)) x = x_new return x class ComplExScorer(nn.Module): """ComplEx link prediction scorer for knowledge graph completion.""" def __init__(self, num_nodes: int, num_relations: int, hidden_dim: int = 50): super().__init__() self.num_nodes = num_nodes self.num_relations = num_relations self.hidden_dim = hidden_dim self.head_real = nn.Embedding(num_nodes, hidden_dim) self.head_imag = nn.Embedding(num_nodes, hidden_dim) self.tail_real = nn.Embedding(num_nodes, hidden_dim) self.tail_imag = nn.Embedding(num_nodes, hidden_dim) self.rel_real = nn.Embedding(num_relations, hidden_dim) self.rel_imag = nn.Embedding(num_relations, hidden_dim) self.reset_parameters() def reset_parameters(self): for param in self.parameters(): nn.init.xavier_uniform_(param) def forward(self, head_idx: torch.Tensor, rel_idx: torch.Tensor, tail_idx: torch.Tensor) -> torch.Tensor: hr = self.head_real(head_idx) hi = self.head_imag(head_idx) tr = self.tail_real(tail_idx) ti = self.tail_imag(tail_idx) rr = self.rel_real(rel_idx) ri = self.rel_imag(rel_idx) score = torch.sum( hr * rr * tr + hr * ri * ti + hi * rr * ti - hi * ri * tr, dim=-1 ) return score def loss(self, head_idx: torch.Tensor, rel_idx: torch.Tensor, tail_idx: torch.Tensor, negative_head: torch.Tensor = None, negative_tail: torch.Tensor = None) -> torch.Tensor: pos_score = self.forward(head_idx, rel_idx, tail_idx) if negative_head is not None: neg_score = self.forward(negative_head, rel_idx, tail_idx) elif negative_tail is not None: neg_score = self.forward(head_idx, rel_idx, negative_tail) else: neg_tail = torch.randint(0, self.num_nodes, tail_idx.size(), device=tail_idx.device) neg_score = self.forward(head_idx, rel_idx, neg_tail) pos_loss = F.softplus(-pos_score) neg_loss = F.softplus(neg_score) return (pos_loss + neg_loss).mean() class KnowledgeGraphEngine(nn.Module): """ Unified knowledge graph engine combining: - NetworkX for graph construction and symbolic reasoning - RGCN for learned embeddings - ComplEx for link prediction - Neuro-symbolic bridge for AETHER integration """ def __init__(self, embedding_dim: int = 128, num_relations: int = 20, max_nodes: int = 10000): super().__init__() self.embedding_dim = embedding_dim self.num_relations = num_relations self.max_nodes = max_nodes self.graph = nx.DiGraph() self.node_id_map: Dict[str, int] = {} self.relation_map: Dict[str, int] = {} self.next_node_id = 0 self.next_rel_id = 0 self.encoder: Optional[KnowledgeGraphEncoder] = None self.scorer: Optional[ComplExScorer] = None self.symbolic_attention = nn.Parameter(torch.ones(num_relations)) self.rules: List[Tuple[str, str, str]] = [] def _get_or_create_node(self, node_name: str) -> int: if node_name not in self.node_id_map: self.node_id_map[node_name] = self.next_node_id self.graph.add_node(self.next_node_id, name=node_name) self.next_node_id += 1 return self.node_id_map[node_name] def _get_or_create_relation(self, relation: str) -> int: if relation not in self.relation_map: self.relation_map[relation] = self.next_rel_id self.next_rel_id += 1 return self.relation_map[relation] def add_fact(self, head: str, relation: str, tail: str, confidence: float = 1.0): h_id = self._get_or_create_node(head) t_id = self._get_or_create_node(tail) r_id = self._get_or_create_relation(relation) self.graph.add_edge(h_id, t_id, relation=r_id, name=relation, confidence=confidence) self._ensure_model_capacity() def add_rule(self, premise: Tuple[str, str, str], conclusion: Tuple[str, str, str]): self.rules.append((premise, conclusion)) def _ensure_model_capacity(self): if self.encoder is None and self.next_node_id > 0: num_nodes = min(self.next_node_id, self.max_nodes) num_rels = max(self.next_rel_id, self.num_relations) self.encoder = KnowledgeGraphEncoder( num_nodes=num_nodes, hidden_dim=self.embedding_dim, num_relations=num_rels, num_layers=2, ) self.scorer = ComplExScorer( num_nodes=num_nodes, num_relations=num_rels, hidden_dim=self.embedding_dim // 2, ) logger.info(f"Initialized KG models: {num_nodes} nodes, {num_rels} relations") def reason_symbolic(self, query_head: str, query_relation: str) -> List[Dict]: results = [] if query_head not in self.node_id_map: return results h_id = self.node_id_map[query_head] r_name = query_relation if r_name in self.relation_map: r_id = self.relation_map[r_name] for _, target, data in self.graph.out_edges(h_id, data=True): if data.get('relation') == r_id: results.append({ "head": query_head, "relation": r_name, "tail": self.graph.nodes[target].get('name', str(target)), "confidence": data.get('confidence', 1.0), "path": "direct", }) for premise, conclusion in self.rules: p_head, p_rel, p_tail = premise c_head, c_rel, c_tail = conclusion if p_head == query_head and self._check_fact(premise): inferred_tail = c_tail if c_head == "?": c_head = query_head results.append({ "head": c_head, "relation": c_rel, "tail": inferred_tail, "confidence": 0.8, "path": "inferred", "rule": f"{premise} -> {conclusion}", }) for neighbor in nx.bfs_tree(self.graph, h_id, depth_limit=2).nodes(): if neighbor != h_id: for path in nx.all_simple_paths(self.graph, h_id, neighbor, cutoff=2): if len(path) > 1: edge_data = self.graph.edges[path[0], path[1]] results.append({ "head": query_head, "relation": f"multi-hop via {edge_data.get('name', 'unknown')}", "tail": self.graph.nodes[neighbor].get('name', str(neighbor)), "confidence": 0.6 ** (len(path) - 1), "path": "->".join(str(n) for n in path), }) return sorted(results, key=lambda x: x["confidence"], reverse=True) def _check_fact(self, fact: Tuple[str, str, str]) -> bool: h, r, t = fact if h not in self.node_id_map or t not in self.node_id_map: return False h_id = self.node_id_map[h] t_id = self.node_id_map[t] if r not in self.relation_map: return False r_id = self.relation_map[r] return self.graph.has_edge(h_id, t_id) and \ self.graph.edges[h_id, t_id].get('relation') == r_id def reason_learned(self, query_head: str, query_relation: str, top_k: int = 5) -> List[Dict]: if self.scorer is None or query_head not in self.node_id_map: return [] h_id = self.node_id_map[query_head] r_id = self.relation_map.get(query_relation) if r_id is None: return [] h_tensor = torch.tensor([h_id]) r_tensor = torch.tensor([r_id]) all_tails = torch.arange(self.scorer.num_nodes) scores = [] batch_size = 1000 for i in range(0, len(all_tails), batch_size): batch_tails = all_tails[i:i + batch_size] h_batch = h_tensor.repeat(len(batch_tails)) r_batch = r_tensor.repeat(len(batch_tails)) batch_scores = self.scorer(h_batch, r_batch, batch_tails) scores.extend(batch_scores.tolist()) scores = torch.tensor(scores) top_scores, top_indices = torch.topk(scores, min(top_k, len(scores))) results = [] for idx, score in zip(top_indices, top_scores): node_name = self.graph.nodes[idx.item()].get('name', str(idx.item())) results.append({ "head": query_head, "relation": query_relation, "tail": node_name, "confidence": torch.sigmoid(score).item(), "path": "learned", }) return results def query(self, text_query: str, top_k: int = 5) -> Dict[str, Any]: parts = text_query.lower().split() if len(parts) >= 2: head = parts[0].capitalize() relation = " ".join(parts[1:]) else: head = text_query.capitalize() relation = "related_to" symbolic_results = self.reason_symbolic(head, relation) learned_results = self.reason_learned(head, relation, top_k) rel_id = self.relation_map.get(relation, 0) symbolic_weight = torch.sigmoid(self.symbolic_attention[rel_id % self.num_relations]) learned_weight = 1.0 - symbolic_weight.item() all_results = [] for r in symbolic_results[:top_k]: r["source"] = "symbolic" r["fusion_weight"] = symbolic_weight.item() all_results.append(r) for r in learned_results[:top_k]: r["source"] = "learned" r["fusion_weight"] = learned_weight all_results.append(r) all_results.sort(key=lambda x: x.get("confidence", 0), reverse=True) return { "query": text_query, "results": all_results[:top_k], "symbolic_weight": symbolic_weight.item(), "learned_weight": learned_weight, "num_symbolic": len(symbolic_results), "num_learned": len(learned_results), } def to_pyg_data(self) -> Dict[str, torch.Tensor]: edges = [] edge_types = [] for u, v, data in self.graph.edges(data=True): edges.append([u, v]) edge_types.append(data.get('relation', 0)) if not edges: return {} edge_index = torch.tensor(edges, dtype=torch.long).t() edge_type = torch.tensor(edge_types, dtype=torch.long) return { "edge_index": edge_index, "edge_type": edge_type, "num_nodes": self.next_node_id, "num_relations": self.next_rel_id, } def stats(self) -> Dict[str, Any]: return { "num_nodes": self.graph.number_of_nodes(), "num_edges": self.graph.number_of_edges(), "num_relations": len(self.relation_map), "num_rules": len(self.rules), "node_names": len(self.node_id_map), } def export(self) -> Dict[str, Any]: edges = [] for u, v, data in self.graph.edges(data=True): edges.append({ "source": u, "target": v, "relation_id": data.get('relation'), "relation_name": data.get('name'), "confidence": data.get('confidence'), }) return { "nodes": {n: self.graph.nodes[n].get('name', str(n)) for n in self.graph.nodes()}, "edges": edges, "node_id_map": self.node_id_map, "relation_map": self.relation_map, "rules": self.rules, }