| """ |
| AETHER Knowledge Graph Engine. |
| Integrates PyTorch Geometric patterns for relational reasoning: |
| - RGCN for node classification on knowledge graphs |
| - ComplEx for link prediction |
| - Neuro-symbolic bridge: learned attention over symbolic rules |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Dict, List, Any, Optional, Tuple |
| import networkx as nx |
| import numpy as np |
| import logging |
|
|
| logger = logging.getLogger("AETHER.Knowledge") |
|
|
|
|
| class RGCNLayer(nn.Module): |
| """Simplified RGCN layer for knowledge graph reasoning.""" |
| |
| def __init__(self, in_dim: int, out_dim: int, num_relations: int, |
| num_bases: int = 4): |
| super().__init__() |
| self.in_dim = in_dim |
| self.out_dim = out_dim |
| self.num_relations = num_relations |
| self.num_bases = num_bases |
| |
| self.bases = nn.Parameter(torch.Tensor(num_bases, in_dim, out_dim)) |
| self.comp = nn.Parameter(torch.Tensor(num_relations, num_bases)) |
| self.self_loop = nn.Parameter(torch.Tensor(in_dim, out_dim)) |
| self.bias = nn.Parameter(torch.Tensor(out_dim)) |
| |
| self.reset_parameters() |
| |
| def reset_parameters(self): |
| nn.init.xavier_uniform_(self.bases) |
| nn.init.xavier_uniform_(self.comp) |
| nn.init.xavier_uniform_(self.self_loop) |
| nn.init.zeros_(self.bias) |
| |
| def forward(self, x: Optional[torch.Tensor], edge_index: torch.Tensor, |
| edge_type: torch.Tensor) -> torch.Tensor: |
| num_nodes = int(edge_index.max().item()) + 1 if x is None else x.size(0) |
| |
| if x is None: |
| x = torch.eye(num_nodes, self.in_dim, device=edge_index.device) |
| |
| weight = torch.einsum('rb, bio -> rio', self.comp, self.bases) |
| |
| out = torch.zeros(num_nodes, self.out_dim, device=x.device) |
| |
| for rel_id in range(self.num_relations): |
| mask = edge_type == rel_id |
| if mask.sum() == 0: |
| continue |
| |
| rel_edges = edge_index[:, mask] |
| source = rel_edges[0] |
| target = rel_edges[1] |
| |
| messages = torch.mm(x[source], weight[rel_id]) |
| out.index_add_(0, target, messages) |
| |
| out = out + torch.mm(x, self.self_loop) |
| out = out + self.bias |
| return out |
|
|
|
|
| class KnowledgeGraphEncoder(nn.Module): |
| """Multi-layer RGCN encoder for knowledge graph embeddings.""" |
| |
| def __init__(self, num_nodes: int, hidden_dim: int, num_relations: int, |
| num_layers: int = 2, num_bases: int = 4): |
| super().__init__() |
| self.num_nodes = num_nodes |
| self.hidden_dim = hidden_dim |
| self.num_relations = num_relations |
| |
| self.node_embeddings = nn.Embedding(num_nodes, hidden_dim) |
| |
| self.layers = nn.ModuleList([ |
| RGCNLayer( |
| in_dim=hidden_dim if i == 0 else hidden_dim, |
| out_dim=hidden_dim, |
| num_relations=num_relations, |
| num_bases=num_bases, |
| ) |
| for i in range(num_layers) |
| ]) |
| |
| self.norms = nn.ModuleList([ |
| nn.LayerNorm(hidden_dim) |
| for _ in range(num_layers) |
| ]) |
| |
| def forward(self, edge_index: torch.Tensor, |
| edge_type: torch.Tensor) -> torch.Tensor: |
| num_nodes = int(edge_index.max().item()) + 1 |
| x = self.node_embeddings(torch.arange(num_nodes, device=edge_index.device)) |
| |
| for layer, norm in zip(self.layers, self.norms): |
| x_new = layer(x, edge_index, edge_type) |
| x_new = F.relu(norm(x_new)) |
| x = x_new |
| |
| return x |
|
|
|
|
| class ComplExScorer(nn.Module): |
| """ComplEx link prediction scorer for knowledge graph completion.""" |
| |
| def __init__(self, num_nodes: int, num_relations: int, hidden_dim: int = 50): |
| super().__init__() |
| self.num_nodes = num_nodes |
| self.num_relations = num_relations |
| self.hidden_dim = hidden_dim |
| |
| self.head_real = nn.Embedding(num_nodes, hidden_dim) |
| self.head_imag = nn.Embedding(num_nodes, hidden_dim) |
| self.tail_real = nn.Embedding(num_nodes, hidden_dim) |
| self.tail_imag = nn.Embedding(num_nodes, hidden_dim) |
| |
| self.rel_real = nn.Embedding(num_relations, hidden_dim) |
| self.rel_imag = nn.Embedding(num_relations, hidden_dim) |
| |
| self.reset_parameters() |
| |
| def reset_parameters(self): |
| for param in self.parameters(): |
| nn.init.xavier_uniform_(param) |
| |
| def forward(self, head_idx: torch.Tensor, rel_idx: torch.Tensor, |
| tail_idx: torch.Tensor) -> torch.Tensor: |
| hr = self.head_real(head_idx) |
| hi = self.head_imag(head_idx) |
| tr = self.tail_real(tail_idx) |
| ti = self.tail_imag(tail_idx) |
| rr = self.rel_real(rel_idx) |
| ri = self.rel_imag(rel_idx) |
| |
| score = torch.sum( |
| hr * rr * tr + hr * ri * ti + hi * rr * ti - hi * ri * tr, |
| dim=-1 |
| ) |
| return score |
| |
| def loss(self, head_idx: torch.Tensor, rel_idx: torch.Tensor, |
| tail_idx: torch.Tensor, negative_head: torch.Tensor = None, |
| negative_tail: torch.Tensor = None) -> torch.Tensor: |
| pos_score = self.forward(head_idx, rel_idx, tail_idx) |
| |
| if negative_head is not None: |
| neg_score = self.forward(negative_head, rel_idx, tail_idx) |
| elif negative_tail is not None: |
| neg_score = self.forward(head_idx, rel_idx, negative_tail) |
| else: |
| neg_tail = torch.randint(0, self.num_nodes, tail_idx.size(), |
| device=tail_idx.device) |
| neg_score = self.forward(head_idx, rel_idx, neg_tail) |
| |
| pos_loss = F.softplus(-pos_score) |
| neg_loss = F.softplus(neg_score) |
| |
| return (pos_loss + neg_loss).mean() |
|
|
|
|
| class KnowledgeGraphEngine(nn.Module): |
| """ |
| Unified knowledge graph engine combining: |
| - NetworkX for graph construction and symbolic reasoning |
| - RGCN for learned embeddings |
| - ComplEx for link prediction |
| - Neuro-symbolic bridge for AETHER integration |
| """ |
| |
| def __init__(self, embedding_dim: int = 128, num_relations: int = 20, |
| max_nodes: int = 10000): |
| super().__init__() |
| self.embedding_dim = embedding_dim |
| self.num_relations = num_relations |
| self.max_nodes = max_nodes |
| |
| self.graph = nx.DiGraph() |
| self.node_id_map: Dict[str, int] = {} |
| self.relation_map: Dict[str, int] = {} |
| self.next_node_id = 0 |
| self.next_rel_id = 0 |
| |
| self.encoder: Optional[KnowledgeGraphEncoder] = None |
| self.scorer: Optional[ComplExScorer] = None |
| |
| self.symbolic_attention = nn.Parameter(torch.ones(num_relations)) |
| self.rules: List[Tuple[str, str, str]] = [] |
| |
| def _get_or_create_node(self, node_name: str) -> int: |
| if node_name not in self.node_id_map: |
| self.node_id_map[node_name] = self.next_node_id |
| self.graph.add_node(self.next_node_id, name=node_name) |
| self.next_node_id += 1 |
| return self.node_id_map[node_name] |
| |
| def _get_or_create_relation(self, relation: str) -> int: |
| if relation not in self.relation_map: |
| self.relation_map[relation] = self.next_rel_id |
| self.next_rel_id += 1 |
| return self.relation_map[relation] |
| |
| def add_fact(self, head: str, relation: str, tail: str, |
| confidence: float = 1.0): |
| h_id = self._get_or_create_node(head) |
| t_id = self._get_or_create_node(tail) |
| r_id = self._get_or_create_relation(relation) |
| |
| self.graph.add_edge(h_id, t_id, relation=r_id, name=relation, |
| confidence=confidence) |
| self._ensure_model_capacity() |
| |
| def add_rule(self, premise: Tuple[str, str, str], |
| conclusion: Tuple[str, str, str]): |
| self.rules.append((premise, conclusion)) |
| |
| def _ensure_model_capacity(self): |
| if self.encoder is None and self.next_node_id > 0: |
| num_nodes = min(self.next_node_id, self.max_nodes) |
| num_rels = max(self.next_rel_id, self.num_relations) |
| |
| self.encoder = KnowledgeGraphEncoder( |
| num_nodes=num_nodes, |
| hidden_dim=self.embedding_dim, |
| num_relations=num_rels, |
| num_layers=2, |
| ) |
| |
| self.scorer = ComplExScorer( |
| num_nodes=num_nodes, |
| num_relations=num_rels, |
| hidden_dim=self.embedding_dim // 2, |
| ) |
| |
| logger.info(f"Initialized KG models: {num_nodes} nodes, {num_rels} relations") |
| |
| def reason_symbolic(self, query_head: str, query_relation: str) -> List[Dict]: |
| results = [] |
| |
| if query_head not in self.node_id_map: |
| return results |
| |
| h_id = self.node_id_map[query_head] |
| r_name = query_relation |
| |
| if r_name in self.relation_map: |
| r_id = self.relation_map[r_name] |
| for _, target, data in self.graph.out_edges(h_id, data=True): |
| if data.get('relation') == r_id: |
| results.append({ |
| "head": query_head, |
| "relation": r_name, |
| "tail": self.graph.nodes[target].get('name', str(target)), |
| "confidence": data.get('confidence', 1.0), |
| "path": "direct", |
| }) |
| |
| for premise, conclusion in self.rules: |
| p_head, p_rel, p_tail = premise |
| c_head, c_rel, c_tail = conclusion |
| |
| if p_head == query_head and self._check_fact(premise): |
| inferred_tail = c_tail |
| if c_head == "?": |
| c_head = query_head |
| |
| results.append({ |
| "head": c_head, |
| "relation": c_rel, |
| "tail": inferred_tail, |
| "confidence": 0.8, |
| "path": "inferred", |
| "rule": f"{premise} -> {conclusion}", |
| }) |
| |
| for neighbor in nx.bfs_tree(self.graph, h_id, depth_limit=2).nodes(): |
| if neighbor != h_id: |
| for path in nx.all_simple_paths(self.graph, h_id, neighbor, cutoff=2): |
| if len(path) > 1: |
| edge_data = self.graph.edges[path[0], path[1]] |
| results.append({ |
| "head": query_head, |
| "relation": f"multi-hop via {edge_data.get('name', 'unknown')}", |
| "tail": self.graph.nodes[neighbor].get('name', str(neighbor)), |
| "confidence": 0.6 ** (len(path) - 1), |
| "path": "->".join(str(n) for n in path), |
| }) |
| |
| return sorted(results, key=lambda x: x["confidence"], reverse=True) |
| |
| def _check_fact(self, fact: Tuple[str, str, str]) -> bool: |
| h, r, t = fact |
| if h not in self.node_id_map or t not in self.node_id_map: |
| return False |
| |
| h_id = self.node_id_map[h] |
| t_id = self.node_id_map[t] |
| |
| if r not in self.relation_map: |
| return False |
| |
| r_id = self.relation_map[r] |
| return self.graph.has_edge(h_id, t_id) and \ |
| self.graph.edges[h_id, t_id].get('relation') == r_id |
| |
| def reason_learned(self, query_head: str, query_relation: str, |
| top_k: int = 5) -> List[Dict]: |
| if self.scorer is None or query_head not in self.node_id_map: |
| return [] |
| |
| h_id = self.node_id_map[query_head] |
| r_id = self.relation_map.get(query_relation) |
| |
| if r_id is None: |
| return [] |
| |
| h_tensor = torch.tensor([h_id]) |
| r_tensor = torch.tensor([r_id]) |
| |
| all_tails = torch.arange(self.scorer.num_nodes) |
| scores = [] |
| |
| batch_size = 1000 |
| for i in range(0, len(all_tails), batch_size): |
| batch_tails = all_tails[i:i + batch_size] |
| h_batch = h_tensor.repeat(len(batch_tails)) |
| r_batch = r_tensor.repeat(len(batch_tails)) |
| |
| batch_scores = self.scorer(h_batch, r_batch, batch_tails) |
| scores.extend(batch_scores.tolist()) |
| |
| scores = torch.tensor(scores) |
| top_scores, top_indices = torch.topk(scores, min(top_k, len(scores))) |
| |
| results = [] |
| for idx, score in zip(top_indices, top_scores): |
| node_name = self.graph.nodes[idx.item()].get('name', str(idx.item())) |
| results.append({ |
| "head": query_head, |
| "relation": query_relation, |
| "tail": node_name, |
| "confidence": torch.sigmoid(score).item(), |
| "path": "learned", |
| }) |
| |
| return results |
| |
| def query(self, text_query: str, top_k: int = 5) -> Dict[str, Any]: |
| parts = text_query.lower().split() |
| |
| if len(parts) >= 2: |
| head = parts[0].capitalize() |
| relation = " ".join(parts[1:]) |
| else: |
| head = text_query.capitalize() |
| relation = "related_to" |
| |
| symbolic_results = self.reason_symbolic(head, relation) |
| learned_results = self.reason_learned(head, relation, top_k) |
| |
| rel_id = self.relation_map.get(relation, 0) |
| symbolic_weight = torch.sigmoid(self.symbolic_attention[rel_id % self.num_relations]) |
| learned_weight = 1.0 - symbolic_weight.item() |
| |
| all_results = [] |
| |
| for r in symbolic_results[:top_k]: |
| r["source"] = "symbolic" |
| r["fusion_weight"] = symbolic_weight.item() |
| all_results.append(r) |
| |
| for r in learned_results[:top_k]: |
| r["source"] = "learned" |
| r["fusion_weight"] = learned_weight |
| all_results.append(r) |
| |
| all_results.sort(key=lambda x: x.get("confidence", 0), reverse=True) |
| |
| return { |
| "query": text_query, |
| "results": all_results[:top_k], |
| "symbolic_weight": symbolic_weight.item(), |
| "learned_weight": learned_weight, |
| "num_symbolic": len(symbolic_results), |
| "num_learned": len(learned_results), |
| } |
| |
| def to_pyg_data(self) -> Dict[str, torch.Tensor]: |
| edges = [] |
| edge_types = [] |
| |
| for u, v, data in self.graph.edges(data=True): |
| edges.append([u, v]) |
| edge_types.append(data.get('relation', 0)) |
| |
| if not edges: |
| return {} |
| |
| edge_index = torch.tensor(edges, dtype=torch.long).t() |
| edge_type = torch.tensor(edge_types, dtype=torch.long) |
| |
| return { |
| "edge_index": edge_index, |
| "edge_type": edge_type, |
| "num_nodes": self.next_node_id, |
| "num_relations": self.next_rel_id, |
| } |
| |
| def stats(self) -> Dict[str, Any]: |
| return { |
| "num_nodes": self.graph.number_of_nodes(), |
| "num_edges": self.graph.number_of_edges(), |
| "num_relations": len(self.relation_map), |
| "num_rules": len(self.rules), |
| "node_names": len(self.node_id_map), |
| } |
| |
| def export(self) -> Dict[str, Any]: |
| edges = [] |
| for u, v, data in self.graph.edges(data=True): |
| edges.append({ |
| "source": u, |
| "target": v, |
| "relation_id": data.get('relation'), |
| "relation_name": data.get('name'), |
| "confidence": data.get('confidence'), |
| }) |
| |
| return { |
| "nodes": {n: self.graph.nodes[n].get('name', str(n)) |
| for n in self.graph.nodes()}, |
| "edges": edges, |
| "node_id_map": self.node_id_map, |
| "relation_map": self.relation_map, |
| "rules": self.rules, |
| } |
|
|