Spaces:
Build error
Build error
github-actions[bot] commited on
Commit ·
2202d61
1
Parent(s): 44c2466
Auto-sync from demo at Fri Apr 10 08:47:17 UTC 2026
Browse files- graphgen/bases/base_generator.py +6 -0
- graphgen/models/__init__.py +6 -0
- graphgen/models/generator/__init__.py +1 -0
- graphgen/models/generator/masked_fill_in_blank_generator.py +134 -0
- graphgen/models/partitioner/__init__.py +2 -0
- graphgen/models/partitioner/quintuple_partitioner.py +74 -0
- graphgen/models/partitioner/triple_partitioner.py +58 -0
- graphgen/operators/generate/generate_service.py +4 -0
- graphgen/operators/partition/partition_service.py +9 -1
graphgen/bases/base_generator.py
CHANGED
|
@@ -74,4 +74,10 @@ class BaseGenerator(ABC):
|
|
| 74 |
{"role": "assistant", "content": answer},
|
| 75 |
]
|
| 76 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
raise ValueError(f"Unknown output data format: {output_data_format}")
|
|
|
|
| 74 |
{"role": "assistant", "content": answer},
|
| 75 |
]
|
| 76 |
}
|
| 77 |
+
|
| 78 |
+
if output_data_format == "QA_pairs":
|
| 79 |
+
return {
|
| 80 |
+
"question": question,
|
| 81 |
+
"answer": answer,
|
| 82 |
+
}
|
| 83 |
raise ValueError(f"Unknown output data format: {output_data_format}")
|
graphgen/models/__init__.py
CHANGED
|
@@ -15,6 +15,7 @@ if TYPE_CHECKING:
|
|
| 15 |
AtomicGenerator,
|
| 16 |
CoTGenerator,
|
| 17 |
FillInBlankGenerator,
|
|
|
|
| 18 |
MultiAnswerGenerator,
|
| 19 |
MultiChoiceGenerator,
|
| 20 |
MultiHopGenerator,
|
|
@@ -30,6 +31,8 @@ if TYPE_CHECKING:
|
|
| 30 |
DFSPartitioner,
|
| 31 |
ECEPartitioner,
|
| 32 |
LeidenPartitioner,
|
|
|
|
|
|
|
| 33 |
)
|
| 34 |
from .reader import (
|
| 35 |
CSVReader,
|
|
@@ -73,6 +76,7 @@ _import_map = {
|
|
| 73 |
"QuizGenerator": ".generator",
|
| 74 |
"TrueFalseGenerator": ".generator",
|
| 75 |
"VQAGenerator": ".generator",
|
|
|
|
| 76 |
# KG Builder
|
| 77 |
"LightRAGKGBuilder": ".kg_builder",
|
| 78 |
"MMKGBuilder": ".kg_builder",
|
|
@@ -86,6 +90,8 @@ _import_map = {
|
|
| 86 |
"DFSPartitioner": ".partitioner",
|
| 87 |
"ECEPartitioner": ".partitioner",
|
| 88 |
"LeidenPartitioner": ".partitioner",
|
|
|
|
|
|
|
| 89 |
# Reader
|
| 90 |
"CSVReader": ".reader",
|
| 91 |
"JSONReader": ".reader",
|
|
|
|
| 15 |
AtomicGenerator,
|
| 16 |
CoTGenerator,
|
| 17 |
FillInBlankGenerator,
|
| 18 |
+
MaskedFillInBlankGenerator,
|
| 19 |
MultiAnswerGenerator,
|
| 20 |
MultiChoiceGenerator,
|
| 21 |
MultiHopGenerator,
|
|
|
|
| 31 |
DFSPartitioner,
|
| 32 |
ECEPartitioner,
|
| 33 |
LeidenPartitioner,
|
| 34 |
+
QuintuplePartitioner,
|
| 35 |
+
TriplePartitioner,
|
| 36 |
)
|
| 37 |
from .reader import (
|
| 38 |
CSVReader,
|
|
|
|
| 76 |
"QuizGenerator": ".generator",
|
| 77 |
"TrueFalseGenerator": ".generator",
|
| 78 |
"VQAGenerator": ".generator",
|
| 79 |
+
"MaskedFillInBlankGenerator": ".generator",
|
| 80 |
# KG Builder
|
| 81 |
"LightRAGKGBuilder": ".kg_builder",
|
| 82 |
"MMKGBuilder": ".kg_builder",
|
|
|
|
| 90 |
"DFSPartitioner": ".partitioner",
|
| 91 |
"ECEPartitioner": ".partitioner",
|
| 92 |
"LeidenPartitioner": ".partitioner",
|
| 93 |
+
"TriplePartitioner": ".partitioner",
|
| 94 |
+
"QuintuplePartitioner": ".partitioner",
|
| 95 |
# Reader
|
| 96 |
"CSVReader": ".reader",
|
| 97 |
"JSONReader": ".reader",
|
graphgen/models/generator/__init__.py
CHANGED
|
@@ -8,3 +8,4 @@ from .multi_hop_generator import MultiHopGenerator
|
|
| 8 |
from .quiz_generator import QuizGenerator
|
| 9 |
from .true_false_generator import TrueFalseGenerator
|
| 10 |
from .vqa_generator import VQAGenerator
|
|
|
|
|
|
| 8 |
from .quiz_generator import QuizGenerator
|
| 9 |
from .true_false_generator import TrueFalseGenerator
|
| 10 |
from .vqa_generator import VQAGenerator
|
| 11 |
+
from .masked_fill_in_blank_generator import MaskedFillInBlankGenerator
|
graphgen/models/generator/masked_fill_in_blank_generator.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import re
|
| 3 |
+
from typing import Any, Optional
|
| 4 |
+
|
| 5 |
+
from graphgen.bases import BaseGenerator
|
| 6 |
+
from graphgen.templates import AGGREGATED_GENERATION_PROMPT
|
| 7 |
+
from graphgen.utils import detect_main_language, logger
|
| 8 |
+
|
| 9 |
+
random.seed(42)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class MaskedFillInBlankGenerator(BaseGenerator):
|
| 13 |
+
"""
|
| 14 |
+
Masked Fill-in-blank Generator follows a TWO-STEP process:
|
| 15 |
+
1. rephrase: Rephrase the input nodes and edges into a coherent text that maintains the original meaning.
|
| 16 |
+
2. mask: Randomly select a node from the input nodes, and then mask the name of the node in the rephrased text.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
@staticmethod
|
| 20 |
+
def build_prompt(
|
| 21 |
+
batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
|
| 22 |
+
) -> str:
|
| 23 |
+
"""
|
| 24 |
+
Build prompts for REPHRASE.
|
| 25 |
+
:param batch
|
| 26 |
+
:return:
|
| 27 |
+
"""
|
| 28 |
+
nodes, edges = batch
|
| 29 |
+
entities_str = "\n".join(
|
| 30 |
+
[
|
| 31 |
+
f"{index + 1}. {node[0]}: {node[1]['description']}"
|
| 32 |
+
for index, node in enumerate(nodes)
|
| 33 |
+
]
|
| 34 |
+
)
|
| 35 |
+
relations_str = "\n".join(
|
| 36 |
+
[
|
| 37 |
+
f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}"
|
| 38 |
+
for index, edge in enumerate(edges)
|
| 39 |
+
]
|
| 40 |
+
)
|
| 41 |
+
language = detect_main_language(entities_str + relations_str)
|
| 42 |
+
|
| 43 |
+
# TODO: configure add_context
|
| 44 |
+
# if add_context:
|
| 45 |
+
# original_ids = [
|
| 46 |
+
# node["source_id"].split("<SEP>")[0] for node in _process_nodes
|
| 47 |
+
# ] + [edge[2]["source_id"].split("<SEP>")[0] for edge in _process_edges]
|
| 48 |
+
# original_ids = list(set(original_ids))
|
| 49 |
+
# original_text = await text_chunks_storage.get_by_ids(original_ids)
|
| 50 |
+
# original_text = "\n".join(
|
| 51 |
+
# [
|
| 52 |
+
# f"{index + 1}. {text['content']}"
|
| 53 |
+
# for index, text in enumerate(original_text)
|
| 54 |
+
# ]
|
| 55 |
+
# )
|
| 56 |
+
prompt = AGGREGATED_GENERATION_PROMPT[language]["ANSWER_REPHRASING"].format(
|
| 57 |
+
entities=entities_str, relationships=relations_str
|
| 58 |
+
)
|
| 59 |
+
return prompt
|
| 60 |
+
|
| 61 |
+
@staticmethod
|
| 62 |
+
def parse_rephrased_text(response: str) -> Optional[str]:
|
| 63 |
+
"""
|
| 64 |
+
Parse the rephrased text from the response.
|
| 65 |
+
:param response:
|
| 66 |
+
:return: rephrased text
|
| 67 |
+
"""
|
| 68 |
+
rephrased_match = re.search(
|
| 69 |
+
r"<rephrased_text>(.*?)</rephrased_text>", response, re.DOTALL
|
| 70 |
+
)
|
| 71 |
+
if rephrased_match:
|
| 72 |
+
rephrased_text = rephrased_match.group(1).strip()
|
| 73 |
+
else:
|
| 74 |
+
logger.warning("Failed to parse rephrased text from response: %s", response)
|
| 75 |
+
return None
|
| 76 |
+
return rephrased_text.strip('"').strip("'")
|
| 77 |
+
|
| 78 |
+
@staticmethod
|
| 79 |
+
def parse_response(response: str) -> dict:
|
| 80 |
+
pass
|
| 81 |
+
|
| 82 |
+
async def generate(
|
| 83 |
+
self,
|
| 84 |
+
batch: tuple[
|
| 85 |
+
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
|
| 86 |
+
],
|
| 87 |
+
) -> list[dict]:
|
| 88 |
+
"""
|
| 89 |
+
Generate QAs based on a given batch.
|
| 90 |
+
:param batch
|
| 91 |
+
:return: QA pairs
|
| 92 |
+
"""
|
| 93 |
+
rephrasing_prompt = self.build_prompt(batch)
|
| 94 |
+
response = await self.llm_client.generate_answer(rephrasing_prompt)
|
| 95 |
+
context = self.parse_rephrased_text(response)
|
| 96 |
+
if not context:
|
| 97 |
+
return []
|
| 98 |
+
|
| 99 |
+
nodes, edges = batch
|
| 100 |
+
|
| 101 |
+
assert len(nodes) == 3, (
|
| 102 |
+
"MaskedFillInBlankGenerator currently only supports quintuples that has 3 nodes, "
|
| 103 |
+
f"but got {len(nodes)} nodes."
|
| 104 |
+
)
|
| 105 |
+
assert len(edges) == 2, (
|
| 106 |
+
"MaskedFillInBlankGenerator currently only supports quintuples that has 2 edges, "
|
| 107 |
+
f"but got {len(edges)} edges."
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
node1, node2, node3 = nodes
|
| 111 |
+
mask_node = random.choice([node1, node2, node3])
|
| 112 |
+
mask_node_name = mask_node[1]["entity_name"].strip("'\" \n\r\t")
|
| 113 |
+
mask_pattern = re.compile(re.escape(mask_node_name), re.IGNORECASE)
|
| 114 |
+
|
| 115 |
+
match = re.search(mask_pattern, context)
|
| 116 |
+
if match:
|
| 117 |
+
gth = match.group(0)
|
| 118 |
+
masked_context = mask_pattern.sub("___", context)
|
| 119 |
+
else:
|
| 120 |
+
logger.debug(
|
| 121 |
+
"Regex Match Failed!\n"
|
| 122 |
+
"Expected name of node: %s\n"
|
| 123 |
+
"Actual context: %s\n",
|
| 124 |
+
mask_node_name,
|
| 125 |
+
context,
|
| 126 |
+
)
|
| 127 |
+
return []
|
| 128 |
+
|
| 129 |
+
logger.debug("masked_context: %s", masked_context)
|
| 130 |
+
qa_pairs = {
|
| 131 |
+
"question": masked_context,
|
| 132 |
+
"answer": gth,
|
| 133 |
+
}
|
| 134 |
+
return [qa_pairs]
|
graphgen/models/partitioner/__init__.py
CHANGED
|
@@ -3,3 +3,5 @@ from .bfs_partitioner import BFSPartitioner
|
|
| 3 |
from .dfs_partitioner import DFSPartitioner
|
| 4 |
from .ece_partitioner import ECEPartitioner
|
| 5 |
from .leiden_partitioner import LeidenPartitioner
|
|
|
|
|
|
|
|
|
| 3 |
from .dfs_partitioner import DFSPartitioner
|
| 4 |
from .ece_partitioner import ECEPartitioner
|
| 5 |
from .leiden_partitioner import LeidenPartitioner
|
| 6 |
+
from .quintuple_partitioner import QuintuplePartitioner
|
| 7 |
+
from .triple_partitioner import TriplePartitioner
|
graphgen/models/partitioner/quintuple_partitioner.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from collections import deque
|
| 3 |
+
from typing import Any, Iterable, Set
|
| 4 |
+
|
| 5 |
+
from graphgen.bases import BaseGraphStorage, BasePartitioner
|
| 6 |
+
from graphgen.bases.datatypes import Community
|
| 7 |
+
|
| 8 |
+
random.seed(42)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class QuintuplePartitioner(BasePartitioner):
|
| 12 |
+
"""
|
| 13 |
+
quintuple Partitioner that partitions the graph into multiple distinct quintuple (node, edge, node, edge, node).
|
| 14 |
+
1. Automatically ignore isolated points.
|
| 15 |
+
2. In each connected component, yield quintuples in the order of BFS.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def partition(
|
| 19 |
+
self,
|
| 20 |
+
g: BaseGraphStorage,
|
| 21 |
+
**kwargs: Any,
|
| 22 |
+
) -> Iterable[Community]:
|
| 23 |
+
nodes = [n[0] for n in g.get_all_nodes()]
|
| 24 |
+
random.shuffle(nodes)
|
| 25 |
+
|
| 26 |
+
visited_nodes: Set[str] = set()
|
| 27 |
+
used_edges: Set[frozenset[str]] = set()
|
| 28 |
+
|
| 29 |
+
for seed in nodes:
|
| 30 |
+
if seed in visited_nodes:
|
| 31 |
+
continue
|
| 32 |
+
|
| 33 |
+
# start BFS in a connected component
|
| 34 |
+
queue = deque([seed])
|
| 35 |
+
visited_nodes.add(seed)
|
| 36 |
+
|
| 37 |
+
while queue:
|
| 38 |
+
u = queue.popleft()
|
| 39 |
+
|
| 40 |
+
# collect all neighbors connected to node u via unused edges
|
| 41 |
+
available_neighbors = []
|
| 42 |
+
for v in g.get_neighbors(u):
|
| 43 |
+
edge_key = frozenset((u, v))
|
| 44 |
+
if edge_key not in used_edges:
|
| 45 |
+
available_neighbors.append(v)
|
| 46 |
+
|
| 47 |
+
# standard BFS queue maintenance
|
| 48 |
+
if v not in visited_nodes:
|
| 49 |
+
visited_nodes.add(v)
|
| 50 |
+
queue.append(v)
|
| 51 |
+
|
| 52 |
+
random.shuffle(available_neighbors)
|
| 53 |
+
|
| 54 |
+
# every two neighbors paired with the center node u creates one quintuple
|
| 55 |
+
# Note: If available_neighbors has an odd length, the remaining edge
|
| 56 |
+
# stays unused for now. It may be matched into a quintuple later
|
| 57 |
+
# when its other endpoint is processed as a center node.
|
| 58 |
+
for i in range(0, len(available_neighbors) // 2 * 2, 2):
|
| 59 |
+
v1 = available_neighbors[i]
|
| 60 |
+
v2 = available_neighbors[i + 1]
|
| 61 |
+
|
| 62 |
+
edge1 = frozenset((u, v1))
|
| 63 |
+
edge2 = frozenset((u, v2))
|
| 64 |
+
|
| 65 |
+
used_edges.add(edge1)
|
| 66 |
+
used_edges.add(edge2)
|
| 67 |
+
|
| 68 |
+
v1_s, v2_s = sorted((v1, v2))
|
| 69 |
+
|
| 70 |
+
yield Community(
|
| 71 |
+
id=f"{v1_s}-{u}-{v2_s}",
|
| 72 |
+
nodes=[v1_s, u, v2_s],
|
| 73 |
+
edges=[tuple(sorted((v1_s, u))), tuple(sorted((u, v2_s)))],
|
| 74 |
+
)
|
graphgen/models/partitioner/triple_partitioner.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from collections import deque
|
| 3 |
+
from typing import Any, Iterable, Set
|
| 4 |
+
|
| 5 |
+
from graphgen.bases import BaseGraphStorage, BasePartitioner
|
| 6 |
+
from graphgen.bases.datatypes import Community
|
| 7 |
+
|
| 8 |
+
random.seed(42)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class TriplePartitioner(BasePartitioner):
|
| 12 |
+
"""
|
| 13 |
+
Triple Partitioner that partitions the graph into multiple distinct triples (node, edge, node).
|
| 14 |
+
1. Automatically ignore isolated points.
|
| 15 |
+
2. In each connected component, yield triples in the order of BFS.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def partition(
|
| 19 |
+
self,
|
| 20 |
+
g: BaseGraphStorage,
|
| 21 |
+
**kwargs: Any,
|
| 22 |
+
) -> Iterable[Community]:
|
| 23 |
+
nodes = [n[0] for n in g.get_all_nodes()]
|
| 24 |
+
random.shuffle(nodes)
|
| 25 |
+
|
| 26 |
+
visited_nodes: Set[str] = set()
|
| 27 |
+
used_edges: Set[frozenset[str]] = set()
|
| 28 |
+
|
| 29 |
+
for seed in nodes:
|
| 30 |
+
if seed in visited_nodes:
|
| 31 |
+
continue
|
| 32 |
+
|
| 33 |
+
# start BFS in a connected component
|
| 34 |
+
queue = deque([seed])
|
| 35 |
+
visited_nodes.add(seed)
|
| 36 |
+
|
| 37 |
+
while queue:
|
| 38 |
+
u = queue.popleft()
|
| 39 |
+
|
| 40 |
+
for v in g.get_neighbors(u):
|
| 41 |
+
edge_key = frozenset((u, v))
|
| 42 |
+
|
| 43 |
+
# if this edge has not been used, a new triple has been found
|
| 44 |
+
if edge_key not in used_edges:
|
| 45 |
+
used_edges.add(edge_key)
|
| 46 |
+
|
| 47 |
+
# use the edge name to ensure the uniqueness of the ID
|
| 48 |
+
u_sorted, v_sorted = sorted((u, v))
|
| 49 |
+
yield Community(
|
| 50 |
+
id=f"{u_sorted}-{v_sorted}",
|
| 51 |
+
nodes=[u_sorted, v_sorted],
|
| 52 |
+
edges=[(u_sorted, v_sorted)],
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# continue to BFS
|
| 56 |
+
if v not in visited_nodes:
|
| 57 |
+
visited_nodes.add(v)
|
| 58 |
+
queue.append(v)
|
graphgen/operators/generate/generate_service.py
CHANGED
|
@@ -71,6 +71,10 @@ class GenerateService(BaseOperator):
|
|
| 71 |
self.llm_client,
|
| 72 |
num_of_questions=generate_kwargs.get("num_of_questions", 5),
|
| 73 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
elif self.method == "true_false":
|
| 75 |
from graphgen.models import TrueFalseGenerator
|
| 76 |
|
|
|
|
| 71 |
self.llm_client,
|
| 72 |
num_of_questions=generate_kwargs.get("num_of_questions", 5),
|
| 73 |
)
|
| 74 |
+
elif self.method == "masked_fill_in_blank":
|
| 75 |
+
from graphgen.models import MaskedFillInBlankGenerator
|
| 76 |
+
|
| 77 |
+
self.generator = MaskedFillInBlankGenerator(self.llm_client)
|
| 78 |
elif self.method == "true_false":
|
| 79 |
from graphgen.models import TrueFalseGenerator
|
| 80 |
|
graphgen/operators/partition/partition_service.py
CHANGED
|
@@ -28,7 +28,7 @@ class PartitionService(BaseOperator):
|
|
| 28 |
|
| 29 |
self.tokenizer_instance: BaseTokenizer = Tokenizer(model_name=tokenizer_model)
|
| 30 |
method = partition_kwargs["method"]
|
| 31 |
-
self.method_params = partition_kwargs
|
| 32 |
|
| 33 |
if method == "bfs":
|
| 34 |
from graphgen.models import BFSPartitioner
|
|
@@ -57,6 +57,14 @@ class PartitionService(BaseOperator):
|
|
| 57 |
if self.method_params.get("anchor_ids")
|
| 58 |
else None,
|
| 59 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
else:
|
| 61 |
raise ValueError(f"Unsupported partition method: {method}")
|
| 62 |
|
|
|
|
| 28 |
|
| 29 |
self.tokenizer_instance: BaseTokenizer = Tokenizer(model_name=tokenizer_model)
|
| 30 |
method = partition_kwargs["method"]
|
| 31 |
+
self.method_params = partition_kwargs.get("method_params", {})
|
| 32 |
|
| 33 |
if method == "bfs":
|
| 34 |
from graphgen.models import BFSPartitioner
|
|
|
|
| 57 |
if self.method_params.get("anchor_ids")
|
| 58 |
else None,
|
| 59 |
)
|
| 60 |
+
elif method == "triple":
|
| 61 |
+
from graphgen.models import TriplePartitioner
|
| 62 |
+
|
| 63 |
+
self.partitioner = TriplePartitioner()
|
| 64 |
+
elif method == "quintuple":
|
| 65 |
+
from graphgen.models import QuintuplePartitioner
|
| 66 |
+
|
| 67 |
+
self.partitioner = QuintuplePartitioner()
|
| 68 |
else:
|
| 69 |
raise ValueError(f"Unsupported partition method: {method}")
|
| 70 |
|