| import os |
| import json |
| import xml.etree.ElementTree as ET |
| from neo4j import GraphDatabase |
|
|
| |
| WORKING_DIR = "./dickens" |
| BATCH_SIZE_NODES = 500 |
| BATCH_SIZE_EDGES = 100 |
|
|
| |
| NEO4J_URI = "bolt://localhost:7687" |
| NEO4J_USERNAME = "neo4j" |
| NEO4J_PASSWORD = "your_password" |
|
|
|
|
| def xml_to_json(xml_file): |
| try: |
| tree = ET.parse(xml_file) |
| root = tree.getroot() |
|
|
| |
| print(f"Root element: {root.tag}") |
| print(f"Root attributes: {root.attrib}") |
|
|
| data = {"nodes": [], "edges": []} |
|
|
| |
| namespace = {"": "http://graphml.graphdrawing.org/xmlns"} |
|
|
| for node in root.findall(".//node", namespace): |
| node_data = { |
| "id": node.get("id").strip('"'), |
| "entity_type": node.find("./data[@key='d1']", namespace).text.strip('"') |
| if node.find("./data[@key='d1']", namespace) is not None |
| else "", |
| "description": node.find("./data[@key='d2']", namespace).text |
| if node.find("./data[@key='d2']", namespace) is not None |
| else "", |
| "source_id": node.find("./data[@key='d3']", namespace).text |
| if node.find("./data[@key='d3']", namespace) is not None |
| else "", |
| } |
| data["nodes"].append(node_data) |
|
|
| for edge in root.findall(".//edge", namespace): |
| edge_data = { |
| "source": edge.get("source").strip('"'), |
| "target": edge.get("target").strip('"'), |
| "weight": float(edge.find("./data[@key='d5']", namespace).text) |
| if edge.find("./data[@key='d5']", namespace) is not None |
| else 0.0, |
| "description": edge.find("./data[@key='d6']", namespace).text |
| if edge.find("./data[@key='d6']", namespace) is not None |
| else "", |
| "keywords": edge.find("./data[@key='d7']", namespace).text |
| if edge.find("./data[@key='d7']", namespace) is not None |
| else "", |
| "source_id": edge.find("./data[@key='d8']", namespace).text |
| if edge.find("./data[@key='d8']", namespace) is not None |
| else "", |
| } |
| data["edges"].append(edge_data) |
|
|
| |
| print(f"Found {len(data['nodes'])} nodes and {len(data['edges'])} edges") |
|
|
| return data |
| except ET.ParseError as e: |
| print(f"Error parsing XML file: {e}") |
| return None |
| except Exception as e: |
| print(f"An error occurred: {e}") |
| return None |
|
|
|
|
| def convert_xml_to_json(xml_path, output_path): |
| """Converts XML file to JSON and saves the output.""" |
| if not os.path.exists(xml_path): |
| print(f"Error: File not found - {xml_path}") |
| return None |
|
|
| json_data = xml_to_json(xml_path) |
| if json_data: |
| with open(output_path, "w", encoding="utf-8") as f: |
| json.dump(json_data, f, ensure_ascii=False, indent=2) |
| print(f"JSON file created: {output_path}") |
| return json_data |
| else: |
| print("Failed to create JSON data") |
| return None |
|
|
|
|
| def process_in_batches(tx, query, data, batch_size): |
| """Process data in batches and execute the given query.""" |
| for i in range(0, len(data), batch_size): |
| batch = data[i : i + batch_size] |
| tx.run(query, {"nodes": batch} if "nodes" in query else {"edges": batch}) |
|
|
|
|
| def main(): |
| |
| xml_file = os.path.join(WORKING_DIR, "graph_chunk_entity_relation.graphml") |
| json_file = os.path.join(WORKING_DIR, "graph_data.json") |
|
|
| |
| json_data = convert_xml_to_json(xml_file, json_file) |
| if json_data is None: |
| return |
|
|
| |
| nodes = json_data.get("nodes", []) |
| edges = json_data.get("edges", []) |
|
|
| |
| create_nodes_query = """ |
| UNWIND $nodes AS node |
| MERGE (e:Entity {id: node.id}) |
| SET e.entity_type = node.entity_type, |
| e.description = node.description, |
| e.source_id = node.source_id, |
| e.displayName = node.id |
| REMOVE e:Entity |
| WITH e, node |
| CALL apoc.create.addLabels(e, [node.id]) YIELD node AS labeledNode |
| RETURN count(*) |
| """ |
|
|
| create_edges_query = """ |
| UNWIND $edges AS edge |
| MATCH (source {id: edge.source}) |
| MATCH (target {id: edge.target}) |
| WITH source, target, edge, |
| CASE |
| WHEN edge.keywords CONTAINS 'lead' THEN 'lead' |
| WHEN edge.keywords CONTAINS 'participate' THEN 'participate' |
| WHEN edge.keywords CONTAINS 'uses' THEN 'uses' |
| WHEN edge.keywords CONTAINS 'located' THEN 'located' |
| WHEN edge.keywords CONTAINS 'occurs' THEN 'occurs' |
| ELSE REPLACE(SPLIT(edge.keywords, ',')[0], '\"', '') |
| END AS relType |
| CALL apoc.create.relationship(source, relType, { |
| weight: edge.weight, |
| description: edge.description, |
| keywords: edge.keywords, |
| source_id: edge.source_id |
| }, target) YIELD rel |
| RETURN count(*) |
| """ |
|
|
| set_displayname_and_labels_query = """ |
| MATCH (n) |
| SET n.displayName = n.id |
| WITH n |
| CALL apoc.create.setLabels(n, [n.entity_type]) YIELD node |
| RETURN count(*) |
| """ |
|
|
| |
| driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) |
|
|
| try: |
| |
| with driver.session() as session: |
| |
| session.execute_write( |
| process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES |
| ) |
|
|
| |
| session.execute_write( |
| process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES |
| ) |
|
|
| |
| session.run(set_displayname_and_labels_query) |
|
|
| except Exception as e: |
| print(f"Error occurred: {e}") |
|
|
| finally: |
| driver.close() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|