import os from collections import defaultdict from neo4j import GraphDatabase from dotenv import load_dotenv # Carico le env vars. Su HF Spaces pesca in automatico dai secrets. load_dotenv() class KnowledgeGraphPersister: def __init__(self): # Setup della connessione a Neo4j. uri = os.getenv("NEO4J_URI") user = os.getenv("NEO4J_USER") password = os.getenv("NEO4J_PASSWORD") try: self.driver = GraphDatabase.driver(uri, auth=(user, password)) self.driver.verify_connectivity() print(f"✅ Connesso a Neo4j ({uri}).") # Chiamo subito la creazione degli indici. Se partiamo a fare ingestion massiva # senza constraint, il DB collassa al primo blocco di MERGE. self._create_constraints() except Exception as e: print(f"❌ Errore critico connessione Neo4j: {e}") self.driver = None def close(self): # Chiudo pulito il driver (chiamato nel lifecycle shutdown dell'API) if self.driver: self.driver.close() def _create_constraints(self): if not self.driver: return # Senza questo vincolo UNIQUE, l'istruzione MERGE fa un Full Table Scan ogni volta. # Fondamentale per mantenere le transazioni < 10ms anche con migliaia di nodi. query = "CREATE CONSTRAINT resource_uri_unique IF NOT EXISTS FOR (n:Resource) REQUIRE n.uri IS UNIQUE" # Indice vettoriale nativo per le ricerche di similarità (dimensionato a 384 per matchare all-MiniLM) query_vector = """ CREATE VECTOR INDEX entity_embeddings IF NOT EXISTS FOR (n:Resource) ON (n.embedding) OPTIONS {indexConfig: { `vector.dimensions`: 384, `vector.similarity_function`: 'cosine' }} """ with self.driver.session() as session: try: session.run(query) print("⚡ Vincolo di unicità verificato.") except Exception as e: print(f"⚠️ Warning vincolo unicità: {e}") try: session.run(query_vector) print("⚡ Vector Index verificato.") except Exception as e: print(f"⚠️ Warning vector index: {e}") def sanitize_name(self, name): # Canonicalizzazione molto base: sostituisco spazi inutili e tolgo gli apici che spaccano le query Cypher. if not name: return "Unknown" return name.strip().replace(" ", "_").replace("'", "").replace('"', "") def sanitize_predicate(self, pred): # Cruciale per evitare Cypher Injection. In Cypher NON si può parametrizzare # il tipo di relazione in un MERGE (es. non puoi fare -[r:$pred]-). Devo per forza # iniettarlo nella stringa della query, quindi lo normalizzo in modo drastico. if not pred: return "RELATED_TO" pred = pred.replace(":", "_").replace("-", "_").replace(" ", "_") clean = "".join(x for x in pred if x.isalnum() or x == "_") # Convenzione Neo4j: le relationships sono sempre in UPPERCASE return clean.upper() if clean else "RELATED_TO" def save_triples(self, triples): if not self.driver or not triples: return print(f"💾 Preparazione Batch di {len(triples)} triple...") batched_by_pred = defaultdict(list) for t in triples: safe_pred = self.sanitize_predicate(t.predicate) item = { "subj_uri": self.sanitize_name(t.subject), "subj_label": t.subject, "subj_type": getattr(t, 'subject_type', '').replace(":", "_").replace("-", "_"), "obj_uri": self.sanitize_name(t.object), "obj_label": t.object, "obj_type": getattr(t, 'object_type', '').replace(":", "_").replace("-", "_"), "evidence": getattr(t, 'evidence', 'N/A'), "reasoning": getattr(t, 'reasoning', 'N/A'), "src": getattr(t, 'source', 'unknown') or 'unknown' } batched_by_pred[safe_pred].append(item) with self.driver.session() as session: for pred, data_list in batched_by_pred.items(): try: session.execute_write(self._unwind_write_tx, pred, data_list) print(f" -> Inserite {len(data_list)} relazioni :{pred}") except Exception as e: print(f"⚠️ Errore batch per relazione :{pred} -> {e}") print("✅ Salvataggio completato.") def save_entities_and_triples(self, entities_to_save, triples): if not self.driver: return # Ingestion a 2 step: prima butto dentro i nodi isolati con tutti i loro payload # (embedding vettoriali e link a Wikidata), poi in un secondo momento ci aggancio sopra le relazioni. if entities_to_save: print(f"💾 Salvataggio di {len(entities_to_save)} nodi singoli con vettori...") node_batch = [] for item in entities_to_save: item["uri"] = self.sanitize_name(item["label"]) node_batch.append(item) with self.driver.session() as session: session.execute_write(self._unwind_write_nodes, node_batch) if triples: self.save_triples(triples) @staticmethod def _unwind_write_nodes(tx, batch_data): # L'UNWIND è l'unico modo per fare VERO batching massivo in Neo4j senza distruggere la RAM. # Passo un intero array JSON ($batch) e Cypher lo "srotola" inserendo migliaia di nodi al volo. query = ( "UNWIND $batch AS row " "MERGE (n:Resource {uri: row.uri}) " "ON CREATE SET n.label = row.label, " " n.embedding = row.embedding, " " n.wikidata_sameAs = row.wikidata_sameAs, " " n.last_updated = datetime() " ) tx.run(query, batch=batch_data) @staticmethod def _unwind_write_tx(tx, predicate, batch_data): # Qui avviene la vera traduzione dal mondo RDF a quello Labeled Property Graph (LPG). if predicate in ["RDF_TYPE", "TYPE", "A", "CORE_HASTYPE"]: # Se l'LLM ha generato una tripla di classificazione ontologica, NON creo un nodo astratto inutile. # Uso APOC per convertire l'oggetto della tripla in una vera Label sul nodo di partenza. query = ( "UNWIND $batch AS row " "MERGE (s:Resource {uri: row.subj_uri}) " "ON CREATE SET s.label = row.subj_label, s.last_updated = datetime() " "WITH s, row " "SET s:$( [replace(row.obj_label, ':', '_')] ) " "RETURN count(node)" ) tx.run(query, batch=batch_data) else: # Per tutte le altre relazioni semantiche classiche (es. si_trova_in, ha_autore) # eseguo un merge standard tra le due entità. query = ( f"UNWIND $batch AS row " f"MERGE (s:Resource {{uri: row.subj_uri}}) " f"ON CREATE SET s.label = row.subj_label " f"MERGE (o:Resource {{uri: row.obj_uri}}) " f"ON CREATE SET o.label = row.obj_label " f"WITH s, o, row, " f" CASE WHEN row.subj_type <> '' THEN [row.subj_type] ELSE [] END AS s_labels, " f" CASE WHEN row.obj_type <> '' THEN [row.obj_type] ELSE [] END AS o_labels " f"SET s:$(s_labels), o:$(o_labels) " f"MERGE (s)-[r:`{predicate}`]->(o) " f"SET r.evidence = row.evidence, " f" r.reasoning = row.reasoning, " f" r.source = row.src, " f" r.last_updated = datetime()" ) tx.run(query, batch=batch_data)