github-actions[bot] commited on
Commit
7566ac3
·
1 Parent(s): 31df32c

Auto-sync from demo at Fri Dec 26 08:29:01 UTC 2025

Browse files
Files changed (35) hide show
  1. graphgen/bases/__init__.py +1 -0
  2. graphgen/bases/base_evaluator.py +10 -0
  3. graphgen/bases/base_storage.py +43 -4
  4. graphgen/common/init_storage.py +39 -5
  5. graphgen/engine.py +2 -0
  6. graphgen/models/__init__.py +9 -1
  7. graphgen/models/evaluator/__init__.py +2 -4
  8. graphgen/models/evaluator/base_evaluator.py +0 -52
  9. graphgen/models/evaluator/kg/__init__.py +18 -0
  10. graphgen/models/evaluator/kg/accuracy_evaluator.py +350 -0
  11. graphgen/models/evaluator/kg/consistency_evaluator.py +380 -0
  12. graphgen/models/evaluator/kg/structure_evaluator.py +97 -0
  13. graphgen/models/evaluator/length_evaluator.py +0 -19
  14. graphgen/models/evaluator/qa/__init__.py +4 -0
  15. graphgen/models/evaluator/qa/length_evaluator.py +18 -0
  16. graphgen/models/evaluator/{mtld_evaluator.py → qa/mtld_evaluator.py} +18 -23
  17. graphgen/models/evaluator/qa/reward_evaluator.py +66 -0
  18. graphgen/models/evaluator/qa/uni_evaluator.py +105 -0
  19. graphgen/models/evaluator/reward_evaluator.py +0 -107
  20. graphgen/models/evaluator/uni_evaluator.py +0 -183
  21. graphgen/models/storage/graph/kuzu_storage.py +90 -1
  22. graphgen/models/storage/graph/networkx_storage.py +26 -1
  23. graphgen/operators/__init__.py +3 -0
  24. graphgen/operators/evaluate/__init__.py +3 -0
  25. graphgen/operators/evaluate/evaluate.py +0 -177
  26. graphgen/operators/evaluate/evaluate_service.py +181 -0
  27. graphgen/run.py +5 -4
  28. graphgen/templates/__init__.py +1 -0
  29. graphgen/templates/evaluation/__init__.py +1 -0
  30. graphgen/templates/evaluation/kg/__init__.py +2 -0
  31. graphgen/templates/evaluation/kg/accuracy_evaluation.py +156 -0
  32. graphgen/templates/evaluation/kg/consistency_evaluation.py +102 -0
  33. graphgen/utils/help_nltk.py +46 -24
  34. requirements.txt +2 -1
  35. webui/utils/count_tokens.py +3 -1
graphgen/bases/__init__.py CHANGED
@@ -9,4 +9,5 @@ from .base_searcher import BaseSearcher
9
  from .base_splitter import BaseSplitter
10
  from .base_storage import BaseGraphStorage, BaseKVStorage, StorageNameSpace
11
  from .base_tokenizer import BaseTokenizer
 
12
  from .datatypes import Chunk, Config, Node, QAPair, Token
 
9
  from .base_splitter import BaseSplitter
10
  from .base_storage import BaseGraphStorage, BaseKVStorage, StorageNameSpace
11
  from .base_tokenizer import BaseTokenizer
12
+ from .base_evaluator import BaseEvaluator
13
  from .datatypes import Chunk, Config, Node, QAPair, Token
graphgen/bases/base_evaluator.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from .datatypes import QAPair
3
+
4
+
5
+ class BaseEvaluator(ABC):
6
+ @abstractmethod
7
+ def evaluate(self, pair: QAPair) -> float:
8
+ """
9
+ Evaluate the text and return a score.
10
+ """
graphgen/bases/base_storage.py CHANGED
@@ -1,5 +1,6 @@
 
1
  from dataclasses import dataclass
2
- from typing import Generic, TypeVar, Union
3
 
4
  T = TypeVar("T")
5
 
@@ -45,52 +46,90 @@ class BaseKVStorage(Generic[T], StorageNameSpace):
45
  raise NotImplementedError
46
 
47
 
48
- class BaseGraphStorage(StorageNameSpace):
 
 
 
 
 
49
  def has_node(self, node_id: str) -> bool:
50
  raise NotImplementedError
51
 
 
52
  def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
53
  raise NotImplementedError
54
 
 
55
  def node_degree(self, node_id: str) -> int:
56
  raise NotImplementedError
57
 
58
- def edge_degree(self, src_id: str, tgt_id: str) -> int:
59
- raise NotImplementedError
 
60
 
 
 
 
 
 
 
 
 
61
  def get_node(self, node_id: str) -> Union[dict, None]:
62
  raise NotImplementedError
63
 
 
64
  def update_node(self, node_id: str, node_data: dict[str, str]):
65
  raise NotImplementedError
66
 
 
67
  def get_all_nodes(self) -> Union[list[tuple[str, dict]], None]:
68
  raise NotImplementedError
69
 
 
 
 
 
 
70
  def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None]:
71
  raise NotImplementedError
72
 
 
73
  def update_edge(
74
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
75
  ):
76
  raise NotImplementedError
77
 
 
78
  def get_all_edges(self) -> Union[list[tuple[str, str, dict]], None]:
79
  raise NotImplementedError
80
 
 
 
 
 
 
81
  def get_node_edges(self, source_node_id: str) -> Union[list[tuple[str, str]], None]:
82
  raise NotImplementedError
83
 
 
84
  def upsert_node(self, node_id: str, node_data: dict[str, str]):
85
  raise NotImplementedError
86
 
 
87
  def upsert_edge(
88
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
89
  ):
90
  raise NotImplementedError
91
 
 
92
  def delete_node(self, node_id: str):
93
  raise NotImplementedError
94
 
 
95
  def reload(self):
96
  raise NotImplementedError
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
  from dataclasses import dataclass
3
+ from typing import Dict, Generic, List, Set, TypeVar, Union
4
 
5
  T = TypeVar("T")
6
 
 
46
  raise NotImplementedError
47
 
48
 
49
+ class BaseGraphStorage(StorageNameSpace, ABC):
50
+ @abstractmethod
51
+ def is_directed(self) -> bool:
52
+ pass
53
+
54
+ @abstractmethod
55
  def has_node(self, node_id: str) -> bool:
56
  raise NotImplementedError
57
 
58
+ @abstractmethod
59
  def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
60
  raise NotImplementedError
61
 
62
+ @abstractmethod
63
  def node_degree(self, node_id: str) -> int:
64
  raise NotImplementedError
65
 
66
+ @abstractmethod
67
+ def get_all_node_degrees(self) -> Dict[str, int]:
68
+ pass
69
 
70
+ def get_isolated_nodes(self) -> List[str]:
71
+ return [
72
+ node_id
73
+ for node_id, degree in self.get_all_node_degrees().items()
74
+ if degree == 0
75
+ ]
76
+
77
+ @abstractmethod
78
  def get_node(self, node_id: str) -> Union[dict, None]:
79
  raise NotImplementedError
80
 
81
+ @abstractmethod
82
  def update_node(self, node_id: str, node_data: dict[str, str]):
83
  raise NotImplementedError
84
 
85
+ @abstractmethod
86
  def get_all_nodes(self) -> Union[list[tuple[str, dict]], None]:
87
  raise NotImplementedError
88
 
89
+ @abstractmethod
90
+ def get_node_count(self) -> int:
91
+ pass
92
+
93
+ @abstractmethod
94
  def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None]:
95
  raise NotImplementedError
96
 
97
+ @abstractmethod
98
  def update_edge(
99
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
100
  ):
101
  raise NotImplementedError
102
 
103
+ @abstractmethod
104
  def get_all_edges(self) -> Union[list[tuple[str, str, dict]], None]:
105
  raise NotImplementedError
106
 
107
+ @abstractmethod
108
+ def get_edge_count(self) -> int:
109
+ pass
110
+
111
+ @abstractmethod
112
  def get_node_edges(self, source_node_id: str) -> Union[list[tuple[str, str]], None]:
113
  raise NotImplementedError
114
 
115
+ @abstractmethod
116
  def upsert_node(self, node_id: str, node_data: dict[str, str]):
117
  raise NotImplementedError
118
 
119
+ @abstractmethod
120
  def upsert_edge(
121
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
122
  ):
123
  raise NotImplementedError
124
 
125
+ @abstractmethod
126
  def delete_node(self, node_id: str):
127
  raise NotImplementedError
128
 
129
+ @abstractmethod
130
  def reload(self):
131
  raise NotImplementedError
132
+
133
+ @abstractmethod
134
+ def get_connected_components(self, undirected: bool = True) -> List[Set[str]]:
135
+ raise NotImplementedError
graphgen/common/init_storage.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, Union
2
 
3
  import ray
4
 
@@ -68,6 +68,21 @@ class GraphStorageActor:
68
  def index_done_callback(self):
69
  return self.graph.index_done_callback()
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  def has_node(self, node_id: str) -> bool:
72
  return self.graph.has_node(node_id)
73
 
@@ -165,6 +180,21 @@ class RemoteGraphStorageProxy(BaseGraphStorage):
165
  def index_done_callback(self):
166
  return ray.get(self.actor.index_done_callback.remote())
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  def has_node(self, node_id: str) -> bool:
169
  return ray.get(self.actor.has_node.remote(node_id))
170
 
@@ -239,10 +269,14 @@ class StorageFactory:
239
  try:
240
  actor_handle = ray.get_actor(actor_name)
241
  except ValueError:
242
- actor_handle = ray.remote(actor_class).options(
243
- name=actor_name,
244
- get_if_exists=True,
245
- ).remote(backend, working_dir, namespace)
 
 
 
 
246
  ray.get(actor_handle.ready.remote())
247
  return proxy_class(actor_handle)
248
 
 
1
+ from typing import Any, Dict, List, Set, Union
2
 
3
  import ray
4
 
 
68
  def index_done_callback(self):
69
  return self.graph.index_done_callback()
70
 
71
+ def is_directed(self) -> bool:
72
+ return self.graph.is_directed()
73
+
74
+ def get_all_node_degrees(self) -> Dict[str, int]:
75
+ return self.graph.get_all_node_degrees()
76
+
77
+ def get_node_count(self) -> int:
78
+ return self.graph.get_node_count()
79
+
80
+ def get_edge_count(self) -> int:
81
+ return self.graph.get_edge_count()
82
+
83
+ def get_connected_components(self, undirected: bool = True) -> List[Set[str]]:
84
+ return self.graph.get_connected_components(undirected)
85
+
86
  def has_node(self, node_id: str) -> bool:
87
  return self.graph.has_node(node_id)
88
 
 
180
  def index_done_callback(self):
181
  return ray.get(self.actor.index_done_callback.remote())
182
 
183
+ def is_directed(self) -> bool:
184
+ return ray.get(self.actor.is_directed.remote())
185
+
186
+ def get_all_node_degrees(self) -> Dict[str, int]:
187
+ return ray.get(self.actor.get_all_node_degrees.remote())
188
+
189
+ def get_node_count(self) -> int:
190
+ return ray.get(self.actor.get_node_count.remote())
191
+
192
+ def get_edge_count(self) -> int:
193
+ return ray.get(self.actor.get_edge_count.remote())
194
+
195
+ def get_connected_components(self, undirected: bool = True) -> List[Set[str]]:
196
+ return ray.get(self.actor.get_connected_components.remote(undirected))
197
+
198
  def has_node(self, node_id: str) -> bool:
199
  return ray.get(self.actor.has_node.remote(node_id))
200
 
 
269
  try:
270
  actor_handle = ray.get_actor(actor_name)
271
  except ValueError:
272
+ actor_handle = (
273
+ ray.remote(actor_class)
274
+ .options(
275
+ name=actor_name,
276
+ get_if_exists=True,
277
+ )
278
+ .remote(backend, working_dir, namespace)
279
+ )
280
  ray.get(actor_handle.ready.remote())
281
  return proxy_class(actor_handle)
282
 
graphgen/engine.py CHANGED
@@ -271,6 +271,8 @@ class Engine:
271
 
272
  for node in sorted_nodes:
273
  self._execute_node(node, initial_ds)
 
 
274
 
275
  output_nodes = [n for n in sorted_nodes if getattr(n, "save_output", False)]
276
  return {node.id: self.datasets[node.id] for node in output_nodes}
 
271
 
272
  for node in sorted_nodes:
273
  self._execute_node(node, initial_ds)
274
+ if getattr(node, "save_output", False):
275
+ self.datasets[node.id] = self.datasets[node.id].materialize()
276
 
277
  output_nodes = [n for n in sorted_nodes if getattr(n, "save_output", False)]
278
  return {node.id: self.datasets[node.id] for node in output_nodes}
