github-actions[bot] commited on
Commit
2202d61
·
1 Parent(s): 44c2466

Auto-sync from demo at Fri Apr 10 08:47:17 UTC 2026

Browse files
graphgen/bases/base_generator.py CHANGED
@@ -74,4 +74,10 @@ class BaseGenerator(ABC):
74
  {"role": "assistant", "content": answer},
75
  ]
76
  }
 
 
 
 
 
 
77
  raise ValueError(f"Unknown output data format: {output_data_format}")
 
74
  {"role": "assistant", "content": answer},
75
  ]
76
  }
77
+
78
+ if output_data_format == "QA_pairs":
79
+ return {
80
+ "question": question,
81
+ "answer": answer,
82
+ }
83
  raise ValueError(f"Unknown output data format: {output_data_format}")
graphgen/models/__init__.py CHANGED
@@ -15,6 +15,7 @@ if TYPE_CHECKING:
15
  AtomicGenerator,
16
  CoTGenerator,
17
  FillInBlankGenerator,
 
18
  MultiAnswerGenerator,
19
  MultiChoiceGenerator,
20
  MultiHopGenerator,
@@ -30,6 +31,8 @@ if TYPE_CHECKING:
30
  DFSPartitioner,
31
  ECEPartitioner,
32
  LeidenPartitioner,
 
 
33
  )
