camdog920 commited on
Commit
db5f09b
·
verified ·
1 Parent(s): 786ed57

Upload aether/knowledge.py

Browse files
Files changed (1) hide show
  1. aether/knowledge.py +444 -0
aether/knowledge.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AETHER Knowledge Graph Engine.
3
+ Integrates PyTorch Geometric patterns for relational reasoning:
4
+ - RGCN for node classification on knowledge graphs
5
+ - ComplEx for link prediction
6
+ - Neuro-symbolic bridge: learned attention over symbolic rules
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from typing import Dict, List, Any, Optional, Tuple
13
+ import networkx as nx
14
+ import numpy as np
15
+ import logging
16
+
17
+ logger = logging.getLogger("AETHER.Knowledge")
18
+
19
+
20
+ class RGCNLayer(nn.Module):
21
+ """Simplified RGCN layer for knowledge graph reasoning."""
22
+
23
+ def __init__(self, in_dim: int, out_dim: int, num_relations: int,
24
+ num_bases: int = 4):
25
+ super().__init__()
26
+ self.in_dim = in_dim
27
+ self.out_dim = out_dim
28
+ self.num_relations = num_relations
29
+ self.num_bases = num_bases
30
+
31
+ self.bases = nn.Parameter(torch.Tensor(num_bases, in_dim, out_dim))
32
+ self.comp = nn.Parameter(torch.Tensor(num_relations, num_bases))
33
+ self.self_loop = nn.Parameter(torch.Tensor(in_dim, out_dim))
34
+ self.bias = nn.Parameter(torch.Tensor(out_dim))
35
+
36
+ self.reset_parameters()
37
+
38
+ def reset_parameters(self):
39
+ nn.init.xavier_uniform_(self.bases)
40
+ nn.init.xavier_uniform_(self.comp)
41
+ nn.init.xavier_uniform_(self.self_loop)
42
+ nn.init.zeros_(self.bias)
43
+
44
+ def forward(self, x: Optional[torch.Tensor], edge_index: torch.Tensor,
45
+ edge_type: torch.Tensor) -> torch.Tensor:
46
+ num_nodes = int(edge_index.max().item()) + 1 if x is None else x.size(0)
47
+
48
+ if x is None:
49
+ x = torch.eye(num_nodes, self.in_dim, device=edge_index.device)
50
+
51
+ weight = torch.einsum('rb, bio -> rio', self.comp, self.bases)
52
+
53
+ out = torch.zeros(num_nodes, self.out_dim, device=x.device)
54
+
55
+ for rel_id in range(self.num_relations):
56
+ mask = edge_type == rel_id
57
+ if mask.sum() == 0:
58
+ continue
59
+
60
+ rel_edges = edge_index[:, mask]
61
+ source = rel_edges[0]
62
+ target = rel_edges[1]
63
+
64
+ messages = torch.mm(x[source], weight[rel_id])
65
+ out.index_add_(0, target, messages)
66
+
67
+ out = out + torch.mm(x, self.self_loop)
68
+ out = out + self.bias
69
+ return out
70
+
71
+
72
+ class KnowledgeGraphEncoder(nn.Module):
73
+ """Multi-layer RGCN encoder for knowledge graph embeddings."""
74
+
75
+ def __init__(self, num_nodes: int, hidden_dim: int, num_relations: int,
76
+ num_layers: int = 2, num_bases: int = 4):
77
+ super().__init__()
78
+ self.num_nodes = num_nodes
79
+ self.hidden_dim = hidden_dim
80
+ self.num_relations = num_relations
81
+
82
+ self.node_embeddings = nn.Embedding(num_nodes, hidden_dim)
83
+
84
+ self.layers = nn.ModuleList([
85
+ RGCNLayer(
86
+ in_dim=hidden_dim if i == 0 else hidden_dim,
87
+ out_dim=hidden_dim,
88
+ num_relations=num_relations,
89
+ num_bases=num_bases,
90
+ )
91
+ for i in range(num_layers)
92
+ ])
93
+
94
+ self.norms = nn.ModuleList([
95
+ nn.LayerNorm(hidden_dim)
96
+ for _ in range(num_layers)
97
+ ])
98
+
99
+ def forward(self, edge_index: torch.Tensor,
100
+ edge_type: torch.Tensor) -> torch.Tensor:
101
+ num_nodes = int(edge_index.max().item()) + 1
102
+ x = self.node_embeddings(torch.arange(num_nodes, device=edge_index.device))
103
+
104
+ for layer, norm in zip(self.layers, self.norms):
105
+ x_new = layer(x, edge_index, edge_type)
106
+ x_new = F.relu(norm(x_new))
107
+ x = x_new
108
+
109
+ return x
110
+
111
+
112
+ class ComplExScorer(nn.Module):
113
+ """ComplEx link prediction scorer for knowledge graph completion."""
114
+
115
+ def __init__(self, num_nodes: int, num_relations: int, hidden_dim: int = 50):
116
+ super().__init__()
117
+ self.num_nodes = num_nodes
118
+ self.num_relations = num_relations
119
+ self.hidden_dim = hidden_dim
120
+
121
+ self.head_real = nn.Embedding(num_nodes, hidden_dim)
122
+ self.head_imag = nn.Embedding(num_nodes, hidden_dim)
123
+ self.tail_real = nn.Embedding(num_nodes, hidden_dim)
124
+ self.tail_imag = nn.Embedding(num_nodes, hidden_dim)
125
+
126
+ self.rel_real = nn.Embedding(num_relations, hidden_dim)
127
+ self.rel_imag = nn.Embedding(num_relations, hidden_dim)
128
+
129
+ self.reset_parameters()
130
+
131
+ def reset_parameters(self):
132
+ for param in self.parameters():
133
+ nn.init.xavier_uniform_(param)
134
+
135
+ def forward(self, head_idx: torch.Tensor, rel_idx: torch.Tensor,
136
+ tail_idx: torch.Tensor) -> torch.Tensor:
137
+ hr = self.head_real(head_idx)
138
+ hi = self.head_imag(head_idx)
139
+ tr = self.tail_real(tail_idx)
140
+ ti = self.tail_imag(tail_idx)
141
+ rr = self.rel_real(rel_idx)
142
+ ri = self.rel_imag(rel_idx)
143
+
144
+ score = torch.sum(
145
+ hr * rr * tr + hr * ri * ti + hi * rr * ti - hi * ri * tr,
146
+ dim=-1
147
+ )
148
+ return score
149
+
150
+ def loss(self, head_idx: torch.Tensor, rel_idx: torch.Tensor,
151
+ tail_idx: torch.Tensor, negative_head: torch.Tensor = None,
152
+ negative_tail: torch.Tensor = None) -> torch.Tensor:
153
+ pos_score = self.forward(head_idx, rel_idx, tail_idx)
154
+
155
+ if negative_head is not None:
156
+ neg_score = self.forward(negative_head, rel_idx, tail_idx)
157
+ elif negative_tail is not None:
158
+ neg_score = self.forward(head_idx, rel_idx, negative_tail)
159
+ else:
160
+ neg_tail = torch.randint(0, self.num_nodes, tail_idx.size(),
161
+ device=tail_idx.device)
162
+ neg_score = self.forward(head_idx, rel_idx, neg_tail)
163
+
164
+ pos_loss = F.softplus(-pos_score)
165
+ neg_loss = F.softplus(neg_score)
166
+
167
+ return (pos_loss + neg_loss).mean()
168
+
169
+
170
+ class KnowledgeGraphEngine(nn.Module):
171
+ """
172
+ Unified knowledge graph engine combining:
173
+ - NetworkX for graph construction and symbolic reasoning
174
+ - RGCN for learned embeddings
175
+ - ComplEx for link prediction
176
+ - Neuro-symbolic bridge for AETHER integration
177
+ """
178
+
179
+ def __init__(self, embedding_dim: int = 128, num_relations: int = 20,
180
+ max_nodes: int = 10000):
181
+ super().__init__()
182
+ self.embedding_dim = embedding_dim
183
+ self.num_relations = num_relations
184
+ self.max_nodes = max_nodes
185
+
186
+ self.graph = nx.DiGraph()
187
+ self.node_id_map: Dict[str, int] = {}
188
+ self.relation_map: Dict[str, int] = {}
189
+ self.next_node_id = 0
190
+ self.next_rel_id = 0
191
+
192
+ self.encoder: Optional[KnowledgeGraphEncoder] = None
193
+ self.scorer: Optional[ComplExScorer] = None
194
+
195
+ self.symbolic_attention = nn.Parameter(torch.ones(num_relations))
196
+ self.rules: List[Tuple[str, str, str]] = []
197
+
198
+ def _get_or_create_node(self, node_name: str) -> int:
199
+ if node_name not in self.node_id_map:
200
+ self.node_id_map[node_name] = self.next_node_id
201
+ self.graph.add_node(self.next_node_id, name=node_name)
202
+ self.next_node_id += 1
203
+ return self.node_id_map[node_name]
204
+
205
+ def _get_or_create_relation(self, relation: str) -> int:
206
+ if relation not in self.relation_map:
207
+ self.relation_map[relation] = self.next_rel_id
208
+ self.next_rel_id += 1
209
+ return self.relation_map[relation]
210
+
211
+ def add_fact(self, head: str, relation: str, tail: str,
212
+ confidence: float = 1.0):
213
+ h_id = self._get_or_create_node(head)
214
+ t_id = self._get_or_create_node(tail)
215
+ r_id = self._get_or_create_relation(relation)
216
+
217
+ self.graph.add_edge(h_id, t_id, relation=r_id, name=relation,
218
+ confidence=confidence)
219
+ self._ensure_model_capacity()
220
+
221
+ def add_rule(self, premise: Tuple[str, str, str],
222
+ conclusion: Tuple[str, str, str]):
223
+ self.rules.append((premise, conclusion))
224
+
225
+ def _ensure_model_capacity(self):
226
+ if self.encoder is None and self.next_node_id > 0:
227
+ num_nodes = min(self.next_node_id, self.max_nodes)
228
+ num_rels = max(self.next_rel_id, self.num_relations)
229
+
230
+ self.encoder = KnowledgeGraphEncoder(
231
+ num_nodes=num_nodes,
232
+ hidden_dim=self.embedding_dim,
233
+ num_relations=num_rels,
234
+ num_layers=2,
235
+ )
236
+
237
+ self.scorer = ComplExScorer(
238
+ num_nodes=num_nodes,
239
+ num_relations=num_rels,
240
+ hidden_dim=self.embedding_dim // 2,
241
+ )
242
+
243
+ logger.info(f"Initialized KG models: {num_nodes} nodes, {num_rels} relations")
244
+
245
+ def reason_symbolic(self, query_head: str, query_relation: str) -> List[Dict]:
246
+ results = []
247
+
248
+ if query_head not in self.node_id_map:
249
+ return results
250
+
251
+ h_id = self.node_id_map[query_head]
252
+ r_name = query_relation
253
+
254
+ if r_name in self.relation_map:
255
+ r_id = self.relation_map[r_name]
256
+ for _, target, data in self.graph.out_edges(h_id, data=True):
257
+ if data.get('relation') == r_id:
258
+ results.append({
259
+ "head": query_head,
260
+ "relation": r_name,
261
+ "tail": self.graph.nodes[target].get('name', str(target)),
262
+ "confidence": data.get('confidence', 1.0),
263
+ "path": "direct",
264
+ })
265
+
266
+ for premise, conclusion in self.rules:
267
+ p_head, p_rel, p_tail = premise
268
+ c_head, c_rel, c_tail = conclusion
269
+
270
+ if p_head == query_head and self._check_fact(premise):
271
+ inferred_tail = c_tail
272
+ if c_head == "?":
273
+ c_head = query_head
274
+
275
+ results.append({
276
+ "head": c_head,
277
+ "relation": c_rel,
278
+ "tail": inferred_tail,
279
+ "confidence": 0.8,
280
+ "path": "inferred",
281
+ "rule": f"{premise} -> {conclusion}",
282
+ })
283
+
284
+ for neighbor in nx.bfs_tree(self.graph, h_id, depth_limit=2).nodes():
285
+ if neighbor != h_id:
286
+ for path in nx.all_simple_paths(self.graph, h_id, neighbor, cutoff=2):
287
+ if len(path) > 1:
288
+ edge_data = self.graph.edges[path[0], path[1]]
289
+ results.append({
290
+ "head": query_head,
291
+ "relation": f"multi-hop via {edge_data.get('name', 'unknown')}",
292
+ "tail": self.graph.nodes[neighbor].get('name', str(neighbor)),
293
+ "confidence": 0.6 ** (len(path) - 1),
294
+ "path": "->".join(str(n) for n in path),
295
+ })
296
+
297
+ return sorted(results, key=lambda x: x["confidence"], reverse=True)
298
+
299
+ def _check_fact(self, fact: Tuple[str, str, str]) -> bool:
300
+ h, r, t = fact
301
+ if h not in self.node_id_map or t not in self.node_id_map:
302
+ return False
303
+
304
+ h_id = self.node_id_map[h]
305
+ t_id = self.node_id_map[t]
306
+
307
+ if r not in self.relation_map:
308
+ return False
309
+
310
+ r_id = self.relation_map[r]
311
+ return self.graph.has_edge(h_id, t_id) and \
312
+ self.graph.edges[h_id, t_id].get('relation') == r_id
313
+
314
+ def reason_learned(self, query_head: str, query_relation: str,
315
+ top_k: int = 5) -> List[Dict]:
316
+ if self.scorer is None or query_head not in self.node_id_map:
317
+ return []
318
+
319
+ h_id = self.node_id_map[query_head]
320
+ r_id = self.relation_map.get(query_relation)
321
+
322
+ if r_id is None:
323
+ return []
324
+
325
+ h_tensor = torch.tensor([h_id])
326
+ r_tensor = torch.tensor([r_id])
327
+
328
+ all_tails = torch.arange(self.scorer.num_nodes)
329
+ scores = []
330
+
331
+ batch_size = 1000
332
+ for i in range(0, len(all_tails), batch_size):
333
+ batch_tails = all_tails[i:i + batch_size]
334
+ h_batch = h_tensor.repeat(len(batch_tails))
335
+ r_batch = r_tensor.repeat(len(batch_tails))
336
+
337
+ batch_scores = self.scorer(h_batch, r_batch, batch_tails)
338
+ scores.extend(batch_scores.tolist())
339
+
340
+ scores = torch.tensor(scores)
341
+ top_scores, top_indices = torch.topk(scores, min(top_k, len(scores)))
342
+
343
+ results = []
344
+ for idx, score in zip(top_indices, top_scores):
345
+ node_name = self.graph.nodes[idx.item()].get('name', str(idx.item()))
346
+ results.append({
347
+ "head": query_head,
348
+ "relation": query_relation,
349
+ "tail": node_name,
350
+ "confidence": torch.sigmoid(score).item(),
351
+ "path": "learned",
352
+ })
353
+
354
+ return results
355
+
356
+ def query(self, text_query: str, top_k: int = 5) -> Dict[str, Any]:
357
+ parts = text_query.lower().split()
358
+
359
+ if len(parts) >= 2:
360
+ head = parts[0].capitalize()
361
+ relation = " ".join(parts[1:])
362
+ else:
363
+ head = text_query.capitalize()
364
+ relation = "related_to"
365
+
366
+ symbolic_results = self.reason_symbolic(head, relation)
367
+ learned_results = self.reason_learned(head, relation, top_k)
368
+
369
+ rel_id = self.relation_map.get(relation, 0)
370
+ symbolic_weight = torch.sigmoid(self.symbolic_attention[rel_id % self.num_relations])
371
+ learned_weight = 1.0 - symbolic_weight.item()
372
+
373
+ all_results = []
374
+
375
+ for r in symbolic_results[:top_k]:
376
+ r["source"] = "symbolic"
377
+ r["fusion_weight"] = symbolic_weight.item()
378
+ all_results.append(r)
379
+
380
+ for r in learned_results[:top_k]:
381
+ r["source"] = "learned"
382
+ r["fusion_weight"] = learned_weight
383
+ all_results.append(r)
384
+
385
+ all_results.sort(key=lambda x: x.get("confidence", 0), reverse=True)
386
+
387
+ return {
388
+ "query": text_query,
389
+ "results": all_results[:top_k],
390
+ "symbolic_weight": symbolic_weight.item(),
391
+ "learned_weight": learned_weight,
392
+ "num_symbolic": len(symbolic_results),
393
+ "num_learned": len(learned_results),
394
+ }
395
+
396
+ def to_pyg_data(self) -> Dict[str, torch.Tensor]:
397
+ edges = []
398
+ edge_types = []
399
+
400
+ for u, v, data in self.graph.edges(data=True):
401
+ edges.append([u, v])
402
+ edge_types.append(data.get('relation', 0))
403
+
404
+ if not edges:
405
+ return {}
406
+
407
+ edge_index = torch.tensor(edges, dtype=torch.long).t()
408
+ edge_type = torch.tensor(edge_types, dtype=torch.long)
409
+
410
+ return {
411
+ "edge_index": edge_index,
412
+ "edge_type": edge_type,
413
+ "num_nodes": self.next_node_id,
414
+ "num_relations": self.next_rel_id,
415
+ }
416
+
417
+ def stats(self) -> Dict[str, Any]:
418
+ return {
419
+ "num_nodes": self.graph.number_of_nodes(),
420
+ "num_edges": self.graph.number_of_edges(),
421
+ "num_relations": len(self.relation_map),
422
+ "num_rules": len(self.rules),
423
+ "node_names": len(self.node_id_map),
424
+ }
425
+
426
+ def export(self) -> Dict[str, Any]:
427
+ edges = []
428
+ for u, v, data in self.graph.edges(data=True):
429
+ edges.append({
430
+ "source": u,
431
+ "target": v,
432
+ "relation_id": data.get('relation'),
433
+ "relation_name": data.get('name'),
434
+ "confidence": data.get('confidence'),
435
+ })
436
+
437
+ return {
438
+ "nodes": {n: self.graph.nodes[n].get('name', str(n))
439
+ for n in self.graph.nodes()},
440
+ "edges": edges,
441
+ "node_id_map": self.node_id_map,
442
+ "relation_map": self.relation_map,
443
+ "rules": self.rules,
444
+ }