graphgen/models/__init__.py CHANGED
@@ -1,4 +1,12 @@
1
- from .evaluator import LengthEvaluator, MTLDEvaluator, RewardEvaluator, UniEvaluator
 
 
 
 
 
 
 
 
2
  from .generator import (
3
  AggregatedGenerator,
4
  AtomicGenerator,
 
1
+ from .evaluator import (
2
+ AccuracyEvaluator,
3
+ ConsistencyEvaluator,
4
+ LengthEvaluator,
5
+ MTLDEvaluator,
6
+ RewardEvaluator,
7
+ StructureEvaluator,
8
+ UniEvaluator,
9
+ )
10
  from .generator import (
11
  AggregatedGenerator,
12
  AtomicGenerator,
graphgen/models/evaluator/__init__.py CHANGED
@@ -1,4 +1,2 @@
1
- from .length_evaluator import LengthEvaluator
2
- from .mtld_evaluator import MTLDEvaluator
3
- from .reward_evaluator import RewardEvaluator
4
- from .uni_evaluator import UniEvaluator
 
1
+ from .kg import AccuracyEvaluator, ConsistencyEvaluator, StructureEvaluator
2
+ from .qa import LengthEvaluator, MTLDEvaluator, RewardEvaluator, UniEvaluator
 
 
graphgen/models/evaluator/base_evaluator.py DELETED
@@ -1,52 +0,0 @@
1
- import asyncio
2
-
3
- from tqdm.asyncio import tqdm as tqdm_async
4
-
5
- from graphgen.bases.datatypes import QAPair
6
- from graphgen.utils import create_event_loop
7
-
8
-
9
- class BaseEvaluator:
10
- def __init__(self, max_concurrent: int = 100):
11
- self.max_concurrent = max_concurrent
12
- self.results: list[float] = None
13
-
14
- def evaluate(self, pairs: list[QAPair]) -> list[float]:
15
- """
16
- Evaluate the text and return a score.
17
- """
18
- return create_event_loop().run_until_complete(self.async_evaluate(pairs))
19
-
20
- async def async_evaluate(self, pairs: list[QAPair]) -> list[float]:
21
- semaphore = asyncio.Semaphore(self.max_concurrent)
22
-
23
- async def evaluate_with_semaphore(pair):
24
- async with semaphore: # 获取Semaphore
25
- return await self.evaluate_single(pair)
26
-
27
- results = []
28
- for result in tqdm_async(
29
- asyncio.as_completed([evaluate_with_semaphore(pair) for pair in pairs]),
30
- total=len(pairs),
31
- ):
32
- results.append(await result)
33
- return results
34
-
35
- async def evaluate_single(self, pair: QAPair) -> float:
36
- raise NotImplementedError()
37
-
38
- def get_average_score(self, pairs: list[QAPair]) -> float:
39
- """
40
- Get the average score of a batch of texts.
41
- """
42
- results = self.evaluate(pairs)
43
- self.results = results
44
- return sum(self.results) / len(pairs)
45
-
46
- def get_min_max_score(self, pairs: list[QAPair]) -> tuple[float, float]:
47
- """
48
- Get the min and max score of a batch of texts.
49
- """
50
- if self.results is None:
51
- self.get_average_score(pairs)
52
- return min(self.results), max(self.results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphgen/models/evaluator/kg/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Knowledge Graph Quality Evaluator
3
+
4
+ This module provides comprehensive quality evaluation for knowledge graphs,
5
+ 1. accuracy assessment (entity/relation/triple validation),
6
+ 2. consistency assessment (attribute conflict detection), and structural
7
+ 3. robustness assessment (noise ratio, connectivity, degree distribution).
8
+ """
9
+
10
+ from .accuracy_evaluator import AccuracyEvaluator
11
+ from .consistency_evaluator import ConsistencyEvaluator
12
+ from .structure_evaluator import StructureEvaluator
13
+
14
+ __all__ = [
15
+ "AccuracyEvaluator",
16
+ "ConsistencyEvaluator",
17
+ "StructureEvaluator",
18
+ ]
graphgen/models/evaluator/kg/accuracy_evaluator.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import re
4
+ from typing import Any, Dict, List
5
+
6
+ from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper
7
+ from graphgen.bases.datatypes import Chunk
8
+ from graphgen.templates import ACCURACY_EVALUATION_PROMPT
9
+ from graphgen.utils import detect_main_language, logger
10
+
11
+
12
+ class AccuracyEvaluator:
13
+ """Evaluates accuracy of entity recognition and relation extraction using LLM-as-a-Judge.
14
+
15
+ For each chunk, uses LLM to evaluate the quality of extracted entities and relations
16
+ by comparing them with the original chunk content. Provides multi-dimensional quality
17
+ scores (accuracy, completeness, precision).
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ graph_storage: BaseGraphStorage,
23
+ chunk_storage: BaseKVStorage,
24
+ llm_client: BaseLLMWrapper,
25
+ ):
26
+ self.graph_storage = graph_storage
27
+ self.chunk_storage = chunk_storage
28
+ self.llm_client = llm_client
29
+
30
+ def evaluate(self) -> Dict[str, Any]:
31
+ """Evaluate entity and relation extraction quality using LLM-as-a-Judge.
32
+
33
+ Returns:
34
+ Dictionary containing entity_accuracy and relation_accuracy metrics.
35
+ """
36
+ # 1. Load all chunks from storage
37
+ chunks = self._load_chunks_from_storage()
38
+
39
+ if not chunks:
40
+ logger.warning("No chunks found in storage")
41
+ return {"error": "No chunks found in storage"}
42
+
43
+ logger.info(f"Found {len(chunks)} chunks to evaluate")
44
+
45
+ # 2. Evaluate each chunk
46
+ entity_evaluations, relation_evaluations = self._evaluate_all_chunks(chunks)
47
+
48
+ # 3. Aggregate results
49
+ return self._aggregate_evaluation_results(
50
+ entity_evaluations, relation_evaluations
51
+ )
52
+
53
+ def _load_chunks_from_storage(self) -> List[Chunk]:
54
+ """Load all chunks from chunk storage."""
55
+ chunks = []
56
+ all_chunk_data = self.chunk_storage.get_all()
57
+
58
+ for chunk_id, chunk_data in all_chunk_data.items():
59
+ try:
60
+ chunk = Chunk.from_dict(chunk_id, chunk_data)
61
+ chunks.append(chunk)
62
+ except Exception as e:
63
+ logger.warning(f"Failed to load chunk {chunk_id}: {e}")
64
+ continue
65
+
66
+ return chunks
67
+
68
+ def _get_extracted_entities_for_chunk(self, chunk_id: str) -> List[Dict]:
69
+ """Get all entities extracted from the specified chunk."""
70
+ entities = []
71
+ all_nodes = self.graph_storage.get_all_nodes() or []
72
+
73
+ for node_id, node_data in all_nodes:
74
+ if not isinstance(node_data, dict):
75
+ continue
76
+ source_ids = node_data.get("source_id", "").split("<SEP>")
77
+ # Check if this chunk_id is in the source_ids
78
+ if chunk_id in [sid.strip() for sid in source_ids if sid.strip()]:
79
+ entities.append(
80
+ {
81
+ "entity_name": node_data.get("entity_name", node_id),
82
+ "entity_type": node_data.get("entity_type", ""),
83
+ "description": node_data.get("description", ""),
84
+ }
85
+ )
86
+
87
+ return entities
88
+
89
+ def _get_extracted_relations_for_chunk(self, chunk_id: str) -> List[Dict]:
90
+ """Get all relations extracted from the specified chunk."""
91
+ relations = []
92
+ all_edges = self.graph_storage.get_all_edges() or []
93
+
94
+ for src_id, dst_id, edge_data in all_edges:
95
+ if not isinstance(edge_data, dict):
96
+ continue
97
+ source_ids = edge_data.get("source_id", "").split("<SEP>")
98
+ # Check if this chunk_id is in the source_ids
99
+ if chunk_id in [sid.strip() for sid in source_ids if sid.strip()]:
100
+ src_node = self.graph_storage.get_node(src_id) or {}
101
+ dst_node = self.graph_storage.get_node(dst_id) or {}
102
+ relations.append(
103
+ {
104
+ "source_entity": src_node.get("entity_name", src_id),
105
+ "target_entity": dst_node.get("entity_name", dst_id),
106
+ "relationship_summary": edge_data.get("description", ""),
107
+ }
108
+ )
109
+
110
+ return relations
111
+
112
+ def _evaluate_all_chunks(
113
+ self, chunks: List[Chunk]
114
+ ) -> tuple[List[Dict], List[Dict]]:
115
+ """Evaluate all chunks sequentially."""
116
+ entity_evaluations = []
117
+ relation_evaluations = []
118
+
119
+ for chunk in chunks:
120
+ try:
121
+ entities = self._get_extracted_entities_for_chunk(chunk.id)
122
+ relations = self._get_extracted_relations_for_chunk(chunk.id)
123
+
124
+ entity_eval = self._evaluate_entity_extraction(chunk, entities)
125
+ relation_eval = self._evaluate_relation_extraction(chunk, relations)
126
+
127
+ entity_evaluations.append(entity_eval)
128
+ relation_evaluations.append(relation_eval)
129
+ except Exception as e:
130
+ logger.error(f"Failed to evaluate chunk {chunk.id}: {e}")
131
+ continue
132
+
133
+ return entity_evaluations, relation_evaluations
134
+
135
+ def _evaluate_entity_extraction(
136
+ self, chunk: Chunk, extracted_entities: List[Dict]
137
+ ) -> Dict[str, Any]:
138
+ """Use LLM to evaluate entity extraction quality."""
139
+ try:
140
+ lang = detect_main_language(chunk.content)
141
+
142
+ prompt = ACCURACY_EVALUATION_PROMPT[lang]["ENTITY"].format(
143
+ chunk_content=chunk.content,
144
+ extracted_entities=json.dumps(
145
+ extracted_entities, ensure_ascii=False, indent=2
146
+ ),
147
+ )
148
+
149
+ response = asyncio.run(self.llm_client.generate_answer(prompt))
150
+
151
+ # Try to parse JSON response
152
+ try:
153
+ evaluation_result = json.loads(response)
154
+ except json.JSONDecodeError:
155
+ # Try to extract JSON from markdown code blocks or other formats
156
+ json_match = re.search(r"\{.*\}", response, re.DOTALL)
157
+ if json_match:
158
+ evaluation_result = json.loads(json_match.group(0))
159
+ else:
160
+ logger.warning(
161
+ f"Failed to parse LLM response for chunk {chunk.id}: {response[:200]}"
162
+ )
163
+ # Return default evaluation
164
+ evaluation_result = {
165
+ "accuracy": 0.0,
166
+ "completeness": 0.0,
167
+ "precision": 0.0,
168
+ "overall_score": 0.0,
169
+ "accuracy_reasoning": "Failed to parse LLM response",
170
+ "completeness_reasoning": "",
171
+ "precision_reasoning": "",
172
+ "issues": ["LLM response parsing failed"],
173
+ }
174
+
175
+ # Validate and calculate overall_score if not provided
176
+ if "overall_score" not in evaluation_result:
177
+ accuracy = float(evaluation_result.get("accuracy", 0.0))
178
+ completeness = float(evaluation_result.get("completeness", 0.0))
179
+ precision = float(evaluation_result.get("precision", 0.0))
180
+ evaluation_result["overall_score"] = (
181
+ 0.4 * accuracy + 0.4 * completeness + 0.2 * precision
182
+ )
183
+
184
+ return {
185
+ "chunk_id": chunk.id,
186
+ "chunk_content": chunk.content[:200]
187
+ if chunk.content
188
+ else "", # First 200 chars for debugging
189
+ "extracted_entities_count": len(extracted_entities),
190
+ **evaluation_result,
191
+ }
192
+ except Exception as e:
193
+ logger.error(
194
+ f"Error evaluating entity extraction for chunk {chunk.id}: {e}"
195
+ )
196
+ return {
197
+ "chunk_id": chunk.id,
198
+ "chunk_content": chunk.content[:200] if chunk.content else "",
199
+ "extracted_entities_count": len(extracted_entities),
200
+ "accuracy": 0.0,
201
+ "completeness": 0.0,
202
+ "precision": 0.0,
203
+ "overall_score": 0.0,
204
+ "accuracy_reasoning": f"Evaluation failed: {str(e)}",
205
+ "completeness_reasoning": "",
206
+ "precision_reasoning": "",
207
+ "issues": [f"Evaluation error: {str(e)}"],
208
+ }
209
+
210
+ def _evaluate_relation_extraction(
211
+ self, chunk: Chunk, extracted_relations: List[Dict]
212
+ ) -> Dict[str, Any]:
213
+ """Use LLM to evaluate relation extraction quality."""
214
+ try:
215
+ lang = detect_main_language(chunk.content)
216
+ prompt = ACCURACY_EVALUATION_PROMPT[lang]["RELATION"].format(
217
+ chunk_content=chunk.content,
218
+ extracted_relations=json.dumps(
219
+ extracted_relations, ensure_ascii=False, indent=2
220
+ ),
221
+ )
222
+
223
+ response = asyncio.run(self.llm_client.generate_answer(prompt))
224
+
225
+ # Try to parse JSON response
226
+ try:
227
+ evaluation_result = json.loads(response)
228
+ except json.JSONDecodeError:
229
+ # Try to extract JSON from markdown code blocks or other formats
230
+ json_match = re.search(r"\{.*\}", response, re.DOTALL)
231
+ if json_match:
232
+ evaluation_result = json.loads(json_match.group(0))
233
+ else:
234
+ logger.warning(
235
+ f"Failed to parse LLM response for chunk {chunk.id}: {response[:200]}"
236
+ )
237
+ # Return default evaluation
238
+ evaluation_result = {
239
+ "accuracy": 0.0,
240
+ "completeness": 0.0,
241
+ "precision": 0.0,
242
+ "overall_score": 0.0,
243
+ "accuracy_reasoning": "Failed to parse LLM response",
244
+ "completeness_reasoning": "",
245
+ "precision_reasoning": "",
246
+ "issues": ["LLM response parsing failed"],
247
+ }
248
+
249
+ # Validate and calculate overall_score if not provided
250
+ if "overall_score" not in evaluation_result:
251
+ accuracy = float(evaluation_result.get("accuracy", 0.0))
252
+ completeness = float(evaluation_result.get("completeness", 0.0))
253
+ precision = float(evaluation_result.get("precision", 0.0))
254
+ evaluation_result["overall_score"] = (
255
+ 0.4 * accuracy + 0.4 * completeness + 0.2 * precision
256
+ )
257
+
258
+ return {
259
+ "chunk_id": chunk.id,
260
+ "chunk_content": chunk.content[:200] if chunk.content else "",
261
+ "extracted_relations_count": len(extracted_relations),
262
+ **evaluation_result,
263
+ }
264
+ except Exception as e:
265
+ logger.error(
266
+ f"Error evaluating relation extraction for chunk {chunk.id}: {e}"
267
+ )
268
+ return {
269
+ "chunk_id": chunk.id,
270
+ "chunk_content": chunk.content[:200] if chunk.content else "",
271
+ "extracted_relations_count": len(extracted_relations),
272
+ "accuracy": 0.0,
273
+ "completeness": 0.0,
274
+ "precision": 0.0,
275
+ "overall_score": 0.0,
276
+ "accuracy_reasoning": f"Evaluation failed: {str(e)}",
277
+ "completeness_reasoning": "",
278
+ "precision_reasoning": "",
279
+ "issues": [f"Evaluation error: {str(e)}"],
280
+ }
281
+
282
+ @staticmethod
283
+ def _aggregate_evaluation_results(
284
+ entity_evaluations: List[Dict], relation_evaluations: List[Dict]
285
+ ) -> Dict[str, Any]:
286
+ """Aggregate evaluation results from all chunks."""
287
+
288
+ def calculate_stats(scores: List[float]) -> Dict[str, float]:
289
+ if not scores:
290
+ return {"mean": 0.0, "median": 0.0, "min": 0.0, "max": 0.0, "std": 0.0}
291
+ sorted_scores = sorted(scores)
292
+ n = len(scores)
293
+ mean = sum(scores) / n
294
+ median = (
295
+ sorted_scores[n // 2]
296
+ if n % 2 == 1
297
+ else (sorted_scores[n // 2 - 1] + sorted_scores[n // 2]) / 2
298
+ )
299
+ variance = sum((x - mean) ** 2 for x in scores) / n
300
+ std = variance**0.5
301
+
302
+ return {
303
+ "mean": mean,
304
+ "median": median,
305
+ "min": min(scores),
306
+ "max": max(scores),
307
+ "std": std,
308
+ }
309
+
310
+ # Extract scores
311
+ entity_overall_scores = [
312
+ e.get("overall_score", 0.0) for e in entity_evaluations
313
+ ]
314
+ entity_accuracy_scores = [e.get("accuracy", 0.0) for e in entity_evaluations]
315
+ entity_completeness_scores = [
316
+ e.get("completeness", 0.0) for e in entity_evaluations
317
+ ]
318
+ entity_precision_scores = [e.get("precision", 0.0) for e in entity_evaluations]
319
+
320
+ relation_overall_scores = [
321
+ r.get("overall_score", 0.0) for r in relation_evaluations
322
+ ]
323
+ relation_accuracy_scores = [
324
+ r.get("accuracy", 0.0) for r in relation_evaluations
325
+ ]
326
+ relation_completeness_scores = [
327
+ r.get("completeness", 0.0) for r in relation_evaluations
328
+ ]
329
+ relation_precision_scores = [
330
+ r.get("precision", 0.0) for r in relation_evaluations
331
+ ]
332
+
333
+ return {
334
+ "entity_accuracy": {
335
+ "overall_score": calculate_stats(entity_overall_scores),
336
+ "accuracy": calculate_stats(entity_accuracy_scores),
337
+ "completeness": calculate_stats(entity_completeness_scores),
338
+ "precision": calculate_stats(entity_precision_scores),
339
+ "total_chunks": len(entity_evaluations),
340
+ "detailed_results": entity_evaluations,
341
+ },
342
+ "relation_accuracy": {
343
+ "overall_score": calculate_stats(relation_overall_scores),
344
+ "accuracy": calculate_stats(relation_accuracy_scores),
345
+ "completeness": calculate_stats(relation_completeness_scores),
346
+ "precision": calculate_stats(relation_precision_scores),
347
+ "total_chunks": len(relation_evaluations),
348
+ "detailed_results": relation_evaluations,
349
+ },
350
+ }
graphgen/models/evaluator/kg/consistency_evaluator.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import re
4
+ from typing import Any, Dict, List
5
+
6
+ from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper
7
+ from graphgen.bases.datatypes import Chunk
8
+ from graphgen.templates.evaluation.kg.consistency_evaluation import (
9
+ ENTITY_DESCRIPTION_CONFLICT_PROMPT,
10
+ ENTITY_EXTRACTION_PROMPT,
11
+ ENTITY_TYPE_CONFLICT_PROMPT,
12
+ RELATION_CONFLICT_PROMPT,
13
+ )
14
+ from graphgen.utils import logger
15
+
16
+
17
+ class ConsistencyEvaluator:
18
+ """Evaluates consistency by detecting semantic conflicts using LLM-as-a-Judge.
19
+
20
+ For entities with multiple source chunks, compares entity_type and description
21
+ extracted from different chunks to detect semantic conflicts.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ graph_storage: BaseGraphStorage,
27
+ chunk_storage: BaseKVStorage,
28
+ llm_client: BaseLLMWrapper,
29
+ ):
30
+ self.graph_storage = graph_storage
31
+ self.chunk_storage = chunk_storage
32
+ self.llm_client = llm_client
33
+
34
+ def evaluate(self) -> Dict[str, Any]:
35
+ """Evaluate consistency by detecting semantic conflicts."""
36
+ all_nodes = self.graph_storage.get_all_nodes() or []
37
+ if not all_nodes:
38
+ return {"error": "Empty graph"}
39
+
40
+ return self._evaluate_consistency(all_nodes)
41
+
42
+ def _evaluate_consistency(self, all_nodes: List) -> Dict[str, Any]:
43
+ """Evaluate consistency by detecting semantic conflicts."""
44
+ # Filter entities with multiple source chunks
45
+ entities_with_multiple_sources = []
46
+ for node_id, node_data in all_nodes:
47
+ if not isinstance(node_data, dict):
48
+ continue
49
+ source_ids = node_data.get("source_id", "").split("<SEP>")
50
+ source_ids = [sid.strip() for sid in source_ids if sid.strip()]
51
+ if len(source_ids) > 1: # Only check entities from multiple chunks
52
+ entities_with_multiple_sources.append((node_id, node_data, source_ids))
53
+
54
+ if not entities_with_multiple_sources:
55
+ logger.info(
56
+ "No entities with multiple sources found, skipping consistency check"
57
+ )
58
+ return {
59
+ "conflict_rate": 0.0,
60
+ "conflict_entities_count": 0,
61
+ "total_entities": len(all_nodes),
62
+ "conflicts": [],
63
+ }
64
+
65
+ logger.info(
66
+ f"Checking consistency for {len(entities_with_multiple_sources)} entities with multiple sources"
67
+ )
68
+
69
+ # Evaluate entities sequentially
70
+ conflicts = []
71
+ conflict_entities = set()
72
+
73
+ for entity_info in entities_with_multiple_sources:
74
+ try:
75
+ entity_id, entity_conflicts = self._evaluate_entity_consistency(entity_info)
76
+ if entity_conflicts:
77
+ conflicts.extend(entity_conflicts)
78
+ conflict_entities.add(entity_id)
79
+ except Exception as e:
80
+ logger.error(
81
+ f"Failed to evaluate entity {entity_info[0]}: {e}"
82
+ )
83
+ continue
84
+
85
+ total_entities = len(all_nodes)
86
+ conflict_rate = (
87
+ len(conflict_entities) / total_entities if total_entities > 0 else 0
88
+ )
89
+
90
+ return {
91
+ "conflict_rate": conflict_rate,
92
+ "conflict_entities_count": len(conflict_entities),
93
+ "total_entities": total_entities,
94
+ "entities_checked": len(entities_with_multiple_sources),
95
+ "conflicts": conflicts[:100], # Limit to first 100 conflicts
96
+ }
97
+
98
+ def _clean_entity_id(self, entity_id: str) -> str:
99
+ """Clean entity ID by removing surrounding quotes."""
100
+ clean_id = entity_id.strip()
101
+ if (clean_id.startswith('"') and clean_id.endswith('"')) or (
102
+ clean_id.startswith("'") and clean_id.endswith("'")
103
+ ):
104
+ clean_id = clean_id[1:-1].strip()
105
+ return clean_id
106
+
107
+ def _evaluate_entity_consistency(
108
+ self, entity_info: tuple
109
+ ) -> tuple[str, List[Dict]]:
110
+ """Evaluate consistency for a single entity."""
111
+ entity_id, _node_data, source_ids = entity_info
112
+ # Clean entity_id for display
113
+ clean_entity_id = self._clean_entity_id(entity_id)
114
+ conflicts = []
115
+
116
+ # Get chunks for this entity
117
+ chunks = self._get_entity_chunks(source_ids)
118
+ if len(chunks) < 2:
119
+ return entity_id, []
120
+
121
+ # Extract entity attributes from each chunk
122
+ entity_extractions = {}
123
+ for chunk in chunks:
124
+ extraction = self._extract_entity_from_chunk(entity_id, chunk)
125
+ if extraction:
126
+ entity_extractions[chunk.id] = extraction
127
+
128
+ if len(entity_extractions) < 2:
129
+ return entity_id, []
130
+
131
+ # Check entity type consistency
132
+ type_extractions = {
133
+ chunk_id: ext.get("entity_type", "")
134
+ for chunk_id, ext in entity_extractions.items()
135
+ }
136
+ type_conflict = self._check_entity_type_consistency(
137
+ entity_id, type_extractions
138
+ )
139
+ if type_conflict and type_conflict.get("has_conflict", False):
140
+ conflicts.append(
141
+ {
142
+ "entity_id": clean_entity_id,
143
+ "conflict_type": "entity_type",
144
+ "conflict_severity": type_conflict.get("conflict_severity", 0.0),
145
+ "conflict_reasoning": type_conflict.get("conflict_reasoning", ""),
146
+ "conflicting_values": type_conflict.get("conflicting_types", []),
147
+ "recommended_value": type_conflict.get("recommended_type", ""),
148
+ }
149
+ )
150
+
151
+ # Check entity description consistency
152
+ descriptions = {
153
+ chunk_id: ext.get("description", "")
154
+ for chunk_id, ext in entity_extractions.items()
155
+ }
156
+ desc_conflict = self._check_entity_description_consistency(
157
+ entity_id, descriptions
158
+ )
159
+ if desc_conflict and desc_conflict.get("has_conflict", False):
160
+ conflicts.append(
161
+ {
162
+ "entity_id": clean_entity_id,
163
+ "conflict_type": "description",
164
+ "conflict_severity": desc_conflict.get("conflict_severity", 0.0),
165
+ "conflict_reasoning": desc_conflict.get("conflict_reasoning", ""),
166
+ "conflicting_values": desc_conflict.get(
167
+ "conflicting_descriptions", []
168
+ ),
169
+ "conflict_details": desc_conflict.get("conflict_details", ""),
170
+ }
171
+ )
172
+
173
+ return entity_id, conflicts
174
+
175
+ def _get_entity_chunks(self, source_ids: List[str]) -> List[Chunk]:
176
+ """Get all chunks related to an entity."""
177
+ chunks = []
178
+ for chunk_id in source_ids:
179
+ chunk_data = self.chunk_storage.get_by_id(chunk_id)
180
+ if chunk_data:
181
+ try:
182
+ chunk = Chunk.from_dict(chunk_id, chunk_data)
183
+ chunks.append(chunk)
184
+ except Exception as e:
185
+ logger.warning(f"Failed to load chunk {chunk_id}: {e}")
186
+ continue
187
+ return chunks
188
+
189
+ def _extract_entity_from_chunk(
190
+ self, entity_id: str, chunk: Chunk
191
+ ) -> Dict[str, str]:
192
+ """Extract entity attributes from a chunk using LLM."""
193
+ try:
194
+ # Clean entity_id: remove surrounding quotes if present
195
+ clean_entity_id = self._clean_entity_id(entity_id)
196
+
197
+ prompt = ENTITY_EXTRACTION_PROMPT.format(
198
+ entity_name=clean_entity_id,
199
+ chunk_content=chunk.content[:2000]
200
+ if chunk.content
201
+ else "", # Limit content length
202
+ )
203
+
204
+ response = asyncio.run(self.llm_client.generate_answer(prompt))
205
+
206
+ # Try to parse JSON response
207
+ try:
208
+ extraction = json.loads(response)
209
+ except json.JSONDecodeError:
210
+ # Try to extract JSON from markdown code blocks
211
+ json_match = re.search(r"\{.*\}", response, re.DOTALL)
212
+ if json_match:
213
+ extraction = json.loads(json_match.group(0))
214
+ else:
215
+ logger.warning(
216
+ f"Failed to parse extraction response for {entity_id} in chunk {chunk.id}"
217
+ )
218
+ return {}
219
+
220
+ # Normalize entity_type to lowercase and validate
221
+ entity_type = extraction.get("entity_type", "").lower().strip()
222
+ # Valid preset types
223
+ valid_types = {
224
+ "concept",
225
+ "date",
226
+ "location",
227
+ "keyword",
228
+ "organization",
229
+ "person",
230
+ "event",
231
+ "work",
232
+ "nature",
233
+ "artificial",
234
+ "science",
235
+ "technology",
236
+ "mission",
237
+ "gene",
238
+ }
239
+ # If entity_type is not in valid types, default to "concept"
240
+ if entity_type not in valid_types:
241
+ if entity_type: # If LLM provided a type but it's invalid
242
+ logger.warning(
243
+ f"Invalid entity_type '{entity_type}' for entity {clean_entity_id} in chunk {chunk.id}, "
244
+ f"defaulting to 'concept'"
245
+ )
246
+ entity_type = "concept"
247
+
248
+ return {
249
+ "entity_type": entity_type,
250
+ "description": extraction.get("description", ""),
251
+ }
252
+ except Exception as e:
253
+ logger.error(
254
+ f"Error extracting entity {entity_id} from chunk {chunk.id}: {e}"
255
+ )
256
+ return {}
257
+
258
+ def _check_entity_type_consistency(
259
+ self, entity_id: str, type_extractions: Dict[str, str]
260
+ ) -> Dict[str, Any]:
261
+ """Check entity type consistency using LLM."""
262
+ if len(set(type_extractions.values())) <= 1:
263
+ # All types are the same, no conflict
264
+ return {"has_conflict": False}
265
+
266
+ try:
267
+ type_list = [
268
+ f"Chunk {chunk_id}: {entity_type}"
269
+ for chunk_id, entity_type in type_extractions.items()
270
+ if entity_type
271
+ ]
272
+
273
+ prompt = ENTITY_TYPE_CONFLICT_PROMPT.format(
274
+ entity_name=entity_id, type_extractions="\n".join(type_list)
275
+ )
276
+
277
+ response = asyncio.run(self.llm_client.generate_answer(prompt))
278
+
279
+ # Parse JSON response
280
+ try:
281
+ result = json.loads(response)
282
+ except json.JSONDecodeError:
283
+ json_match = re.search(r"\{.*\}", response, re.DOTALL)
284
+ if json_match:
285
+ result = json.loads(json_match.group(0))
286
+ else:
287
+ logger.warning(
288
+ f"Failed to parse conflict detection response for {entity_id}"
289
+ )
290
+ return {"has_conflict": False}
291
+
292
+ return result
293
+ except Exception as e:
294
+ logger.error(f"Error checking type consistency for {entity_id}: {e}")
295
+ return {"has_conflict": False}
296
+
297
+ def _check_entity_description_consistency(
298
+ self, entity_id: str, descriptions: Dict[str, str]
299
+ ) -> Dict[str, Any]:
300
+ """Check entity description consistency using LLM."""
301
+ # Filter out empty descriptions
302
+ valid_descriptions = {k: v for k, v in descriptions.items() if v}
303
+ if len(valid_descriptions) < 2:
304
+ return {"has_conflict": False}
305
+
306
+ if len(set(valid_descriptions.values())) <= 1:
307
+ # All descriptions are the same, no conflict
308
+ return {"has_conflict": False}
309
+
310
+ try:
311
+ desc_list = [
312
+ f"Chunk {chunk_id}: {description}"
313
+ for chunk_id, description in valid_descriptions.items()
314
+ ]
315
+
316
+ prompt = ENTITY_DESCRIPTION_CONFLICT_PROMPT.format(
317
+ entity_name=entity_id, descriptions="\n".join(desc_list)
318
+ )
319
+
320
+ response = asyncio.run(self.llm_client.generate_answer(prompt))
321
+
322
+ # Parse JSON response
323
+ try:
324
+ result = json.loads(response)
325
+ except json.JSONDecodeError:
326
+ json_match = re.search(r"\{.*\}", response, re.DOTALL)
327
+ if json_match:
328
+ result = json.loads(json_match.group(0))
329
+ else:
330
+ logger.warning(
331
+ f"Failed to parse conflict detection response for {entity_id}"
332
+ )
333
+ return {"has_conflict": False}
334
+
335
+ return result
336
+ except Exception as e:
337
+ logger.error(f"Error checking description consistency for {entity_id}: {e}")
338
+ return {"has_conflict": False}
339
+
340
+ def _check_relation_consistency(
341
+ self, src_id: str, dst_id: str, relation_extractions: Dict[str, str]
342
+ ) -> Dict[str, Any]:
343
+ """Check relation consistency using LLM."""
344
+ if len(set(relation_extractions.values())) <= 1:
345
+ return {"has_conflict": False}
346
+
347
+ try:
348
+ rel_list = [
349
+ f"Chunk {chunk_id}: {relation}"
350
+ for chunk_id, relation in relation_extractions.items()
351
+ if relation
352
+ ]
353
+
354
+ prompt = RELATION_CONFLICT_PROMPT.format(
355
+ source_entity=src_id,
356
+ target_entity=dst_id,
357
+ relation_descriptions="\n".join(rel_list),
358
+ )
359
+
360
+ response = asyncio.run(self.llm_client.generate_answer(prompt))
361
+
362
+ # Parse JSON response
363
+ try:
364
+ result = json.loads(response)
365
+ except json.JSONDecodeError:
366
+ json_match = re.search(r"\{.*\}", response, re.DOTALL)
367
+ if json_match:
368
+ result = json.loads(json_match.group(0))
369
+ else:
370
+ logger.warning(
371
+ f"Failed to parse relation conflict response for {src_id}->{dst_id}"
372
+ )
373
+ return {"has_conflict": False}
374
+
375
+ return result
376
+ except Exception as e:
377
+ logger.error(
378
+ f"Error checking relation consistency for {src_id}->{dst_id}: {e}"
379
+ )
380
+ return {"has_conflict": False}
graphgen/models/evaluator/kg/structure_evaluator.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional
2
+
3
+ import numpy as np
4
+ from scipy import stats
5
+
6
+ from graphgen.bases import BaseGraphStorage
7
+ from graphgen.utils import logger
8
+
9
+
10
+ class StructureEvaluator:
11
+ """Evaluates structural robustness of the graph."""
12
+
13
+ def __init__(
14
+ self,
15
+ graph_storage: BaseGraphStorage,
16
+ noise_ratio_threshold: float = 0.15,
17
+ largest_cc_ratio_threshold: float = 0.90,
18
+ avg_degree_min: float = 2.0,
19
+ avg_degree_max: float = 5.0,
20
+ powerlaw_r2_threshold: float = 0.75,
21
+ ):
22
+ self.graph_storage = graph_storage
23
+ self.noise_ratio_threshold = noise_ratio_threshold
24
+ self.largest_cc_ratio_threshold = largest_cc_ratio_threshold
25
+ self.avg_degree_min = avg_degree_min
26
+ self.avg_degree_max = avg_degree_max
27
+ self.powerlaw_r2_threshold = powerlaw_r2_threshold
28
+
29
+ def evaluate(self) -> Dict[str, Any]:
30
+ """
31
+ Evaluate the structural robustness of the graph.
32
+ :return:
33
+ """
34
+ storage = self.graph_storage
35
+
36
+ total_nodes = storage.get_node_count()
37
+ if total_nodes == 0:
38
+ return {"error": "Empty graph"}
39
+
40
+ total_edges = storage.get_edge_count()
41
+ degree_map = storage.get_all_node_degrees()
42
+
43
+ # Noise ratio: isolated nodes / total nodes
44
+ isolated_nodes = [nid for nid, deg in degree_map.items() if deg == 0]
45
+ noise_ratio = len(isolated_nodes) / total_nodes
46
+
47
+ # Largest connected component
48
+ components = storage.get_connected_components(undirected=True)
49
+ largest_cc_ratio = (
50
+ len(max(components, key=len)) / total_nodes if components else 0
51
+ )
52
+
53
+ avg_degree = sum(degree_map.values()) / total_nodes
54
+ powerlaw_r2 = self._calculate_powerlaw_r2(degree_map)
55
+
56
+ results = {
57
+ "total_nodes": total_nodes,
58
+ "total_edges": total_edges,
59
+ "noise_ratio": noise_ratio,
60
+ "largest_cc_ratio": largest_cc_ratio,
61
+ "avg_degree": avg_degree,
62
+ "powerlaw_r2": powerlaw_r2,
63
+ "is_robust": (
64
+ noise_ratio < self.noise_ratio_threshold
65
+ and largest_cc_ratio > self.largest_cc_ratio_threshold
66
+ and self.avg_degree_min <= avg_degree <= self.avg_degree_max
67
+ and (
68
+ powerlaw_r2 is not None and powerlaw_r2 > self.powerlaw_r2_threshold
69
+ )
70
+ ),
71
+ }
72
+
73
+ return results
74
+
75
+ @staticmethod
76
+ def _calculate_powerlaw_r2(degree_map: Dict[str, int]) -> Optional[float]:
77
+ degrees = [deg for deg in degree_map.values() if deg > 0]
78
+
79
+ if len(degrees) < 10:
80
+ logger.warning("Insufficient nodes for power law fitting")
81
+ return None
82
+
83
+ try:
84
+ # Fit power law: log(y) = a * log(x) + b
85
+ log_degrees = np.log(degrees)
86
+ sorted_log_degrees = np.sort(log_degrees)
87
+ x = np.arange(1, len(sorted_log_degrees) + 1)
88
+ log_x = np.log(x)
89
+
90
+ # Linear regression on log-log scale
91
+ r_value, *_ = stats.linregress(log_x, sorted_log_degrees)
92
+ r2 = r_value**2
93
+
94
+ return float(r2)
95
+ except Exception as e:
96
+ logger.error(f"Power law R² calculation failed: {e}")
97
+ return None
graphgen/models/evaluator/length_evaluator.py DELETED
@@ -1,19 +0,0 @@
1
- from graphgen.bases.datatypes import QAPair
2
- from graphgen.models.evaluator.base_evaluator import BaseEvaluator
3
- from graphgen.models.tokenizer import Tokenizer
4
- from graphgen.utils import create_event_loop
5
-
6
-
7
- class LengthEvaluator(BaseEvaluator):
8
- def __init__(self, tokenizer_name: str = "cl100k_base", max_concurrent: int = 100):
9
- super().__init__(max_concurrent)
10
- self.tokenizer_name = tokenizer_name
11
- self.tokenizer = Tokenizer(model_name=self.tokenizer_name)
12
-
13
- async def evaluate_single(self, pair: QAPair) -> float:
14
- loop = create_event_loop()
15
- return await loop.run_in_executor(None, self._calculate_length, pair.answer)
16
-
17
- def _calculate_length(self, text: str) -> float:
18
- tokens = self.tokenizer.encode(text)
19
- return len(tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphgen/models/evaluator/qa/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .length_evaluator import LengthEvaluator
2
+ from .mtld_evaluator import MTLDEvaluator
3
+ from .reward_evaluator import RewardEvaluator
4
+ from .uni_evaluator import UniEvaluator
graphgen/models/evaluator/qa/length_evaluator.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ from graphgen.bases import BaseEvaluator, QAPair
4
+ from graphgen.models.tokenizer import Tokenizer
5
+
6
+
7
+ class LengthEvaluator(BaseEvaluator):
8
+ def __init__(self, tokenizer_name: str = None):
9
+ tokenizer_model = tokenizer_name or os.environ.get("TOKENIZER_MODEL", "cl100k_base")
10
+ self.tokenizer: Tokenizer = Tokenizer(tokenizer_model)
11
+
12
+ def evaluate(self, pair: QAPair) -> float:
13
+ """
14
+ Evaluate the length of the qa pair.
15
+ """
16
+ content = pair.question + pair.answer
17
+ tokens = self.tokenizer.encode(content)
18
+ return len(tokens)
graphgen/models/evaluator/{mtld_evaluator.py → qa/mtld_evaluator.py} RENAMED
@@ -1,38 +1,33 @@
1
  from typing import Set
2
 
3
- from graphgen.bases.datatypes import QAPair
4
- from graphgen.models.evaluator.base_evaluator import BaseEvaluator
5
- from graphgen.utils import NLTKHelper, create_event_loop, detect_main_language
6
-
7
- nltk_helper = NLTKHelper()
8
 
9
 
10
  class MTLDEvaluator(BaseEvaluator):
11
  """
12
- 衡量文本词汇多样性的指标
13
  """
14
 
15
- def __init__(self, max_concurrent: int = 100):
16
- super().__init__(max_concurrent)
17
- self.stopwords_en: Set[str] = set(nltk_helper.get_stopwords("english"))
18
- self.stopwords_zh: Set[str] = set(nltk_helper.get_stopwords("chinese"))
19
-
20
- async def evaluate_single(self, pair: QAPair) -> float:
21
- loop = create_event_loop()
22
- return await loop.run_in_executor(None, self._calculate_mtld_score, pair.answer)
23
 
24
- def _calculate_mtld_score(self, text: str, threshold=0.72) -> float:
25
  """
26
- 计算MTLD (向前和向后的平均值)
27
 
28
  min is 1.0
29
  higher is better
30
  """
 
31
  if not text or not text.strip():
32
  return 0.0
33
 
34
  lang = detect_main_language(text)
35
- tokens = nltk_helper.word_tokenize(text, lang)
36
 
37
  stopwords = self.stopwords_zh if lang == "zh" else self.stopwords_en
38
  filtered_tokens = [word for word in tokens if word not in stopwords]
@@ -41,13 +36,13 @@ class MTLDEvaluator(BaseEvaluator):
41
  if not filtered_tokens:
42
  return 0
43
 
44
- # 计算向前的MTLD
45
- forward_factors = self._compute_factors(filtered_tokens, threshold)
46
 
47
- # 计算向后的MTLD
48
- backward_factors = self._compute_factors(filtered_tokens[::-1], threshold)
49
 
50
- # 取平均值
51
  return (forward_factors + backward_factors) / 2
52
 
53
  @staticmethod
@@ -66,7 +61,7 @@ class MTLDEvaluator(BaseEvaluator):
66
  current_segment = []
67
  unique_words = set()
68
 
69
- # 处理最后一个不完整片段
70
  if current_segment:
71
  ttr = len(unique_words) / len(current_segment)
72
  if ttr <= threshold:
 
1
  from typing import Set
2
 
3
+ from graphgen.bases import BaseEvaluator, QAPair
4
+ from graphgen.utils import NLTKHelper, detect_main_language
 
 
 
5
 
6
 
7
  class MTLDEvaluator(BaseEvaluator):
8
  """
9
+ Metrics for measuring the lexical diversity of text.
10
  """
11
 
12
+ def __init__(self, threshold: float = 0.72):
13
+ self.nltk_helper = NLTKHelper()
14
+ self.stopwords_en: Set[str] = set(self.nltk_helper.get_stopwords("en"))
15
+ self.stopwords_zh: Set[str] = set(self.nltk_helper.get_stopwords("zh"))
16
+ self.threshold = threshold
 
 
 
17
 
18
+ def evaluate(self, pair: QAPair) -> float:
19
  """
20
+ Calculate the MTLD (Mean Token Length Diversity) score for a given text.
21
 
22
  min is 1.0
23
  higher is better
24
  """
25
+ text = pair.answer
26
  if not text or not text.strip():
27
  return 0.0
28
 
29
  lang = detect_main_language(text)
30
+ tokens = self.nltk_helper.word_tokenize(text, lang)
31
 
32
  stopwords = self.stopwords_zh if lang == "zh" else self.stopwords_en
33
  filtered_tokens = [word for word in tokens if word not in stopwords]
 
36
  if not filtered_tokens:
37
  return 0
38
 
39
+ # Compute forward factors
40
+ forward_factors = self._compute_factors(filtered_tokens, self.threshold)
41
 
42
+ # Compute backward factors
43
+ backward_factors = self._compute_factors(filtered_tokens[::-1], self.threshold)
44
 
45
+ # Compute average factors
46
  return (forward_factors + backward_factors) / 2
47
 
48
  @staticmethod
 
61
  current_segment = []
62
  unique_words = set()
63
 
64
+ # handle last segment
65
  if current_segment:
66
  ttr = len(unique_words) / len(current_segment)
67
  if ttr <= threshold:
graphgen/models/evaluator/qa/reward_evaluator.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ from graphgen.bases import BaseEvaluator, QAPair
3
+
4
+
5
+ class RewardEvaluator(BaseEvaluator):
6
+ """
7
+ Reward Model Evaluator for single QAPair evaluation.
8
+ """
9
+
10
+ def __init__(
11
+ self,
12
+ reward_name: str = "OpenAssistant/reward-model-deberta-v3-large-v2",
13
+ max_length: int = 2560,
14
+ device: Optional[str] = None,
15
+ ):
16
+ """
17
+ Initialize the reward evaluator.
18
+
19
+ Args:
20
+ reward_name: Model name or path on HuggingFace Hub
21
+ max_length: Maximum token length for the model
22
+ device: Device to run the model on. If None, auto-detect CUDA/CPU.
23
+ """
24
+ self.reward_name = reward_name
25
+ self.max_length = max_length
26
+
27
+ import torch
28
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
29
+ self.torch = torch
30
+
31
+ # Set device (auto-detect if not specified)
32
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
33
+
34
+ try:
35
+ self.tokenizer = AutoTokenizer.from_pretrained(reward_name)
36
+ self.model = AutoModelForSequenceClassification.from_pretrained(reward_name)
37
+ self.model.to(self.device)
38
+ self.model.eval()
39
+ except Exception as e:
40
+ raise RuntimeError(f"Failed to load reward model '{reward_name}': {e}") from e
41
+
42
+ def evaluate(self, pair: QAPair) -> float:
43
+ """
44
+ Evaluate a single question-answer pair using the reward model.
45
+
46
+ Args:
47
+ pair: QAPair containing question and answer strings
48
+
49
+ Returns:
50
+ Score as a float
51
+ """
52
+ # Tokenize
53
+ inputs = self.tokenizer(
54
+ pair.question,
55
+ pair.answer,
56
+ return_tensors="pt",
57
+ max_length=self.max_length,
58
+ truncation=True,
59
+ )
60
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
61
+
62
+ # Get score
63
+ with self.torch.no_grad():
64
+ score = self.model(**inputs).logits[0].item()
65
+
66
+ return score
graphgen/models/evaluator/qa/uni_evaluator.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/maszhongming/UniEval/tree/main
2
+ from typing import Optional, List
3
+ from graphgen.bases import BaseEvaluator, QAPair
4
+
5
+
6
+ class UniEvaluator(BaseEvaluator):
7
+ """
8
+ UniEvaluator for single QAPair evaluation across quality dimensions.
9
+
10
+ Dimensions: naturalness, coherence, understandability
11
+
12
+ Usage:
13
+ evaluator = UniEvaluator()
14
+ pair = QAPair(question="...", answer="...")
15
+ scores = evaluator.evaluate(pair)
16
+ # {"naturalness": 0.85, "coherence": 0.92, "understandability": 0.88}
17
+ """
18
+
19
+ DEFAULT_MODEL: str = "MingZhong/unieval-sum"
20
+ DEFAULT_DIMS: List[str] = ["naturalness", "coherence", "understandability"]
21
+ DEFAULT_MAX_LENGTH: int = 2560
22
+
23
+ def __init__(
24
+ self,
25
+ model_name: Optional[str] = None,
26
+ max_length: Optional[int] = None,
27
+ device: Optional[str] = None,
28
+ ):
29
+ """
30
+ Args:
31
+ model_name: HuggingFace model name/path
32
+ max_length: Tokenizer max sequence length
33
+ device: 'cuda', 'cpu', or None for auto-detect
34
+ """
35
+ import torch
36
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
37
+ self.torch = torch
38
+
39
+ self.model_name = model_name or self.DEFAULT_MODEL
40
+ self.max_length = max_length or self.DEFAULT_MAX_LENGTH
41
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
42
+
43
+ # Load model & tokenizer
44
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
45
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
46
+ self.model.to(self.device)
47
+ self.model.eval()
48
+
49
+ # Pre-compute Yes/No token IDs
50
+ self._yes_id = self.tokenizer("Yes")["input_ids"][0]
51
+ self._no_id = self.tokenizer("No")["input_ids"][0]
52
+
53
+ @staticmethod
54
+ def _build_input_text(dimension: str, question: str, answer: str) -> str:
55
+ """Construct input text for specified dimension."""
56
+ if dimension == "naturalness":
57
+ return f"question: Is this a natural response? </s> response: {answer}"
58
+ if dimension == "coherence":
59
+ return f"question: Is this a coherent response? </s> response: {answer} </s> history: {question}"
60
+ if dimension == "understandability":
61
+ return f"question: Is this an understandable response? </s> response: {answer}"
62
+ raise NotImplementedError(f"Unsupported dimension '{dimension}'")
63
+
64
+ def evaluate(
65
+ self,
66
+ pair: QAPair,
67
+ dimensions: Optional[List[str]] = None,
68
+ ) -> dict[str, float]:
69
+ """Evaluate a single QAPair across specified dimensions."""
70
+ dimensions = dimensions or self.DEFAULT_DIMS
71
+
72
+ # Validate dimensions
73
+ invalid = set(dimensions) - set(self.DEFAULT_DIMS)
74
+ if invalid:
75
+ raise ValueError(f"Invalid dimensions: {invalid}. Available: {self.DEFAULT_DIMS}")
76
+
77
+ results = {}
78
+ no_token = self.torch.tensor([[self._no_id]], device=self.device)
79
+
80
+ for dim in dimensions:
81
+ # Tokenize input
82
+ src = self.tokenizer(
83
+ self._build_input_text(dim, pair.question, pair.answer),
84
+ max_length=self.max_length,
85
+ truncation=True,
86
+ return_tensors="pt",
87
+ )
88
+ src_tokens = src["input_ids"].to(self.device)
89
+ src_mask = src["attention_mask"].to(self.device)
90
+
91
+ # Score
92
+ with self.torch.no_grad():
93
+ logits = self.model(
94
+ input_ids=src_tokens,
95
+ attention_mask=src_mask,
96
+ labels=no_token,
97
+ use_cache=False,
98
+ ).logits[:, 0, :] # [1, vocab_size]
99
+
100
+ probs = self.torch.softmax(logits, dim=-1)[0]
101
+ score = probs[self._yes_id] / (probs[self._yes_id] + probs[self._no_id])
102
+
103
+ results[dim] = score.item()
104
+
105
+ return results
graphgen/models/evaluator/reward_evaluator.py DELETED
@@ -1,107 +0,0 @@
1
- from dataclasses import dataclass
2
-
3
- from tqdm import tqdm
4
-
5
- from graphgen.bases.datatypes import QAPair
6
-
7
-
8
- @dataclass
9
- class RewardEvaluator:
10
- """
11
- Reward Model Evaluator.
12
- OpenAssistant/reward-model-deberta-v3-large-v2: 分数范围为[-inf, inf],越高越好
13
- """
14
-
15
- reward_name: str = "OpenAssistant/reward-model-deberta-v3-large-v2"
16
- max_length: int = 2560
17
- results: list[float] = None
18
-
19
- def __post_init__(self):
20
- import torch
21
-
22
- self.num_gpus = torch.cuda.device_count()
23
-
24
- @staticmethod
25
- def process_chunk(rank, pairs, reward_name, max_length, return_dict):
26
- import torch
27
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
28
-
29
- device = f"cuda:{rank}"
30
- torch.cuda.set_device(rank)
31
-
32
- rank_model = AutoModelForSequenceClassification.from_pretrained(reward_name)
33
- tokenizer = AutoTokenizer.from_pretrained(reward_name)
34
- rank_model.to(device)
35
- rank_model.eval()
36
-
37
- results = []
38
- with torch.no_grad():
39
- for pair in tqdm(pairs):
40
- inputs = tokenizer(
41
- pair.question,
42
- pair.answer,
43
- return_tensors="pt",
44
- max_length=max_length,
45
- truncation=True,
46
- )
47
- inputs = {k: v.to(device) for k, v in inputs.items()}
48
- score = rank_model(**inputs).logits[0].item()
49
- results.append(score)
50
-
51
- return_dict[rank] = results
52
-
53
- def evaluate(self, pairs: list[QAPair]) -> list[float]:
54
- import torch.multiprocessing as mp
55
-
56
- chunk_size = len(pairs) // self.num_gpus
57
- chunks = []
58
- for i in range(self.num_gpus):
59
- start = i * chunk_size
60
- end = start + chunk_size
61
- if i == self.num_gpus - 1:
62
- end = len(pairs)
63
- chunks.append(pairs[start:end])
64
-
65
- # multi-process
66
- manager = mp.Manager()
67
- return_dict = manager.dict()
68
- processes = []
69
-
70
- for rank, chunk in enumerate(chunks):
71
- p = mp.Process(
72
- target=self.process_chunk,
73
- args=(rank, chunk, self.reward_name, self.max_length, return_dict),
74
- )
75
- p.start()
76
- processes.append(p)
77
-
78
- for p in processes:
79
- p.join()
80
-
81
- # 合并结果
82
- results = []
83
- for rank in range(len(chunks)):
84
- results.extend(return_dict[rank])
85
-
86
- for p in processes:
87
- if p.is_alive():
88
- p.terminate()
89
- p.join()
90
-
91
- return results
92
-
93
- def get_average_score(self, pairs: list[QAPair]) -> float:
94
- """
95
- Get the average score of a batch of texts.
96
- """
97
- results = self.evaluate(pairs)
98
- self.results = results
99
- return sum(self.results) / len(pairs)
100
-
101
- def get_min_max_score(self, pairs: list[QAPair]) -> tuple[float, float]:
102
- """
103
- Get the min and max score of a batch of texts.
104
- """
105
- if self.results is None:
106
- self.get_average_score(pairs)
107
- return min(self.results), max(self.results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphgen/models/evaluator/uni_evaluator.py DELETED
@@ -1,183 +0,0 @@
1
- # https://github.com/maszhongming/UniEval/tree/main
2
-
3
- from dataclasses import dataclass, field
4
-
5
- from tqdm import tqdm
6
-
7
- from graphgen.bases.datatypes import QAPair
8
-
9
-
10
- def _add_questions(dimension: str, question: str, answer: str):
11
- if dimension == "naturalness":
12
- cur_input = (
13
- "question: Is this a natural response in the dialogue? </s> response: "
14
- + answer
15
- )
16
- elif dimension == "coherence":
17
- cur_input = (
18
- "question: Is this a coherent response given the dialogue history? </s> response: "
19
- + answer
20
- + " </s> dialogue history: "
21
- + question
22
- )
23
- elif dimension == "understandability":
24
- cur_input = (
25
- "question: Is this an understandable response in the dialogue? </s> response: "
26
- + answer
27
- )
28
- else:
29
- raise NotImplementedError(
30
- "The input format for this dimension is still undefined. Please customize it first."
31
- )
32
- return cur_input
33
-
34
-
35
- @dataclass
36
- class UniEvaluator:
37
- model_name: str = "MingZhong/unieval-sum"
38
- dimensions: list = field(
39
- default_factory=lambda: ["naturalness", "coherence", "understandability"]
40
- )
41
- max_length: int = 2560
42
- results: dict = None
43
-
44
- def __post_init__(self):
45
- import torch
46
-
47
- self.num_gpus = torch.cuda.device_count()
48
- self.results = {}
49
-
50
- @staticmethod
51
- def process_chunk(rank, pairs, model_name, max_length, dimension, return_dict):
52
- import torch
53
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
54
-
55
- device = f"cuda:{rank}"
56
- torch.cuda.set_device(rank)
57
-
58
- rank_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
59
- tokenizer = AutoTokenizer.from_pretrained(model_name)
60
- rank_model.to(device)
61
- rank_model.eval()
62
-
63
- softmax = torch.nn.Softmax(dim=1)
64
-
65
- pos_id = tokenizer("Yes")["input_ids"][0]
66
- neg_id = tokenizer("No")["input_ids"][0]
67
-
68
- results = []
69
- with torch.no_grad():
70
- for pair in tqdm(pairs):
71
- text = _add_questions(dimension, pair.question, pair.answer)
72
-
73
- tgt = "No"
74
-
75
- encoded_src = tokenizer(
76
- text,
77
- max_length=max_length,
78
- truncation=True,
79
- padding=True,
80
- return_tensors="pt",
81
- )
82
- encoded_tgt = tokenizer(
83
- tgt,
84
- max_length=max_length,
85
- truncation=True,
86
- padding=True,
87
- return_tensors="pt",
88
- )
89
-
90
- src_tokens = encoded_src["input_ids"].to(device)
91
- src_mask = encoded_src["attention_mask"].to(device)
92
-
93
- tgt_tokens = encoded_tgt["input_ids"].to(device)[:, 0].unsqueeze(-1)
94
-
95
- output = rank_model(
96
- input_ids=src_tokens,
97
- attention_mask=src_mask,
98
- labels=tgt_tokens,
99
- use_cache=False,
100
- )
101
-
102
- logits = output.logits.view(-1, rank_model.config.vocab_size)
103
-
104
- pos_score = softmax(logits)[:, pos_id] # Yes
105
- neg_score = softmax(logits)[:, neg_id]
106
- score = pos_score / (pos_score + neg_score)
107
-
108
- results.append(score.item())
109
-
110
- return_dict[rank] = results
111
-
112
- def evaluate(self, pairs: list[QAPair]) -> list[dict]:
113
- import torch.multiprocessing as mp
114
-
115
- final_results = []
116
- for dimension in self.dimensions:
117
- chunk_size = len(pairs) // self.num_gpus
118
- chunks = []
119
- for i in range(self.num_gpus):
120
- start = i * chunk_size
121
- end = start + chunk_size
122
- if i == self.num_gpus - 1:
123
- end = len(pairs)
124
- chunks.append(pairs[start:end])
125
-
126
- # multi-process
127
- manager = mp.Manager()
128
- return_dict = manager.dict()
129
- processes = []
130
-
131
- for rank, chunk in enumerate(chunks):
132
- p = mp.Process(
133
- target=self.process_chunk,
134
- args=(
135
- rank,
136
- chunk,
137
- self.model_name,
138
- self.max_length,
139
- dimension,
140
- return_dict,
141
- ),
142
- )
143
- p.start()
144
- processes.append(p)
145
-
146
- for p in processes:
147
- p.join()
148
-
149
- # 合并结果
150
- results = []
151
- for rank in range(len(chunks)):
152
- results.extend(return_dict[rank])
153
-
154
- for p in processes:
155
- if p.is_alive():
156
- p.terminate()
157
- p.join()
158
-
159
- final_results.append({dimension: results})
160
- return final_results
161
-
162
- def get_average_score(self, pairs: list[QAPair]) -> dict:
163
- """
164
- Get the average score of a batch of texts.
165
- """
166
- results = self.evaluate(pairs)
167
- final_results = {}
168
- for result in results:
169
- for key, value in result.items():
170
- final_results[key] = sum(value) / len(value)
171
- self.results[key] = value
172
- return final_results
173
-
174
- def get_min_max_score(self, pairs: list[QAPair]) -> dict:
175
- """
176
- Get the min and max score of a batch of texts.
177
- """
178
- if self.results is None:
179
- self.get_average_score(pairs)
180
- final_results = {}
181
- for key, value in self.results.items():
182
- final_results[key] = min(value), max(value)
183
- return final_results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphgen/models/storage/graph/kuzu_storage.py CHANGED
@@ -1,7 +1,8 @@
1
  import json
2
  import os
 
3
  from dataclasses import dataclass
4
- from typing import Any
5
 
6
  try:
7
  import kuzu
@@ -78,6 +79,94 @@ class KuzuStorage(BaseGraphStorage):
78
  print(f"Error decoding JSON: {e}")
79
  return {}
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  def has_node(self, node_id: str) -> bool:
82
  result = self._conn.execute(
83
  "MATCH (a:Entity {id: $id}) RETURN count(a)", {"id": node_id}
 
1
  import json
2
  import os
3
+ from collections import defaultdict
4
  from dataclasses import dataclass
5
+ from typing import Any, Dict, List, Set
6
 
7
  try:
8
  import kuzu
 
79
  print(f"Error decoding JSON: {e}")
80
  return {}
81
 
82
+ def is_directed(self) -> bool:
83
+ return True
84
+
85
+ def get_all_node_degrees(self) -> Dict[str, int]:
86
+ query = """
87
+ MATCH (n:Entity)
88
+ OPTIONAL MATCH (n)-[r]-()
89
+ RETURN n.id, count(r) as degree
90
+ """
91
+
92
+ result = self._conn.execute(query)
93
+ degree_map = {}
94
+ while result.has_next():
95
+ row = result.get_next()
96
+ if row and len(row) >= 2:
97
+ node_id, degree = row[0], row[1]
98
+ degree_map[node_id] = int(degree)
99
+
100
+ return degree_map
101
+
102
+ def get_isolated_nodes(self) -> List[str]:
103
+ query = """
104
+ MATCH (n:Entity)
105
+ WHERE NOT (n)--()
106
+ RETURN n.id
107
+ """
108
+
109
+ result = self._conn.execute(query)
110
+ return [row[0] for row in result if row]
111
+
112
+ def get_node_count(self) -> int:
113
+ result = self._conn.execute("MATCH (n:Entity) RETURN count(n)")
114
+ return result.get_next()[0]
115
+
116
+ def get_edge_count(self) -> int:
117
+ result = self._conn.execute("MATCH ()-[e:Relation]->() RETURN count(e)")
118
+ return result.get_next()[0]
119
+
120
+ def get_connected_components(self, undirected: bool = True) -> List[Set[str]]:
121
+ parent = {}
122
+ rank = {}
123
+
124
+ def find(x: str) -> str:
125
+ if parent[x] != x:
126
+ parent[x] = find(parent[x])
127
+ return parent[x]
128
+
129
+ def union(x: str, y: str):
130
+ root_x, root_y = find(x), find(y)
131
+ if root_x == root_y:
132
+ return
133
+ if rank[root_x] < rank[root_y]:
134
+ parent[root_x] = root_y
135
+ elif rank[root_x] > rank[root_y]:
136
+ parent[root_y] = root_x
137
+ else:
138
+ parent[root_y] = root_x
139
+ rank[root_x] += 1
140
+
141
+ all_nodes = self.get_all_node_degrees().keys()
142
+ for node_id in all_nodes:
143
+ parent[node_id] = node_id
144
+ rank[node_id] = 0
145
+
146
+ query = (
147
+ """
148
+ MATCH (a:Entity)-[e:Relation]-(b:Entity)
149
+ RETURN DISTINCT a.id, b.id
150
+ """
151
+ if undirected
152
+ else """
153
+ MATCH (a:Entity)-[e:Relation]->(b:Entity)
154
+ RETURN DISTINCT a.id, b.id
155
+ """
156
+ )
157
+
158
+ result = self._conn.execute(query)
159
+ for row in result:
160
+ if row and len(row) >= 2:
161
+ union(row[0], row[1])
162
+
163
+ components_dict = defaultdict(set)
164
+ for node_id in all_nodes:
165
+ root = find(node_id)
166
+ components_dict[root].add(node_id)
167
+
168
+ return list(components_dict.values())
169
+
170
  def has_node(self, node_id: str) -> bool:
171
  result = self._conn.execute(
172
  "MATCH (a:Entity {id: $id}) RETURN count(a)", {"id": node_id}
graphgen/models/storage/graph/networkx_storage.py CHANGED
@@ -1,7 +1,7 @@
1
  import html
2
  import os
3
  from dataclasses import dataclass
4
- from typing import Any, Optional, Union, cast
5
 
6
  import networkx as nx
7
 
@@ -10,6 +10,31 @@ from graphgen.bases.base_storage import BaseGraphStorage
10
 
11
  @dataclass
12
  class NetworkXStorage(BaseGraphStorage):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  @staticmethod
14
  def load_nx_graph(file_name) -> Optional[nx.Graph]:
15
  if os.path.exists(file_name):
 
1
  import html
2
  import os
3
  from dataclasses import dataclass
4
+ from typing import Any, Dict, List, Optional, Set, Union, cast
5
 
6
  import networkx as nx
7
 
 
10
 
11
  @dataclass
12
  class NetworkXStorage(BaseGraphStorage):
13
+ def is_directed(self) -> bool:
14
+ return self._graph.is_directed()
15
+
16
+ def get_all_node_degrees(self) -> Dict[str, int]:
17
+ return {
18
+ str(node_id): int(self._graph.degree[node_id])
19
+ for node_id in self._graph.nodes()
20
+ }
21
+
22
+ def get_node_count(self) -> int:
23
+ return self._graph.number_of_nodes()
24
+
25
+ def get_edge_count(self) -> int:
26
+ return self._graph.number_of_edges()
27
+
28
+ def get_connected_components(self, undirected: bool = True) -> List[Set[str]]:
29
+ graph = self._graph
30
+
31
+ if undirected and graph.is_directed():
32
+ graph = graph.to_undirected()
33
+
34
+ return [
35
+ set(str(node) for node in comp) for comp in nx.connected_components(graph)
36
+ ]
37
+
38
  @staticmethod
39
  def load_nx_graph(file_name) -> Optional[nx.Graph]:
40
  if os.path.exists(file_name):
graphgen/operators/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
  from .build_kg import BuildKGService
2
  from .chunk import ChunkService
 
3
  from .extract import ExtractService
4
  from .generate import GenerateService
5
  from .judge import JudgeService
@@ -8,6 +9,7 @@ from .quiz import QuizService
8
  from .read import read
9
  from .search import SearchService
10
 
 
11
  operators = {
12
  "read": read,
13
  "chunk": ChunkService,
@@ -18,4 +20,5 @@ operators = {
18
  "search": SearchService,
19
  "partition": PartitionService,
20
  "generate": GenerateService,
 
21
  }
 
1
  from .build_kg import BuildKGService
2
  from .chunk import ChunkService
3
+ from .evaluate import EvaluateService
4
  from .extract import ExtractService
5
  from .generate import GenerateService
6
  from .judge import JudgeService
 
9
  from .read import read
10
  from .search import SearchService
11
 
12
+
13
  operators = {
14
  "read": read,
15
  "chunk": ChunkService,
 
20
  "search": SearchService,
21
  "partition": PartitionService,
22
  "generate": GenerateService,
23
+ "evaluate": EvaluateService,
24
  }
graphgen/operators/evaluate/__init__.py CHANGED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .evaluate_service import EvaluateService
2
+
3
+ __all__ = ["EvaluateService"]
graphgen/operators/evaluate/evaluate.py DELETED
@@ -1,177 +0,0 @@
1
- # TODO: this module needs refactoring to merge into GraphGen framework
2
- """Evaluate the quality of the generated text using various metrics"""
3
-
4
- import argparse
5
- import json
6
- import os
7
-
8
- import pandas as pd
9
- from dotenv import load_dotenv
10
-
11
- from graphgen.bases.datatypes import QAPair
12
- from graphgen.models import (
13
- LengthEvaluator,
14
- MTLDEvaluator,
15
- RewardEvaluator,
16
- UniEvaluator,
17
- )
18
- from graphgen.utils import logger, set_logger
19
-
20
- sys_path = os.path.abspath(os.path.dirname(__file__))
21
- set_logger(os.path.join(sys_path, "cache", "logs", "evaluate.log"))
22
-
23
- load_dotenv()
24
-
25
-
26
- def evaluate_length(corpus, tokenizer_name):
27
- length_evaluator = LengthEvaluator(tokenizer_name=tokenizer_name)
28
- logger.info("Length evaluator loaded")
29
- scores = length_evaluator.get_average_score(corpus)
30
- logger.info("Length scores: %s", scores)
31
- return scores
32
-
33
-
34
- def evaluate_mtld(corpus):
35
- mtld_evaluator = MTLDEvaluator()
36
- logger.info("MTLD evaluator loaded")
37
- scores = mtld_evaluator.get_average_score(corpus)
38
- logger.info("MTLD scores: %s", scores)
39
- min_max_scores = mtld_evaluator.get_min_max_score(corpus)
40
- logger.info("MTLD min max scores: %s", min_max_scores)
41
- return scores, min_max_scores
42
-
43
-
44
- def evaluate_reward(corpus, reward_model_names):
45
- scores = []
46
- for reward_name in reward_model_names:
47
- reward_evaluator = RewardEvaluator(reward_name=reward_name)
48
- logger.info("Loaded reward model: %s", reward_name)
49
- average_score = reward_evaluator.get_average_score(corpus)
50
- logger.info("%s scores: %s", reward_name, average_score)
51
- min_max_scores = reward_evaluator.get_min_max_score(corpus)
52
- logger.info("%s min max scores: %s", reward_name, min_max_scores)
53
- scores.append(
54
- {
55
- "reward_name": reward_name.split("/")[-1],
56
- "score": average_score,
57
- "min_max_scores": min_max_scores,
58
- }
59
- )
60
- del reward_evaluator
61
- clean_gpu_cache()
62
- return scores
63
-
64
-
65
- def evaluate_uni(corpus, uni_model_name):
66
- uni_evaluator = UniEvaluator(model_name=uni_model_name)
67
- logger.info("Uni evaluator loaded with model %s", uni_model_name)
68
- uni_scores = uni_evaluator.get_average_score(corpus)
69
- for key, value in uni_scores.items():
70
- logger.info("Uni %s scores: %s", key, value)
71
- min_max_scores = uni_evaluator.get_min_max_score(corpus)
72
- for key, value in min_max_scores.items():
73
- logger.info("Uni %s min max scores: %s", key, value)
74
- del uni_evaluator
75
- clean_gpu_cache()
76
- return (
77
- uni_scores["naturalness"],
78
- uni_scores["coherence"],
79
- uni_scores["understandability"],
80
- min_max_scores["naturalness"],
81
- min_max_scores["coherence"],
82
- min_max_scores["understandability"],
83
- )
84
-
85
-
86
- def clean_gpu_cache():
87
- import torch
88
-
89
- if torch.cuda.is_available():
90
- torch.cuda.empty_cache()
91
-
92
-
93
- if __name__ == "__main__":
94
- import torch.multiprocessing as mp
95
-
96
- parser = argparse.ArgumentParser()
97
-
98
- parser.add_argument(
99
- "--folder", type=str, default="cache/data", help="folder to load data"
100
- )
101
- parser.add_argument(
102
- "--output", type=str, default="cache/output", help="path to save output"
103
- )
104
-
105
- parser.add_argument(
106
- "--tokenizer", type=str, default="cl100k_base", help="tokenizer name"
107
- )
108
- parser.add_argument(
109
- "--reward",
110
- type=str,
111
- default="OpenAssistant/reward-model-deberta-v3-large-v2",
112
- help="Comma-separated list of reward models",
113
- )
114
- parser.add_argument(
115
- "--uni", type=str, default="MingZhong/unieval-sum", help="uni model name"
116
- )
117
-
118
- args = parser.parse_args()
119
-
120
- if not os.path.exists(args.folder):
121
- raise ValueError(f"Folder {args.folder} does not exist")
122
-
123
- if not os.path.exists(args.output):
124
- os.makedirs(args.output)
125
-
126
- reward_models = args.reward.split(",")
127
-
128
- results = []
129
-
130
- logger.info("Data loaded from %s", args.folder)
131
- mp.set_start_method("spawn")
132
-
133
- for file in os.listdir(args.folder):
134
- if file.endswith(".json"):
135
- logger.info("Processing %s", file)
136
- with open(os.path.join(args.folder, file), "r", encoding="utf-8") as f:
137
- data = json.load(f)
138
- data = [
139
- QAPair(question=data[key]["question"], answer=data[key]["answer"])
140
- for key in data
141
- ]
142
-
143
- length_scores = evaluate_length(data, args.tokenizer)
144
- mtld_scores, min_max_mtld_scores = evaluate_mtld(data)
145
- reward_scores = evaluate_reward(data, reward_models)
146
- (
147
- uni_naturalness_scores,
148
- uni_coherence_scores,
149
- uni_understandability_scores,
150
- min_max_uni_naturalness_scores,
151
- min_max_uni_coherence_scores,
152
- min_max_uni_understandability_scores,
153
- ) = evaluate_uni(data, args.uni)
154
-
155
- result = {
156
- "file": file,
157
- "number": len(data),
158
- "length": length_scores,
159
- "mtld": mtld_scores,
160
- "mtld_min_max": min_max_mtld_scores,
161
- "uni_naturalness": uni_naturalness_scores,
162
- "uni_coherence": uni_coherence_scores,
163
- "uni_understandability": uni_understandability_scores,
164
- "uni_naturalness_min_max": min_max_uni_naturalness_scores,
165
- "uni_coherence_min_max": min_max_uni_coherence_scores,
166
- "uni_understandability_min_max": min_max_uni_understandability_scores,
167
- }
168
- for reward_score in reward_scores:
169
- result[reward_score["reward_name"]] = reward_score["score"]
170
- result[f"{reward_score['reward_name']}_min_max"] = reward_score[
171
- "min_max_scores"
172
- ]
173
-
174
- results.append(result)
175
-
176
- results = pd.DataFrame(results)
177
- results.to_csv(os.path.join(args.output, "evaluation.csv"), index=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphgen/operators/evaluate/evaluate_service.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict
2
+
3
+ import pandas as pd
4
+
5
+ from graphgen.bases import BaseLLMWrapper, BaseOperator, QAPair
6
+ from graphgen.common import init_llm, init_storage
7
+ from graphgen.utils import logger, run_concurrent
8
+
9
+
10
+ class EvaluateService(BaseOperator):
11
+ """
12
+ 1. KG Quality Evaluation
13
+ 2. QA Quality Evaluation
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ working_dir: str = "cache",
19
+ metrics: list[str] = None,
20
+ graph_backend: str = "kuzu",
21
+ kv_backend: str = "rocksdb",
22
+ **kwargs,
23
+ ):
24
+ super().__init__(working_dir=working_dir, op_name="evaluate_service")
25
+ self.llm_client: BaseLLMWrapper = init_llm("synthesizer")
26
+ self.metrics = metrics or []
27
+ self.kwargs = kwargs
28
+ self.graph_storage = init_storage(
29
+ backend=graph_backend, working_dir=working_dir, namespace="graph"
30
+ )
31
+ self.chunk_storage = init_storage(
32
+ backend=kv_backend, working_dir=working_dir, namespace="chunk"
33
+ )
34
+
35
+ # Initialize evaluators
36
+ self.qa_evaluators = {}
37
+ self.kg_evaluators = {}
38
+ self._init_evaluators()
39
+
40
+ def _init_evaluators(self):
41
+ """Initialize QA and KG evaluators based on metrics."""
42
+ for metric in self.metrics:
43
+ if metric == "qa_length":
44
+ from graphgen.models import LengthEvaluator
45
+
46
+ self.qa_evaluators[metric] = LengthEvaluator()
47
+ elif metric == "qa_mtld":
48
+ from graphgen.models import MTLDEvaluator
49
+
50
+ self.qa_evaluators[metric] = MTLDEvaluator(
51
+ **self.kwargs.get("mtld_params", {})
52
+ )
53
+ elif metric == "qa_reward_score":
54
+ from graphgen.models import RewardEvaluator
55
+
56
+ self.qa_evaluators[metric] = RewardEvaluator(
57
+ **self.kwargs.get("reward_params", {})
58
+ )
59
+ elif metric == "qa_uni_score":
60
+ from graphgen.models import UniEvaluator
61
+
62
+ self.qa_evaluators[metric] = UniEvaluator(
63
+ **self.kwargs.get("uni_params", {})
64
+ )
65
+ elif metric == "kg_accuracy":
66
+ from graphgen.models import AccuracyEvaluator
67
+
68
+ self.kg_evaluators[metric] = AccuracyEvaluator(
69
+ graph_storage=self.graph_storage,
70
+ chunk_storage=self.chunk_storage,
71
+ llm_client=self.llm_client,
72
+ )
73
+ elif metric == "kg_consistency":
74
+ from graphgen.models import ConsistencyEvaluator
75
+
76
+ self.kg_evaluators[metric] = ConsistencyEvaluator(
77
+ graph_storage=self.graph_storage,
78
+ chunk_storage=self.chunk_storage,
79
+ llm_client=self.llm_client,
80
+ )
81
+ elif metric == "kg_structure":
82
+ from graphgen.models import StructureEvaluator
83
+
84
+ self.kg_evaluators[metric] = StructureEvaluator(
85
+ graph_storage=self.graph_storage,
86
+ **self.kwargs.get("structure_params", {}),
87
+ )
88
+ else:
89
+ raise ValueError(f"Unknown QA metric: {metric}")
90
+
91
+ async def _process_single_qa(self, item: dict[str, Any]) -> dict[str, Any]:
92
+ try:
93
+ qa_pair = QAPair(
94
+ question=str(item.get("question", "")),
95
+ answer=str(item.get("answer", "")),
96
+ )
97
+ if not qa_pair.question or not qa_pair.answer:
98
+ self.logger.error("Empty question or answer, skipping.")
99
+ return {}
100
+ except Exception as e:
101
+ self.logger.error("Error in QAPair creation: %s", str(e))
102
+ return {}
103
+
104
+ for metric, evaluator in self.qa_evaluators.items():
105
+ try:
106
+ score = evaluator.evaluate(qa_pair)
107
+ if isinstance(score, dict):
108
+ for sub_metric, sub_score in score.items():
109
+ item[f"{metric}_{sub_metric}"] = float(sub_score)
110
+ else:
111
+ item[metric] = float(score)
112
+ except Exception as e:
113
+ self.logger.error("Error in %s evaluation: %s", metric, str(e))
114
+ item[metric] = None
115
+ return item
116
+
117
+ def _evaluate_qa(self, items: list[dict[str, Any]]) -> list[dict[str, Any]]:
118
+ def transform_messages_format(items: list[dict]) -> list[dict]:
119
+ """
120
+ Transform from [{'messages': [...]}, ...] to [{'question': '...', 'answer': '...'}, ...]
121
+ """
122
+ transformed = []
123
+ for item in items:
124
+ messages = item.get("messages", [])
125
+ question = next(
126
+ (m["content"] for m in messages if m.get("role") == "user"), ""
127
+ )
128
+ answer = next(
129
+ (m["content"] for m in messages if m.get("role") == "assistant"), ""
130
+ )
131
+
132
+ transformed.append({"question": question, "answer": answer})
133
+ return transformed
134
+
135
+ if not items:
136
+ return []
137
+
138
+ if not self.qa_evaluators:
139
+ self.logger.warning("No QA evaluators initialized, skipping QA evaluation")
140
+ return []
141
+
142
+ items = transform_messages_format(items)
143
+ results = run_concurrent(
144
+ self._process_single_qa,
145
+ items,
146
+ desc="Evaluating QA items",
147
+ unit="item",
148
+ )
149
+
150
+ results = [item for item in results if item]
151
+ return results
152
+
153
+ def _evaluate_kg(self) -> Dict[str, Any]:
154
+ results = {}
155
+
156
+ for metric, evaluator in self.kg_evaluators.items():
157
+ try:
158
+ self.logger.info("Running %s evaluation...", metric)
159
+ score = evaluator.evaluate()
160
+ results[metric] = score
161
+ except Exception as e:
162
+ self.logger.error("Error in %s evaluation: %s", metric, str(e))
163
+ results[metric] = {"error": str(e)}
164
+ return results
165
+
166
+ def process(self, batch: pd.DataFrame) -> pd.DataFrame:
167
+ # QA evaluation
168
+ if len(self.qa_evaluators) > 0:
169
+ items = batch.to_dict(orient="records")
170
+ results = self._evaluate_qa(items)
171
+ return pd.DataFrame(results)
172
+
173
+ # KG evaluation
174
+ if len(self.kg_evaluators) > 0:
175
+ results = self._evaluate_kg()
176
+ # Convert dict to DataFrame (single row)
177
+ return pd.DataFrame([results])
178
+
179
+ # No metrics specified
180
+ logger.warning("No metrics specified, returning empty DataFrame")
181
+ return pd.DataFrame()
graphgen/run.py CHANGED
@@ -91,10 +91,11 @@ def main():
91
  results = engine.execute(ds)
92
 
93
  for node_id, dataset in results.items():
94
- output_path = os.path.join(output_path, f"{node_id}")
95
- os.makedirs(output_path, exist_ok=True)
 
96
  dataset.write_json(
97
- output_path,
98
  filename_provider=NodeFilenameProvider(node_id),
99
  pandas_json_args_fn=lambda: {
100
  "force_ascii": False,
@@ -102,7 +103,7 @@ def main():
102
  "lines": True,
103
  },
104
  )
105
- logger.info("Node %s results saved to %s", node_id, output_path)
106
 
107
  save_config(os.path.join(output_path, "config.yaml"), config)
108
  logger.info("GraphGen completed successfully. Data saved to %s", output_path)
 
91
  results = engine.execute(ds)
92
 
93
  for node_id, dataset in results.items():
94
+ logger.info("Saving results for node %s", node_id)
95
+ node_output_path = os.path.join(output_path, f"{node_id}")
96
+ os.makedirs(node_output_path, exist_ok=True)
97
  dataset.write_json(
98
+ node_output_path,
99
  filename_provider=NodeFilenameProvider(node_id),
100
  pandas_json_args_fn=lambda: {
101
  "force_ascii": False,
 
103
  "lines": True,
104
  },
105
  )
106
+ logger.info("Node %s results saved to %s", node_id, node_output_path)
107
 
108
  save_config(os.path.join(output_path, "config.yaml"), config)
109
  logger.info("GraphGen completed successfully. Data saved to %s", output_path)
graphgen/templates/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
  from .coreference_resolution import COREFERENCE_RESOLUTION_PROMPT
2
  from .description_rephrasing import DESCRIPTION_REPHRASING_PROMPT
 
3
  from .extraction import SCHEMA_GUIDED_EXTRACTION_PROMPT
4
  from .generation import (
5
  AGGREGATED_GENERATION_PROMPT,
 
1
  from .coreference_resolution import COREFERENCE_RESOLUTION_PROMPT
2
  from .description_rephrasing import DESCRIPTION_REPHRASING_PROMPT
3
+ from .evaluation import ACCURACY_EVALUATION_PROMPT, CONSISTENCY_EVALUATION_PROMPT
4
  from .extraction import SCHEMA_GUIDED_EXTRACTION_PROMPT
5
  from .generation import (
6
  AGGREGATED_GENERATION_PROMPT,
graphgen/templates/evaluation/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .kg import ACCURACY_EVALUATION_PROMPT, CONSISTENCY_EVALUATION_PROMPT
graphgen/templates/evaluation/kg/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .accuracy_evaluation import ACCURACY_EVALUATION_PROMPT
2
+ from .consistency_evaluation import CONSISTENCY_EVALUATION_PROMPT
graphgen/templates/evaluation/kg/accuracy_evaluation.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ENTITY_EVALUATION_PROMPT_ZH = """你是一个知识图谱质量评估专家。你的任务是从给定的文本块和提取的实体列表,评估实体提取的质量。
2
+
3
+ 评估维度:
4
+ 1. ACCURACY (准确性, 权重: 40%): 提取的实体是否正确,是否有误提取或错误识别
5
+ 2. COMPLETENESS (完整性, 权重: 40%): 是否遗漏了文本中的重要实体
6
+ 3. PRECISION (精确性, 权重: 20%): 提取的实体是否精确,命名是否准确
7
+
8
+ 评分标准(每个维度 0-1 分):
9
+ - EXCELLENT (0.8-1.0): 高质量提取
10
+ - GOOD (0.6-0.79): 良好质量,有少量问题
11
+ - ACCEPTABLE (0.4-0.59): 可接受,有明显问题
12
+ - POOR (0.0-0.39): 质量差,需要改进
13
+
14
+ 综合评分 = 0.4 × Accuracy + 0.4 × Completeness + 0.2 × Precision
15
+
16
+ 请评估以下内容:
17
+
18
+ 原始文本块:
19
+ {chunk_content}
20
+
21
+ 提取的实体列表:
22
+ {extracted_entities}
23
+
24
+ 请以 JSON 格式返回评估结果:
25
+ {{
26
+ "accuracy": <0-1之间的浮点数>,
27
+ "completeness": <0-1之间的浮点数>,
28
+ "precision": <0-1之间的浮点数>,
29
+ "overall_score": <综合评分>,
30
+ "accuracy_reasoning": "<准确性评估理由>",
31
+ "completeness_reasoning": "<完整性评估理由,包括遗漏的重要实体>",
32
+ "precision_reasoning": "<精确性评估理由>",
33
+ "issues": ["<发现的问题列表>"]
34
+ }}
35
+ """
36
+
37
+ ENTITY_EVALUATION_PROMPT_EN = """You are a Knowledge Graph Quality Assessment Expert. \
38
+ Your task is to evaluate the quality of entity extraction from a given text block and extracted entity list.
39
+
40
+ Evaluation Dimensions:
41
+ 1. ACCURACY (Weight: 40%): Whether the extracted entities are correct, and if there are any false extractions or misidentifications
42
+ 2. COMPLETENESS (Weight: 40%): Whether important entities from the text are missing
43
+ 3. PRECISION (Weight: 20%): Whether the extracted entities are precise and accurately named
44
+
45
+ Scoring Criteria (0-1 scale for each dimension):
46
+ - EXCELLENT (0.8-1.0): High-quality extraction
47
+ - GOOD (0.6-0.79): Good quality with minor issues
48
+ - ACCEPTABLE (0.4-0.59): Acceptable with noticeable issues
49
+ - POOR (0.0-0.39): Poor quality, needs improvement
50
+
51
+ Overall Score = 0.4 × Accuracy + 0.4 × Completeness + 0.2 × Precision
52
+
53
+ Please evaluate the following:
54
+
55
+ Original Text Block:
56
+ {chunk_content}
57
+
58
+ Extracted Entity List:
59
+ {extracted_entities}
60
+
61
+ Please return the evaluation result in JSON format:
62
+ {{
63
+ "accuracy": <float between 0-1>,
64
+ "completeness": <float between 0-1>,
65
+ "precision": <float between 0-1>,
66
+ "overall_score": <overall score>,
67
+ "accuracy_reasoning": "<reasoning for accuracy assessment>",
68
+ "completeness_reasoning": "<reasoning for completeness assessment, including important missing entities>",
69
+ "precision_reasoning": "<reasoning for precision assessment>",
70
+ "issues": ["<list of identified issues>"]
71
+ }}
72
+ """
73
+
74
+ RELATION_EVALUATION_PROMPT_ZH = """你是一个知识图谱质量评估专家。你的任务是从给定的文本块和提取的关系列表,评估关系抽取的质量。
75
+
76
+ 评估维度:
77
+ 1. ACCURACY (准确性, 权重: 40%): 提取的关系是否正确,关系描述是否准确
78
+ 2. COMPLETENESS (完整性, 权重: 40%): 是否遗漏了文本中的重要关系
79
+ 3. PRECISION (精确性, 权重: 20%): 关系描述是否精确,是否过于宽泛
80
+
81
+ 评分标准(每个维度 0-1 分):
82
+ - EXCELLENT (0.8-1.0): 高质量提取
83
+ - GOOD (0.6-0.79): 良好质量,有少量问题
84
+ - ACCEPTABLE (0.4-0.59): 可接受,有明显问题
85
+ - POOR (0.0-0.39): 质量差,需要改进
86
+
87
+ 综合评分 = 0.4 × Accuracy + 0.4 × Completeness + 0.2 × Precision
88
+
89
+ 请评估以下内容:
90
+
91
+ 原始文本块:
92
+ {chunk_content}
93
+
94
+ 提取的关系列表:
95
+ {extracted_relations}
96
+
97
+ 请以 JSON 格式返回评估结果:
98
+ {{
99
+ "accuracy": <0-1之间的浮点数>,
100
+ "completeness": <0-1之间的浮点数>,
101
+ "precision": <0-1之间的浮点数>,
102
+ "overall_score": <综合评分>,
103
+ "accuracy_reasoning": "<准确性评估理由>",
104
+ "completeness_reasoning": "<完整性评估理由,包括遗漏的重要关系>",
105
+ "precision_reasoning": "<精确性评估理由>",
106
+ "issues": ["<发现的问题列表>"]
107
+ }}
108
+ """
109
+
110
+ RELATION_EVALUATION_PROMPT_EN = """You are a Knowledge Graph Quality Assessment Expert. \
111
+ Your task is to evaluate the quality of relation extraction from a given text block and extracted relation list.
112
+
113
+ Evaluation Dimensions:
114
+ 1. ACCURACY (Weight: 40%): Whether the extracted relations are correct and the relation descriptions are accurate
115
+ 2. COMPLETENESS (Weight: 40%): Whether important relations from the text are missing
116
+ 3. PRECISION (Weight: 20%): Whether the relation descriptions are precise and not overly broad
117
+
118
+ Scoring Criteria (0-1 scale for each dimension):
119
+ - EXCELLENT (0.8-1.0): High-quality extraction
120
+ - GOOD (0.6-0.79): Good quality with minor issues
121
+ - ACCEPTABLE (0.4-0.59): Acceptable with noticeable issues
122
+ - POOR (0.0-0.39): Poor quality, needs improvement
123
+
124
+ Overall Score = 0.4 × Accuracy + 0.4 × Completeness + 0.2 × Precision
125
+
126
+ Please evaluate the following:
127
+
128
+ Original Text Block:
129
+ {chunk_content}
130
+
131
+ Extracted Relation List:
132
+ {extracted_relations}
133
+
134
+ Please return the evaluation result in JSON format:
135
+ {{
136
+ "accuracy": <float between 0-1>,
137
+ "completeness": <float between 0-1>,
138
+ "precision": <float between 0-1>,
139
+ "overall_score": <overall score>,
140
+ "accuracy_reasoning": "<reasoning for accuracy assessment>",
141
+ "completeness_reasoning": "<reasoning for completeness assessment, including important missing relations>",
142
+ "precision_reasoning": "<reasoning for precision assessment>",
143
+ "issues": ["<list of identified issues>"]
144
+ }}
145
+ """
146
+
147
+ ACCURACY_EVALUATION_PROMPT = {
148
+ "zh": {
149
+ "ENTITY": ENTITY_EVALUATION_PROMPT_ZH,
150
+ "RELATION": RELATION_EVALUATION_PROMPT_ZH,
151
+ },
152
+ "en": {
153
+ "ENTITY": ENTITY_EVALUATION_PROMPT_EN,
154
+ "RELATION": RELATION_EVALUATION_PROMPT_EN,
155
+ },
156
+ }
graphgen/templates/evaluation/kg/consistency_evaluation.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ENTITY_TYPE_CONFLICT_PROMPT = """你是一个知识图谱一致性评估专家。你的任务是判断同一个实体在不同文本块中被提取为不同的类型,是否存在语义冲突。
2
+
3
+ 实体名称:{entity_name}
4
+
5
+ 在不同文本块中的类型提取结果:
6
+ {type_extractions}
7
+
8
+ 预设的实体类型列表(供参考):
9
+ concept, date, location, keyword, organization, person, event, work, nature, artificial, science, technology, mission, gene
10
+
11
+ 请判断这些类型是否存在语义冲突(即它们是否描述的是同一类事物,还是存在矛盾)。
12
+ 注意:如果类型只是同一概念的不同表述(如 concept 和 keyword),可能不算严重冲突。
13
+
14
+ 请以 JSON 格式返回:
15
+ {{
16
+ "has_conflict": <true/false>,
17
+ "conflict_severity": <0-1之间的浮点数,0表示无冲突,1表示严重冲突>,
18
+ "conflict_reasoning": "<冲突判断的理由>",
19
+ "conflicting_types": ["<存在冲突的类型对>"],
20
+ "recommended_type": "<如果存在冲突,推荐的正确类型(必须是预设类型之一)>"
21
+ }}
22
+ """
23
+
24
+ ENTITY_DESCRIPTION_CONFLICT_PROMPT = """你是一个知识图谱一致性评估专家。你的任务是判断同一个实体在不同文本块中的描述是否存在语义冲突。
25
+
26
+ 实体名称:{entity_name}
27
+
28
+ 在不同文本块中的描述:
29
+ {descriptions}
30
+
31
+ 请判断这些描述是否存在语义冲突(即它们是否描述的是同一个实体,还是存在矛盾的信息)。
32
+
33
+ 请以 JSON 格式返回:
34
+ {{
35
+ "has_conflict": <true/false>,
36
+ "conflict_severity": <0-1之间的浮点数>,
37
+ "conflict_reasoning": "<冲突判断的理由>",
38
+ "conflicting_descriptions": ["<存在冲突的描述对>"],
39
+ "conflict_details": "<具体的冲突内容>"
40
+ }}
41
+ """
42
+
43
+ RELATION_CONFLICT_PROMPT = """你是一个知识图谱一致性评估专家。你的任务是判断同一对实体在不同文本块中的关系描述是否存在语义冲突。
44
+
45
+ 实体对:{source_entity} -> {target_entity}
46
+
47
+ 在不同文本块中的关系描述:
48
+ {relation_descriptions}
49
+
50
+ 请判断这些关系描述是否存在语义冲突。
51
+
52
+ 请以 JSON 格式返回:
53
+ {{
54
+ "has_conflict": <true/false>,
55
+ "conflict_severity": <0-1之间的浮点数>,
56
+ "conflict_reasoning": "<冲突判断的理由>",
57
+ "conflicting_relations": ["<存在冲突的关系描述对>"]
58
+ }}
59
+ """
60
+
61
+ ENTITY_EXTRACTION_PROMPT = """从以下文本块中提取指定实体的类型和描述。
62
+
63
+ **重要**:你只需要提取指定的实体,不要提取其他实体。
64
+
65
+ 实体名称:{entity_name}
66
+
67
+ 文本块:
68
+ {chunk_content}
69
+
70
+ 请从文本块中找到并提取**仅此实体**(实体名称:{entity_name})的以下信息:
71
+
72
+ 1. entity_type: 实体类型,必须是以下预设类型之一(小写):
73
+ - concept: 概念
74
+ - date: 日期
75
+ - location: 地点
76
+ - keyword: 关键词
77
+ - organization: 组织
78
+ - person: 人物
79
+ - event: 事件
80
+ - work: 作品/工作
81
+ - nature: 自然
82
+ - artificial: 人工
83
+ - science: 科学
84
+ - technology: 技术
85
+ - mission: 任务
86
+ - gene: 基因
87
+
88
+ 如果无法确定类型,请使用 "concept" 作为默认值。
89
+
90
+ 2. description: 实体描述(简要描述该实体在文本中的作用和特征)
91
+
92
+ 请以 JSON 格式返回:
93
+ {{
94
+ "entity_type": "<实体类型(必须是上述预设类型之一)>",
95
+ "description": "<实体描述>"
96
+ }}
97
+ """
98
+
99
+ CONSISTENCY_EVALUATION_PROMPT = {
100
+ "en": "",
101
+ "zh": ""
102
+ }
graphgen/utils/help_nltk.py CHANGED
@@ -1,39 +1,61 @@
 
1
  import os
2
- from typing import Dict, List, Optional
 
3
  import nltk
4
  import jieba
5
 
6
- resource_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "resources")
7
-
 
 
 
8
 
9
  class NLTKHelper:
10
- _stopwords: Dict[str, Optional[List[str]]] = {
11
- "english": None,
12
- "chinese": None,
 
 
 
 
 
 
 
 
13
  }
14
 
15
- def __init__(self):
 
 
 
 
 
 
16
  jieba.initialize()
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def get_stopwords(self, lang: str) -> List[str]:
19
- nltk.data.path.append(os.path.join(resource_path, "nltk_data"))
20
- if self._stopwords[lang] is None:
21
- try:
22
- nltk.data.find("corpora/stopwords")
23
- except LookupError:
24
- nltk.download("stopwords", download_dir=os.path.join(resource_path, "nltk_data"))
25
-
26
- self._stopwords[lang] = nltk.corpus.stopwords.words(lang)
27
- return self._stopwords[lang]
28
-
29
- @staticmethod
30
- def word_tokenize(text: str, lang: str) -> List[str]:
31
  if lang == "zh":
32
  return jieba.lcut(text)
33
- nltk.data.path.append(os.path.join(resource_path, "nltk_data"))
34
- try:
35
- nltk.data.find("tokenizers/punkt_tab")
36
- except LookupError:
37
- nltk.download("punkt_tab", download_dir=os.path.join(resource_path, "nltk_data"))
38
 
39
  return nltk.word_tokenize(text)
 
1
+ from functools import lru_cache
2
  import os
3
+ from typing import Dict, List, Final, Optional
4
+ import warnings
5
  import nltk
6
  import jieba
7
 
8
+ warnings.filterwarnings(
9
+ "ignore",
10
+ category=UserWarning,
11
+ module=r"jieba\._compat"
12
+ )
13
 
14
  class NLTKHelper:
15
+ """
16
+ NLTK helper class
17
+ """
18
+
19
+ SUPPORTED_LANGUAGES: Final[Dict[str, str]] = {
20
+ "en": "english",
21
+ "zh": "chinese"
22
+ }
23
+ _NLTK_PACKAGES: Final[Dict[str, str]] = {
24
+ "stopwords": "corpora",
25
+ "punkt_tab": "tokenizers"
26
  }
27
 
28
+ def __init__(self, nltk_data_path: Optional[str] = None):
29
+ self._nltk_path = nltk_data_path or os.path.join(
30
+ os.path.dirname(os.path.dirname(__file__)),
31
+ "resources",
32
+ "nltk_data"
33
+ )
34
+ nltk.data.path.append(self._nltk_path)
35
  jieba.initialize()
36
 
37
+ self._ensure_nltk_data("stopwords")
38
+ self._ensure_nltk_data("punkt_tab")
39
+
40
+ def _ensure_nltk_data(self, package_name: str) -> None:
41
+ """
42
+ ensure nltk data is downloaded
43
+ """
44
+ try:
45
+ nltk.data.find(f"{self._NLTK_PACKAGES[package_name]}/{package_name}")
46
+ except LookupError:
47
+ nltk.download(package_name, download_dir=self._nltk_path, quiet=True)
48
+
49
+ @lru_cache(maxsize=2)
50
  def get_stopwords(self, lang: str) -> List[str]:
51
+ if lang not in self.SUPPORTED_LANGUAGES:
52
+ raise ValueError(f"Language {lang} is not supported.")
53
+ return nltk.corpus.stopwords.words(self.SUPPORTED_LANGUAGES[lang])
54
+
55
+ def word_tokenize(self, text: str, lang: str) -> List[str]:
56
+ if lang not in self.SUPPORTED_LANGUAGES:
57
+ raise ValueError(f"Language {lang} is not supported.")
 
 
 
 
 
58
  if lang == "zh":
59
  return jieba.lcut(text)
 
 
 
 
 
60
 
61
  return nltk.word_tokenize(text)
requirements.txt CHANGED
@@ -23,7 +23,6 @@ aiohttp
23
  socksio
24
  pydantic
25
  ray==2.52.1
26
- kuzu
27
  pyarrow
28
 
29
  leidenalg
@@ -32,9 +31,11 @@ python-louvain
32
 
33
  # storage
34
  rocksdict
 
35
 
36
  # KG
37
  rdflib
 
38
 
39
  # Bioinformatics
40
  biopython
 
23
  socksio
24
  pydantic
25
  ray==2.52.1
 
26
  pyarrow
27
 
28
  leidenalg
 
31
 
32
  # storage
33
  rocksdict
34
+ kuzu
35
 
36
  # KG
37
  rdflib
38
+ scipy
39
 
40
  # Bioinformatics
41
  biopython
webui/utils/count_tokens.py CHANGED
@@ -7,10 +7,12 @@ import pandas as pd
7
  # pylint: disable=wrong-import-position
8
  root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
9
  sys.path.append(root_dir)
10
- from graphgen.models import Tokenizer
11
 
12
 
13
  def count_tokens(file, tokenizer_name, data_frame):
 
 
 
14
  if not file or not os.path.exists(file):
15
  return data_frame
16
 
 
7
  # pylint: disable=wrong-import-position
8
  root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
9
  sys.path.append(root_dir)
 
10
 
11
 
12
  def count_tokens(file, tokenizer_name, data_frame):
13
+ # Lazy import to avoid circular dependency
14
+ from graphgen.models import Tokenizer # pylint: disable=import-outside-toplevel
15
+
16
  if not file or not os.path.exists(file):
17
  return data_frame
18