34
  from .reader import (
35
  CSVReader,
@@ -73,6 +76,7 @@ _import_map = {
73
  "QuizGenerator": ".generator",
74
  "TrueFalseGenerator": ".generator",
75
  "VQAGenerator": ".generator",
 
76
  # KG Builder
77
  "LightRAGKGBuilder": ".kg_builder",
78
  "MMKGBuilder": ".kg_builder",
@@ -86,6 +90,8 @@ _import_map = {
86
  "DFSPartitioner": ".partitioner",
87
  "ECEPartitioner": ".partitioner",
88
  "LeidenPartitioner": ".partitioner",
 
 
89
  # Reader
90
  "CSVReader": ".reader",
91
  "JSONReader": ".reader",
 
15
  AtomicGenerator,
16
  CoTGenerator,
17
  FillInBlankGenerator,
18
+ MaskedFillInBlankGenerator,
19
  MultiAnswerGenerator,
20
  MultiChoiceGenerator,
21
  MultiHopGenerator,
 
31
  DFSPartitioner,
32
  ECEPartitioner,
33
  LeidenPartitioner,
34
+ QuintuplePartitioner,
35
+ TriplePartitioner,
36
  )
37
  from .reader import (
38
  CSVReader,
 
76
  "QuizGenerator": ".generator",
77
  "TrueFalseGenerator": ".generator",
78
  "VQAGenerator": ".generator",
79
+ "MaskedFillInBlankGenerator": ".generator",
80
  # KG Builder
81
  "LightRAGKGBuilder": ".kg_builder",
82
  "MMKGBuilder": ".kg_builder",
 
90
  "DFSPartitioner": ".partitioner",
91
  "ECEPartitioner": ".partitioner",
92
  "LeidenPartitioner": ".partitioner",
93
+ "TriplePartitioner": ".partitioner",
94
+ "QuintuplePartitioner": ".partitioner",
95
  # Reader
96
  "CSVReader": ".reader",
97
  "JSONReader": ".reader",
graphgen/models/generator/__init__.py CHANGED
@@ -8,3 +8,4 @@ from .multi_hop_generator import MultiHopGenerator
8
  from .quiz_generator import QuizGenerator
9
  from .true_false_generator import TrueFalseGenerator
10
  from .vqa_generator import VQAGenerator
 
 
8
  from .quiz_generator import QuizGenerator
9
  from .true_false_generator import TrueFalseGenerator
10
  from .vqa_generator import VQAGenerator
11
+ from .masked_fill_in_blank_generator import MaskedFillInBlankGenerator
graphgen/models/generator/masked_fill_in_blank_generator.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import re
3
+ from typing import Any, Optional
4
+
5
+ from graphgen.bases import BaseGenerator
6
+ from graphgen.templates import AGGREGATED_GENERATION_PROMPT
7
+ from graphgen.utils import detect_main_language, logger
8
+
9
+ random.seed(42)
10
+
11
+
12
+ class MaskedFillInBlankGenerator(BaseGenerator):
13
+ """
14
+ Masked Fill-in-blank Generator follows a TWO-STEP process:
15
+ 1. rephrase: Rephrase the input nodes and edges into a coherent text that maintains the original meaning.
16
+ 2. mask: Randomly select a node from the input nodes, and then mask the name of the node in the rephrased text.
17
+ """
18
+
19
+ @staticmethod
20
+ def build_prompt(
21
+ batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
22
+ ) -> str:
23
+ """
24
+ Build prompts for REPHRASE.
25
+ :param batch
26
+ :return:
27
+ """
28
+ nodes, edges = batch
29
+ entities_str = "\n".join(
30
+ [
31
+ f"{index + 1}. {node[0]}: {node[1]['description']}"
32
+ for index, node in enumerate(nodes)
33
+ ]
34
+ )
35
+ relations_str = "\n".join(
36
+ [
37
+ f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}"
38
+ for index, edge in enumerate(edges)
39
+ ]
40
+ )
41
+ language = detect_main_language(entities_str + relations_str)
42
+
43
+ # TODO: configure add_context
44
+ # if add_context:
45
+ # original_ids = [
46
+ # node["source_id"].split("<SEP>")[0] for node in _process_nodes
47
+ # ] + [edge[2]["source_id"].split("<SEP>")[0] for edge in _process_edges]
48
+ # original_ids = list(set(original_ids))
49
+ # original_text = await text_chunks_storage.get_by_ids(original_ids)
50
+ # original_text = "\n".join(
51
+ # [
52
+ # f"{index + 1}. {text['content']}"
53
+ # for index, text in enumerate(original_text)
54
+ # ]
55
+ # )
56
+ prompt = AGGREGATED_GENERATION_PROMPT[language]["ANSWER_REPHRASING"].format(
57
+ entities=entities_str, relationships=relations_str
58
+ )
59
+ return prompt
60
+
61
+ @staticmethod
62
+ def parse_rephrased_text(response: str) -> Optional[str]:
63
+ """
64
+ Parse the rephrased text from the response.
65
+ :param response:
66
+ :return: rephrased text
67
+ """
68
+ rephrased_match = re.search(
69
+ r"<rephrased_text>(.*?)</rephrased_text>", response, re.DOTALL
70
+ )
71
+ if rephrased_match:
72
+ rephrased_text = rephrased_match.group(1).strip()
73
+ else:
74
+ logger.warning("Failed to parse rephrased text from response: %s", response)
75
+ return None
76
+ return rephrased_text.strip('"').strip("'")
77
+
78
+ @staticmethod
79
+ def parse_response(response: str) -> dict:
80
+ pass
81
+
82
+ async def generate(
83
+ self,
84
+ batch: tuple[
85
+ list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
86
+ ],
87
+ ) -> list[dict]:
88
+ """
89
+ Generate QAs based on a given batch.
90
+ :param batch
91
+ :return: QA pairs
92
+ """
93
+ rephrasing_prompt = self.build_prompt(batch)
94
+ response = await self.llm_client.generate_answer(rephrasing_prompt)
95
+ context = self.parse_rephrased_text(response)
96
+ if not context:
97
+ return []
98
+
99
+ nodes, edges = batch
100
+
101
+ assert len(nodes) == 3, (
102
+ "MaskedFillInBlankGenerator currently only supports quintuples that has 3 nodes, "
103
+ f"but got {len(nodes)} nodes."
104
+ )
105
+ assert len(edges) == 2, (
106
+ "MaskedFillInBlankGenerator currently only supports quintuples that has 2 edges, "
107
+ f"but got {len(edges)} edges."
108
+ )
109
+
110
+ node1, node2, node3 = nodes
111
+ mask_node = random.choice([node1, node2, node3])
112
+ mask_node_name = mask_node[1]["entity_name"].strip("'\" \n\r\t")
113
+ mask_pattern = re.compile(re.escape(mask_node_name), re.IGNORECASE)
114
+
115
+ match = re.search(mask_pattern, context)
116
+ if match:
117
+ gth = match.group(0)
118
+ masked_context = mask_pattern.sub("___", context)
119
+ else:
120
+ logger.debug(
121
+ "Regex Match Failed!\n"
122
+ "Expected name of node: %s\n"
123
+ "Actual context: %s\n",
124
+ mask_node_name,
125
+ context,
126
+ )
127
+ return []
128
+
129
+ logger.debug("masked_context: %s", masked_context)
130
+ qa_pairs = {
131
+ "question": masked_context,
132
+ "answer": gth,
133
+ }
134
+ return [qa_pairs]
graphgen/models/partitioner/__init__.py CHANGED
@@ -3,3 +3,5 @@ from .bfs_partitioner import BFSPartitioner
3
  from .dfs_partitioner import DFSPartitioner
4
  from .ece_partitioner import ECEPartitioner
5
  from .leiden_partitioner import LeidenPartitioner
 
 
 
3
  from .dfs_partitioner import DFSPartitioner
4
  from .ece_partitioner import ECEPartitioner
5
  from .leiden_partitioner import LeidenPartitioner
6
+ from .quintuple_partitioner import QuintuplePartitioner
7
+ from .triple_partitioner import TriplePartitioner
graphgen/models/partitioner/quintuple_partitioner.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from collections import deque
3
+ from typing import Any, Iterable, Set
4
+
5
+ from graphgen.bases import BaseGraphStorage, BasePartitioner
6
+ from graphgen.bases.datatypes import Community
7
+
8
+ random.seed(42)
9
+
10
+
11
+ class QuintuplePartitioner(BasePartitioner):
12
+ """
13
+ quintuple Partitioner that partitions the graph into multiple distinct quintuple (node, edge, node, edge, node).
14
+ 1. Automatically ignore isolated points.
15
+ 2. In each connected component, yield quintuples in the order of BFS.
16
+ """
17
+
18
+ def partition(
19
+ self,
20
+ g: BaseGraphStorage,
21
+ **kwargs: Any,
22
+ ) -> Iterable[Community]:
23
+ nodes = [n[0] for n in g.get_all_nodes()]
24
+ random.shuffle(nodes)
25
+
26
+ visited_nodes: Set[str] = set()
27
+ used_edges: Set[frozenset[str]] = set()
28
+
29
+ for seed in nodes:
30
+ if seed in visited_nodes:
31
+ continue
32
+
33
+ # start BFS in a connected component
34
+ queue = deque([seed])
35
+ visited_nodes.add(seed)
36
+
37
+ while queue:
38
+ u = queue.popleft()
39
+
40
+ # collect all neighbors connected to node u via unused edges
41
+ available_neighbors = []
42
+ for v in g.get_neighbors(u):
43
+ edge_key = frozenset((u, v))
44
+ if edge_key not in used_edges:
45
+ available_neighbors.append(v)
46
+
47
+ # standard BFS queue maintenance
48
+ if v not in visited_nodes:
49
+ visited_nodes.add(v)
50
+ queue.append(v)
51
+
52
+ random.shuffle(available_neighbors)
53
+
54
+ # every two neighbors paired with the center node u creates one quintuple
55
+ # Note: If available_neighbors has an odd length, the remaining edge
56
+ # stays unused for now. It may be matched into a quintuple later
57
+ # when its other endpoint is processed as a center node.
58
+ for i in range(0, len(available_neighbors) // 2 * 2, 2):
59
+ v1 = available_neighbors[i]
60
+ v2 = available_neighbors[i + 1]
61
+
62
+ edge1 = frozenset((u, v1))
63
+ edge2 = frozenset((u, v2))
64
+
65
+ used_edges.add(edge1)
66
+ used_edges.add(edge2)
67
+
68
+ v1_s, v2_s = sorted((v1, v2))
69
+
70
+ yield Community(
71
+ id=f"{v1_s}-{u}-{v2_s}",
72
+ nodes=[v1_s, u, v2_s],
73
+ edges=[tuple(sorted((v1_s, u))), tuple(sorted((u, v2_s)))],
74
+ )
graphgen/models/partitioner/triple_partitioner.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from collections import deque
3
+ from typing import Any, Iterable, Set
4
+
5
+ from graphgen.bases import BaseGraphStorage, BasePartitioner
6
+ from graphgen.bases.datatypes import Community
7
+
8
+ random.seed(42)
9
+
10
+
11
+ class TriplePartitioner(BasePartitioner):
12
+ """
13
+ Triple Partitioner that partitions the graph into multiple distinct triples (node, edge, node).
14
+ 1. Automatically ignore isolated points.
15
+ 2. In each connected component, yield triples in the order of BFS.
16
+ """
17
+
18
+ def partition(
19
+ self,
20
+ g: BaseGraphStorage,
21
+ **kwargs: Any,
22
+ ) -> Iterable[Community]:
23
+ nodes = [n[0] for n in g.get_all_nodes()]
24
+ random.shuffle(nodes)
25
+
26
+ visited_nodes: Set[str] = set()
27
+ used_edges: Set[frozenset[str]] = set()
28
+
29
+ for seed in nodes:
30
+ if seed in visited_nodes:
31
+ continue
32
+
33
+ # start BFS in a connected component
34
+ queue = deque([seed])
35
+ visited_nodes.add(seed)
36
+
37
+ while queue:
38
+ u = queue.popleft()
39
+
40
+ for v in g.get_neighbors(u):
41
+ edge_key = frozenset((u, v))
42
+
43
+ # if this edge has not been used, a new triple has been found
44
+ if edge_key not in used_edges:
45
+ used_edges.add(edge_key)
46
+
47
+ # use the edge name to ensure the uniqueness of the ID
48
+ u_sorted, v_sorted = sorted((u, v))
49
+ yield Community(
50
+ id=f"{u_sorted}-{v_sorted}",
51
+ nodes=[u_sorted, v_sorted],
52
+ edges=[(u_sorted, v_sorted)],
53
+ )
54
+
55
+ # continue to BFS
56
+ if v not in visited_nodes:
57
+ visited_nodes.add(v)
58
+ queue.append(v)
graphgen/operators/generate/generate_service.py CHANGED
@@ -71,6 +71,10 @@ class GenerateService(BaseOperator):
71
  self.llm_client,
72
  num_of_questions=generate_kwargs.get("num_of_questions", 5),
73
  )
 
 
 
 
74
  elif self.method == "true_false":
75
  from graphgen.models import TrueFalseGenerator
76
 
 
71
  self.llm_client,
72
  num_of_questions=generate_kwargs.get("num_of_questions", 5),
73
  )
74
+ elif self.method == "masked_fill_in_blank":
75
+ from graphgen.models import MaskedFillInBlankGenerator
76
+
77
+ self.generator = MaskedFillInBlankGenerator(self.llm_client)
78
  elif self.method == "true_false":
79
  from graphgen.models import TrueFalseGenerator
80
 
graphgen/operators/partition/partition_service.py CHANGED
@@ -28,7 +28,7 @@ class PartitionService(BaseOperator):
28
 
29
  self.tokenizer_instance: BaseTokenizer = Tokenizer(model_name=tokenizer_model)
30
  method = partition_kwargs["method"]
31
- self.method_params = partition_kwargs["method_params"]
32
 
33
  if method == "bfs":
34
  from graphgen.models import BFSPartitioner
@@ -57,6 +57,14 @@ class PartitionService(BaseOperator):
57
  if self.method_params.get("anchor_ids")
58
  else None,
59
  )
 
 
 
 
 
 
 
 
60
  else:
61
  raise ValueError(f"Unsupported partition method: {method}")
62
 
 
28
 
29
  self.tokenizer_instance: BaseTokenizer = Tokenizer(model_name=tokenizer_model)
30
  method = partition_kwargs["method"]
31
+ self.method_params = partition_kwargs.get("method_params", {})
32
 
33
  if method == "bfs":
34
  from graphgen.models import BFSPartitioner
 
57
  if self.method_params.get("anchor_ids")
58
  else None,
59
  )
60
+ elif method == "triple":
61
+ from graphgen.models import TriplePartitioner
62
+
63
+ self.partitioner = TriplePartitioner()
64
+ elif method == "quintuple":
65
+ from graphgen.models import QuintuplePartitioner
66
+
67
+ self.partitioner = QuintuplePartitioner()
68
  else:
69
  raise ValueError(f"Unsupported partition method: {method}")
70