| import os |
| import re |
| from dataclasses import dataclass |
| from typing import final |
| import configparser |
|
|
|
|
| from tenacity import ( |
| retry, |
| stop_after_attempt, |
| wait_exponential, |
| retry_if_exception_type, |
| ) |
|
|
| import logging |
| from ..utils import logger |
| from ..base import BaseGraphStorage |
| from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge |
| from ..constants import GRAPH_FIELD_SEP |
| from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock |
| import pipmaster as pm |
|
|
| if not pm.is_installed("neo4j"): |
| pm.install("neo4j") |
|
|
| from neo4j import ( |
| AsyncGraphDatabase, |
| exceptions as neo4jExceptions, |
| AsyncDriver, |
| AsyncManagedTransaction, |
| ) |
|
|
| from dotenv import load_dotenv |
|
|
| |
| |
| |
| load_dotenv(dotenv_path=".env", override=False) |
|
|
| config = configparser.ConfigParser() |
| config.read("config.ini", "utf-8") |
|
|
|
|
| |
| logging.getLogger("neo4j").setLevel(logging.ERROR) |
|
|
|
|
| @final |
| @dataclass |
| class Neo4JStorage(BaseGraphStorage): |
| def __init__(self, namespace, global_config, embedding_func, workspace=None): |
| |
| neo4j_workspace = os.environ.get("NEO4J_WORKSPACE") |
| if neo4j_workspace and neo4j_workspace.strip(): |
| workspace = neo4j_workspace |
|
|
| |
| if not workspace or not str(workspace).strip(): |
| workspace = "base" |
|
|
| super().__init__( |
| namespace=namespace, |
| workspace=workspace, |
| global_config=global_config, |
| embedding_func=embedding_func, |
| ) |
| self._driver = None |
|
|
| def _get_workspace_label(self) -> str: |
| """Return workspace label (guaranteed non-empty during initialization)""" |
| return self.workspace |
|
|
| def _is_chinese_text(self, text: str) -> bool: |
| """Check if text contains Chinese characters.""" |
| chinese_pattern = re.compile(r"[\u4e00-\u9fff]+") |
| return bool(chinese_pattern.search(text)) |
|
|
| async def initialize(self): |
| async with get_data_init_lock(): |
| URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None)) |
| USERNAME = os.environ.get( |
| "NEO4J_USERNAME", config.get("neo4j", "username", fallback=None) |
| ) |
| PASSWORD = os.environ.get( |
| "NEO4J_PASSWORD", config.get("neo4j", "password", fallback=None) |
| ) |
| MAX_CONNECTION_POOL_SIZE = int( |
| os.environ.get( |
| "NEO4J_MAX_CONNECTION_POOL_SIZE", |
| config.get("neo4j", "connection_pool_size", fallback=100), |
| ) |
| ) |
| CONNECTION_TIMEOUT = float( |
| os.environ.get( |
| "NEO4J_CONNECTION_TIMEOUT", |
| config.get("neo4j", "connection_timeout", fallback=30.0), |
| ), |
| ) |
| CONNECTION_ACQUISITION_TIMEOUT = float( |
| os.environ.get( |
| "NEO4J_CONNECTION_ACQUISITION_TIMEOUT", |
| config.get( |
| "neo4j", "connection_acquisition_timeout", fallback=30.0 |
| ), |
| ), |
| ) |
| MAX_TRANSACTION_RETRY_TIME = float( |
| os.environ.get( |
| "NEO4J_MAX_TRANSACTION_RETRY_TIME", |
| config.get("neo4j", "max_transaction_retry_time", fallback=30.0), |
| ), |
| ) |
| MAX_CONNECTION_LIFETIME = float( |
| os.environ.get( |
| "NEO4J_MAX_CONNECTION_LIFETIME", |
| config.get("neo4j", "max_connection_lifetime", fallback=300.0), |
| ), |
| ) |
| LIVENESS_CHECK_TIMEOUT = float( |
| os.environ.get( |
| "NEO4J_LIVENESS_CHECK_TIMEOUT", |
| config.get("neo4j", "liveness_check_timeout", fallback=30.0), |
| ), |
| ) |
| KEEP_ALIVE = os.environ.get( |
| "NEO4J_KEEP_ALIVE", |
| config.get("neo4j", "keep_alive", fallback="true"), |
| ).lower() in ("true", "1", "yes", "on") |
| DATABASE = os.environ.get( |
| "NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", self.namespace) |
| ) |
| """The default value approach for the DATABASE is only intended to maintain compatibility with legacy practices.""" |
|
|
| self._driver: AsyncDriver = AsyncGraphDatabase.driver( |
| URI, |
| auth=(USERNAME, PASSWORD), |
| max_connection_pool_size=MAX_CONNECTION_POOL_SIZE, |
| connection_timeout=CONNECTION_TIMEOUT, |
| connection_acquisition_timeout=CONNECTION_ACQUISITION_TIMEOUT, |
| max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME, |
| max_connection_lifetime=MAX_CONNECTION_LIFETIME, |
| liveness_check_timeout=LIVENESS_CHECK_TIMEOUT, |
| keep_alive=KEEP_ALIVE, |
| ) |
|
|
| |
| for database in (DATABASE, None): |
| self._DATABASE = database |
| connected = False |
|
|
| try: |
| async with self._driver.session(database=database) as session: |
| try: |
| result = await session.run("MATCH (n) RETURN n LIMIT 0") |
| await result.consume() |
| logger.info( |
| f"[{self.workspace}] Connected to {database} at {URI}" |
| ) |
| connected = True |
| except neo4jExceptions.ServiceUnavailable as e: |
| logger.error( |
| f"[{self.workspace}] " |
| + f"Database {database} at {URI} is not available" |
| ) |
| raise e |
| except neo4jExceptions.AuthError as e: |
| logger.error( |
| f"[{self.workspace}] Authentication failed for {database} at {URI}" |
| ) |
| raise e |
| except neo4jExceptions.ClientError as e: |
| if e.code == "Neo.ClientError.Database.DatabaseNotFound": |
| logger.info( |
| f"[{self.workspace}] " |
| + f"Database {database} at {URI} not found. Try to create specified database." |
| ) |
| try: |
| async with self._driver.session() as session: |
| result = await session.run( |
| f"CREATE DATABASE `{database}` IF NOT EXISTS" |
| ) |
| await result.consume() |
| logger.info( |
| f"[{self.workspace}] " |
| + f"Database {database} at {URI} created" |
| ) |
| connected = True |
| except ( |
| neo4jExceptions.ClientError, |
| neo4jExceptions.DatabaseError, |
| ) as e: |
| if ( |
| e.code |
| == "Neo.ClientError.Statement.UnsupportedAdministrationCommand" |
| ) or ( |
| e.code == "Neo.DatabaseError.Statement.ExecutionFailed" |
| ): |
| if database is not None: |
| logger.warning( |
| f"[{self.workspace}] This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database." |
| ) |
| if database is None: |
| logger.error( |
| f"[{self.workspace}] Failed to create {database} at {URI}" |
| ) |
| raise e |
|
|
| if connected: |
| workspace_label = self._get_workspace_label() |
| |
| try: |
| async with self._driver.session(database=database) as session: |
| await session.run( |
| f"CREATE INDEX IF NOT EXISTS FOR (n:`{workspace_label}`) ON (n.entity_id)" |
| ) |
| logger.info( |
| f"[{self.workspace}] Ensured B-Tree index on entity_id for {workspace_label} in {database}" |
| ) |
| except Exception as e: |
| logger.warning( |
| f"[{self.workspace}] Failed to create B-Tree index: {str(e)}" |
| ) |
|
|
| |
| await self._create_fulltext_index( |
| self._driver, self._DATABASE, workspace_label |
| ) |
| break |
|
|
| async def _create_fulltext_index( |
| self, driver: AsyncDriver, database: str, workspace_label: str |
| ): |
| """Create a full-text index on the entity_id property with Chinese tokenizer support.""" |
| index_name = "entity_id_fulltext_idx" |
| try: |
| async with driver.session(database=database) as session: |
| |
| check_index_query = "SHOW FULLTEXT INDEXES" |
| result = await session.run(check_index_query) |
| indexes = await result.data() |
| await result.consume() |
|
|
| existing_index = None |
| for idx in indexes: |
| if idx["name"] == index_name: |
| existing_index = idx |
| break |
|
|
| |
| if existing_index: |
| index_state = existing_index.get("state", "UNKNOWN") |
| logger.info( |
| f"[{self.workspace}] Found existing index '{index_name}' with state: {index_state}" |
| ) |
|
|
| if index_state == "ONLINE": |
| logger.info( |
| f"[{self.workspace}] Full-text index '{index_name}' already exists and is online. Skipping recreation." |
| ) |
| return |
| else: |
| logger.warning( |
| f"[{self.workspace}] Existing index '{index_name}' is not online (state: {index_state}). Will recreate." |
| ) |
| else: |
| logger.info( |
| f"[{self.workspace}] No existing index '{index_name}' found. Creating new index." |
| ) |
|
|
| |
| needs_recreation = ( |
| existing_index is not None |
| and existing_index.get("state") != "ONLINE" |
| ) |
| needs_creation = existing_index is None |
|
|
| if needs_recreation or needs_creation: |
| |
| if needs_recreation: |
| try: |
| drop_query = f"DROP INDEX {index_name}" |
| result = await session.run(drop_query) |
| await result.consume() |
| logger.info( |
| f"[{self.workspace}] Dropped existing index '{index_name}'" |
| ) |
| except Exception as drop_error: |
| logger.warning( |
| f"[{self.workspace}] Failed to drop existing index: {str(drop_error)}" |
| ) |
|
|
| |
| logger.info( |
| f"[{self.workspace}] Creating full-text index '{index_name}' with Chinese tokenizer support." |
| ) |
|
|
| try: |
| create_index_query = f""" |
| CREATE FULLTEXT INDEX {index_name} |
| FOR (n:`{workspace_label}`) ON EACH [n.entity_id] |
| OPTIONS {{ |
| indexConfig: {{ |
| `fulltext.analyzer`: 'cjk', |
| `fulltext.eventually_consistent`: true |
| }} |
| }} |
| """ |
| result = await session.run(create_index_query) |
| await result.consume() |
| logger.info( |
| f"[{self.workspace}] Successfully created full-text index '{index_name}' with CJK analyzer." |
| ) |
| except Exception as cjk_error: |
| |
| logger.warning( |
| f"[{self.workspace}] CJK analyzer not supported: {str(cjk_error)}. " |
| "Falling back to standard analyzer." |
| ) |
| create_index_query = f""" |
| CREATE FULLTEXT INDEX {index_name} |
| FOR (n:`{workspace_label}`) ON EACH [n.entity_id] |
| """ |
| result = await session.run(create_index_query) |
| await result.consume() |
| logger.info( |
| f"[{self.workspace}] Successfully created full-text index '{index_name}' with standard analyzer." |
| ) |
|
|
| except Exception as e: |
| |
| if "Unknown command" in str(e) or "invalid syntax" in str(e).lower(): |
| logger.warning( |
| f"[{self.workspace}] Could not create or verify full-text index '{index_name}'. " |
| "This might be because you are using a Neo4j version that does not support it. " |
| "Search functionality will fall back to slower, non-indexed queries." |
| ) |
| else: |
| logger.error( |
| f"[{self.workspace}] Failed to create or verify full-text index '{index_name}': {str(e)}" |
| ) |
|
|
| async def finalize(self): |
| """Close the Neo4j driver and release all resources""" |
| async with get_graph_db_lock(): |
| if self._driver: |
| await self._driver.close() |
| self._driver = None |
|
|
| async def __aexit__(self, exc_type, exc, tb): |
| """Ensure driver is closed when context manager exits""" |
| await self.finalize() |
|
|
| async def index_done_callback(self) -> None: |
| |
| pass |
|
|
| async def has_node(self, node_id: str) -> bool: |
| """ |
| Check if a node with the given label exists in the database |
| |
| Args: |
| node_id: Label of the node to check |
| |
| Returns: |
| bool: True if node exists, False otherwise |
| |
| Raises: |
| ValueError: If node_id is invalid |
| Exception: If there is an error executing the query |
| """ |
| workspace_label = self._get_workspace_label() |
| async with self._driver.session( |
| database=self._DATABASE, default_access_mode="READ" |
| ) as session: |
| try: |
| query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists" |
| result = await session.run(query, entity_id=node_id) |
| single_result = await result.single() |
| await result.consume() |
| return single_result["node_exists"] |
| except Exception as e: |
| logger.error( |
| f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}" |
| ) |
| await result.consume() |
| raise |
|
|
| async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: |
| """ |
| Check if an edge exists between two nodes |
| |
| Args: |
| source_node_id: Label of the source node |
| target_node_id: Label of the target node |
| |
| Returns: |
| bool: True if edge exists, False otherwise |
| |
| Raises: |
| ValueError: If either node_id is invalid |
| Exception: If there is an error executing the query |
| """ |
| workspace_label = self._get_workspace_label() |
| async with self._driver.session( |
| database=self._DATABASE, default_access_mode="READ" |
| ) as session: |
| try: |
| query = ( |
| f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) " |
| "RETURN COUNT(r) > 0 AS edgeExists" |
| ) |
| result = await session.run( |
| query, |
| source_entity_id=source_node_id, |
| target_entity_id=target_node_id, |
| ) |
| single_result = await result.single() |
| await result.consume() |
| return single_result["edgeExists"] |
| except Exception as e: |
| logger.error( |
| f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}" |
| ) |
| await result.consume() |
| raise |
|
|
| async def get_node(self, node_id: str) -> dict[str, str] | None: |
| """Get node by its label identifier, return only node properties |
| |
| Args: |
| node_id: The node label to look up |
| |
| Returns: |
| dict: Node properties if found |
| None: If node not found |
| |
| Raises: |
| ValueError: If node_id is invalid |
| Exception: If there is an error executing the query |
| """ |
| workspace_label = self._get_workspace_label() |
| async with self._driver.session( |
| database=self._DATABASE, default_access_mode="READ" |
| ) as session: |
| try: |
| query = ( |
| f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN n" |
| ) |
| result = await session.run(query, entity_id=node_id) |
| try: |
| records = await result.fetch( |
| 2 |
| ) |
|
|
| if len(records) > 1: |
| logger.warning( |
| f"[{self.workspace}] Multiple nodes found with label '{node_id}'. Using first node." |
| ) |
| if records: |
| node = records[0]["n"] |
| node_dict = dict(node) |
| |
| if "labels" in node_dict: |
| node_dict["labels"] = [ |
| label |
| for label in node_dict["labels"] |
| if label != workspace_label |
| ] |
| |
| return node_dict |
| return None |
| finally: |
| await result.consume() |
| except Exception as e: |
| logger.error( |
| f"[{self.workspace}] Error getting node for {node_id}: {str(e)}" |
| ) |
| raise |
|
|
| async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: |
| """ |
| Retrieve multiple nodes in one query using UNWIND. |
| |
| Args: |
| node_ids: List of node entity IDs to fetch. |
| |
| Returns: |
| A dictionary mapping each node_id to its node data (or None if not found). |
| """ |
| workspace_label = self._get_workspace_label() |
| async with self._driver.session( |
| database=self._DATABASE, default_access_mode="READ" |
| ) as session: |
| query = f""" |
| UNWIND $node_ids AS id |
| MATCH (n:`{workspace_label}` {{entity_id: id}}) |
| RETURN n.entity_id AS entity_id, n |
| """ |
| result = await session.run(query, node_ids=node_ids) |
| nodes = {} |
| async for record in result: |
| entity_id = record["entity_id"] |
| node = record["n"] |
| node_dict = dict(node) |
| |
| if "labels" in node_dict: |
| node_dict["labels"] = [ |
| label |
| for label in node_dict["labels"] |
| if label != workspace_label |
| ] |
| nodes[entity_id] = node_dict |
| await result.consume() |
| return nodes |
|
|
| async def node_degree(self, node_id: str) -> int: |
| """Get the degree (number of relationships) of a node with the given label. |
| If multiple nodes have the same label, returns the degree of the first node. |
| If no node is found, returns 0. |
| |
| Args: |
| node_id: The label of the node |
| |
| Returns: |
| int: The number of relationships the node has, or 0 if no node found |
| |
| Raises: |
| ValueError: If node_id is invalid |
| Exception: If there is an error executing the query |
| """ |
| workspace_label = self._get_workspace_label() |
| async with self._driver.session( |
| database=self._DATABASE, default_access_mode="READ" |
| ) as session: |
| try: |
| query = f""" |
| MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) |
| OPTIONAL MATCH (n)-[r]-() |
| RETURN COUNT(r) AS degree |
| """ |
| result = await session.run(query, entity_id=node_id) |
| try: |
| record = await result.single() |
|
|
| if not record: |
| logger.warning( |
| f"[{self.workspace}] No node found with label '{node_id}'" |
| ) |
| return 0 |
|
|
| degree = record["degree"] |
| |
| |
| |
| return degree |
| finally: |
| await result.consume() |
| except Exception as e: |
| logger.error( |
| f"[{self.workspace}] Error getting node degree for {node_id}: {str(e)}" |
| ) |
| raise |
|
|
| async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]: |
| """ |
| Retrieve the degree for multiple nodes in a single query using UNWIND. |
| |
| Args: |
| node_ids: List of node labels (entity_id values) to look up. |
| |
| Returns: |
| A dictionary mapping each node_id to its degree (number of relationships). |
| If a node is not found, its degree will be set to 0. |
| """ |
| workspace_label = self._get_workspace_label() |
| async with self._driver.session( |
| database=self._DATABASE, default_access_mode="READ" |
| ) as session: |
| query = f""" |
| UNWIND $node_ids AS id |
| MATCH (n:`{workspace_label}` {{entity_id: id}}) |
| RETURN n.entity_id AS entity_id, count {{ (n)--() }} AS degree; |
| """ |
| result = await session.run(query, node_ids=node_ids) |
| degrees = {} |
| async for record in result: |
| entity_id = record["entity_id"] |
| degrees[entity_id] = record["degree"] |
| await result.consume() |
|
|
| |
| for nid in node_ids: |
| if nid not in degrees: |
| logger.warning( |
| f"[{self.workspace}] No node found with label '{nid}'" |
| ) |
| degrees[nid] = 0 |
|
|
| |
| return degrees |
|
|
| async def edge_degree(self, src_id: str, tgt_id: str) -> int: |
| """Get the total degree (sum of relationships) of two nodes. |
| |
| Args: |
| src_id: Label of the source node |
| tgt_id: Label of the target node |
| |
| Returns: |
| int: Sum of the degrees of both nodes |
| """ |
| src_degree = await self.node_degree(src_id) |
| trg_degree = await self.node_degree(tgt_id) |
|
|
| |
| src_degree = 0 if src_degree is None else src_degree |
| trg_degree = 0 if trg_degree is None else trg_degree |
|
|
| degrees = int(src_degree) + int(trg_degree) |
| return degrees |
|
|
| async def edge_degrees_batch( |
| self, edge_pairs: list[tuple[str, str]] |
| ) -> dict[tuple[str, str], int]: |
| """ |
| Calculate the combined degree for each edge (sum of the source and target node degrees) |
| in batch using the already implemented node_degrees_batch. |
| |
| Args: |
| edge_pairs: List of (src, tgt) tuples. |
| |
| Returns: |
| A dictionary mapping each (src, tgt) tuple to the sum of their degrees. |
| """ |
| |
| unique_node_ids = {src for src, _ in edge_pairs} |
| unique_node_ids.update({tgt for _, tgt in edge_pairs}) |
|
|
| |
| degrees = await self.node_degrees_batch(list(unique_node_ids)) |
|
|
| |
| edge_degrees = {} |
| for src, tgt in edge_pairs: |
| edge_degrees[(src, tgt)] = degrees.get(src, 0) + degrees.get(tgt, 0) |
| return edge_degrees |
|
|
| async def get_edge( |
| self, source_node_id: str, target_node_id: str |
| ) -> dict[str, str] | None: |
| """Get edge properties between two nodes. |
| |
| Args: |
| source_node_id: Label of the source node |
| target_node_id: Label of the target node |
| |
| Returns: |
| dict: Edge properties if found, default properties if not found or on error |
| |
| Raises: |
| ValueError: If either node_id is invalid |
| Exception: If there is an error executing the query |
| """ |
| workspace_label = self._get_workspace_label() |
| try: |
| async with self._driver.session( |
| database=self._DATABASE, default_access_mode="READ" |
| ) as session: |
| query = f""" |
| MATCH (start:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(end:`{workspace_label}` {{entity_id: $target_entity_id}}) |
| RETURN properties(r) as edge_properties |
| """ |
| result = await session.run( |
| query, |
| source_entity_id=source_node_id, |
| target_entity_id=target_node_id, |
| ) |
| try: |
| records = await result.fetch(2) |
|
|
| if len(records) > 1: |
| logger.warning( |
| f"[{self.workspace}] Multiple edges found between '{source_node_id}' and '{target_node_id}'. Using first edge." |
| ) |
| if records: |
| try: |
| edge_result = dict(records[0]["edge_properties"]) |
| |
| |
| required_keys = { |
| "weight": 1.0, |
| "source_id": None, |
| "description": None, |
| "keywords": None, |
| } |
| for key, default_value in required_keys.items(): |
| if key not in edge_result: |
| edge_result[key] = default_value |
| logger.warning( |
| f"[{self.workspace}] Edge between {source_node_id} and {target_node_id} " |
| f"missing {key}, using default: {default_value}" |
| ) |
|
|
| |
| |
| |
| return edge_result |
| except (KeyError, TypeError, ValueError) as e: |
| logger.error( |
| f"[{self.workspace}] Error processing edge properties between {source_node_id} " |
| f"and {target_node_id}: {str(e)}" |
| ) |
| |
| return { |
| "weight": 1.0, |
| "source_id": None, |
| "description": None, |
| "keywords": None, |
| } |
|
|
| |
| |
| |
| |
| return None |
| finally: |
| await result.consume() |
|
|
| except Exception as e: |
| logger.error( |
| f"[{self.workspace}] Error in get_edge between {source_node_id} and {target_node_id}: {str(e)}" |
| ) |
| raise |
|
|
| async def get_edges_batch( |
| self, pairs: list[dict[str, str]] |
| ) -> dict[tuple[str, str], dict]: |
| """ |
| Retrieve edge properties for multiple (src, tgt) pairs in one query. |
| |
| Args: |
| pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...] |
| |
| Returns: |
| A dictionary mapping (src, tgt) tuples to their edge properties. |
| """ |
| workspace_label = self._get_workspace_label() |
| async with self._driver.session( |
| database=self._DATABASE, default_access_mode="READ" |
| ) as session: |
| query = f""" |
| UNWIND $pairs AS pair |
| MATCH (start:`{workspace_label}` {{entity_id: pair.src}})-[r:DIRECTED]-(end:`{workspace_label}` {{entity_id: pair.tgt}}) |
| RETURN pair.src AS src_id, pair.tgt AS tgt_id, collect(properties(r)) AS edges |
| """ |
| result = await session.run(query, pairs=pairs) |
| edges_dict = {} |
| async for record in result: |
| src = record["src_id"] |
| tgt = record["tgt_id"] |
| edges = record["edges"] |
| if edges and len(edges) > 0: |
| edge_props = edges[0] |
| |
| for key, default in { |
| "weight": 1.0, |
| "source_id": None, |
| "description": None, |
| "keywords": None, |
| }.items(): |
| if key not in edge_props: |
| edge_props[key] = default |
| edges_dict[(src, tgt)] = edge_props |
| else: |
| |
| edges_dict[(src, tgt)] = { |
| "weight": 1.0, |
| "source_id": None, |
| "description": None, |
| "keywords": None, |
| } |
| await result.consume() |
| return edges_dict |
|
|
| async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: |
| """Retrieves all edges (relationships) for a particular node identified by its label. |
| |
| Args: |
| source_node_id: Label of the node to get edges for |
| |
| Returns: |
| list[tuple[str, str]]: List of (source_label, target_label) tuples representing edges |
| None: If no edges found |
| |
| Raises: |
| ValueError: If source_node_id is invalid |
| Exception: If there is an error executing the query |
| """ |
| try: |
| async with self._driver.session( |
| database=self._DATABASE, default_access_mode="READ" |
| ) as session: |
| try: |
| workspace_label = self._get_workspace_label() |
| query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) |
| OPTIONAL MATCH (n)-[r]-(connected:`{workspace_label}`) |
| WHERE connected.entity_id IS NOT NULL |
| RETURN n, r, connected""" |
| results = await session.run(query, entity_id=source_node_id) |
|
|
| edges = [] |
| async for record in results: |
| source_node = record["n"] |
| connected_node = record["connected"] |
|
|
| |
| if not source_node or not connected_node: |
| continue |
|
|
| source_label = ( |
| source_node.get("entity_id") |
| if source_node.get("entity_id") |
| else None |
| ) |
| target_label = ( |
| connected_node.get("entity_id") |
| if connected_node.get("entity_id") |
| else None |
| ) |
|
|
| if source_label and target_label: |
| edges.append((source_label, target_label)) |
|
|
| await results.consume() |
| return edges |
| except Exception as e: |
| logger.error( |
| f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}" |
| ) |
| await results.consume() |
| raise |
| except Exception as e: |
| logger.error( |
| f"[{self.workspace}] Error in get_node_edges for {source_node_id}: {str(e)}" |
| ) |
| raise |
|
|
| async def get_nodes_edges_batch( |
| self, node_ids: list[str] |
| ) -> dict[str, list[tuple[str, str]]]: |
| """ |
| Batch retrieve edges for multiple nodes in one query using UNWIND. |
| For each node, returns both outgoing and incoming edges to properly represent |
| the undirected graph nature. |
| |
| Args: |
| node_ids: List of node IDs (entity_id) for which to retrieve edges. |
| |
| Returns: |
| A dictionary mapping each node ID to its list of edge tuples (source, target). |
| For each node, the list includes both: |
| - Outgoing edges: (queried_node, connected_node) |
| - Incoming edges: (connected_node, queried_node) |
| """ |
| async with self._driver.session( |
| database=self._DATABASE, default_access_mode="READ" |
| ) as session: |
| |
| workspace_label = self._get_workspace_label() |
| query = f""" |
| UNWIND $node_ids AS id |
| MATCH (n:`{workspace_label}` {{entity_id: id}}) |
| OPTIONAL MATCH (n)-[r]-(connected:`{workspace_label}`) |
| RETURN id AS queried_id, n.entity_id AS node_entity_id, |
| connected.entity_id AS connected_entity_id, |
| startNode(r).entity_id AS start_entity_id |
| """ |
| result = await session.run(query, node_ids=node_ids) |
|
|
| |
| edges_dict = {node_id: [] for node_id in node_ids} |
|
|
| |
| async for record in result: |
| queried_id = record["queried_id"] |
| node_entity_id = record["node_entity_id"] |
| connected_entity_id = record["connected_entity_id"] |
| start_entity_id = record["start_entity_id"] |
|
|
| |
| if not node_entity_id or not connected_entity_id: |
| continue |
|
|
| |
| |
| |
| if start_entity_id == node_entity_id: |
| |
| edges_dict[queried_id].append((node_entity_id, connected_entity_id)) |
| else: |
| |
| edges_dict[queried_id].append((connected_entity_id, node_entity_id)) |
|
|
| await result.consume() |
| return edges_dict |
|
|
| async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: |
| workspace_label = self._get_workspace_label() |
| async with self._driver.session( |
| database=self._DATABASE, default_access_mode="READ" |
| ) as session: |
| query = f""" |
| UNWIND $chunk_ids AS chunk_id |
| MATCH (n:`{workspace_label}`) |
| WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep) |
| RETURN DISTINCT n |
| """ |
| result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP) |
| nodes = [] |
| async for record in result: |
| node = record["n"] |
| node_dict = dict(node) |
| |
| node_dict["id"] = node_dict.get("entity_id") |
| nodes.append(node_dict) |
| await result.consume() |
| return nodes |
|
|
| async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: |
| workspace_label = self._get_workspace_label() |
| async with self._driver.session( |
| database=self._DATABASE, default_access_mode="READ" |
| ) as session: |
| query = f""" |
| UNWIND $chunk_ids AS chunk_id |
| MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`) |
| WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep) |
| RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties |
| """ |
| result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP) |
| edges = [] |
| async for record in result: |
| edge_properties = record["properties"] |
| edge_properties["source"] = record["source"] |
| edge_properties["target"] = record["target"] |
| edges.append(edge_properties) |
| await result.consume() |
| return edges |
|
|
| @retry( |
| stop=stop_after_attempt(3), |
| wait=wait_exponential(multiplier=1, min=4, max=10), |
| retry=retry_if_exception_type( |
| ( |
| neo4jExceptions.ServiceUnavailable, |
| neo4jExceptions.TransientError, |
| neo4jExceptions.WriteServiceUnavailable, |
| neo4jExceptions.ClientError, |
| neo4jExceptions.SessionExpired, |
| ConnectionResetError, |
| OSError, |
| ) |
| ), |
| ) |
| async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: |
| """ |
| Upsert a node in the Neo4j database. |
| |
| Args: |
| node_id: The unique identifier for the node (used as label) |
| node_data: Dictionary of node properties |
| """ |
| workspace_label = self._get_workspace_label() |
| properties = node_data |
| entity_type = properties["entity_type"] |
| if "entity_id" not in properties: |
| raise ValueError("Neo4j: node properties must contain an 'entity_id' field") |
|
|
| try: |
| async with self._driver.session(database=self._DATABASE) as session: |
|
|
| async def execute_upsert(tx: AsyncManagedTransaction): |
| query = f""" |
| MERGE (n:`{workspace_label}` {{entity_id: $entity_id}}) |
| SET n += $properties |
| SET n:`{entity_type}` |
| """ |
| result = await tx.run( |
| query, entity_id=node_id, properties=properties |
| ) |
| await result.consume() |
|
|
| await session.execute_write(execute_upsert) |
| except Exception as e: |
| logger.error(f"[{self.workspace}] Error during upsert: {str(e)}") |
| raise |
|
|
| @retry( |
| stop=stop_after_attempt(3), |
| wait=wait_exponential(multiplier=1, min=4, max=10), |
| retry=retry_if_exception_type( |
| ( |
| neo4jExceptions.ServiceUnavailable, |
| neo4jExceptions.TransientError, |
| neo4jExceptions.WriteServiceUnavailable, |
| neo4jExceptions.ClientError, |
| neo4jExceptions.SessionExpired, |
| ConnectionResetError, |
| OSError, |
| ) |
| ), |
| ) |
| async def upsert_edge( |
| self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] |
| ) -> None: |
| """ |
| Upsert an edge and its properties between two nodes identified by their labels. |
| Ensures both source and target nodes exist and are unique before creating the edge. |
| Uses entity_id property to uniquely identify nodes. |
| |
| Args: |
| source_node_id (str): Label of the source node (used as identifier) |
| target_node_id (str): Label of the target node (used as identifier) |
| edge_data (dict): Dictionary of properties to set on the edge |
| |
| Raises: |
| ValueError: If either source or target node does not exist or is not unique |
| """ |
| try: |
| edge_properties = edge_data |
| async with self._driver.session(database=self._DATABASE) as session: |
|
|
| async def execute_upsert(tx: AsyncManagedTransaction): |
| workspace_label = self._get_workspace_label() |
| query = f""" |
| MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}}) |
| WITH source |
| MATCH (target:`{workspace_label}` {{entity_id: $target_entity_id}}) |
| MERGE (source)-[r:DIRECTED]-(target) |
| SET r += $properties |
| RETURN r, source, target |
| """ |
| result = await tx.run( |
| query, |
| source_entity_id=source_node_id, |
| target_entity_id=target_node_id, |
| properties=edge_properties, |
| ) |
| try: |
| await result.fetch(2) |
| finally: |
| await result.consume() |
|
|
| await session.execute_write(execute_upsert) |
| except Exception as e: |
| logger.error(f"[{self.workspace}] Error during edge upsert: {str(e)}") |
| raise |
|
|
| async def get_knowledge_graph( |
| self, |
| node_label: str, |
| max_depth: int = 3, |
| max_nodes: int = None, |
| ) -> KnowledgeGraph: |
| """ |
| Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. |
| |
| Args: |
| node_label: Label of the starting node, * means all nodes |
| max_depth: Maximum depth of the subgraph, Defaults to 3 |
| max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000 |
| |
| Returns: |
| KnowledgeGraph object containing nodes and edges, with an is_truncated flag |
| indicating whether the graph was truncated due to max_nodes limit |
| """ |
| |
| if max_nodes is None: |
| max_nodes = self.global_config.get("max_graph_nodes", 1000) |
| else: |
| |
| max_nodes = min(max_nodes, self.global_config.get("max_graph_nodes", 1000)) |
|
|
| workspace_label = self._get_workspace_label() |
| result = KnowledgeGraph() |
| seen_nodes = set() |
| seen_edges = set() |
|
|
| async with self._driver.session( |
| database=self._DATABASE, default_access_mode="READ" |
| ) as session: |
| try: |
| if node_label == "*": |
| |
| count_query = ( |
| f"MATCH (n:`{workspace_label}`) RETURN count(n) as total" |
| ) |
| count_result = None |
| try: |
| count_result = await session.run(count_query) |
| count_record = await count_result.single() |
|
|
| if count_record and count_record["total"] > max_nodes: |
| result.is_truncated = True |
| logger.info( |
| f"[{self.workspace}] Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}" |
| ) |
| finally: |
| if count_result: |
| await count_result.consume() |
|
|
| |
| main_query = f""" |
| MATCH (n:`{workspace_label}`) |
| OPTIONAL MATCH (n)-[r]-() |
| WITH n, COALESCE(count(r), 0) AS degree |
| ORDER BY degree DESC |
| LIMIT $max_nodes |
| WITH collect({{node: n}}) AS filtered_nodes |
| UNWIND filtered_nodes AS node_info |
| WITH collect(node_info.node) AS kept_nodes, filtered_nodes |
| OPTIONAL MATCH (a)-[r]-(b) |
| WHERE a IN kept_nodes AND b IN kept_nodes |
| RETURN filtered_nodes AS node_info, |
| collect(DISTINCT r) AS relationships |
| """ |
| result_set = None |
| try: |
| result_set = await session.run( |
| main_query, |
| {"max_nodes": max_nodes}, |
| ) |
| record = await result_set.single() |
| finally: |
| if result_set: |
| await result_set.consume() |
|
|
| else: |
| |
| |
| full_query = f""" |
| MATCH (start:`{workspace_label}`) |
| WHERE start.entity_id = $entity_id |
| WITH start |
| CALL apoc.path.subgraphAll(start, {{ |
| relationshipFilter: '', |
| labelFilter: '{workspace_label}', |
| minLevel: 0, |
| maxLevel: $max_depth, |
| bfs: true |
| }}) |
| YIELD nodes, relationships |
| WITH nodes, relationships, size(nodes) AS total_nodes |
| UNWIND nodes AS node |
| WITH collect({{node: node}}) AS node_info, relationships, total_nodes |
| RETURN node_info, relationships, total_nodes |
| """ |
|
|
| |
| full_result = None |
| try: |
| full_result = await session.run( |
| full_query, |
| { |
| "entity_id": node_label, |
| "max_depth": max_depth, |
| }, |
| ) |
| full_record = await full_result.single() |
|
|
| |
| if not full_record: |
| logger.debug( |
| f"[{self.workspace}] No nodes found for entity_id: {node_label}" |
| ) |
| return result |
|
|
| |
| total_nodes = full_record["total_nodes"] |
|
|
| if total_nodes <= max_nodes: |
| |
| logger.debug( |
| f"[{self.workspace}] Using full result with {total_nodes} nodes (no truncation needed)" |
| ) |
| record = full_record |
| else: |
| |
| result.is_truncated = True |
| logger.info( |
| f"[{self.workspace}] Graph truncated: {total_nodes} nodes found, breadth-first search limited to {max_nodes}" |
| ) |
|
|
| |
| limited_query = f""" |
| MATCH (start:`{workspace_label}`) |
| WHERE start.entity_id = $entity_id |
| WITH start |
| CALL apoc.path.subgraphAll(start, {{ |
| relationshipFilter: '', |
| labelFilter: '{workspace_label}', |
| minLevel: 0, |
| maxLevel: $max_depth, |
| limit: $max_nodes, |
| bfs: true |
| }}) |
| YIELD nodes, relationships |
| UNWIND nodes AS node |
| WITH collect({{node: node}}) AS node_info, relationships |
| RETURN node_info, relationships |
| """ |
| result_set = None |
| try: |
| result_set = await session.run( |
| limited_query, |
| { |
| "entity_id": node_label, |
| "max_depth": max_depth, |
| "max_nodes": max_nodes, |
| }, |
| ) |
| record = await result_set.single() |
| finally: |
| if result_set: |
| await result_set.consume() |
| finally: |
| if full_result: |
| await full_result.consume() |
|
|
| if record: |
| |
| for node_info in record["node_info"]: |
| node = node_info["node"] |
| node_id = node.id |
| if node_id not in seen_nodes: |
| result.nodes.append( |
| KnowledgeGraphNode( |
| id=f"{node_id}", |
| labels=[node.get("entity_id")], |
| properties=dict(node), |
| ) |
| ) |
| seen_nodes.add(node_id) |
|
|
| |
| for rel in record["relationships"]: |
| edge_id = rel.id |
| if edge_id not in seen_edges: |
| start = rel.start_node |
| end = rel.end_node |
| result.edges.append( |
| KnowledgeGraphEdge( |
| id=f"{edge_id}", |
| type=rel.type, |
| source=f"{start.id}", |
| target=f"{end.id}", |
| properties=dict(rel), |
| ) |
| ) |
| seen_edges.add(edge_id) |
|
|
| logger.info( |
| f"[{self.workspace}] Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" |
| ) |
|
|
| except neo4jExceptions.ClientError as e: |
| logger.warning(f"[{self.workspace}] APOC plugin error: {str(e)}") |
| if node_label != "*": |
| logger.warning( |
| f"[{self.workspace}] Neo4j: falling back to basic Cypher recursive search..." |
| ) |
| return await self._robust_fallback(node_label, max_depth, max_nodes) |
| else: |
| logger.warning( |
| f"[{self.workspace}] Neo4j: APOC plugin error with wildcard query, returning empty result" |
| ) |
|
|
| return result |
|
|
| async def _robust_fallback( |
| self, node_label: str, max_depth: int, max_nodes: int |
| ) -> KnowledgeGraph: |
| """ |
| Fallback implementation when APOC plugin is not available or incompatible. |
| This method implements the same functionality as get_knowledge_graph but uses |
| only basic Cypher queries and true breadth-first traversal instead of APOC procedures. |
| """ |
| from collections import deque |
|
|
| result = KnowledgeGraph() |
| visited_nodes = set() |
| visited_edges = set() |
| visited_edge_pairs = set() |
|
|
| |
| workspace_label = self._get_workspace_label() |
| async with self._driver.session( |
| database=self._DATABASE, default_access_mode="READ" |
| ) as session: |
| query = f""" |
| MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) |
| RETURN id(n) as node_id, n |
| """ |
| node_result = await session.run(query, entity_id=node_label) |
| try: |
| node_record = await node_result.single() |
| if not node_record: |
| return result |
|
|
| |
| start_node = KnowledgeGraphNode( |
| id=f"{node_record['n'].get('entity_id')}", |
| labels=[node_record["n"].get("entity_id")], |
| properties=dict(node_record["n"]._properties), |
| ) |
| finally: |
| await node_result.consume() |
|
|
| |
| |
| queue = deque([(start_node, None, 0)]) |
|
|
| |
| while queue and len(visited_nodes) < max_nodes: |
| |
| current_node, current_edge, current_depth = queue.popleft() |
|
|
| |
| if current_node.id in visited_nodes: |
| continue |
|
|
| if current_depth > max_depth: |
| logger.debug( |
| f"[{self.workspace}] Skipping node at depth {current_depth} (max_depth: {max_depth})" |
| ) |
| continue |
|
|
| |
| result.nodes.append(current_node) |
| visited_nodes.add(current_node.id) |
|
|
| |
| if current_edge and current_edge.id not in visited_edges: |
| result.edges.append(current_edge) |
| visited_edges.add(current_edge.id) |
|
|
| |
| if len(visited_nodes) >= max_nodes: |
| result.is_truncated = True |
| logger.info( |
| f"[{self.workspace}] Graph truncated: breadth-first search limited to: {max_nodes} nodes" |
| ) |
| break |
|
|
| |
| async with self._driver.session( |
| database=self._DATABASE, default_access_mode="READ" |
| ) as session: |
| workspace_label = self._get_workspace_label() |
| query = f""" |
| MATCH (a:`{workspace_label}` {{entity_id: $entity_id}})-[r]-(b) |
| WITH r, b, id(r) as edge_id, id(b) as target_id |
| RETURN r, b, edge_id, target_id |
| """ |
| results = await session.run(query, entity_id=current_node.id) |
|
|
| |
| records = await results.fetch(1000) |
| await results.consume() |
|
|
| |
| for record in records: |
| rel = record["r"] |
| edge_id = str(record["edge_id"]) |
|
|
| if edge_id not in visited_edges: |
| b_node = record["b"] |
| target_id = b_node.get("entity_id") |
|
|
| if target_id: |
| |
| target_node = KnowledgeGraphNode( |
| id=f"{target_id}", |
| labels=[target_id], |
| properties=dict(b_node._properties), |
| ) |
|
|
| |
| target_edge = KnowledgeGraphEdge( |
| id=f"{edge_id}", |
| type=rel.type, |
| source=f"{current_node.id}", |
| target=f"{target_id}", |
| properties=dict(rel), |
| ) |
|
|
| |
| sorted_pair = tuple(sorted([current_node.id, target_id])) |
|
|
| |
| if sorted_pair not in visited_edge_pairs: |
| |
| if target_id in visited_nodes or ( |
| target_id not in visited_nodes |
| and current_depth < max_depth |
| ): |
| result.edges.append(target_edge) |
| visited_edges.add(edge_id) |
| visited_edge_pairs.add(sorted_pair) |
|
|
| |
| if target_id not in visited_nodes: |
| |
| if current_depth < max_depth: |
| |
| |
| queue.append((target_node, None, current_depth + 1)) |
| else: |
| |
| |
| logger.debug( |
| f"[{self.workspace}] Node {target_id} beyond max depth {max_depth}, edge added but node not included" |
| ) |
| else: |
| |
| logger.debug( |
| f"[{self.workspace}] Node {target_id} already visited, edge added but node not queued" |
| ) |
| else: |
| logger.warning( |
| f"[{self.workspace}] Skipping edge {edge_id} due to missing entity_id on target node" |
| ) |
|
|
| logger.info( |
| f"[{self.workspace}] BFS subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" |
| ) |
| return result |
|
|
| async def get_all_labels(self) -> list[str]: |
| """ |
| Get all existing node labels in the database |
| Returns: |
| ["Person", "Company", ...] # Alphabetically sorted label list |
| """ |
| workspace_label = self._get_workspace_label() |
| async with self._driver.session( |
| database=self._DATABASE, default_access_mode="READ" |
| ) as session: |
| |
| |
|
|
| |
| query = f""" |
| MATCH (n:`{workspace_label}`) |
| WHERE n.entity_id IS NOT NULL |
| RETURN DISTINCT n.entity_id AS label |
| ORDER BY label |
| """ |
| result = await session.run(query) |
| labels = [] |
| try: |
| async for record in result: |
| labels.append(record["label"]) |
| finally: |
| await ( |
| result.consume() |
| ) |
| return labels |
|
|
| @retry( |
| stop=stop_after_attempt(3), |
| wait=wait_exponential(multiplier=1, min=4, max=10), |
| retry=retry_if_exception_type( |
| ( |
| neo4jExceptions.ServiceUnavailable, |
| neo4jExceptions.TransientError, |
| neo4jExceptions.WriteServiceUnavailable, |
| neo4jExceptions.ClientError, |
| neo4jExceptions.SessionExpired, |
| ConnectionResetError, |
| OSError, |
| ) |
| ), |
| ) |
| async def delete_node(self, node_id: str) -> None: |
| """Delete a node with the specified label |
| |
| Args: |
| node_id: The label of the node to delete |
| """ |
|
|
| async def _do_delete(tx: AsyncManagedTransaction): |
| workspace_label = self._get_workspace_label() |
| query = f""" |
| MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) |
| DETACH DELETE n |
| """ |
| result = await tx.run(query, entity_id=node_id) |
| logger.debug(f"[{self.workspace}] Deleted node with label '{node_id}'") |
| await result.consume() |
|
|
| try: |
| async with self._driver.session(database=self._DATABASE) as session: |
| await session.execute_write(_do_delete) |
| except Exception as e: |
| logger.error(f"[{self.workspace}] Error during node deletion: {str(e)}") |
| raise |
|
|
| @retry( |
| stop=stop_after_attempt(3), |
| wait=wait_exponential(multiplier=1, min=4, max=10), |
| retry=retry_if_exception_type( |
| ( |
| neo4jExceptions.ServiceUnavailable, |
| neo4jExceptions.TransientError, |
| neo4jExceptions.WriteServiceUnavailable, |
| neo4jExceptions.ClientError, |
| neo4jExceptions.SessionExpired, |
| ConnectionResetError, |
| OSError, |
| ) |
| ), |
| ) |
| async def remove_nodes(self, nodes: list[str]): |
| """Delete multiple nodes |
| |
| Args: |
| nodes: List of node labels to be deleted |
| """ |
| for node in nodes: |
| await self.delete_node(node) |
|
|
| @retry( |
| stop=stop_after_attempt(3), |
| wait=wait_exponential(multiplier=1, min=4, max=10), |
| retry=retry_if_exception_type( |
| ( |
| neo4jExceptions.ServiceUnavailable, |
| neo4jExceptions.TransientError, |
| neo4jExceptions.WriteServiceUnavailable, |
| neo4jExceptions.ClientError, |
| neo4jExceptions.SessionExpired, |
| ConnectionResetError, |
| OSError, |
| ) |
| ), |
| ) |
| async def remove_edges(self, edges: list[tuple[str, str]]): |
| """Delete multiple edges |
| |
| Args: |
| edges: List of edges to be deleted, each edge is a (source, target) tuple |
| """ |
| for source, target in edges: |
|
|
| async def _do_delete_edge(tx: AsyncManagedTransaction): |
| workspace_label = self._get_workspace_label() |
| query = f""" |
| MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(target:`{workspace_label}` {{entity_id: $target_entity_id}}) |
| DELETE r |
| """ |
| result = await tx.run( |
| query, source_entity_id=source, target_entity_id=target |
| ) |
| logger.debug( |
| f"[{self.workspace}] Deleted edge from '{source}' to '{target}'" |
| ) |
| await result.consume() |
|
|
| try: |
| async with self._driver.session(database=self._DATABASE) as session: |
| await session.execute_write(_do_delete_edge) |
| except Exception as e: |
| logger.error(f"[{self.workspace}] Error during edge deletion: {str(e)}") |
| raise |
|
|
| async def get_all_nodes(self) -> list[dict]: |
| """Get all nodes in the graph. |
| |
| Returns: |
| A list of all nodes, where each node is a dictionary of its properties |
| """ |
| workspace_label = self._get_workspace_label() |
| async with self._driver.session( |
| database=self._DATABASE, default_access_mode="READ" |
| ) as session: |
| query = f""" |
| MATCH (n:`{workspace_label}`) |
| RETURN n |
| """ |
| result = await session.run(query) |
| nodes = [] |
| async for record in result: |
| node = record["n"] |
| node_dict = dict(node) |
| |
| node_dict["id"] = node_dict.get("entity_id") |
| nodes.append(node_dict) |
| await result.consume() |
| return nodes |
|
|
| async def get_all_edges(self) -> list[dict]: |
| """Get all edges in the graph. |
| |
| Returns: |
| A list of all edges, where each edge is a dictionary of its properties |
| """ |
| workspace_label = self._get_workspace_label() |
| async with self._driver.session( |
| database=self._DATABASE, default_access_mode="READ" |
| ) as session: |
| query = f""" |
| MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`) |
| RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties |
| """ |
| result = await session.run(query) |
| edges = [] |
| async for record in result: |
| edge_properties = record["properties"] |
| edge_properties["source"] = record["source"] |
| edge_properties["target"] = record["target"] |
| edges.append(edge_properties) |
| await result.consume() |
| return edges |
|
|
| async def get_popular_labels(self, limit: int = 300) -> list[str]: |
| """Get popular labels by node degree (most connected entities) |
| |
| Args: |
| limit: Maximum number of labels to return |
| |
| Returns: |
| List of labels sorted by degree (highest first) |
| """ |
| workspace_label = self._get_workspace_label() |
| async with self._driver.session( |
| database=self._DATABASE, default_access_mode="READ" |
| ) as session: |
| try: |
| query = f""" |
| MATCH (n:`{workspace_label}`) |
| WHERE n.entity_id IS NOT NULL |
| OPTIONAL MATCH (n)-[r]-() |
| WITH n.entity_id AS label, count(r) AS degree |
| ORDER BY degree DESC, label ASC |
| LIMIT $limit |
| RETURN label |
| """ |
| result = await session.run(query, limit=limit) |
| labels = [] |
| async for record in result: |
| labels.append(record["label"]) |
| await result.consume() |
|
|
| logger.debug( |
| f"[{self.workspace}] Retrieved {len(labels)} popular labels (limit: {limit})" |
| ) |
| return labels |
| except Exception as e: |
| logger.error( |
| f"[{self.workspace}] Error getting popular labels: {str(e)}" |
| ) |
| await result.consume() |
| raise |
|
|
| async def search_labels(self, query: str, limit: int = 50) -> list[str]: |
| """ |
| Search labels with fuzzy matching, using a full-text index for performance if available. |
| Enhanced with Chinese text support using CJK analyzer. |
| Falls back to a slower CONTAINS search if the index is not available or fails. |
| """ |
| workspace_label = self._get_workspace_label() |
| query_strip = query.strip() |
| if not query_strip: |
| return [] |
|
|
| query_lower = query_strip.lower() |
| is_chinese = self._is_chinese_text(query_strip) |
| index_name = "entity_id_fulltext_idx" |
|
|
| |
| try: |
| async with self._driver.session( |
| database=self._DATABASE, default_access_mode="READ" |
| ) as session: |
| if is_chinese: |
| |
| cypher_query = f""" |
| CALL db.index.fulltext.queryNodes($index_name, $search_query) YIELD node, score |
| WITH node, score |
| WHERE node:`{workspace_label}` |
| WITH node.entity_id AS label, score |
| WITH label, score, |
| CASE |
| WHEN label = $query_strip THEN score + 1000 |
| WHEN label CONTAINS $query_strip THEN score + 500 |
| ELSE score |
| END AS final_score |
| RETURN label |
| ORDER BY final_score DESC, label ASC |
| LIMIT $limit |
| """ |
| |
| search_query = query_strip |
| else: |
| |
| cypher_query = f""" |
| CALL db.index.fulltext.queryNodes($index_name, $search_query) YIELD node, score |
| WITH node, score |
| WHERE node:`{workspace_label}` |
| WITH node.entity_id AS label, toLower(node.entity_id) AS label_lower, score |
| WITH label, label_lower, score, |
| CASE |
| WHEN label_lower = $query_lower THEN score + 1000 |
| WHEN label_lower STARTS WITH $query_lower THEN score + 500 |
| WHEN label_lower CONTAINS ' ' + $query_lower OR label_lower CONTAINS '_' + $query_lower THEN score + 50 |
| ELSE score |
| END AS final_score |
| RETURN label |
| ORDER BY final_score DESC, label ASC |
| LIMIT $limit |
| """ |
| search_query = f"{query_strip}*" |
|
|
| result = await session.run( |
| cypher_query, |
| index_name=index_name, |
| search_query=search_query, |
| query_lower=query_lower, |
| query_strip=query_strip, |
| limit=limit, |
| ) |
| labels = [record["label"] async for record in result] |
| await result.consume() |
|
|
| logger.debug( |
| f"[{self.workspace}] Full-text search ({'Chinese' if is_chinese else 'Latin'}) for '{query}' returned {len(labels)} results (limit: {limit})" |
| ) |
| return labels |
|
|
| except Exception as e: |
| |
| logger.warning( |
| f"[{self.workspace}] Full-text search failed with error: {str(e)}. " |
| "Falling back to slower, non-indexed search." |
| ) |
|
|
| |
| async with self._driver.session( |
| database=self._DATABASE, default_access_mode="READ" |
| ) as session: |
| if is_chinese: |
| |
| cypher_query = f""" |
| MATCH (n:`{workspace_label}`) |
| WHERE n.entity_id IS NOT NULL |
| WITH n.entity_id AS label |
| WHERE label CONTAINS $query_strip |
| WITH label, |
| CASE |
| WHEN label = $query_strip THEN 1000 |
| WHEN label STARTS WITH $query_strip THEN 500 |
| ELSE 100 - size(label) |
| END AS score |
| ORDER BY score DESC, label ASC |
| LIMIT $limit |
| RETURN label |
| """ |
| result = await session.run( |
| cypher_query, query_strip=query_strip, limit=limit |
| ) |
| else: |
| |
| cypher_query = f""" |
| MATCH (n:`{workspace_label}`) |
| WHERE n.entity_id IS NOT NULL |
| WITH n.entity_id AS label, toLower(n.entity_id) AS label_lower |
| WHERE label_lower CONTAINS $query_lower |
| WITH label, label_lower, |
| CASE |
| WHEN label_lower = $query_lower THEN 1000 |
| WHEN label_lower STARTS WITH $query_lower THEN 500 |
| ELSE 100 - size(label) |
| END AS score |
| ORDER BY score DESC, label ASC |
| LIMIT $limit |
| RETURN label |
| """ |
| result = await session.run( |
| cypher_query, query_lower=query_lower, limit=limit |
| ) |
|
|
| labels = [record["label"] async for record in result] |
| await result.consume() |
| logger.debug( |
| f"[{self.workspace}] Fallback search ({'Chinese' if is_chinese else 'Latin'}) for '{query}' returned {len(labels)} results (limit: {limit})" |
| ) |
| return labels |
|
|
| async def drop(self) -> dict[str, str]: |
| """Drop all data from current workspace storage and clean up resources |
| |
| This method will delete all nodes and relationships in the current workspace only. |
| |
| Returns: |
| dict[str, str]: Operation status and message |
| - On success: {"status": "success", "message": "workspace data dropped"} |
| - On failure: {"status": "error", "message": "<error details>"} |
| """ |
| async with get_graph_db_lock(): |
| workspace_label = self._get_workspace_label() |
| try: |
| async with self._driver.session(database=self._DATABASE) as session: |
| |
| query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n" |
| result = await session.run(query) |
| await result.consume() |
|
|
| |
| |
| |
| return { |
| "status": "success", |
| "message": f"workspace '{workspace_label}' data dropped", |
| } |
| except Exception as e: |
| logger.error( |
| f"[{self.workspace}] Error dropping Neo4j workspace '{workspace_label}' in database {self._DATABASE}: {e}" |
| ) |
| return {"status": "error", "message": str(e)} |
|
|