Spaces:
Build error
Build error
github-actions[bot] commited on
Commit ·
7566ac3
1
Parent(s): 31df32c
Auto-sync from demo at Fri Dec 26 08:29:01 UTC 2025
Browse files- graphgen/bases/__init__.py +1 -0
- graphgen/bases/base_evaluator.py +10 -0
- graphgen/bases/base_storage.py +43 -4
- graphgen/common/init_storage.py +39 -5
- graphgen/engine.py +2 -0
- graphgen/models/__init__.py +9 -1
- graphgen/models/evaluator/__init__.py +2 -4
- graphgen/models/evaluator/base_evaluator.py +0 -52
- graphgen/models/evaluator/kg/__init__.py +18 -0
- graphgen/models/evaluator/kg/accuracy_evaluator.py +350 -0
- graphgen/models/evaluator/kg/consistency_evaluator.py +380 -0
- graphgen/models/evaluator/kg/structure_evaluator.py +97 -0
- graphgen/models/evaluator/length_evaluator.py +0 -19
- graphgen/models/evaluator/qa/__init__.py +4 -0
- graphgen/models/evaluator/qa/length_evaluator.py +18 -0
- graphgen/models/evaluator/{mtld_evaluator.py → qa/mtld_evaluator.py} +18 -23
- graphgen/models/evaluator/qa/reward_evaluator.py +66 -0
- graphgen/models/evaluator/qa/uni_evaluator.py +105 -0
- graphgen/models/evaluator/reward_evaluator.py +0 -107
- graphgen/models/evaluator/uni_evaluator.py +0 -183
- graphgen/models/storage/graph/kuzu_storage.py +90 -1
- graphgen/models/storage/graph/networkx_storage.py +26 -1
- graphgen/operators/__init__.py +3 -0
- graphgen/operators/evaluate/__init__.py +3 -0
- graphgen/operators/evaluate/evaluate.py +0 -177
- graphgen/operators/evaluate/evaluate_service.py +181 -0
- graphgen/run.py +5 -4
- graphgen/templates/__init__.py +1 -0
- graphgen/templates/evaluation/__init__.py +1 -0
- graphgen/templates/evaluation/kg/__init__.py +2 -0
- graphgen/templates/evaluation/kg/accuracy_evaluation.py +156 -0
- graphgen/templates/evaluation/kg/consistency_evaluation.py +102 -0
- graphgen/utils/help_nltk.py +46 -24
- requirements.txt +2 -1
- webui/utils/count_tokens.py +3 -1
graphgen/bases/__init__.py
CHANGED
|
@@ -9,4 +9,5 @@ from .base_searcher import BaseSearcher
|
|
| 9 |
from .base_splitter import BaseSplitter
|
| 10 |
from .base_storage import BaseGraphStorage, BaseKVStorage, StorageNameSpace
|
| 11 |
from .base_tokenizer import BaseTokenizer
|
|
|
|
| 12 |
from .datatypes import Chunk, Config, Node, QAPair, Token
|
|
|
|
| 9 |
from .base_splitter import BaseSplitter
|
| 10 |
from .base_storage import BaseGraphStorage, BaseKVStorage, StorageNameSpace
|
| 11 |
from .base_tokenizer import BaseTokenizer
|
| 12 |
+
from .base_evaluator import BaseEvaluator
|
| 13 |
from .datatypes import Chunk, Config, Node, QAPair, Token
|
graphgen/bases/base_evaluator.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from .datatypes import QAPair
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class BaseEvaluator(ABC):
|
| 6 |
+
@abstractmethod
|
| 7 |
+
def evaluate(self, pair: QAPair) -> float:
|
| 8 |
+
"""
|
| 9 |
+
Evaluate the text and return a score.
|
| 10 |
+
"""
|
graphgen/bases/base_storage.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
|
|
| 1 |
from dataclasses import dataclass
|
| 2 |
-
from typing import Generic, TypeVar, Union
|
| 3 |
|
| 4 |
T = TypeVar("T")
|
| 5 |
|
|
@@ -45,52 +46,90 @@ class BaseKVStorage(Generic[T], StorageNameSpace):
|
|
| 45 |
raise NotImplementedError
|
| 46 |
|
| 47 |
|
| 48 |
-
class BaseGraphStorage(StorageNameSpace):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
def has_node(self, node_id: str) -> bool:
|
| 50 |
raise NotImplementedError
|
| 51 |
|
|
|
|
| 52 |
def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
| 53 |
raise NotImplementedError
|
| 54 |
|
|
|
|
| 55 |
def node_degree(self, node_id: str) -> int:
|
| 56 |
raise NotImplementedError
|
| 57 |
|
| 58 |
-
|
| 59 |
-
|
|
|
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
def get_node(self, node_id: str) -> Union[dict, None]:
|
| 62 |
raise NotImplementedError
|
| 63 |
|
|
|
|
| 64 |
def update_node(self, node_id: str, node_data: dict[str, str]):
|
| 65 |
raise NotImplementedError
|
| 66 |
|
|
|
|
| 67 |
def get_all_nodes(self) -> Union[list[tuple[str, dict]], None]:
|
| 68 |
raise NotImplementedError
|
| 69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None]:
|
| 71 |
raise NotImplementedError
|
| 72 |
|
|
|
|
| 73 |
def update_edge(
|
| 74 |
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
| 75 |
):
|
| 76 |
raise NotImplementedError
|
| 77 |
|
|
|
|
| 78 |
def get_all_edges(self) -> Union[list[tuple[str, str, dict]], None]:
|
| 79 |
raise NotImplementedError
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
def get_node_edges(self, source_node_id: str) -> Union[list[tuple[str, str]], None]:
|
| 82 |
raise NotImplementedError
|
| 83 |
|
|
|
|
| 84 |
def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
| 85 |
raise NotImplementedError
|
| 86 |
|
|
|
|
| 87 |
def upsert_edge(
|
| 88 |
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
| 89 |
):
|
| 90 |
raise NotImplementedError
|
| 91 |
|
|
|
|
| 92 |
def delete_node(self, node_id: str):
|
| 93 |
raise NotImplementedError
|
| 94 |
|
|
|
|
| 95 |
def reload(self):
|
| 96 |
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
from dataclasses import dataclass
|
| 3 |
+
from typing import Dict, Generic, List, Set, TypeVar, Union
|
| 4 |
|
| 5 |
T = TypeVar("T")
|
| 6 |
|
|
|
|
| 46 |
raise NotImplementedError
|
| 47 |
|
| 48 |
|
| 49 |
+
class BaseGraphStorage(StorageNameSpace, ABC):
|
| 50 |
+
@abstractmethod
|
| 51 |
+
def is_directed(self) -> bool:
|
| 52 |
+
pass
|
| 53 |
+
|
| 54 |
+
@abstractmethod
|
| 55 |
def has_node(self, node_id: str) -> bool:
|
| 56 |
raise NotImplementedError
|
| 57 |
|
| 58 |
+
@abstractmethod
|
| 59 |
def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
| 60 |
raise NotImplementedError
|
| 61 |
|
| 62 |
+
@abstractmethod
|
| 63 |
def node_degree(self, node_id: str) -> int:
|
| 64 |
raise NotImplementedError
|
| 65 |
|
| 66 |
+
@abstractmethod
|
| 67 |
+
def get_all_node_degrees(self) -> Dict[str, int]:
|
| 68 |
+
pass
|
| 69 |
|
| 70 |
+
def get_isolated_nodes(self) -> List[str]:
|
| 71 |
+
return [
|
| 72 |
+
node_id
|
| 73 |
+
for node_id, degree in self.get_all_node_degrees().items()
|
| 74 |
+
if degree == 0
|
| 75 |
+
]
|
| 76 |
+
|
| 77 |
+
@abstractmethod
|
| 78 |
def get_node(self, node_id: str) -> Union[dict, None]:
|
| 79 |
raise NotImplementedError
|
| 80 |
|
| 81 |
+
@abstractmethod
|
| 82 |
def update_node(self, node_id: str, node_data: dict[str, str]):
|
| 83 |
raise NotImplementedError
|
| 84 |
|
| 85 |
+
@abstractmethod
|
| 86 |
def get_all_nodes(self) -> Union[list[tuple[str, dict]], None]:
|
| 87 |
raise NotImplementedError
|
| 88 |
|
| 89 |
+
@abstractmethod
|
| 90 |
+
def get_node_count(self) -> int:
|
| 91 |
+
pass
|
| 92 |
+
|
| 93 |
+
@abstractmethod
|
| 94 |
def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None]:
|
| 95 |
raise NotImplementedError
|
| 96 |
|
| 97 |
+
@abstractmethod
|
| 98 |
def update_edge(
|
| 99 |
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
| 100 |
):
|
| 101 |
raise NotImplementedError
|
| 102 |
|
| 103 |
+
@abstractmethod
|
| 104 |
def get_all_edges(self) -> Union[list[tuple[str, str, dict]], None]:
|
| 105 |
raise NotImplementedError
|
| 106 |
|
| 107 |
+
@abstractmethod
|
| 108 |
+
def get_edge_count(self) -> int:
|
| 109 |
+
pass
|
| 110 |
+
|
| 111 |
+
@abstractmethod
|
| 112 |
def get_node_edges(self, source_node_id: str) -> Union[list[tuple[str, str]], None]:
|
| 113 |
raise NotImplementedError
|
| 114 |
|
| 115 |
+
@abstractmethod
|
| 116 |
def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
| 117 |
raise NotImplementedError
|
| 118 |
|
| 119 |
+
@abstractmethod
|
| 120 |
def upsert_edge(
|
| 121 |
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
| 122 |
):
|
| 123 |
raise NotImplementedError
|
| 124 |
|
| 125 |
+
@abstractmethod
|
| 126 |
def delete_node(self, node_id: str):
|
| 127 |
raise NotImplementedError
|
| 128 |
|
| 129 |
+
@abstractmethod
|
| 130 |
def reload(self):
|
| 131 |
raise NotImplementedError
|
| 132 |
+
|
| 133 |
+
@abstractmethod
|
| 134 |
+
def get_connected_components(self, undirected: bool = True) -> List[Set[str]]:
|
| 135 |
+
raise NotImplementedError
|
graphgen/common/init_storage.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from typing import Any, Dict, Union
|
| 2 |
|
| 3 |
import ray
|
| 4 |
|
|
@@ -68,6 +68,21 @@ class GraphStorageActor:
|
|
| 68 |
def index_done_callback(self):
|
| 69 |
return self.graph.index_done_callback()
|
| 70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
def has_node(self, node_id: str) -> bool:
|
| 72 |
return self.graph.has_node(node_id)
|
| 73 |
|
|
@@ -165,6 +180,21 @@ class RemoteGraphStorageProxy(BaseGraphStorage):
|
|
| 165 |
def index_done_callback(self):
|
| 166 |
return ray.get(self.actor.index_done_callback.remote())
|
| 167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
def has_node(self, node_id: str) -> bool:
|
| 169 |
return ray.get(self.actor.has_node.remote(node_id))
|
| 170 |
|
|
@@ -239,10 +269,14 @@ class StorageFactory:
|
|
| 239 |
try:
|
| 240 |
actor_handle = ray.get_actor(actor_name)
|
| 241 |
except ValueError:
|
| 242 |
-
actor_handle =
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
ray.get(actor_handle.ready.remote())
|
| 247 |
return proxy_class(actor_handle)
|
| 248 |
|
|
|
|
| 1 |
+
from typing import Any, Dict, List, Set, Union
|
| 2 |
|
| 3 |
import ray
|
| 4 |
|
|
|
|
| 68 |
def index_done_callback(self):
|
| 69 |
return self.graph.index_done_callback()
|
| 70 |
|
| 71 |
+
def is_directed(self) -> bool:
|
| 72 |
+
return self.graph.is_directed()
|
| 73 |
+
|
| 74 |
+
def get_all_node_degrees(self) -> Dict[str, int]:
|
| 75 |
+
return self.graph.get_all_node_degrees()
|
| 76 |
+
|
| 77 |
+
def get_node_count(self) -> int:
|
| 78 |
+
return self.graph.get_node_count()
|
| 79 |
+
|
| 80 |
+
def get_edge_count(self) -> int:
|
| 81 |
+
return self.graph.get_edge_count()
|
| 82 |
+
|
| 83 |
+
def get_connected_components(self, undirected: bool = True) -> List[Set[str]]:
|
| 84 |
+
return self.graph.get_connected_components(undirected)
|
| 85 |
+
|
| 86 |
def has_node(self, node_id: str) -> bool:
|
| 87 |
return self.graph.has_node(node_id)
|
| 88 |
|
|
|
|
| 180 |
def index_done_callback(self):
|
| 181 |
return ray.get(self.actor.index_done_callback.remote())
|
| 182 |
|
| 183 |
+
def is_directed(self) -> bool:
|
| 184 |
+
return ray.get(self.actor.is_directed.remote())
|
| 185 |
+
|
| 186 |
+
def get_all_node_degrees(self) -> Dict[str, int]:
|
| 187 |
+
return ray.get(self.actor.get_all_node_degrees.remote())
|
| 188 |
+
|
| 189 |
+
def get_node_count(self) -> int:
|
| 190 |
+
return ray.get(self.actor.get_node_count.remote())
|
| 191 |
+
|
| 192 |
+
def get_edge_count(self) -> int:
|
| 193 |
+
return ray.get(self.actor.get_edge_count.remote())
|
| 194 |
+
|
| 195 |
+
def get_connected_components(self, undirected: bool = True) -> List[Set[str]]:
|
| 196 |
+
return ray.get(self.actor.get_connected_components.remote(undirected))
|
| 197 |
+
|
| 198 |
def has_node(self, node_id: str) -> bool:
|
| 199 |
return ray.get(self.actor.has_node.remote(node_id))
|
| 200 |
|
|
|
|
| 269 |
try:
|
| 270 |
actor_handle = ray.get_actor(actor_name)
|
| 271 |
except ValueError:
|
| 272 |
+
actor_handle = (
|
| 273 |
+
ray.remote(actor_class)
|
| 274 |
+
.options(
|
| 275 |
+
name=actor_name,
|
| 276 |
+
get_if_exists=True,
|
| 277 |
+
)
|
| 278 |
+
.remote(backend, working_dir, namespace)
|
| 279 |
+
)
|
| 280 |
ray.get(actor_handle.ready.remote())
|
| 281 |
return proxy_class(actor_handle)
|
| 282 |
|
graphgen/engine.py
CHANGED
|
@@ -271,6 +271,8 @@ class Engine:
|
|
| 271 |
|
| 272 |
for node in sorted_nodes:
|
| 273 |
self._execute_node(node, initial_ds)
|
|
|
|
|
|
|
| 274 |
|
| 275 |
output_nodes = [n for n in sorted_nodes if getattr(n, "save_output", False)]
|
| 276 |
return {node.id: self.datasets[node.id] for node in output_nodes}
|
|
|
|
| 271 |
|
| 272 |
for node in sorted_nodes:
|
| 273 |
self._execute_node(node, initial_ds)
|
| 274 |
+
if getattr(node, "save_output", False):
|
| 275 |
+
self.datasets[node.id] = self.datasets[node.id].materialize()
|
| 276 |
|
| 277 |
output_nodes = [n for n in sorted_nodes if getattr(n, "save_output", False)]
|
| 278 |
return {node.id: self.datasets[node.id] for node in output_nodes}
|
graphgen/models/__init__.py
CHANGED
|
@@ -1,4 +1,12 @@
|
|
| 1 |
-
from .evaluator import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from .generator import (
|
| 3 |
AggregatedGenerator,
|
| 4 |
AtomicGenerator,
|
|
|
|
| 1 |
+
from .evaluator import (
|
| 2 |
+
AccuracyEvaluator,
|
| 3 |
+
ConsistencyEvaluator,
|
| 4 |
+
LengthEvaluator,
|
| 5 |
+
MTLDEvaluator,
|
| 6 |
+
RewardEvaluator,
|
| 7 |
+
StructureEvaluator,
|
| 8 |
+
UniEvaluator,
|
| 9 |
+
)
|
| 10 |
from .generator import (
|
| 11 |
AggregatedGenerator,
|
| 12 |
AtomicGenerator,
|
graphgen/models/evaluator/__init__.py
CHANGED
|
@@ -1,4 +1,2 @@
|
|
| 1 |
-
from .
|
| 2 |
-
from .
|
| 3 |
-
from .reward_evaluator import RewardEvaluator
|
| 4 |
-
from .uni_evaluator import UniEvaluator
|
|
|
|
| 1 |
+
from .kg import AccuracyEvaluator, ConsistencyEvaluator, StructureEvaluator
|
| 2 |
+
from .qa import LengthEvaluator, MTLDEvaluator, RewardEvaluator, UniEvaluator
|
|
|
|
|
|
graphgen/models/evaluator/base_evaluator.py
DELETED
|
@@ -1,52 +0,0 @@
|
|
| 1 |
-
import asyncio
|
| 2 |
-
|
| 3 |
-
from tqdm.asyncio import tqdm as tqdm_async
|
| 4 |
-
|
| 5 |
-
from graphgen.bases.datatypes import QAPair
|
| 6 |
-
from graphgen.utils import create_event_loop
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
class BaseEvaluator:
|
| 10 |
-
def __init__(self, max_concurrent: int = 100):
|
| 11 |
-
self.max_concurrent = max_concurrent
|
| 12 |
-
self.results: list[float] = None
|
| 13 |
-
|
| 14 |
-
def evaluate(self, pairs: list[QAPair]) -> list[float]:
|
| 15 |
-
"""
|
| 16 |
-
Evaluate the text and return a score.
|
| 17 |
-
"""
|
| 18 |
-
return create_event_loop().run_until_complete(self.async_evaluate(pairs))
|
| 19 |
-
|
| 20 |
-
async def async_evaluate(self, pairs: list[QAPair]) -> list[float]:
|
| 21 |
-
semaphore = asyncio.Semaphore(self.max_concurrent)
|
| 22 |
-
|
| 23 |
-
async def evaluate_with_semaphore(pair):
|
| 24 |
-
async with semaphore: # 获取Semaphore
|
| 25 |
-
return await self.evaluate_single(pair)
|
| 26 |
-
|
| 27 |
-
results = []
|
| 28 |
-
for result in tqdm_async(
|
| 29 |
-
asyncio.as_completed([evaluate_with_semaphore(pair) for pair in pairs]),
|
| 30 |
-
total=len(pairs),
|
| 31 |
-
):
|
| 32 |
-
results.append(await result)
|
| 33 |
-
return results
|
| 34 |
-
|
| 35 |
-
async def evaluate_single(self, pair: QAPair) -> float:
|
| 36 |
-
raise NotImplementedError()
|
| 37 |
-
|
| 38 |
-
def get_average_score(self, pairs: list[QAPair]) -> float:
|
| 39 |
-
"""
|
| 40 |
-
Get the average score of a batch of texts.
|
| 41 |
-
"""
|
| 42 |
-
results = self.evaluate(pairs)
|
| 43 |
-
self.results = results
|
| 44 |
-
return sum(self.results) / len(pairs)
|
| 45 |
-
|
| 46 |
-
def get_min_max_score(self, pairs: list[QAPair]) -> tuple[float, float]:
|
| 47 |
-
"""
|
| 48 |
-
Get the min and max score of a batch of texts.
|
| 49 |
-
"""
|
| 50 |
-
if self.results is None:
|
| 51 |
-
self.get_average_score(pairs)
|
| 52 |
-
return min(self.results), max(self.results)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/models/evaluator/kg/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Knowledge Graph Quality Evaluator
|
| 3 |
+
|
| 4 |
+
This module provides comprehensive quality evaluation for knowledge graphs,
|
| 5 |
+
1. accuracy assessment (entity/relation/triple validation),
|
| 6 |
+
2. consistency assessment (attribute conflict detection), and structural
|
| 7 |
+
3. robustness assessment (noise ratio, connectivity, degree distribution).
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from .accuracy_evaluator import AccuracyEvaluator
|
| 11 |
+
from .consistency_evaluator import ConsistencyEvaluator
|
| 12 |
+
from .structure_evaluator import StructureEvaluator
|
| 13 |
+
|
| 14 |
+
__all__ = [
|
| 15 |
+
"AccuracyEvaluator",
|
| 16 |
+
"ConsistencyEvaluator",
|
| 17 |
+
"StructureEvaluator",
|
| 18 |
+
]
|
graphgen/models/evaluator/kg/accuracy_evaluator.py
ADDED
|
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import json
|
| 3 |
+
import re
|
| 4 |
+
from typing import Any, Dict, List
|
| 5 |
+
|
| 6 |
+
from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper
|
| 7 |
+
from graphgen.bases.datatypes import Chunk
|
| 8 |
+
from graphgen.templates import ACCURACY_EVALUATION_PROMPT
|
| 9 |
+
from graphgen.utils import detect_main_language, logger
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class AccuracyEvaluator:
|
| 13 |
+
"""Evaluates accuracy of entity recognition and relation extraction using LLM-as-a-Judge.
|
| 14 |
+
|
| 15 |
+
For each chunk, uses LLM to evaluate the quality of extracted entities and relations
|
| 16 |
+
by comparing them with the original chunk content. Provides multi-dimensional quality
|
| 17 |
+
scores (accuracy, completeness, precision).
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
graph_storage: BaseGraphStorage,
|
| 23 |
+
chunk_storage: BaseKVStorage,
|
| 24 |
+
llm_client: BaseLLMWrapper,
|
| 25 |
+
):
|
| 26 |
+
self.graph_storage = graph_storage
|
| 27 |
+
self.chunk_storage = chunk_storage
|
| 28 |
+
self.llm_client = llm_client
|
| 29 |
+
|
| 30 |
+
def evaluate(self) -> Dict[str, Any]:
|
| 31 |
+
"""Evaluate entity and relation extraction quality using LLM-as-a-Judge.
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
Dictionary containing entity_accuracy and relation_accuracy metrics.
|
| 35 |
+
"""
|
| 36 |
+
# 1. Load all chunks from storage
|
| 37 |
+
chunks = self._load_chunks_from_storage()
|
| 38 |
+
|
| 39 |
+
if not chunks:
|
| 40 |
+
logger.warning("No chunks found in storage")
|
| 41 |
+
return {"error": "No chunks found in storage"}
|
| 42 |
+
|
| 43 |
+
logger.info(f"Found {len(chunks)} chunks to evaluate")
|
| 44 |
+
|
| 45 |
+
# 2. Evaluate each chunk
|
| 46 |
+
entity_evaluations, relation_evaluations = self._evaluate_all_chunks(chunks)
|
| 47 |
+
|
| 48 |
+
# 3. Aggregate results
|
| 49 |
+
return self._aggregate_evaluation_results(
|
| 50 |
+
entity_evaluations, relation_evaluations
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
def _load_chunks_from_storage(self) -> List[Chunk]:
|
| 54 |
+
"""Load all chunks from chunk storage."""
|
| 55 |
+
chunks = []
|
| 56 |
+
all_chunk_data = self.chunk_storage.get_all()
|
| 57 |
+
|
| 58 |
+
for chunk_id, chunk_data in all_chunk_data.items():
|
| 59 |
+
try:
|
| 60 |
+
chunk = Chunk.from_dict(chunk_id, chunk_data)
|
| 61 |
+
chunks.append(chunk)
|
| 62 |
+
except Exception as e:
|
| 63 |
+
logger.warning(f"Failed to load chunk {chunk_id}: {e}")
|
| 64 |
+
continue
|
| 65 |
+
|
| 66 |
+
return chunks
|
| 67 |
+
|
| 68 |
+
def _get_extracted_entities_for_chunk(self, chunk_id: str) -> List[Dict]:
|
| 69 |
+
"""Get all entities extracted from the specified chunk."""
|
| 70 |
+
entities = []
|
| 71 |
+
all_nodes = self.graph_storage.get_all_nodes() or []
|
| 72 |
+
|
| 73 |
+
for node_id, node_data in all_nodes:
|
| 74 |
+
if not isinstance(node_data, dict):
|
| 75 |
+
continue
|
| 76 |
+
source_ids = node_data.get("source_id", "").split("<SEP>")
|
| 77 |
+
# Check if this chunk_id is in the source_ids
|
| 78 |
+
if chunk_id in [sid.strip() for sid in source_ids if sid.strip()]:
|
| 79 |
+
entities.append(
|
| 80 |
+
{
|
| 81 |
+
"entity_name": node_data.get("entity_name", node_id),
|
| 82 |
+
"entity_type": node_data.get("entity_type", ""),
|
| 83 |
+
"description": node_data.get("description", ""),
|
| 84 |
+
}
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
return entities
|
| 88 |
+
|
| 89 |
+
def _get_extracted_relations_for_chunk(self, chunk_id: str) -> List[Dict]:
|
| 90 |
+
"""Get all relations extracted from the specified chunk."""
|
| 91 |
+
relations = []
|
| 92 |
+
all_edges = self.graph_storage.get_all_edges() or []
|
| 93 |
+
|
| 94 |
+
for src_id, dst_id, edge_data in all_edges:
|
| 95 |
+
if not isinstance(edge_data, dict):
|
| 96 |
+
continue
|
| 97 |
+
source_ids = edge_data.get("source_id", "").split("<SEP>")
|
| 98 |
+
# Check if this chunk_id is in the source_ids
|
| 99 |
+
if chunk_id in [sid.strip() for sid in source_ids if sid.strip()]:
|
| 100 |
+
src_node = self.graph_storage.get_node(src_id) or {}
|
| 101 |
+
dst_node = self.graph_storage.get_node(dst_id) or {}
|
| 102 |
+
relations.append(
|
| 103 |
+
{
|
| 104 |
+
"source_entity": src_node.get("entity_name", src_id),
|
| 105 |
+
"target_entity": dst_node.get("entity_name", dst_id),
|
| 106 |
+
"relationship_summary": edge_data.get("description", ""),
|
| 107 |
+
}
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
return relations
|
| 111 |
+
|
| 112 |
+
def _evaluate_all_chunks(
|
| 113 |
+
self, chunks: List[Chunk]
|
| 114 |
+
) -> tuple[List[Dict], List[Dict]]:
|
| 115 |
+
"""Evaluate all chunks sequentially."""
|
| 116 |
+
entity_evaluations = []
|
| 117 |
+
relation_evaluations = []
|
| 118 |
+
|
| 119 |
+
for chunk in chunks:
|
| 120 |
+
try:
|
| 121 |
+
entities = self._get_extracted_entities_for_chunk(chunk.id)
|
| 122 |
+
relations = self._get_extracted_relations_for_chunk(chunk.id)
|
| 123 |
+
|
| 124 |
+
entity_eval = self._evaluate_entity_extraction(chunk, entities)
|
| 125 |
+
relation_eval = self._evaluate_relation_extraction(chunk, relations)
|
| 126 |
+
|
| 127 |
+
entity_evaluations.append(entity_eval)
|
| 128 |
+
relation_evaluations.append(relation_eval)
|
| 129 |
+
except Exception as e:
|
| 130 |
+
logger.error(f"Failed to evaluate chunk {chunk.id}: {e}")
|
| 131 |
+
continue
|
| 132 |
+
|
| 133 |
+
return entity_evaluations, relation_evaluations
|
| 134 |
+
|
| 135 |
+
def _evaluate_entity_extraction(
|
| 136 |
+
self, chunk: Chunk, extracted_entities: List[Dict]
|
| 137 |
+
) -> Dict[str, Any]:
|
| 138 |
+
"""Use LLM to evaluate entity extraction quality."""
|
| 139 |
+
try:
|
| 140 |
+
lang = detect_main_language(chunk.content)
|
| 141 |
+
|
| 142 |
+
prompt = ACCURACY_EVALUATION_PROMPT[lang]["ENTITY"].format(
|
| 143 |
+
chunk_content=chunk.content,
|
| 144 |
+
extracted_entities=json.dumps(
|
| 145 |
+
extracted_entities, ensure_ascii=False, indent=2
|
| 146 |
+
),
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
response = asyncio.run(self.llm_client.generate_answer(prompt))
|
| 150 |
+
|
| 151 |
+
# Try to parse JSON response
|
| 152 |
+
try:
|
| 153 |
+
evaluation_result = json.loads(response)
|
| 154 |
+
except json.JSONDecodeError:
|
| 155 |
+
# Try to extract JSON from markdown code blocks or other formats
|
| 156 |
+
json_match = re.search(r"\{.*\}", response, re.DOTALL)
|
| 157 |
+
if json_match:
|
| 158 |
+
evaluation_result = json.loads(json_match.group(0))
|
| 159 |
+
else:
|
| 160 |
+
logger.warning(
|
| 161 |
+
f"Failed to parse LLM response for chunk {chunk.id}: {response[:200]}"
|
| 162 |
+
)
|
| 163 |
+
# Return default evaluation
|
| 164 |
+
evaluation_result = {
|
| 165 |
+
"accuracy": 0.0,
|
| 166 |
+
"completeness": 0.0,
|
| 167 |
+
"precision": 0.0,
|
| 168 |
+
"overall_score": 0.0,
|
| 169 |
+
"accuracy_reasoning": "Failed to parse LLM response",
|
| 170 |
+
"completeness_reasoning": "",
|
| 171 |
+
"precision_reasoning": "",
|
| 172 |
+
"issues": ["LLM response parsing failed"],
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
# Validate and calculate overall_score if not provided
|
| 176 |
+
if "overall_score" not in evaluation_result:
|
| 177 |
+
accuracy = float(evaluation_result.get("accuracy", 0.0))
|
| 178 |
+
completeness = float(evaluation_result.get("completeness", 0.0))
|
| 179 |
+
precision = float(evaluation_result.get("precision", 0.0))
|
| 180 |
+
evaluation_result["overall_score"] = (
|
| 181 |
+
0.4 * accuracy + 0.4 * completeness + 0.2 * precision
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
return {
|
| 185 |
+
"chunk_id": chunk.id,
|
| 186 |
+
"chunk_content": chunk.content[:200]
|
| 187 |
+
if chunk.content
|
| 188 |
+
else "", # First 200 chars for debugging
|
| 189 |
+
"extracted_entities_count": len(extracted_entities),
|
| 190 |
+
**evaluation_result,
|
| 191 |
+
}
|
| 192 |
+
except Exception as e:
|
| 193 |
+
logger.error(
|
| 194 |
+
f"Error evaluating entity extraction for chunk {chunk.id}: {e}"
|
| 195 |
+
)
|
| 196 |
+
return {
|
| 197 |
+
"chunk_id": chunk.id,
|
| 198 |
+
"chunk_content": chunk.content[:200] if chunk.content else "",
|
| 199 |
+
"extracted_entities_count": len(extracted_entities),
|
| 200 |
+
"accuracy": 0.0,
|
| 201 |
+
"completeness": 0.0,
|
| 202 |
+
"precision": 0.0,
|
| 203 |
+
"overall_score": 0.0,
|
| 204 |
+
"accuracy_reasoning": f"Evaluation failed: {str(e)}",
|
| 205 |
+
"completeness_reasoning": "",
|
| 206 |
+
"precision_reasoning": "",
|
| 207 |
+
"issues": [f"Evaluation error: {str(e)}"],
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
def _evaluate_relation_extraction(
|
| 211 |
+
self, chunk: Chunk, extracted_relations: List[Dict]
|
| 212 |
+
) -> Dict[str, Any]:
|
| 213 |
+
"""Use LLM to evaluate relation extraction quality."""
|
| 214 |
+
try:
|
| 215 |
+
lang = detect_main_language(chunk.content)
|
| 216 |
+
prompt = ACCURACY_EVALUATION_PROMPT[lang]["RELATION"].format(
|
| 217 |
+
chunk_content=chunk.content,
|
| 218 |
+
extracted_relations=json.dumps(
|
| 219 |
+
extracted_relations, ensure_ascii=False, indent=2
|
| 220 |
+
),
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
response = asyncio.run(self.llm_client.generate_answer(prompt))
|
| 224 |
+
|
| 225 |
+
# Try to parse JSON response
|
| 226 |
+
try:
|
| 227 |
+
evaluation_result = json.loads(response)
|
| 228 |
+
except json.JSONDecodeError:
|
| 229 |
+
# Try to extract JSON from markdown code blocks or other formats
|
| 230 |
+
json_match = re.search(r"\{.*\}", response, re.DOTALL)
|
| 231 |
+
if json_match:
|
| 232 |
+
evaluation_result = json.loads(json_match.group(0))
|
| 233 |
+
else:
|
| 234 |
+
logger.warning(
|
| 235 |
+
f"Failed to parse LLM response for chunk {chunk.id}: {response[:200]}"
|
| 236 |
+
)
|
| 237 |
+
# Return default evaluation
|
| 238 |
+
evaluation_result = {
|
| 239 |
+
"accuracy": 0.0,
|
| 240 |
+
"completeness": 0.0,
|
| 241 |
+
"precision": 0.0,
|
| 242 |
+
"overall_score": 0.0,
|
| 243 |
+
"accuracy_reasoning": "Failed to parse LLM response",
|
| 244 |
+
"completeness_reasoning": "",
|
| 245 |
+
"precision_reasoning": "",
|
| 246 |
+
"issues": ["LLM response parsing failed"],
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
# Validate and calculate overall_score if not provided
|
| 250 |
+
if "overall_score" not in evaluation_result:
|
| 251 |
+
accuracy = float(evaluation_result.get("accuracy", 0.0))
|
| 252 |
+
completeness = float(evaluation_result.get("completeness", 0.0))
|
| 253 |
+
precision = float(evaluation_result.get("precision", 0.0))
|
| 254 |
+
evaluation_result["overall_score"] = (
|
| 255 |
+
0.4 * accuracy + 0.4 * completeness + 0.2 * precision
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
return {
|
| 259 |
+
"chunk_id": chunk.id,
|
| 260 |
+
"chunk_content": chunk.content[:200] if chunk.content else "",
|
| 261 |
+
"extracted_relations_count": len(extracted_relations),
|
| 262 |
+
**evaluation_result,
|
| 263 |
+
}
|
| 264 |
+
except Exception as e:
|
| 265 |
+
logger.error(
|
| 266 |
+
f"Error evaluating relation extraction for chunk {chunk.id}: {e}"
|
| 267 |
+
)
|
| 268 |
+
return {
|
| 269 |
+
"chunk_id": chunk.id,
|
| 270 |
+
"chunk_content": chunk.content[:200] if chunk.content else "",
|
| 271 |
+
"extracted_relations_count": len(extracted_relations),
|
| 272 |
+
"accuracy": 0.0,
|
| 273 |
+
"completeness": 0.0,
|
| 274 |
+
"precision": 0.0,
|
| 275 |
+
"overall_score": 0.0,
|
| 276 |
+
"accuracy_reasoning": f"Evaluation failed: {str(e)}",
|
| 277 |
+
"completeness_reasoning": "",
|
| 278 |
+
"precision_reasoning": "",
|
| 279 |
+
"issues": [f"Evaluation error: {str(e)}"],
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
@staticmethod
|
| 283 |
+
def _aggregate_evaluation_results(
|
| 284 |
+
entity_evaluations: List[Dict], relation_evaluations: List[Dict]
|
| 285 |
+
) -> Dict[str, Any]:
|
| 286 |
+
"""Aggregate evaluation results from all chunks."""
|
| 287 |
+
|
| 288 |
+
def calculate_stats(scores: List[float]) -> Dict[str, float]:
|
| 289 |
+
if not scores:
|
| 290 |
+
return {"mean": 0.0, "median": 0.0, "min": 0.0, "max": 0.0, "std": 0.0}
|
| 291 |
+
sorted_scores = sorted(scores)
|
| 292 |
+
n = len(scores)
|
| 293 |
+
mean = sum(scores) / n
|
| 294 |
+
median = (
|
| 295 |
+
sorted_scores[n // 2]
|
| 296 |
+
if n % 2 == 1
|
| 297 |
+
else (sorted_scores[n // 2 - 1] + sorted_scores[n // 2]) / 2
|
| 298 |
+
)
|
| 299 |
+
variance = sum((x - mean) ** 2 for x in scores) / n
|
| 300 |
+
std = variance**0.5
|
| 301 |
+
|
| 302 |
+
return {
|
| 303 |
+
"mean": mean,
|
| 304 |
+
"median": median,
|
| 305 |
+
"min": min(scores),
|
| 306 |
+
"max": max(scores),
|
| 307 |
+
"std": std,
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
# Extract scores
|
| 311 |
+
entity_overall_scores = [
|
| 312 |
+
e.get("overall_score", 0.0) for e in entity_evaluations
|
| 313 |
+
]
|
| 314 |
+
entity_accuracy_scores = [e.get("accuracy", 0.0) for e in entity_evaluations]
|
| 315 |
+
entity_completeness_scores = [
|
| 316 |
+
e.get("completeness", 0.0) for e in entity_evaluations
|
| 317 |
+
]
|
| 318 |
+
entity_precision_scores = [e.get("precision", 0.0) for e in entity_evaluations]
|
| 319 |
+
|
| 320 |
+
relation_overall_scores = [
|
| 321 |
+
r.get("overall_score", 0.0) for r in relation_evaluations
|
| 322 |
+
]
|
| 323 |
+
relation_accuracy_scores = [
|
| 324 |
+
r.get("accuracy", 0.0) for r in relation_evaluations
|
| 325 |
+
]
|
| 326 |
+
relation_completeness_scores = [
|
| 327 |
+
r.get("completeness", 0.0) for r in relation_evaluations
|
| 328 |
+
]
|
| 329 |
+
relation_precision_scores = [
|
| 330 |
+
r.get("precision", 0.0) for r in relation_evaluations
|
| 331 |
+
]
|
| 332 |
+
|
| 333 |
+
return {
|
| 334 |
+
"entity_accuracy": {
|
| 335 |
+
"overall_score": calculate_stats(entity_overall_scores),
|
| 336 |
+
"accuracy": calculate_stats(entity_accuracy_scores),
|
| 337 |
+
"completeness": calculate_stats(entity_completeness_scores),
|
| 338 |
+
"precision": calculate_stats(entity_precision_scores),
|
| 339 |
+
"total_chunks": len(entity_evaluations),
|
| 340 |
+
"detailed_results": entity_evaluations,
|
| 341 |
+
},
|
| 342 |
+
"relation_accuracy": {
|
| 343 |
+
"overall_score": calculate_stats(relation_overall_scores),
|
| 344 |
+
"accuracy": calculate_stats(relation_accuracy_scores),
|
| 345 |
+
"completeness": calculate_stats(relation_completeness_scores),
|
| 346 |
+
"precision": calculate_stats(relation_precision_scores),
|
| 347 |
+
"total_chunks": len(relation_evaluations),
|
| 348 |
+
"detailed_results": relation_evaluations,
|
| 349 |
+
},
|
| 350 |
+
}
|
graphgen/models/evaluator/kg/consistency_evaluator.py
ADDED
|
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import json
|
| 3 |
+
import re
|
| 4 |
+
from typing import Any, Dict, List
|
| 5 |
+
|
| 6 |
+
from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper
|
| 7 |
+
from graphgen.bases.datatypes import Chunk
|
| 8 |
+
from graphgen.templates.evaluation.kg.consistency_evaluation import (
|
| 9 |
+
ENTITY_DESCRIPTION_CONFLICT_PROMPT,
|
| 10 |
+
ENTITY_EXTRACTION_PROMPT,
|
| 11 |
+
ENTITY_TYPE_CONFLICT_PROMPT,
|
| 12 |
+
RELATION_CONFLICT_PROMPT,
|
| 13 |
+
)
|
| 14 |
+
from graphgen.utils import logger
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ConsistencyEvaluator:
|
| 18 |
+
"""Evaluates consistency by detecting semantic conflicts using LLM-as-a-Judge.
|
| 19 |
+
|
| 20 |
+
For entities with multiple source chunks, compares entity_type and description
|
| 21 |
+
extracted from different chunks to detect semantic conflicts.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
graph_storage: BaseGraphStorage,
|
| 27 |
+
chunk_storage: BaseKVStorage,
|
| 28 |
+
llm_client: BaseLLMWrapper,
|
| 29 |
+
):
|
| 30 |
+
self.graph_storage = graph_storage
|
| 31 |
+
self.chunk_storage = chunk_storage
|
| 32 |
+
self.llm_client = llm_client
|
| 33 |
+
|
| 34 |
+
def evaluate(self) -> Dict[str, Any]:
|
| 35 |
+
"""Evaluate consistency by detecting semantic conflicts."""
|
| 36 |
+
all_nodes = self.graph_storage.get_all_nodes() or []
|
| 37 |
+
if not all_nodes:
|
| 38 |
+
return {"error": "Empty graph"}
|
| 39 |
+
|
| 40 |
+
return self._evaluate_consistency(all_nodes)
|
| 41 |
+
|
| 42 |
+
def _evaluate_consistency(self, all_nodes: List) -> Dict[str, Any]:
|
| 43 |
+
"""Evaluate consistency by detecting semantic conflicts."""
|
| 44 |
+
# Filter entities with multiple source chunks
|
| 45 |
+
entities_with_multiple_sources = []
|
| 46 |
+
for node_id, node_data in all_nodes:
|
| 47 |
+
if not isinstance(node_data, dict):
|
| 48 |
+
continue
|
| 49 |
+
source_ids = node_data.get("source_id", "").split("<SEP>")
|
| 50 |
+
source_ids = [sid.strip() for sid in source_ids if sid.strip()]
|
| 51 |
+
if len(source_ids) > 1: # Only check entities from multiple chunks
|
| 52 |
+
entities_with_multiple_sources.append((node_id, node_data, source_ids))
|
| 53 |
+
|
| 54 |
+
if not entities_with_multiple_sources:
|
| 55 |
+
logger.info(
|
| 56 |
+
"No entities with multiple sources found, skipping consistency check"
|
| 57 |
+
)
|
| 58 |
+
return {
|
| 59 |
+
"conflict_rate": 0.0,
|
| 60 |
+
"conflict_entities_count": 0,
|
| 61 |
+
"total_entities": len(all_nodes),
|
| 62 |
+
"conflicts": [],
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
logger.info(
|
| 66 |
+
f"Checking consistency for {len(entities_with_multiple_sources)} entities with multiple sources"
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
# Evaluate entities sequentially
|
| 70 |
+
conflicts = []
|
| 71 |
+
conflict_entities = set()
|
| 72 |
+
|
| 73 |
+
for entity_info in entities_with_multiple_sources:
|
| 74 |
+
try:
|
| 75 |
+
entity_id, entity_conflicts = self._evaluate_entity_consistency(entity_info)
|
| 76 |
+
if entity_conflicts:
|
| 77 |
+
conflicts.extend(entity_conflicts)
|
| 78 |
+
conflict_entities.add(entity_id)
|
| 79 |
+
except Exception as e:
|
| 80 |
+
logger.error(
|
| 81 |
+
f"Failed to evaluate entity {entity_info[0]}: {e}"
|
| 82 |
+
)
|
| 83 |
+
continue
|
| 84 |
+
|
| 85 |
+
total_entities = len(all_nodes)
|
| 86 |
+
conflict_rate = (
|
| 87 |
+
len(conflict_entities) / total_entities if total_entities > 0 else 0
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
return {
|
| 91 |
+
"conflict_rate": conflict_rate,
|
| 92 |
+
"conflict_entities_count": len(conflict_entities),
|
| 93 |
+
"total_entities": total_entities,
|
| 94 |
+
"entities_checked": len(entities_with_multiple_sources),
|
| 95 |
+
"conflicts": conflicts[:100], # Limit to first 100 conflicts
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
def _clean_entity_id(self, entity_id: str) -> str:
|
| 99 |
+
"""Clean entity ID by removing surrounding quotes."""
|
| 100 |
+
clean_id = entity_id.strip()
|
| 101 |
+
if (clean_id.startswith('"') and clean_id.endswith('"')) or (
|
| 102 |
+
clean_id.startswith("'") and clean_id.endswith("'")
|
| 103 |
+
):
|
| 104 |
+
clean_id = clean_id[1:-1].strip()
|
| 105 |
+
return clean_id
|
| 106 |
+
|
| 107 |
+
def _evaluate_entity_consistency(
|
| 108 |
+
self, entity_info: tuple
|
| 109 |
+
) -> tuple[str, List[Dict]]:
|
| 110 |
+
"""Evaluate consistency for a single entity."""
|
| 111 |
+
entity_id, _node_data, source_ids = entity_info
|
| 112 |
+
# Clean entity_id for display
|
| 113 |
+
clean_entity_id = self._clean_entity_id(entity_id)
|
| 114 |
+
conflicts = []
|
| 115 |
+
|
| 116 |
+
# Get chunks for this entity
|
| 117 |
+
chunks = self._get_entity_chunks(source_ids)
|
| 118 |
+
if len(chunks) < 2:
|
| 119 |
+
return entity_id, []
|
| 120 |
+
|
| 121 |
+
# Extract entity attributes from each chunk
|
| 122 |
+
entity_extractions = {}
|
| 123 |
+
for chunk in chunks:
|
| 124 |
+
extraction = self._extract_entity_from_chunk(entity_id, chunk)
|
| 125 |
+
if extraction:
|
| 126 |
+
entity_extractions[chunk.id] = extraction
|
| 127 |
+
|
| 128 |
+
if len(entity_extractions) < 2:
|
| 129 |
+
return entity_id, []
|
| 130 |
+
|
| 131 |
+
# Check entity type consistency
|
| 132 |
+
type_extractions = {
|
| 133 |
+
chunk_id: ext.get("entity_type", "")
|
| 134 |
+
for chunk_id, ext in entity_extractions.items()
|
| 135 |
+
}
|
| 136 |
+
type_conflict = self._check_entity_type_consistency(
|
| 137 |
+
entity_id, type_extractions
|
| 138 |
+
)
|
| 139 |
+
if type_conflict and type_conflict.get("has_conflict", False):
|
| 140 |
+
conflicts.append(
|
| 141 |
+
{
|
| 142 |
+
"entity_id": clean_entity_id,
|
| 143 |
+
"conflict_type": "entity_type",
|
| 144 |
+
"conflict_severity": type_conflict.get("conflict_severity", 0.0),
|
| 145 |
+
"conflict_reasoning": type_conflict.get("conflict_reasoning", ""),
|
| 146 |
+
"conflicting_values": type_conflict.get("conflicting_types", []),
|
| 147 |
+
"recommended_value": type_conflict.get("recommended_type", ""),
|
| 148 |
+
}
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# Check entity description consistency
|
| 152 |
+
descriptions = {
|
| 153 |
+
chunk_id: ext.get("description", "")
|
| 154 |
+
for chunk_id, ext in entity_extractions.items()
|
| 155 |
+
}
|
| 156 |
+
desc_conflict = self._check_entity_description_consistency(
|
| 157 |
+
entity_id, descriptions
|
| 158 |
+
)
|
| 159 |
+
if desc_conflict and desc_conflict.get("has_conflict", False):
|
| 160 |
+
conflicts.append(
|
| 161 |
+
{
|
| 162 |
+
"entity_id": clean_entity_id,
|
| 163 |
+
"conflict_type": "description",
|
| 164 |
+
"conflict_severity": desc_conflict.get("conflict_severity", 0.0),
|
| 165 |
+
"conflict_reasoning": desc_conflict.get("conflict_reasoning", ""),
|
| 166 |
+
"conflicting_values": desc_conflict.get(
|
| 167 |
+
"conflicting_descriptions", []
|
| 168 |
+
),
|
| 169 |
+
"conflict_details": desc_conflict.get("conflict_details", ""),
|
| 170 |
+
}
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
return entity_id, conflicts
|
| 174 |
+
|
| 175 |
+
def _get_entity_chunks(self, source_ids: List[str]) -> List[Chunk]:
|
| 176 |
+
"""Get all chunks related to an entity."""
|
| 177 |
+
chunks = []
|
| 178 |
+
for chunk_id in source_ids:
|
| 179 |
+
chunk_data = self.chunk_storage.get_by_id(chunk_id)
|
| 180 |
+
if chunk_data:
|
| 181 |
+
try:
|
| 182 |
+
chunk = Chunk.from_dict(chunk_id, chunk_data)
|
| 183 |
+
chunks.append(chunk)
|
| 184 |
+
except Exception as e:
|
| 185 |
+
logger.warning(f"Failed to load chunk {chunk_id}: {e}")
|
| 186 |
+
continue
|
| 187 |
+
return chunks
|
| 188 |
+
|
| 189 |
+
def _extract_entity_from_chunk(
|
| 190 |
+
self, entity_id: str, chunk: Chunk
|
| 191 |
+
) -> Dict[str, str]:
|
| 192 |
+
"""Extract entity attributes from a chunk using LLM."""
|
| 193 |
+
try:
|
| 194 |
+
# Clean entity_id: remove surrounding quotes if present
|
| 195 |
+
clean_entity_id = self._clean_entity_id(entity_id)
|
| 196 |
+
|
| 197 |
+
prompt = ENTITY_EXTRACTION_PROMPT.format(
|
| 198 |
+
entity_name=clean_entity_id,
|
| 199 |
+
chunk_content=chunk.content[:2000]
|
| 200 |
+
if chunk.content
|
| 201 |
+
else "", # Limit content length
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
response = asyncio.run(self.llm_client.generate_answer(prompt))
|
| 205 |
+
|
| 206 |
+
# Try to parse JSON response
|
| 207 |
+
try:
|
| 208 |
+
extraction = json.loads(response)
|
| 209 |
+
except json.JSONDecodeError:
|
| 210 |
+
# Try to extract JSON from markdown code blocks
|
| 211 |
+
json_match = re.search(r"\{.*\}", response, re.DOTALL)
|
| 212 |
+
if json_match:
|
| 213 |
+
extraction = json.loads(json_match.group(0))
|
| 214 |
+
else:
|
| 215 |
+
logger.warning(
|
| 216 |
+
f"Failed to parse extraction response for {entity_id} in chunk {chunk.id}"
|
| 217 |
+
)
|
| 218 |
+
return {}
|
| 219 |
+
|
| 220 |
+
# Normalize entity_type to lowercase and validate
|
| 221 |
+
entity_type = extraction.get("entity_type", "").lower().strip()
|
| 222 |
+
# Valid preset types
|
| 223 |
+
valid_types = {
|
| 224 |
+
"concept",
|
| 225 |
+
"date",
|
| 226 |
+
"location",
|
| 227 |
+
"keyword",
|
| 228 |
+
"organization",
|
| 229 |
+
"person",
|
| 230 |
+
"event",
|
| 231 |
+
"work",
|
| 232 |
+
"nature",
|
| 233 |
+
"artificial",
|
| 234 |
+
"science",
|
| 235 |
+
"technology",
|
| 236 |
+
"mission",
|
| 237 |
+
"gene",
|
| 238 |
+
}
|
| 239 |
+
# If entity_type is not in valid types, default to "concept"
|
| 240 |
+
if entity_type not in valid_types:
|
| 241 |
+
if entity_type: # If LLM provided a type but it's invalid
|
| 242 |
+
logger.warning(
|
| 243 |
+
f"Invalid entity_type '{entity_type}' for entity {clean_entity_id} in chunk {chunk.id}, "
|
| 244 |
+
f"defaulting to 'concept'"
|
| 245 |
+
)
|
| 246 |
+
entity_type = "concept"
|
| 247 |
+
|
| 248 |
+
return {
|
| 249 |
+
"entity_type": entity_type,
|
| 250 |
+
"description": extraction.get("description", ""),
|
| 251 |
+
}
|
| 252 |
+
except Exception as e:
|
| 253 |
+
logger.error(
|
| 254 |
+
f"Error extracting entity {entity_id} from chunk {chunk.id}: {e}"
|
| 255 |
+
)
|
| 256 |
+
return {}
|
| 257 |
+
|
| 258 |
+
def _check_entity_type_consistency(
|
| 259 |
+
self, entity_id: str, type_extractions: Dict[str, str]
|
| 260 |
+
) -> Dict[str, Any]:
|
| 261 |
+
"""Check entity type consistency using LLM."""
|
| 262 |
+
if len(set(type_extractions.values())) <= 1:
|
| 263 |
+
# All types are the same, no conflict
|
| 264 |
+
return {"has_conflict": False}
|
| 265 |
+
|
| 266 |
+
try:
|
| 267 |
+
type_list = [
|
| 268 |
+
f"Chunk {chunk_id}: {entity_type}"
|
| 269 |
+
for chunk_id, entity_type in type_extractions.items()
|
| 270 |
+
if entity_type
|
| 271 |
+
]
|
| 272 |
+
|
| 273 |
+
prompt = ENTITY_TYPE_CONFLICT_PROMPT.format(
|
| 274 |
+
entity_name=entity_id, type_extractions="\n".join(type_list)
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
response = asyncio.run(self.llm_client.generate_answer(prompt))
|
| 278 |
+
|
| 279 |
+
# Parse JSON response
|
| 280 |
+
try:
|
| 281 |
+
result = json.loads(response)
|
| 282 |
+
except json.JSONDecodeError:
|
| 283 |
+
json_match = re.search(r"\{.*\}", response, re.DOTALL)
|
| 284 |
+
if json_match:
|
| 285 |
+
result = json.loads(json_match.group(0))
|
| 286 |
+
else:
|
| 287 |
+
logger.warning(
|
| 288 |
+
f"Failed to parse conflict detection response for {entity_id}"
|
| 289 |
+
)
|
| 290 |
+
return {"has_conflict": False}
|
| 291 |
+
|
| 292 |
+
return result
|
| 293 |
+
except Exception as e:
|
| 294 |
+
logger.error(f"Error checking type consistency for {entity_id}: {e}")
|
| 295 |
+
return {"has_conflict": False}
|
| 296 |
+
|
| 297 |
+
def _check_entity_description_consistency(
|
| 298 |
+
self, entity_id: str, descriptions: Dict[str, str]
|
| 299 |
+
) -> Dict[str, Any]:
|
| 300 |
+
"""Check entity description consistency using LLM."""
|
| 301 |
+
# Filter out empty descriptions
|
| 302 |
+
valid_descriptions = {k: v for k, v in descriptions.items() if v}
|
| 303 |
+
if len(valid_descriptions) < 2:
|
| 304 |
+
return {"has_conflict": False}
|
| 305 |
+
|
| 306 |
+
if len(set(valid_descriptions.values())) <= 1:
|
| 307 |
+
# All descriptions are the same, no conflict
|
| 308 |
+
return {"has_conflict": False}
|
| 309 |
+
|
| 310 |
+
try:
|
| 311 |
+
desc_list = [
|
| 312 |
+
f"Chunk {chunk_id}: {description}"
|
| 313 |
+
for chunk_id, description in valid_descriptions.items()
|
| 314 |
+
]
|
| 315 |
+
|
| 316 |
+
prompt = ENTITY_DESCRIPTION_CONFLICT_PROMPT.format(
|
| 317 |
+
entity_name=entity_id, descriptions="\n".join(desc_list)
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
response = asyncio.run(self.llm_client.generate_answer(prompt))
|
| 321 |
+
|
| 322 |
+
# Parse JSON response
|
| 323 |
+
try:
|
| 324 |
+
result = json.loads(response)
|
| 325 |
+
except json.JSONDecodeError:
|
| 326 |
+
json_match = re.search(r"\{.*\}", response, re.DOTALL)
|
| 327 |
+
if json_match:
|
| 328 |
+
result = json.loads(json_match.group(0))
|
| 329 |
+
else:
|
| 330 |
+
logger.warning(
|
| 331 |
+
f"Failed to parse conflict detection response for {entity_id}"
|
| 332 |
+
)
|
| 333 |
+
return {"has_conflict": False}
|
| 334 |
+
|
| 335 |
+
return result
|
| 336 |
+
except Exception as e:
|
| 337 |
+
logger.error(f"Error checking description consistency for {entity_id}: {e}")
|
| 338 |
+
return {"has_conflict": False}
|
| 339 |
+
|
| 340 |
+
def _check_relation_consistency(
|
| 341 |
+
self, src_id: str, dst_id: str, relation_extractions: Dict[str, str]
|
| 342 |
+
) -> Dict[str, Any]:
|
| 343 |
+
"""Check relation consistency using LLM."""
|
| 344 |
+
if len(set(relation_extractions.values())) <= 1:
|
| 345 |
+
return {"has_conflict": False}
|
| 346 |
+
|
| 347 |
+
try:
|
| 348 |
+
rel_list = [
|
| 349 |
+
f"Chunk {chunk_id}: {relation}"
|
| 350 |
+
for chunk_id, relation in relation_extractions.items()
|
| 351 |
+
if relation
|
| 352 |
+
]
|
| 353 |
+
|
| 354 |
+
prompt = RELATION_CONFLICT_PROMPT.format(
|
| 355 |
+
source_entity=src_id,
|
| 356 |
+
target_entity=dst_id,
|
| 357 |
+
relation_descriptions="\n".join(rel_list),
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
response = asyncio.run(self.llm_client.generate_answer(prompt))
|
| 361 |
+
|
| 362 |
+
# Parse JSON response
|
| 363 |
+
try:
|
| 364 |
+
result = json.loads(response)
|
| 365 |
+
except json.JSONDecodeError:
|
| 366 |
+
json_match = re.search(r"\{.*\}", response, re.DOTALL)
|
| 367 |
+
if json_match:
|
| 368 |
+
result = json.loads(json_match.group(0))
|
| 369 |
+
else:
|
| 370 |
+
logger.warning(
|
| 371 |
+
f"Failed to parse relation conflict response for {src_id}->{dst_id}"
|
| 372 |
+
)
|
| 373 |
+
return {"has_conflict": False}
|
| 374 |
+
|
| 375 |
+
return result
|
| 376 |
+
except Exception as e:
|
| 377 |
+
logger.error(
|
| 378 |
+
f"Error checking relation consistency for {src_id}->{dst_id}: {e}"
|
| 379 |
+
)
|
| 380 |
+
return {"has_conflict": False}
|
graphgen/models/evaluator/kg/structure_evaluator.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, Optional
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
from scipy import stats
|
| 5 |
+
|
| 6 |
+
from graphgen.bases import BaseGraphStorage
|
| 7 |
+
from graphgen.utils import logger
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class StructureEvaluator:
|
| 11 |
+
"""Evaluates structural robustness of the graph."""
|
| 12 |
+
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
graph_storage: BaseGraphStorage,
|
| 16 |
+
noise_ratio_threshold: float = 0.15,
|
| 17 |
+
largest_cc_ratio_threshold: float = 0.90,
|
| 18 |
+
avg_degree_min: float = 2.0,
|
| 19 |
+
avg_degree_max: float = 5.0,
|
| 20 |
+
powerlaw_r2_threshold: float = 0.75,
|
| 21 |
+
):
|
| 22 |
+
self.graph_storage = graph_storage
|
| 23 |
+
self.noise_ratio_threshold = noise_ratio_threshold
|
| 24 |
+
self.largest_cc_ratio_threshold = largest_cc_ratio_threshold
|
| 25 |
+
self.avg_degree_min = avg_degree_min
|
| 26 |
+
self.avg_degree_max = avg_degree_max
|
| 27 |
+
self.powerlaw_r2_threshold = powerlaw_r2_threshold
|
| 28 |
+
|
| 29 |
+
def evaluate(self) -> Dict[str, Any]:
|
| 30 |
+
"""
|
| 31 |
+
Evaluate the structural robustness of the graph.
|
| 32 |
+
:return:
|
| 33 |
+
"""
|
| 34 |
+
storage = self.graph_storage
|
| 35 |
+
|
| 36 |
+
total_nodes = storage.get_node_count()
|
| 37 |
+
if total_nodes == 0:
|
| 38 |
+
return {"error": "Empty graph"}
|
| 39 |
+
|
| 40 |
+
total_edges = storage.get_edge_count()
|
| 41 |
+
degree_map = storage.get_all_node_degrees()
|
| 42 |
+
|
| 43 |
+
# Noise ratio: isolated nodes / total nodes
|
| 44 |
+
isolated_nodes = [nid for nid, deg in degree_map.items() if deg == 0]
|
| 45 |
+
noise_ratio = len(isolated_nodes) / total_nodes
|
| 46 |
+
|
| 47 |
+
# Largest connected component
|
| 48 |
+
components = storage.get_connected_components(undirected=True)
|
| 49 |
+
largest_cc_ratio = (
|
| 50 |
+
len(max(components, key=len)) / total_nodes if components else 0
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
avg_degree = sum(degree_map.values()) / total_nodes
|
| 54 |
+
powerlaw_r2 = self._calculate_powerlaw_r2(degree_map)
|
| 55 |
+
|
| 56 |
+
results = {
|
| 57 |
+
"total_nodes": total_nodes,
|
| 58 |
+
"total_edges": total_edges,
|
| 59 |
+
"noise_ratio": noise_ratio,
|
| 60 |
+
"largest_cc_ratio": largest_cc_ratio,
|
| 61 |
+
"avg_degree": avg_degree,
|
| 62 |
+
"powerlaw_r2": powerlaw_r2,
|
| 63 |
+
"is_robust": (
|
| 64 |
+
noise_ratio < self.noise_ratio_threshold
|
| 65 |
+
and largest_cc_ratio > self.largest_cc_ratio_threshold
|
| 66 |
+
and self.avg_degree_min <= avg_degree <= self.avg_degree_max
|
| 67 |
+
and (
|
| 68 |
+
powerlaw_r2 is not None and powerlaw_r2 > self.powerlaw_r2_threshold
|
| 69 |
+
)
|
| 70 |
+
),
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
return results
|
| 74 |
+
|
| 75 |
+
@staticmethod
|
| 76 |
+
def _calculate_powerlaw_r2(degree_map: Dict[str, int]) -> Optional[float]:
|
| 77 |
+
degrees = [deg for deg in degree_map.values() if deg > 0]
|
| 78 |
+
|
| 79 |
+
if len(degrees) < 10:
|
| 80 |
+
logger.warning("Insufficient nodes for power law fitting")
|
| 81 |
+
return None
|
| 82 |
+
|
| 83 |
+
try:
|
| 84 |
+
# Fit power law: log(y) = a * log(x) + b
|
| 85 |
+
log_degrees = np.log(degrees)
|
| 86 |
+
sorted_log_degrees = np.sort(log_degrees)
|
| 87 |
+
x = np.arange(1, len(sorted_log_degrees) + 1)
|
| 88 |
+
log_x = np.log(x)
|
| 89 |
+
|
| 90 |
+
# Linear regression on log-log scale
|
| 91 |
+
r_value, *_ = stats.linregress(log_x, sorted_log_degrees)
|
| 92 |
+
r2 = r_value**2
|
| 93 |
+
|
| 94 |
+
return float(r2)
|
| 95 |
+
except Exception as e:
|
| 96 |
+
logger.error(f"Power law R² calculation failed: {e}")
|
| 97 |
+
return None
|
graphgen/models/evaluator/length_evaluator.py
DELETED
|
@@ -1,19 +0,0 @@
|
|
| 1 |
-
from graphgen.bases.datatypes import QAPair
|
| 2 |
-
from graphgen.models.evaluator.base_evaluator import BaseEvaluator
|
| 3 |
-
from graphgen.models.tokenizer import Tokenizer
|
| 4 |
-
from graphgen.utils import create_event_loop
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
class LengthEvaluator(BaseEvaluator):
|
| 8 |
-
def __init__(self, tokenizer_name: str = "cl100k_base", max_concurrent: int = 100):
|
| 9 |
-
super().__init__(max_concurrent)
|
| 10 |
-
self.tokenizer_name = tokenizer_name
|
| 11 |
-
self.tokenizer = Tokenizer(model_name=self.tokenizer_name)
|
| 12 |
-
|
| 13 |
-
async def evaluate_single(self, pair: QAPair) -> float:
|
| 14 |
-
loop = create_event_loop()
|
| 15 |
-
return await loop.run_in_executor(None, self._calculate_length, pair.answer)
|
| 16 |
-
|
| 17 |
-
def _calculate_length(self, text: str) -> float:
|
| 18 |
-
tokens = self.tokenizer.encode(text)
|
| 19 |
-
return len(tokens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/models/evaluator/qa/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .length_evaluator import LengthEvaluator
|
| 2 |
+
from .mtld_evaluator import MTLDEvaluator
|
| 3 |
+
from .reward_evaluator import RewardEvaluator
|
| 4 |
+
from .uni_evaluator import UniEvaluator
|
graphgen/models/evaluator/qa/length_evaluator.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import os
|
| 3 |
+
from graphgen.bases import BaseEvaluator, QAPair
|
| 4 |
+
from graphgen.models.tokenizer import Tokenizer
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class LengthEvaluator(BaseEvaluator):
|
| 8 |
+
def __init__(self, tokenizer_name: str = None):
|
| 9 |
+
tokenizer_model = tokenizer_name or os.environ.get("TOKENIZER_MODEL", "cl100k_base")
|
| 10 |
+
self.tokenizer: Tokenizer = Tokenizer(tokenizer_model)
|
| 11 |
+
|
| 12 |
+
def evaluate(self, pair: QAPair) -> float:
|
| 13 |
+
"""
|
| 14 |
+
Evaluate the length of the qa pair.
|
| 15 |
+
"""
|
| 16 |
+
content = pair.question + pair.answer
|
| 17 |
+
tokens = self.tokenizer.encode(content)
|
| 18 |
+
return len(tokens)
|
graphgen/models/evaluator/{mtld_evaluator.py → qa/mtld_evaluator.py}
RENAMED
|
@@ -1,38 +1,33 @@
|
|
| 1 |
from typing import Set
|
| 2 |
|
| 3 |
-
from graphgen.bases
|
| 4 |
-
from graphgen.
|
| 5 |
-
from graphgen.utils import NLTKHelper, create_event_loop, detect_main_language
|
| 6 |
-
|
| 7 |
-
nltk_helper = NLTKHelper()
|
| 8 |
|
| 9 |
|
| 10 |
class MTLDEvaluator(BaseEvaluator):
|
| 11 |
"""
|
| 12 |
-
|
| 13 |
"""
|
| 14 |
|
| 15 |
-
def __init__(self,
|
| 16 |
-
|
| 17 |
-
self.stopwords_en: Set[str] = set(nltk_helper.get_stopwords("
|
| 18 |
-
self.stopwords_zh: Set[str] = set(nltk_helper.get_stopwords("
|
| 19 |
-
|
| 20 |
-
async def evaluate_single(self, pair: QAPair) -> float:
|
| 21 |
-
loop = create_event_loop()
|
| 22 |
-
return await loop.run_in_executor(None, self._calculate_mtld_score, pair.answer)
|
| 23 |
|
| 24 |
-
def
|
| 25 |
"""
|
| 26 |
-
|
| 27 |
|
| 28 |
min is 1.0
|
| 29 |
higher is better
|
| 30 |
"""
|
|
|
|
| 31 |
if not text or not text.strip():
|
| 32 |
return 0.0
|
| 33 |
|
| 34 |
lang = detect_main_language(text)
|
| 35 |
-
tokens = nltk_helper.word_tokenize(text, lang)
|
| 36 |
|
| 37 |
stopwords = self.stopwords_zh if lang == "zh" else self.stopwords_en
|
| 38 |
filtered_tokens = [word for word in tokens if word not in stopwords]
|
|
@@ -41,13 +36,13 @@ class MTLDEvaluator(BaseEvaluator):
|
|
| 41 |
if not filtered_tokens:
|
| 42 |
return 0
|
| 43 |
|
| 44 |
-
#
|
| 45 |
-
forward_factors = self._compute_factors(filtered_tokens, threshold)
|
| 46 |
|
| 47 |
-
#
|
| 48 |
-
backward_factors = self._compute_factors(filtered_tokens[::-1], threshold)
|
| 49 |
|
| 50 |
-
#
|
| 51 |
return (forward_factors + backward_factors) / 2
|
| 52 |
|
| 53 |
@staticmethod
|
|
@@ -66,7 +61,7 @@ class MTLDEvaluator(BaseEvaluator):
|
|
| 66 |
current_segment = []
|
| 67 |
unique_words = set()
|
| 68 |
|
| 69 |
-
#
|
| 70 |
if current_segment:
|
| 71 |
ttr = len(unique_words) / len(current_segment)
|
| 72 |
if ttr <= threshold:
|
|
|
|
| 1 |
from typing import Set
|
| 2 |
|
| 3 |
+
from graphgen.bases import BaseEvaluator, QAPair
|
| 4 |
+
from graphgen.utils import NLTKHelper, detect_main_language
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
class MTLDEvaluator(BaseEvaluator):
|
| 8 |
"""
|
| 9 |
+
Metrics for measuring the lexical diversity of text.
|
| 10 |
"""
|
| 11 |
|
| 12 |
+
def __init__(self, threshold: float = 0.72):
|
| 13 |
+
self.nltk_helper = NLTKHelper()
|
| 14 |
+
self.stopwords_en: Set[str] = set(self.nltk_helper.get_stopwords("en"))
|
| 15 |
+
self.stopwords_zh: Set[str] = set(self.nltk_helper.get_stopwords("zh"))
|
| 16 |
+
self.threshold = threshold
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
+
def evaluate(self, pair: QAPair) -> float:
|
| 19 |
"""
|
| 20 |
+
Calculate the MTLD (Mean Token Length Diversity) score for a given text.
|
| 21 |
|
| 22 |
min is 1.0
|
| 23 |
higher is better
|
| 24 |
"""
|
| 25 |
+
text = pair.answer
|
| 26 |
if not text or not text.strip():
|
| 27 |
return 0.0
|
| 28 |
|
| 29 |
lang = detect_main_language(text)
|
| 30 |
+
tokens = self.nltk_helper.word_tokenize(text, lang)
|
| 31 |
|
| 32 |
stopwords = self.stopwords_zh if lang == "zh" else self.stopwords_en
|
| 33 |
filtered_tokens = [word for word in tokens if word not in stopwords]
|
|
|
|
| 36 |
if not filtered_tokens:
|
| 37 |
return 0
|
| 38 |
|
| 39 |
+
# Compute forward factors
|
| 40 |
+
forward_factors = self._compute_factors(filtered_tokens, self.threshold)
|
| 41 |
|
| 42 |
+
# Compute backward factors
|
| 43 |
+
backward_factors = self._compute_factors(filtered_tokens[::-1], self.threshold)
|
| 44 |
|
| 45 |
+
# Compute average factors
|
| 46 |
return (forward_factors + backward_factors) / 2
|
| 47 |
|
| 48 |
@staticmethod
|
|
|
|
| 61 |
current_segment = []
|
| 62 |
unique_words = set()
|
| 63 |
|
| 64 |
+
# handle last segment
|
| 65 |
if current_segment:
|
| 66 |
ttr = len(unique_words) / len(current_segment)
|
| 67 |
if ttr <= threshold:
|
graphgen/models/evaluator/qa/reward_evaluator.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
from graphgen.bases import BaseEvaluator, QAPair
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class RewardEvaluator(BaseEvaluator):
|
| 6 |
+
"""
|
| 7 |
+
Reward Model Evaluator for single QAPair evaluation.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
reward_name: str = "OpenAssistant/reward-model-deberta-v3-large-v2",
|
| 13 |
+
max_length: int = 2560,
|
| 14 |
+
device: Optional[str] = None,
|
| 15 |
+
):
|
| 16 |
+
"""
|
| 17 |
+
Initialize the reward evaluator.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
reward_name: Model name or path on HuggingFace Hub
|
| 21 |
+
max_length: Maximum token length for the model
|
| 22 |
+
device: Device to run the model on. If None, auto-detect CUDA/CPU.
|
| 23 |
+
"""
|
| 24 |
+
self.reward_name = reward_name
|
| 25 |
+
self.max_length = max_length
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
| 29 |
+
self.torch = torch
|
| 30 |
+
|
| 31 |
+
# Set device (auto-detect if not specified)
|
| 32 |
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
self.tokenizer = AutoTokenizer.from_pretrained(reward_name)
|
| 36 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(reward_name)
|
| 37 |
+
self.model.to(self.device)
|
| 38 |
+
self.model.eval()
|
| 39 |
+
except Exception as e:
|
| 40 |
+
raise RuntimeError(f"Failed to load reward model '{reward_name}': {e}") from e
|
| 41 |
+
|
| 42 |
+
def evaluate(self, pair: QAPair) -> float:
|
| 43 |
+
"""
|
| 44 |
+
Evaluate a single question-answer pair using the reward model.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
pair: QAPair containing question and answer strings
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
Score as a float
|
| 51 |
+
"""
|
| 52 |
+
# Tokenize
|
| 53 |
+
inputs = self.tokenizer(
|
| 54 |
+
pair.question,
|
| 55 |
+
pair.answer,
|
| 56 |
+
return_tensors="pt",
|
| 57 |
+
max_length=self.max_length,
|
| 58 |
+
truncation=True,
|
| 59 |
+
)
|
| 60 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 61 |
+
|
| 62 |
+
# Get score
|
| 63 |
+
with self.torch.no_grad():
|
| 64 |
+
score = self.model(**inputs).logits[0].item()
|
| 65 |
+
|
| 66 |
+
return score
|
graphgen/models/evaluator/qa/uni_evaluator.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://github.com/maszhongming/UniEval/tree/main
|
| 2 |
+
from typing import Optional, List
|
| 3 |
+
from graphgen.bases import BaseEvaluator, QAPair
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class UniEvaluator(BaseEvaluator):
|
| 7 |
+
"""
|
| 8 |
+
UniEvaluator for single QAPair evaluation across quality dimensions.
|
| 9 |
+
|
| 10 |
+
Dimensions: naturalness, coherence, understandability
|
| 11 |
+
|
| 12 |
+
Usage:
|
| 13 |
+
evaluator = UniEvaluator()
|
| 14 |
+
pair = QAPair(question="...", answer="...")
|
| 15 |
+
scores = evaluator.evaluate(pair)
|
| 16 |
+
# {"naturalness": 0.85, "coherence": 0.92, "understandability": 0.88}
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
DEFAULT_MODEL: str = "MingZhong/unieval-sum"
|
| 20 |
+
DEFAULT_DIMS: List[str] = ["naturalness", "coherence", "understandability"]
|
| 21 |
+
DEFAULT_MAX_LENGTH: int = 2560
|
| 22 |
+
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
model_name: Optional[str] = None,
|
| 26 |
+
max_length: Optional[int] = None,
|
| 27 |
+
device: Optional[str] = None,
|
| 28 |
+
):
|
| 29 |
+
"""
|
| 30 |
+
Args:
|
| 31 |
+
model_name: HuggingFace model name/path
|
| 32 |
+
max_length: Tokenizer max sequence length
|
| 33 |
+
device: 'cuda', 'cpu', or None for auto-detect
|
| 34 |
+
"""
|
| 35 |
+
import torch
|
| 36 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
| 37 |
+
self.torch = torch
|
| 38 |
+
|
| 39 |
+
self.model_name = model_name or self.DEFAULT_MODEL
|
| 40 |
+
self.max_length = max_length or self.DEFAULT_MAX_LENGTH
|
| 41 |
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 42 |
+
|
| 43 |
+
# Load model & tokenizer
|
| 44 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
| 45 |
+
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
|
| 46 |
+
self.model.to(self.device)
|
| 47 |
+
self.model.eval()
|
| 48 |
+
|
| 49 |
+
# Pre-compute Yes/No token IDs
|
| 50 |
+
self._yes_id = self.tokenizer("Yes")["input_ids"][0]
|
| 51 |
+
self._no_id = self.tokenizer("No")["input_ids"][0]
|
| 52 |
+
|
| 53 |
+
@staticmethod
|
| 54 |
+
def _build_input_text(dimension: str, question: str, answer: str) -> str:
|
| 55 |
+
"""Construct input text for specified dimension."""
|
| 56 |
+
if dimension == "naturalness":
|
| 57 |
+
return f"question: Is this a natural response? </s> response: {answer}"
|
| 58 |
+
if dimension == "coherence":
|
| 59 |
+
return f"question: Is this a coherent response? </s> response: {answer} </s> history: {question}"
|
| 60 |
+
if dimension == "understandability":
|
| 61 |
+
return f"question: Is this an understandable response? </s> response: {answer}"
|
| 62 |
+
raise NotImplementedError(f"Unsupported dimension '{dimension}'")
|
| 63 |
+
|
| 64 |
+
def evaluate(
|
| 65 |
+
self,
|
| 66 |
+
pair: QAPair,
|
| 67 |
+
dimensions: Optional[List[str]] = None,
|
| 68 |
+
) -> dict[str, float]:
|
| 69 |
+
"""Evaluate a single QAPair across specified dimensions."""
|
| 70 |
+
dimensions = dimensions or self.DEFAULT_DIMS
|
| 71 |
+
|
| 72 |
+
# Validate dimensions
|
| 73 |
+
invalid = set(dimensions) - set(self.DEFAULT_DIMS)
|
| 74 |
+
if invalid:
|
| 75 |
+
raise ValueError(f"Invalid dimensions: {invalid}. Available: {self.DEFAULT_DIMS}")
|
| 76 |
+
|
| 77 |
+
results = {}
|
| 78 |
+
no_token = self.torch.tensor([[self._no_id]], device=self.device)
|
| 79 |
+
|
| 80 |
+
for dim in dimensions:
|
| 81 |
+
# Tokenize input
|
| 82 |
+
src = self.tokenizer(
|
| 83 |
+
self._build_input_text(dim, pair.question, pair.answer),
|
| 84 |
+
max_length=self.max_length,
|
| 85 |
+
truncation=True,
|
| 86 |
+
return_tensors="pt",
|
| 87 |
+
)
|
| 88 |
+
src_tokens = src["input_ids"].to(self.device)
|
| 89 |
+
src_mask = src["attention_mask"].to(self.device)
|
| 90 |
+
|
| 91 |
+
# Score
|
| 92 |
+
with self.torch.no_grad():
|
| 93 |
+
logits = self.model(
|
| 94 |
+
input_ids=src_tokens,
|
| 95 |
+
attention_mask=src_mask,
|
| 96 |
+
labels=no_token,
|
| 97 |
+
use_cache=False,
|
| 98 |
+
).logits[:, 0, :] # [1, vocab_size]
|
| 99 |
+
|
| 100 |
+
probs = self.torch.softmax(logits, dim=-1)[0]
|
| 101 |
+
score = probs[self._yes_id] / (probs[self._yes_id] + probs[self._no_id])
|
| 102 |
+
|
| 103 |
+
results[dim] = score.item()
|
| 104 |
+
|
| 105 |
+
return results
|
graphgen/models/evaluator/reward_evaluator.py
DELETED
|
@@ -1,107 +0,0 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
-
|
| 3 |
-
from tqdm import tqdm
|
| 4 |
-
|
| 5 |
-
from graphgen.bases.datatypes import QAPair
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
@dataclass
|
| 9 |
-
class RewardEvaluator:
|
| 10 |
-
"""
|
| 11 |
-
Reward Model Evaluator.
|
| 12 |
-
OpenAssistant/reward-model-deberta-v3-large-v2: 分数范围为[-inf, inf],越高越好
|
| 13 |
-
"""
|
| 14 |
-
|
| 15 |
-
reward_name: str = "OpenAssistant/reward-model-deberta-v3-large-v2"
|
| 16 |
-
max_length: int = 2560
|
| 17 |
-
results: list[float] = None
|
| 18 |
-
|
| 19 |
-
def __post_init__(self):
|
| 20 |
-
import torch
|
| 21 |
-
|
| 22 |
-
self.num_gpus = torch.cuda.device_count()
|
| 23 |
-
|
| 24 |
-
@staticmethod
|
| 25 |
-
def process_chunk(rank, pairs, reward_name, max_length, return_dict):
|
| 26 |
-
import torch
|
| 27 |
-
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
| 28 |
-
|
| 29 |
-
device = f"cuda:{rank}"
|
| 30 |
-
torch.cuda.set_device(rank)
|
| 31 |
-
|
| 32 |
-
rank_model = AutoModelForSequenceClassification.from_pretrained(reward_name)
|
| 33 |
-
tokenizer = AutoTokenizer.from_pretrained(reward_name)
|
| 34 |
-
rank_model.to(device)
|
| 35 |
-
rank_model.eval()
|
| 36 |
-
|
| 37 |
-
results = []
|
| 38 |
-
with torch.no_grad():
|
| 39 |
-
for pair in tqdm(pairs):
|
| 40 |
-
inputs = tokenizer(
|
| 41 |
-
pair.question,
|
| 42 |
-
pair.answer,
|
| 43 |
-
return_tensors="pt",
|
| 44 |
-
max_length=max_length,
|
| 45 |
-
truncation=True,
|
| 46 |
-
)
|
| 47 |
-
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 48 |
-
score = rank_model(**inputs).logits[0].item()
|
| 49 |
-
results.append(score)
|
| 50 |
-
|
| 51 |
-
return_dict[rank] = results
|
| 52 |
-
|
| 53 |
-
def evaluate(self, pairs: list[QAPair]) -> list[float]:
|
| 54 |
-
import torch.multiprocessing as mp
|
| 55 |
-
|
| 56 |
-
chunk_size = len(pairs) // self.num_gpus
|
| 57 |
-
chunks = []
|
| 58 |
-
for i in range(self.num_gpus):
|
| 59 |
-
start = i * chunk_size
|
| 60 |
-
end = start + chunk_size
|
| 61 |
-
if i == self.num_gpus - 1:
|
| 62 |
-
end = len(pairs)
|
| 63 |
-
chunks.append(pairs[start:end])
|
| 64 |
-
|
| 65 |
-
# multi-process
|
| 66 |
-
manager = mp.Manager()
|
| 67 |
-
return_dict = manager.dict()
|
| 68 |
-
processes = []
|
| 69 |
-
|
| 70 |
-
for rank, chunk in enumerate(chunks):
|
| 71 |
-
p = mp.Process(
|
| 72 |
-
target=self.process_chunk,
|
| 73 |
-
args=(rank, chunk, self.reward_name, self.max_length, return_dict),
|
| 74 |
-
)
|
| 75 |
-
p.start()
|
| 76 |
-
processes.append(p)
|
| 77 |
-
|
| 78 |
-
for p in processes:
|
| 79 |
-
p.join()
|
| 80 |
-
|
| 81 |
-
# 合并结果
|
| 82 |
-
results = []
|
| 83 |
-
for rank in range(len(chunks)):
|
| 84 |
-
results.extend(return_dict[rank])
|
| 85 |
-
|
| 86 |
-
for p in processes:
|
| 87 |
-
if p.is_alive():
|
| 88 |
-
p.terminate()
|
| 89 |
-
p.join()
|
| 90 |
-
|
| 91 |
-
return results
|
| 92 |
-
|
| 93 |
-
def get_average_score(self, pairs: list[QAPair]) -> float:
|
| 94 |
-
"""
|
| 95 |
-
Get the average score of a batch of texts.
|
| 96 |
-
"""
|
| 97 |
-
results = self.evaluate(pairs)
|
| 98 |
-
self.results = results
|
| 99 |
-
return sum(self.results) / len(pairs)
|
| 100 |
-
|
| 101 |
-
def get_min_max_score(self, pairs: list[QAPair]) -> tuple[float, float]:
|
| 102 |
-
"""
|
| 103 |
-
Get the min and max score of a batch of texts.
|
| 104 |
-
"""
|
| 105 |
-
if self.results is None:
|
| 106 |
-
self.get_average_score(pairs)
|
| 107 |
-
return min(self.results), max(self.results)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/models/evaluator/uni_evaluator.py
DELETED
|
@@ -1,183 +0,0 @@
|
|
| 1 |
-
# https://github.com/maszhongming/UniEval/tree/main
|
| 2 |
-
|
| 3 |
-
from dataclasses import dataclass, field
|
| 4 |
-
|
| 5 |
-
from tqdm import tqdm
|
| 6 |
-
|
| 7 |
-
from graphgen.bases.datatypes import QAPair
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
def _add_questions(dimension: str, question: str, answer: str):
|
| 11 |
-
if dimension == "naturalness":
|
| 12 |
-
cur_input = (
|
| 13 |
-
"question: Is this a natural response in the dialogue? </s> response: "
|
| 14 |
-
+ answer
|
| 15 |
-
)
|
| 16 |
-
elif dimension == "coherence":
|
| 17 |
-
cur_input = (
|
| 18 |
-
"question: Is this a coherent response given the dialogue history? </s> response: "
|
| 19 |
-
+ answer
|
| 20 |
-
+ " </s> dialogue history: "
|
| 21 |
-
+ question
|
| 22 |
-
)
|
| 23 |
-
elif dimension == "understandability":
|
| 24 |
-
cur_input = (
|
| 25 |
-
"question: Is this an understandable response in the dialogue? </s> response: "
|
| 26 |
-
+ answer
|
| 27 |
-
)
|
| 28 |
-
else:
|
| 29 |
-
raise NotImplementedError(
|
| 30 |
-
"The input format for this dimension is still undefined. Please customize it first."
|
| 31 |
-
)
|
| 32 |
-
return cur_input
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
@dataclass
|
| 36 |
-
class UniEvaluator:
|
| 37 |
-
model_name: str = "MingZhong/unieval-sum"
|
| 38 |
-
dimensions: list = field(
|
| 39 |
-
default_factory=lambda: ["naturalness", "coherence", "understandability"]
|
| 40 |
-
)
|
| 41 |
-
max_length: int = 2560
|
| 42 |
-
results: dict = None
|
| 43 |
-
|
| 44 |
-
def __post_init__(self):
|
| 45 |
-
import torch
|
| 46 |
-
|
| 47 |
-
self.num_gpus = torch.cuda.device_count()
|
| 48 |
-
self.results = {}
|
| 49 |
-
|
| 50 |
-
@staticmethod
|
| 51 |
-
def process_chunk(rank, pairs, model_name, max_length, dimension, return_dict):
|
| 52 |
-
import torch
|
| 53 |
-
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
| 54 |
-
|
| 55 |
-
device = f"cuda:{rank}"
|
| 56 |
-
torch.cuda.set_device(rank)
|
| 57 |
-
|
| 58 |
-
rank_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
| 59 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 60 |
-
rank_model.to(device)
|
| 61 |
-
rank_model.eval()
|
| 62 |
-
|
| 63 |
-
softmax = torch.nn.Softmax(dim=1)
|
| 64 |
-
|
| 65 |
-
pos_id = tokenizer("Yes")["input_ids"][0]
|
| 66 |
-
neg_id = tokenizer("No")["input_ids"][0]
|
| 67 |
-
|
| 68 |
-
results = []
|
| 69 |
-
with torch.no_grad():
|
| 70 |
-
for pair in tqdm(pairs):
|
| 71 |
-
text = _add_questions(dimension, pair.question, pair.answer)
|
| 72 |
-
|
| 73 |
-
tgt = "No"
|
| 74 |
-
|
| 75 |
-
encoded_src = tokenizer(
|
| 76 |
-
text,
|
| 77 |
-
max_length=max_length,
|
| 78 |
-
truncation=True,
|
| 79 |
-
padding=True,
|
| 80 |
-
return_tensors="pt",
|
| 81 |
-
)
|
| 82 |
-
encoded_tgt = tokenizer(
|
| 83 |
-
tgt,
|
| 84 |
-
max_length=max_length,
|
| 85 |
-
truncation=True,
|
| 86 |
-
padding=True,
|
| 87 |
-
return_tensors="pt",
|
| 88 |
-
)
|
| 89 |
-
|
| 90 |
-
src_tokens = encoded_src["input_ids"].to(device)
|
| 91 |
-
src_mask = encoded_src["attention_mask"].to(device)
|
| 92 |
-
|
| 93 |
-
tgt_tokens = encoded_tgt["input_ids"].to(device)[:, 0].unsqueeze(-1)
|
| 94 |
-
|
| 95 |
-
output = rank_model(
|
| 96 |
-
input_ids=src_tokens,
|
| 97 |
-
attention_mask=src_mask,
|
| 98 |
-
labels=tgt_tokens,
|
| 99 |
-
use_cache=False,
|
| 100 |
-
)
|
| 101 |
-
|
| 102 |
-
logits = output.logits.view(-1, rank_model.config.vocab_size)
|
| 103 |
-
|
| 104 |
-
pos_score = softmax(logits)[:, pos_id] # Yes
|
| 105 |
-
neg_score = softmax(logits)[:, neg_id]
|
| 106 |
-
score = pos_score / (pos_score + neg_score)
|
| 107 |
-
|
| 108 |
-
results.append(score.item())
|
| 109 |
-
|
| 110 |
-
return_dict[rank] = results
|
| 111 |
-
|
| 112 |
-
def evaluate(self, pairs: list[QAPair]) -> list[dict]:
|
| 113 |
-
import torch.multiprocessing as mp
|
| 114 |
-
|
| 115 |
-
final_results = []
|
| 116 |
-
for dimension in self.dimensions:
|
| 117 |
-
chunk_size = len(pairs) // self.num_gpus
|
| 118 |
-
chunks = []
|
| 119 |
-
for i in range(self.num_gpus):
|
| 120 |
-
start = i * chunk_size
|
| 121 |
-
end = start + chunk_size
|
| 122 |
-
if i == self.num_gpus - 1:
|
| 123 |
-
end = len(pairs)
|
| 124 |
-
chunks.append(pairs[start:end])
|
| 125 |
-
|
| 126 |
-
# multi-process
|
| 127 |
-
manager = mp.Manager()
|
| 128 |
-
return_dict = manager.dict()
|
| 129 |
-
processes = []
|
| 130 |
-
|
| 131 |
-
for rank, chunk in enumerate(chunks):
|
| 132 |
-
p = mp.Process(
|
| 133 |
-
target=self.process_chunk,
|
| 134 |
-
args=(
|
| 135 |
-
rank,
|
| 136 |
-
chunk,
|
| 137 |
-
self.model_name,
|
| 138 |
-
self.max_length,
|
| 139 |
-
dimension,
|
| 140 |
-
return_dict,
|
| 141 |
-
),
|
| 142 |
-
)
|
| 143 |
-
p.start()
|
| 144 |
-
processes.append(p)
|
| 145 |
-
|
| 146 |
-
for p in processes:
|
| 147 |
-
p.join()
|
| 148 |
-
|
| 149 |
-
# 合并结果
|
| 150 |
-
results = []
|
| 151 |
-
for rank in range(len(chunks)):
|
| 152 |
-
results.extend(return_dict[rank])
|
| 153 |
-
|
| 154 |
-
for p in processes:
|
| 155 |
-
if p.is_alive():
|
| 156 |
-
p.terminate()
|
| 157 |
-
p.join()
|
| 158 |
-
|
| 159 |
-
final_results.append({dimension: results})
|
| 160 |
-
return final_results
|
| 161 |
-
|
| 162 |
-
def get_average_score(self, pairs: list[QAPair]) -> dict:
|
| 163 |
-
"""
|
| 164 |
-
Get the average score of a batch of texts.
|
| 165 |
-
"""
|
| 166 |
-
results = self.evaluate(pairs)
|
| 167 |
-
final_results = {}
|
| 168 |
-
for result in results:
|
| 169 |
-
for key, value in result.items():
|
| 170 |
-
final_results[key] = sum(value) / len(value)
|
| 171 |
-
self.results[key] = value
|
| 172 |
-
return final_results
|
| 173 |
-
|
| 174 |
-
def get_min_max_score(self, pairs: list[QAPair]) -> dict:
|
| 175 |
-
"""
|
| 176 |
-
Get the min and max score of a batch of texts.
|
| 177 |
-
"""
|
| 178 |
-
if self.results is None:
|
| 179 |
-
self.get_average_score(pairs)
|
| 180 |
-
final_results = {}
|
| 181 |
-
for key, value in self.results.items():
|
| 182 |
-
final_results[key] = min(value), max(value)
|
| 183 |
-
return final_results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/models/storage/graph/kuzu_storage.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
| 1 |
import json
|
| 2 |
import os
|
|
|
|
| 3 |
from dataclasses import dataclass
|
| 4 |
-
from typing import Any
|
| 5 |
|
| 6 |
try:
|
| 7 |
import kuzu
|
|
@@ -78,6 +79,94 @@ class KuzuStorage(BaseGraphStorage):
|
|
| 78 |
print(f"Error decoding JSON: {e}")
|
| 79 |
return {}
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
def has_node(self, node_id: str) -> bool:
|
| 82 |
result = self._conn.execute(
|
| 83 |
"MATCH (a:Entity {id: $id}) RETURN count(a)", {"id": node_id}
|
|
|
|
| 1 |
import json
|
| 2 |
import os
|
| 3 |
+
from collections import defaultdict
|
| 4 |
from dataclasses import dataclass
|
| 5 |
+
from typing import Any, Dict, List, Set
|
| 6 |
|
| 7 |
try:
|
| 8 |
import kuzu
|
|
|
|
| 79 |
print(f"Error decoding JSON: {e}")
|
| 80 |
return {}
|
| 81 |
|
| 82 |
+
def is_directed(self) -> bool:
|
| 83 |
+
return True
|
| 84 |
+
|
| 85 |
+
def get_all_node_degrees(self) -> Dict[str, int]:
|
| 86 |
+
query = """
|
| 87 |
+
MATCH (n:Entity)
|
| 88 |
+
OPTIONAL MATCH (n)-[r]-()
|
| 89 |
+
RETURN n.id, count(r) as degree
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
result = self._conn.execute(query)
|
| 93 |
+
degree_map = {}
|
| 94 |
+
while result.has_next():
|
| 95 |
+
row = result.get_next()
|
| 96 |
+
if row and len(row) >= 2:
|
| 97 |
+
node_id, degree = row[0], row[1]
|
| 98 |
+
degree_map[node_id] = int(degree)
|
| 99 |
+
|
| 100 |
+
return degree_map
|
| 101 |
+
|
| 102 |
+
def get_isolated_nodes(self) -> List[str]:
|
| 103 |
+
query = """
|
| 104 |
+
MATCH (n:Entity)
|
| 105 |
+
WHERE NOT (n)--()
|
| 106 |
+
RETURN n.id
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
result = self._conn.execute(query)
|
| 110 |
+
return [row[0] for row in result if row]
|
| 111 |
+
|
| 112 |
+
def get_node_count(self) -> int:
|
| 113 |
+
result = self._conn.execute("MATCH (n:Entity) RETURN count(n)")
|
| 114 |
+
return result.get_next()[0]
|
| 115 |
+
|
| 116 |
+
def get_edge_count(self) -> int:
|
| 117 |
+
result = self._conn.execute("MATCH ()-[e:Relation]->() RETURN count(e)")
|
| 118 |
+
return result.get_next()[0]
|
| 119 |
+
|
| 120 |
+
def get_connected_components(self, undirected: bool = True) -> List[Set[str]]:
|
| 121 |
+
parent = {}
|
| 122 |
+
rank = {}
|
| 123 |
+
|
| 124 |
+
def find(x: str) -> str:
|
| 125 |
+
if parent[x] != x:
|
| 126 |
+
parent[x] = find(parent[x])
|
| 127 |
+
return parent[x]
|
| 128 |
+
|
| 129 |
+
def union(x: str, y: str):
|
| 130 |
+
root_x, root_y = find(x), find(y)
|
| 131 |
+
if root_x == root_y:
|
| 132 |
+
return
|
| 133 |
+
if rank[root_x] < rank[root_y]:
|
| 134 |
+
parent[root_x] = root_y
|
| 135 |
+
elif rank[root_x] > rank[root_y]:
|
| 136 |
+
parent[root_y] = root_x
|
| 137 |
+
else:
|
| 138 |
+
parent[root_y] = root_x
|
| 139 |
+
rank[root_x] += 1
|
| 140 |
+
|
| 141 |
+
all_nodes = self.get_all_node_degrees().keys()
|
| 142 |
+
for node_id in all_nodes:
|
| 143 |
+
parent[node_id] = node_id
|
| 144 |
+
rank[node_id] = 0
|
| 145 |
+
|
| 146 |
+
query = (
|
| 147 |
+
"""
|
| 148 |
+
MATCH (a:Entity)-[e:Relation]-(b:Entity)
|
| 149 |
+
RETURN DISTINCT a.id, b.id
|
| 150 |
+
"""
|
| 151 |
+
if undirected
|
| 152 |
+
else """
|
| 153 |
+
MATCH (a:Entity)-[e:Relation]->(b:Entity)
|
| 154 |
+
RETURN DISTINCT a.id, b.id
|
| 155 |
+
"""
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
result = self._conn.execute(query)
|
| 159 |
+
for row in result:
|
| 160 |
+
if row and len(row) >= 2:
|
| 161 |
+
union(row[0], row[1])
|
| 162 |
+
|
| 163 |
+
components_dict = defaultdict(set)
|
| 164 |
+
for node_id in all_nodes:
|
| 165 |
+
root = find(node_id)
|
| 166 |
+
components_dict[root].add(node_id)
|
| 167 |
+
|
| 168 |
+
return list(components_dict.values())
|
| 169 |
+
|
| 170 |
def has_node(self, node_id: str) -> bool:
|
| 171 |
result = self._conn.execute(
|
| 172 |
"MATCH (a:Entity {id: $id}) RETURN count(a)", {"id": node_id}
|
graphgen/models/storage/graph/networkx_storage.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
import html
|
| 2 |
import os
|
| 3 |
from dataclasses import dataclass
|
| 4 |
-
from typing import Any, Optional, Union, cast
|
| 5 |
|
| 6 |
import networkx as nx
|
| 7 |
|
|
@@ -10,6 +10,31 @@ from graphgen.bases.base_storage import BaseGraphStorage
|
|
| 10 |
|
| 11 |
@dataclass
|
| 12 |
class NetworkXStorage(BaseGraphStorage):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
@staticmethod
|
| 14 |
def load_nx_graph(file_name) -> Optional[nx.Graph]:
|
| 15 |
if os.path.exists(file_name):
|
|
|
|
| 1 |
import html
|
| 2 |
import os
|
| 3 |
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, Dict, List, Optional, Set, Union, cast
|
| 5 |
|
| 6 |
import networkx as nx
|
| 7 |
|
|
|
|
| 10 |
|
| 11 |
@dataclass
|
| 12 |
class NetworkXStorage(BaseGraphStorage):
|
| 13 |
+
def is_directed(self) -> bool:
|
| 14 |
+
return self._graph.is_directed()
|
| 15 |
+
|
| 16 |
+
def get_all_node_degrees(self) -> Dict[str, int]:
|
| 17 |
+
return {
|
| 18 |
+
str(node_id): int(self._graph.degree[node_id])
|
| 19 |
+
for node_id in self._graph.nodes()
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
def get_node_count(self) -> int:
|
| 23 |
+
return self._graph.number_of_nodes()
|
| 24 |
+
|
| 25 |
+
def get_edge_count(self) -> int:
|
| 26 |
+
return self._graph.number_of_edges()
|
| 27 |
+
|
| 28 |
+
def get_connected_components(self, undirected: bool = True) -> List[Set[str]]:
|
| 29 |
+
graph = self._graph
|
| 30 |
+
|
| 31 |
+
if undirected and graph.is_directed():
|
| 32 |
+
graph = graph.to_undirected()
|
| 33 |
+
|
| 34 |
+
return [
|
| 35 |
+
set(str(node) for node in comp) for comp in nx.connected_components(graph)
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
@staticmethod
|
| 39 |
def load_nx_graph(file_name) -> Optional[nx.Graph]:
|
| 40 |
if os.path.exists(file_name):
|
graphgen/operators/__init__.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
from .build_kg import BuildKGService
|
| 2 |
from .chunk import ChunkService
|
|
|
|
| 3 |
from .extract import ExtractService
|
| 4 |
from .generate import GenerateService
|
| 5 |
from .judge import JudgeService
|
|
@@ -8,6 +9,7 @@ from .quiz import QuizService
|
|
| 8 |
from .read import read
|
| 9 |
from .search import SearchService
|
| 10 |
|
|
|
|
| 11 |
operators = {
|
| 12 |
"read": read,
|
| 13 |
"chunk": ChunkService,
|
|
@@ -18,4 +20,5 @@ operators = {
|
|
| 18 |
"search": SearchService,
|
| 19 |
"partition": PartitionService,
|
| 20 |
"generate": GenerateService,
|
|
|
|
| 21 |
}
|
|
|
|
| 1 |
from .build_kg import BuildKGService
|
| 2 |
from .chunk import ChunkService
|
| 3 |
+
from .evaluate import EvaluateService
|
| 4 |
from .extract import ExtractService
|
| 5 |
from .generate import GenerateService
|
| 6 |
from .judge import JudgeService
|
|
|
|
| 9 |
from .read import read
|
| 10 |
from .search import SearchService
|
| 11 |
|
| 12 |
+
|
| 13 |
operators = {
|
| 14 |
"read": read,
|
| 15 |
"chunk": ChunkService,
|
|
|
|
| 20 |
"search": SearchService,
|
| 21 |
"partition": PartitionService,
|
| 22 |
"generate": GenerateService,
|
| 23 |
+
"evaluate": EvaluateService,
|
| 24 |
}
|
graphgen/operators/evaluate/__init__.py
CHANGED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .evaluate_service import EvaluateService
|
| 2 |
+
|
| 3 |
+
__all__ = ["EvaluateService"]
|
graphgen/operators/evaluate/evaluate.py
DELETED
|
@@ -1,177 +0,0 @@
|
|
| 1 |
-
# TODO: this module needs refactoring to merge into GraphGen framework
|
| 2 |
-
"""Evaluate the quality of the generated text using various metrics"""
|
| 3 |
-
|
| 4 |
-
import argparse
|
| 5 |
-
import json
|
| 6 |
-
import os
|
| 7 |
-
|
| 8 |
-
import pandas as pd
|
| 9 |
-
from dotenv import load_dotenv
|
| 10 |
-
|
| 11 |
-
from graphgen.bases.datatypes import QAPair
|
| 12 |
-
from graphgen.models import (
|
| 13 |
-
LengthEvaluator,
|
| 14 |
-
MTLDEvaluator,
|
| 15 |
-
RewardEvaluator,
|
| 16 |
-
UniEvaluator,
|
| 17 |
-
)
|
| 18 |
-
from graphgen.utils import logger, set_logger
|
| 19 |
-
|
| 20 |
-
sys_path = os.path.abspath(os.path.dirname(__file__))
|
| 21 |
-
set_logger(os.path.join(sys_path, "cache", "logs", "evaluate.log"))
|
| 22 |
-
|
| 23 |
-
load_dotenv()
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
def evaluate_length(corpus, tokenizer_name):
|
| 27 |
-
length_evaluator = LengthEvaluator(tokenizer_name=tokenizer_name)
|
| 28 |
-
logger.info("Length evaluator loaded")
|
| 29 |
-
scores = length_evaluator.get_average_score(corpus)
|
| 30 |
-
logger.info("Length scores: %s", scores)
|
| 31 |
-
return scores
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
def evaluate_mtld(corpus):
|
| 35 |
-
mtld_evaluator = MTLDEvaluator()
|
| 36 |
-
logger.info("MTLD evaluator loaded")
|
| 37 |
-
scores = mtld_evaluator.get_average_score(corpus)
|
| 38 |
-
logger.info("MTLD scores: %s", scores)
|
| 39 |
-
min_max_scores = mtld_evaluator.get_min_max_score(corpus)
|
| 40 |
-
logger.info("MTLD min max scores: %s", min_max_scores)
|
| 41 |
-
return scores, min_max_scores
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
def evaluate_reward(corpus, reward_model_names):
|
| 45 |
-
scores = []
|
| 46 |
-
for reward_name in reward_model_names:
|
| 47 |
-
reward_evaluator = RewardEvaluator(reward_name=reward_name)
|
| 48 |
-
logger.info("Loaded reward model: %s", reward_name)
|
| 49 |
-
average_score = reward_evaluator.get_average_score(corpus)
|
| 50 |
-
logger.info("%s scores: %s", reward_name, average_score)
|
| 51 |
-
min_max_scores = reward_evaluator.get_min_max_score(corpus)
|
| 52 |
-
logger.info("%s min max scores: %s", reward_name, min_max_scores)
|
| 53 |
-
scores.append(
|
| 54 |
-
{
|
| 55 |
-
"reward_name": reward_name.split("/")[-1],
|
| 56 |
-
"score": average_score,
|
| 57 |
-
"min_max_scores": min_max_scores,
|
| 58 |
-
}
|
| 59 |
-
)
|
| 60 |
-
del reward_evaluator
|
| 61 |
-
clean_gpu_cache()
|
| 62 |
-
return scores
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
def evaluate_uni(corpus, uni_model_name):
|
| 66 |
-
uni_evaluator = UniEvaluator(model_name=uni_model_name)
|
| 67 |
-
logger.info("Uni evaluator loaded with model %s", uni_model_name)
|
| 68 |
-
uni_scores = uni_evaluator.get_average_score(corpus)
|
| 69 |
-
for key, value in uni_scores.items():
|
| 70 |
-
logger.info("Uni %s scores: %s", key, value)
|
| 71 |
-
min_max_scores = uni_evaluator.get_min_max_score(corpus)
|
| 72 |
-
for key, value in min_max_scores.items():
|
| 73 |
-
logger.info("Uni %s min max scores: %s", key, value)
|
| 74 |
-
del uni_evaluator
|
| 75 |
-
clean_gpu_cache()
|
| 76 |
-
return (
|
| 77 |
-
uni_scores["naturalness"],
|
| 78 |
-
uni_scores["coherence"],
|
| 79 |
-
uni_scores["understandability"],
|
| 80 |
-
min_max_scores["naturalness"],
|
| 81 |
-
min_max_scores["coherence"],
|
| 82 |
-
min_max_scores["understandability"],
|
| 83 |
-
)
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
def clean_gpu_cache():
|
| 87 |
-
import torch
|
| 88 |
-
|
| 89 |
-
if torch.cuda.is_available():
|
| 90 |
-
torch.cuda.empty_cache()
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
if __name__ == "__main__":
|
| 94 |
-
import torch.multiprocessing as mp
|
| 95 |
-
|
| 96 |
-
parser = argparse.ArgumentParser()
|
| 97 |
-
|
| 98 |
-
parser.add_argument(
|
| 99 |
-
"--folder", type=str, default="cache/data", help="folder to load data"
|
| 100 |
-
)
|
| 101 |
-
parser.add_argument(
|
| 102 |
-
"--output", type=str, default="cache/output", help="path to save output"
|
| 103 |
-
)
|
| 104 |
-
|
| 105 |
-
parser.add_argument(
|
| 106 |
-
"--tokenizer", type=str, default="cl100k_base", help="tokenizer name"
|
| 107 |
-
)
|
| 108 |
-
parser.add_argument(
|
| 109 |
-
"--reward",
|
| 110 |
-
type=str,
|
| 111 |
-
default="OpenAssistant/reward-model-deberta-v3-large-v2",
|
| 112 |
-
help="Comma-separated list of reward models",
|
| 113 |
-
)
|
| 114 |
-
parser.add_argument(
|
| 115 |
-
"--uni", type=str, default="MingZhong/unieval-sum", help="uni model name"
|
| 116 |
-
)
|
| 117 |
-
|
| 118 |
-
args = parser.parse_args()
|
| 119 |
-
|
| 120 |
-
if not os.path.exists(args.folder):
|
| 121 |
-
raise ValueError(f"Folder {args.folder} does not exist")
|
| 122 |
-
|
| 123 |
-
if not os.path.exists(args.output):
|
| 124 |
-
os.makedirs(args.output)
|
| 125 |
-
|
| 126 |
-
reward_models = args.reward.split(",")
|
| 127 |
-
|
| 128 |
-
results = []
|
| 129 |
-
|
| 130 |
-
logger.info("Data loaded from %s", args.folder)
|
| 131 |
-
mp.set_start_method("spawn")
|
| 132 |
-
|
| 133 |
-
for file in os.listdir(args.folder):
|
| 134 |
-
if file.endswith(".json"):
|
| 135 |
-
logger.info("Processing %s", file)
|
| 136 |
-
with open(os.path.join(args.folder, file), "r", encoding="utf-8") as f:
|
| 137 |
-
data = json.load(f)
|
| 138 |
-
data = [
|
| 139 |
-
QAPair(question=data[key]["question"], answer=data[key]["answer"])
|
| 140 |
-
for key in data
|
| 141 |
-
]
|
| 142 |
-
|
| 143 |
-
length_scores = evaluate_length(data, args.tokenizer)
|
| 144 |
-
mtld_scores, min_max_mtld_scores = evaluate_mtld(data)
|
| 145 |
-
reward_scores = evaluate_reward(data, reward_models)
|
| 146 |
-
(
|
| 147 |
-
uni_naturalness_scores,
|
| 148 |
-
uni_coherence_scores,
|
| 149 |
-
uni_understandability_scores,
|
| 150 |
-
min_max_uni_naturalness_scores,
|
| 151 |
-
min_max_uni_coherence_scores,
|
| 152 |
-
min_max_uni_understandability_scores,
|
| 153 |
-
) = evaluate_uni(data, args.uni)
|
| 154 |
-
|
| 155 |
-
result = {
|
| 156 |
-
"file": file,
|
| 157 |
-
"number": len(data),
|
| 158 |
-
"length": length_scores,
|
| 159 |
-
"mtld": mtld_scores,
|
| 160 |
-
"mtld_min_max": min_max_mtld_scores,
|
| 161 |
-
"uni_naturalness": uni_naturalness_scores,
|
| 162 |
-
"uni_coherence": uni_coherence_scores,
|
| 163 |
-
"uni_understandability": uni_understandability_scores,
|
| 164 |
-
"uni_naturalness_min_max": min_max_uni_naturalness_scores,
|
| 165 |
-
"uni_coherence_min_max": min_max_uni_coherence_scores,
|
| 166 |
-
"uni_understandability_min_max": min_max_uni_understandability_scores,
|
| 167 |
-
}
|
| 168 |
-
for reward_score in reward_scores:
|
| 169 |
-
result[reward_score["reward_name"]] = reward_score["score"]
|
| 170 |
-
result[f"{reward_score['reward_name']}_min_max"] = reward_score[
|
| 171 |
-
"min_max_scores"
|
| 172 |
-
]
|
| 173 |
-
|
| 174 |
-
results.append(result)
|
| 175 |
-
|
| 176 |
-
results = pd.DataFrame(results)
|
| 177 |
-
results.to_csv(os.path.join(args.output, "evaluation.csv"), index=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/operators/evaluate/evaluate_service.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict
|
| 2 |
+
|
| 3 |
+
import pandas as pd
|
| 4 |
+
|
| 5 |
+
from graphgen.bases import BaseLLMWrapper, BaseOperator, QAPair
|
| 6 |
+
from graphgen.common import init_llm, init_storage
|
| 7 |
+
from graphgen.utils import logger, run_concurrent
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class EvaluateService(BaseOperator):
|
| 11 |
+
"""
|
| 12 |
+
1. KG Quality Evaluation
|
| 13 |
+
2. QA Quality Evaluation
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
working_dir: str = "cache",
|
| 19 |
+
metrics: list[str] = None,
|
| 20 |
+
graph_backend: str = "kuzu",
|
| 21 |
+
kv_backend: str = "rocksdb",
|
| 22 |
+
**kwargs,
|
| 23 |
+
):
|
| 24 |
+
super().__init__(working_dir=working_dir, op_name="evaluate_service")
|
| 25 |
+
self.llm_client: BaseLLMWrapper = init_llm("synthesizer")
|
| 26 |
+
self.metrics = metrics or []
|
| 27 |
+
self.kwargs = kwargs
|
| 28 |
+
self.graph_storage = init_storage(
|
| 29 |
+
backend=graph_backend, working_dir=working_dir, namespace="graph"
|
| 30 |
+
)
|
| 31 |
+
self.chunk_storage = init_storage(
|
| 32 |
+
backend=kv_backend, working_dir=working_dir, namespace="chunk"
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
# Initialize evaluators
|
| 36 |
+
self.qa_evaluators = {}
|
| 37 |
+
self.kg_evaluators = {}
|
| 38 |
+
self._init_evaluators()
|
| 39 |
+
|
| 40 |
+
def _init_evaluators(self):
|
| 41 |
+
"""Initialize QA and KG evaluators based on metrics."""
|
| 42 |
+
for metric in self.metrics:
|
| 43 |
+
if metric == "qa_length":
|
| 44 |
+
from graphgen.models import LengthEvaluator
|
| 45 |
+
|
| 46 |
+
self.qa_evaluators[metric] = LengthEvaluator()
|
| 47 |
+
elif metric == "qa_mtld":
|
| 48 |
+
from graphgen.models import MTLDEvaluator
|
| 49 |
+
|
| 50 |
+
self.qa_evaluators[metric] = MTLDEvaluator(
|
| 51 |
+
**self.kwargs.get("mtld_params", {})
|
| 52 |
+
)
|
| 53 |
+
elif metric == "qa_reward_score":
|
| 54 |
+
from graphgen.models import RewardEvaluator
|
| 55 |
+
|
| 56 |
+
self.qa_evaluators[metric] = RewardEvaluator(
|
| 57 |
+
**self.kwargs.get("reward_params", {})
|
| 58 |
+
)
|
| 59 |
+
elif metric == "qa_uni_score":
|
| 60 |
+
from graphgen.models import UniEvaluator
|
| 61 |
+
|
| 62 |
+
self.qa_evaluators[metric] = UniEvaluator(
|
| 63 |
+
**self.kwargs.get("uni_params", {})
|
| 64 |
+
)
|
| 65 |
+
elif metric == "kg_accuracy":
|
| 66 |
+
from graphgen.models import AccuracyEvaluator
|
| 67 |
+
|
| 68 |
+
self.kg_evaluators[metric] = AccuracyEvaluator(
|
| 69 |
+
graph_storage=self.graph_storage,
|
| 70 |
+
chunk_storage=self.chunk_storage,
|
| 71 |
+
llm_client=self.llm_client,
|
| 72 |
+
)
|
| 73 |
+
elif metric == "kg_consistency":
|
| 74 |
+
from graphgen.models import ConsistencyEvaluator
|
| 75 |
+
|
| 76 |
+
self.kg_evaluators[metric] = ConsistencyEvaluator(
|
| 77 |
+
graph_storage=self.graph_storage,
|
| 78 |
+
chunk_storage=self.chunk_storage,
|
| 79 |
+
llm_client=self.llm_client,
|
| 80 |
+
)
|
| 81 |
+
elif metric == "kg_structure":
|
| 82 |
+
from graphgen.models import StructureEvaluator
|
| 83 |
+
|
| 84 |
+
self.kg_evaluators[metric] = StructureEvaluator(
|
| 85 |
+
graph_storage=self.graph_storage,
|
| 86 |
+
**self.kwargs.get("structure_params", {}),
|
| 87 |
+
)
|
| 88 |
+
else:
|
| 89 |
+
raise ValueError(f"Unknown QA metric: {metric}")
|
| 90 |
+
|
| 91 |
+
async def _process_single_qa(self, item: dict[str, Any]) -> dict[str, Any]:
|
| 92 |
+
try:
|
| 93 |
+
qa_pair = QAPair(
|
| 94 |
+
question=str(item.get("question", "")),
|
| 95 |
+
answer=str(item.get("answer", "")),
|
| 96 |
+
)
|
| 97 |
+
if not qa_pair.question or not qa_pair.answer:
|
| 98 |
+
self.logger.error("Empty question or answer, skipping.")
|
| 99 |
+
return {}
|
| 100 |
+
except Exception as e:
|
| 101 |
+
self.logger.error("Error in QAPair creation: %s", str(e))
|
| 102 |
+
return {}
|
| 103 |
+
|
| 104 |
+
for metric, evaluator in self.qa_evaluators.items():
|
| 105 |
+
try:
|
| 106 |
+
score = evaluator.evaluate(qa_pair)
|
| 107 |
+
if isinstance(score, dict):
|
| 108 |
+
for sub_metric, sub_score in score.items():
|
| 109 |
+
item[f"{metric}_{sub_metric}"] = float(sub_score)
|
| 110 |
+
else:
|
| 111 |
+
item[metric] = float(score)
|
| 112 |
+
except Exception as e:
|
| 113 |
+
self.logger.error("Error in %s evaluation: %s", metric, str(e))
|
| 114 |
+
item[metric] = None
|
| 115 |
+
return item
|
| 116 |
+
|
| 117 |
+
def _evaluate_qa(self, items: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
| 118 |
+
def transform_messages_format(items: list[dict]) -> list[dict]:
|
| 119 |
+
"""
|
| 120 |
+
Transform from [{'messages': [...]}, ...] to [{'question': '...', 'answer': '...'}, ...]
|
| 121 |
+
"""
|
| 122 |
+
transformed = []
|
| 123 |
+
for item in items:
|
| 124 |
+
messages = item.get("messages", [])
|
| 125 |
+
question = next(
|
| 126 |
+
(m["content"] for m in messages if m.get("role") == "user"), ""
|
| 127 |
+
)
|
| 128 |
+
answer = next(
|
| 129 |
+
(m["content"] for m in messages if m.get("role") == "assistant"), ""
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
transformed.append({"question": question, "answer": answer})
|
| 133 |
+
return transformed
|
| 134 |
+
|
| 135 |
+
if not items:
|
| 136 |
+
return []
|
| 137 |
+
|
| 138 |
+
if not self.qa_evaluators:
|
| 139 |
+
self.logger.warning("No QA evaluators initialized, skipping QA evaluation")
|
| 140 |
+
return []
|
| 141 |
+
|
| 142 |
+
items = transform_messages_format(items)
|
| 143 |
+
results = run_concurrent(
|
| 144 |
+
self._process_single_qa,
|
| 145 |
+
items,
|
| 146 |
+
desc="Evaluating QA items",
|
| 147 |
+
unit="item",
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
results = [item for item in results if item]
|
| 151 |
+
return results
|
| 152 |
+
|
| 153 |
+
def _evaluate_kg(self) -> Dict[str, Any]:
|
| 154 |
+
results = {}
|
| 155 |
+
|
| 156 |
+
for metric, evaluator in self.kg_evaluators.items():
|
| 157 |
+
try:
|
| 158 |
+
self.logger.info("Running %s evaluation...", metric)
|
| 159 |
+
score = evaluator.evaluate()
|
| 160 |
+
results[metric] = score
|
| 161 |
+
except Exception as e:
|
| 162 |
+
self.logger.error("Error in %s evaluation: %s", metric, str(e))
|
| 163 |
+
results[metric] = {"error": str(e)}
|
| 164 |
+
return results
|
| 165 |
+
|
| 166 |
+
def process(self, batch: pd.DataFrame) -> pd.DataFrame:
|
| 167 |
+
# QA evaluation
|
| 168 |
+
if len(self.qa_evaluators) > 0:
|
| 169 |
+
items = batch.to_dict(orient="records")
|
| 170 |
+
results = self._evaluate_qa(items)
|
| 171 |
+
return pd.DataFrame(results)
|
| 172 |
+
|
| 173 |
+
# KG evaluation
|
| 174 |
+
if len(self.kg_evaluators) > 0:
|
| 175 |
+
results = self._evaluate_kg()
|
| 176 |
+
# Convert dict to DataFrame (single row)
|
| 177 |
+
return pd.DataFrame([results])
|
| 178 |
+
|
| 179 |
+
# No metrics specified
|
| 180 |
+
logger.warning("No metrics specified, returning empty DataFrame")
|
| 181 |
+
return pd.DataFrame()
|
graphgen/run.py
CHANGED
|
@@ -91,10 +91,11 @@ def main():
|
|
| 91 |
results = engine.execute(ds)
|
| 92 |
|
| 93 |
for node_id, dataset in results.items():
|
| 94 |
-
|
| 95 |
-
os.
|
|
|
|
| 96 |
dataset.write_json(
|
| 97 |
-
|
| 98 |
filename_provider=NodeFilenameProvider(node_id),
|
| 99 |
pandas_json_args_fn=lambda: {
|
| 100 |
"force_ascii": False,
|
|
@@ -102,7 +103,7 @@ def main():
|
|
| 102 |
"lines": True,
|
| 103 |
},
|
| 104 |
)
|
| 105 |
-
logger.info("Node %s results saved to %s", node_id,
|
| 106 |
|
| 107 |
save_config(os.path.join(output_path, "config.yaml"), config)
|
| 108 |
logger.info("GraphGen completed successfully. Data saved to %s", output_path)
|
|
|
|
| 91 |
results = engine.execute(ds)
|
| 92 |
|
| 93 |
for node_id, dataset in results.items():
|
| 94 |
+
logger.info("Saving results for node %s", node_id)
|
| 95 |
+
node_output_path = os.path.join(output_path, f"{node_id}")
|
| 96 |
+
os.makedirs(node_output_path, exist_ok=True)
|
| 97 |
dataset.write_json(
|
| 98 |
+
node_output_path,
|
| 99 |
filename_provider=NodeFilenameProvider(node_id),
|
| 100 |
pandas_json_args_fn=lambda: {
|
| 101 |
"force_ascii": False,
|
|
|
|
| 103 |
"lines": True,
|
| 104 |
},
|
| 105 |
)
|
| 106 |
+
logger.info("Node %s results saved to %s", node_id, node_output_path)
|
| 107 |
|
| 108 |
save_config(os.path.join(output_path, "config.yaml"), config)
|
| 109 |
logger.info("GraphGen completed successfully. Data saved to %s", output_path)
|
graphgen/templates/__init__.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
from .coreference_resolution import COREFERENCE_RESOLUTION_PROMPT
|
| 2 |
from .description_rephrasing import DESCRIPTION_REPHRASING_PROMPT
|
|
|
|
| 3 |
from .extraction import SCHEMA_GUIDED_EXTRACTION_PROMPT
|
| 4 |
from .generation import (
|
| 5 |
AGGREGATED_GENERATION_PROMPT,
|
|
|
|
| 1 |
from .coreference_resolution import COREFERENCE_RESOLUTION_PROMPT
|
| 2 |
from .description_rephrasing import DESCRIPTION_REPHRASING_PROMPT
|
| 3 |
+
from .evaluation import ACCURACY_EVALUATION_PROMPT, CONSISTENCY_EVALUATION_PROMPT
|
| 4 |
from .extraction import SCHEMA_GUIDED_EXTRACTION_PROMPT
|
| 5 |
from .generation import (
|
| 6 |
AGGREGATED_GENERATION_PROMPT,
|
graphgen/templates/evaluation/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .kg import ACCURACY_EVALUATION_PROMPT, CONSISTENCY_EVALUATION_PROMPT
|
graphgen/templates/evaluation/kg/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .accuracy_evaluation import ACCURACY_EVALUATION_PROMPT
|
| 2 |
+
from .consistency_evaluation import CONSISTENCY_EVALUATION_PROMPT
|
graphgen/templates/evaluation/kg/accuracy_evaluation.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ENTITY_EVALUATION_PROMPT_ZH = """你是一个知识图谱质量评估专家。你的任务是从给定的文本块和提取的实体列表,评估实体提取的质量。
|
| 2 |
+
|
| 3 |
+
评估维度:
|
| 4 |
+
1. ACCURACY (准确性, 权重: 40%): 提取的实体是否正确,是否有误提取或错误识别
|
| 5 |
+
2. COMPLETENESS (完整性, 权重: 40%): 是否遗漏了文本中的重要实体
|
| 6 |
+
3. PRECISION (精确性, 权重: 20%): 提取的实体是否精确,命名是否准确
|
| 7 |
+
|
| 8 |
+
评分标准(每个维度 0-1 分):
|
| 9 |
+
- EXCELLENT (0.8-1.0): 高质量提取
|
| 10 |
+
- GOOD (0.6-0.79): 良好质量,有少量问题
|
| 11 |
+
- ACCEPTABLE (0.4-0.59): 可接受,有明显问题
|
| 12 |
+
- POOR (0.0-0.39): 质量差,需要改进
|
| 13 |
+
|
| 14 |
+
综合评分 = 0.4 × Accuracy + 0.4 × Completeness + 0.2 × Precision
|
| 15 |
+
|
| 16 |
+
请评估以下内容:
|
| 17 |
+
|
| 18 |
+
原始文本块:
|
| 19 |
+
{chunk_content}
|
| 20 |
+
|
| 21 |
+
提取的实体列表:
|
| 22 |
+
{extracted_entities}
|
| 23 |
+
|
| 24 |
+
请以 JSON 格式返回评估结果:
|
| 25 |
+
{{
|
| 26 |
+
"accuracy": <0-1之间的浮点数>,
|
| 27 |
+
"completeness": <0-1之间的浮点数>,
|
| 28 |
+
"precision": <0-1之间的浮点数>,
|
| 29 |
+
"overall_score": <综合评分>,
|
| 30 |
+
"accuracy_reasoning": "<准确性评估理由>",
|
| 31 |
+
"completeness_reasoning": "<完整性评估理由,包括遗漏的重要实体>",
|
| 32 |
+
"precision_reasoning": "<精确性评估理由>",
|
| 33 |
+
"issues": ["<发现的问题列表>"]
|
| 34 |
+
}}
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
ENTITY_EVALUATION_PROMPT_EN = """You are a Knowledge Graph Quality Assessment Expert. \
|
| 38 |
+
Your task is to evaluate the quality of entity extraction from a given text block and extracted entity list.
|
| 39 |
+
|
| 40 |
+
Evaluation Dimensions:
|
| 41 |
+
1. ACCURACY (Weight: 40%): Whether the extracted entities are correct, and if there are any false extractions or misidentifications
|
| 42 |
+
2. COMPLETENESS (Weight: 40%): Whether important entities from the text are missing
|
| 43 |
+
3. PRECISION (Weight: 20%): Whether the extracted entities are precise and accurately named
|
| 44 |
+
|
| 45 |
+
Scoring Criteria (0-1 scale for each dimension):
|
| 46 |
+
- EXCELLENT (0.8-1.0): High-quality extraction
|
| 47 |
+
- GOOD (0.6-0.79): Good quality with minor issues
|
| 48 |
+
- ACCEPTABLE (0.4-0.59): Acceptable with noticeable issues
|
| 49 |
+
- POOR (0.0-0.39): Poor quality, needs improvement
|
| 50 |
+
|
| 51 |
+
Overall Score = 0.4 × Accuracy + 0.4 × Completeness + 0.2 × Precision
|
| 52 |
+
|
| 53 |
+
Please evaluate the following:
|
| 54 |
+
|
| 55 |
+
Original Text Block:
|
| 56 |
+
{chunk_content}
|
| 57 |
+
|
| 58 |
+
Extracted Entity List:
|
| 59 |
+
{extracted_entities}
|
| 60 |
+
|
| 61 |
+
Please return the evaluation result in JSON format:
|
| 62 |
+
{{
|
| 63 |
+
"accuracy": <float between 0-1>,
|
| 64 |
+
"completeness": <float between 0-1>,
|
| 65 |
+
"precision": <float between 0-1>,
|
| 66 |
+
"overall_score": <overall score>,
|
| 67 |
+
"accuracy_reasoning": "<reasoning for accuracy assessment>",
|
| 68 |
+
"completeness_reasoning": "<reasoning for completeness assessment, including important missing entities>",
|
| 69 |
+
"precision_reasoning": "<reasoning for precision assessment>",
|
| 70 |
+
"issues": ["<list of identified issues>"]
|
| 71 |
+
}}
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
RELATION_EVALUATION_PROMPT_ZH = """你是一个知识图谱质量评估专家。你的任务是从给定的文本块和提取的关系列表,评估关系抽取的质量。
|
| 75 |
+
|
| 76 |
+
评估维度:
|
| 77 |
+
1. ACCURACY (准确性, 权重: 40%): 提取的关系是否正确,关系描述是否准确
|
| 78 |
+
2. COMPLETENESS (完整性, 权重: 40%): 是否遗漏了文本中的重要关系
|
| 79 |
+
3. PRECISION (精确性, 权重: 20%): 关系描述是否精确,是否过于宽泛
|
| 80 |
+
|
| 81 |
+
评分标准(每个维度 0-1 分):
|
| 82 |
+
- EXCELLENT (0.8-1.0): 高质量提取
|
| 83 |
+
- GOOD (0.6-0.79): 良好质量,有少量问题
|
| 84 |
+
- ACCEPTABLE (0.4-0.59): 可接受,有明显问题
|
| 85 |
+
- POOR (0.0-0.39): 质量差,需要改进
|
| 86 |
+
|
| 87 |
+
综合评分 = 0.4 × Accuracy + 0.4 × Completeness + 0.2 × Precision
|
| 88 |
+
|
| 89 |
+
请评估以下内容:
|
| 90 |
+
|
| 91 |
+
原始文本块:
|
| 92 |
+
{chunk_content}
|
| 93 |
+
|
| 94 |
+
提取的关系列表:
|
| 95 |
+
{extracted_relations}
|
| 96 |
+
|
| 97 |
+
请以 JSON 格式返回评估结果:
|
| 98 |
+
{{
|
| 99 |
+
"accuracy": <0-1之间的浮点数>,
|
| 100 |
+
"completeness": <0-1之间的浮点数>,
|
| 101 |
+
"precision": <0-1之间的浮点数>,
|
| 102 |
+
"overall_score": <综合评分>,
|
| 103 |
+
"accuracy_reasoning": "<准确性评估理由>",
|
| 104 |
+
"completeness_reasoning": "<完整性评估理由,包括遗漏的重要关系>",
|
| 105 |
+
"precision_reasoning": "<精确性评估理由>",
|
| 106 |
+
"issues": ["<发现的问题列表>"]
|
| 107 |
+
}}
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
RELATION_EVALUATION_PROMPT_EN = """You are a Knowledge Graph Quality Assessment Expert. \
|
| 111 |
+
Your task is to evaluate the quality of relation extraction from a given text block and extracted relation list.
|
| 112 |
+
|
| 113 |
+
Evaluation Dimensions:
|
| 114 |
+
1. ACCURACY (Weight: 40%): Whether the extracted relations are correct and the relation descriptions are accurate
|
| 115 |
+
2. COMPLETENESS (Weight: 40%): Whether important relations from the text are missing
|
| 116 |
+
3. PRECISION (Weight: 20%): Whether the relation descriptions are precise and not overly broad
|
| 117 |
+
|
| 118 |
+
Scoring Criteria (0-1 scale for each dimension):
|
| 119 |
+
- EXCELLENT (0.8-1.0): High-quality extraction
|
| 120 |
+
- GOOD (0.6-0.79): Good quality with minor issues
|
| 121 |
+
- ACCEPTABLE (0.4-0.59): Acceptable with noticeable issues
|
| 122 |
+
- POOR (0.0-0.39): Poor quality, needs improvement
|
| 123 |
+
|
| 124 |
+
Overall Score = 0.4 × Accuracy + 0.4 × Completeness + 0.2 × Precision
|
| 125 |
+
|
| 126 |
+
Please evaluate the following:
|
| 127 |
+
|
| 128 |
+
Original Text Block:
|
| 129 |
+
{chunk_content}
|
| 130 |
+
|
| 131 |
+
Extracted Relation List:
|
| 132 |
+
{extracted_relations}
|
| 133 |
+
|
| 134 |
+
Please return the evaluation result in JSON format:
|
| 135 |
+
{{
|
| 136 |
+
"accuracy": <float between 0-1>,
|
| 137 |
+
"completeness": <float between 0-1>,
|
| 138 |
+
"precision": <float between 0-1>,
|
| 139 |
+
"overall_score": <overall score>,
|
| 140 |
+
"accuracy_reasoning": "<reasoning for accuracy assessment>",
|
| 141 |
+
"completeness_reasoning": "<reasoning for completeness assessment, including important missing relations>",
|
| 142 |
+
"precision_reasoning": "<reasoning for precision assessment>",
|
| 143 |
+
"issues": ["<list of identified issues>"]
|
| 144 |
+
}}
|
| 145 |
+
"""
|
| 146 |
+
|
| 147 |
+
ACCURACY_EVALUATION_PROMPT = {
|
| 148 |
+
"zh": {
|
| 149 |
+
"ENTITY": ENTITY_EVALUATION_PROMPT_ZH,
|
| 150 |
+
"RELATION": RELATION_EVALUATION_PROMPT_ZH,
|
| 151 |
+
},
|
| 152 |
+
"en": {
|
| 153 |
+
"ENTITY": ENTITY_EVALUATION_PROMPT_EN,
|
| 154 |
+
"RELATION": RELATION_EVALUATION_PROMPT_EN,
|
| 155 |
+
},
|
| 156 |
+
}
|
graphgen/templates/evaluation/kg/consistency_evaluation.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ENTITY_TYPE_CONFLICT_PROMPT = """你是一个知识图谱一致性评估专家。你的任务是判断同一个实体在不同文本块中被提取为不同的类型,是否存在语义冲突。
|
| 2 |
+
|
| 3 |
+
实体名称:{entity_name}
|
| 4 |
+
|
| 5 |
+
在不同文本块中的类型提取结果:
|
| 6 |
+
{type_extractions}
|
| 7 |
+
|
| 8 |
+
预设的实体类型列表(供参考):
|
| 9 |
+
concept, date, location, keyword, organization, person, event, work, nature, artificial, science, technology, mission, gene
|
| 10 |
+
|
| 11 |
+
请判断这些类型是否存在语义冲突(即它们是否描述的是同一类事物,还是存在矛盾)。
|
| 12 |
+
注意:如果类型只是同一概念的不同表述(如 concept 和 keyword),可能不算严重冲突。
|
| 13 |
+
|
| 14 |
+
请以 JSON 格式返回:
|
| 15 |
+
{{
|
| 16 |
+
"has_conflict": <true/false>,
|
| 17 |
+
"conflict_severity": <0-1之间的浮点数,0表示无冲突,1表示严重冲突>,
|
| 18 |
+
"conflict_reasoning": "<冲突判断的理由>",
|
| 19 |
+
"conflicting_types": ["<存在冲突的类型对>"],
|
| 20 |
+
"recommended_type": "<如果存在冲突,推荐的正确类型(必须是预设类型之一)>"
|
| 21 |
+
}}
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
ENTITY_DESCRIPTION_CONFLICT_PROMPT = """你是一个知识图谱一致性评估专家。你的任务是判断同一个实体在不同文本块中的描述是否存在语义冲突。
|
| 25 |
+
|
| 26 |
+
实体名称:{entity_name}
|
| 27 |
+
|
| 28 |
+
在不同文本块中的描述:
|
| 29 |
+
{descriptions}
|
| 30 |
+
|
| 31 |
+
请判断这些描述是否存在语义冲突(即它们是否描述的是同一个实体,还是存在矛盾的信息)。
|
| 32 |
+
|
| 33 |
+
请以 JSON 格式返回:
|
| 34 |
+
{{
|
| 35 |
+
"has_conflict": <true/false>,
|
| 36 |
+
"conflict_severity": <0-1之间的浮点数>,
|
| 37 |
+
"conflict_reasoning": "<冲突判断的理由>",
|
| 38 |
+
"conflicting_descriptions": ["<存在冲突的描述对>"],
|
| 39 |
+
"conflict_details": "<具体的冲突内容>"
|
| 40 |
+
}}
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
RELATION_CONFLICT_PROMPT = """你是一个知识图谱一致性评估专家。你的任务是判断同一对实体在不同文本块中的关系描述是否存在语义冲突。
|
| 44 |
+
|
| 45 |
+
实体对:{source_entity} -> {target_entity}
|
| 46 |
+
|
| 47 |
+
在不同文本块中的关系描述:
|
| 48 |
+
{relation_descriptions}
|
| 49 |
+
|
| 50 |
+
请判断这些关系描述是否存在语义冲突。
|
| 51 |
+
|
| 52 |
+
请以 JSON 格式返回:
|
| 53 |
+
{{
|
| 54 |
+
"has_conflict": <true/false>,
|
| 55 |
+
"conflict_severity": <0-1之间的浮点数>,
|
| 56 |
+
"conflict_reasoning": "<冲突判断的理由>",
|
| 57 |
+
"conflicting_relations": ["<存在冲突的关系描述对>"]
|
| 58 |
+
}}
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
ENTITY_EXTRACTION_PROMPT = """从以下文本块中提取指定实体的类型和描述。
|
| 62 |
+
|
| 63 |
+
**重要**:你只需要提取指定的实体,不要提取其他实体。
|
| 64 |
+
|
| 65 |
+
实体名称:{entity_name}
|
| 66 |
+
|
| 67 |
+
文本块:
|
| 68 |
+
{chunk_content}
|
| 69 |
+
|
| 70 |
+
请从文本块中找到并提取**仅此实体**(实体名称:{entity_name})的以下信息:
|
| 71 |
+
|
| 72 |
+
1. entity_type: 实体类型,必须是以下预设类型之一(小写):
|
| 73 |
+
- concept: 概念
|
| 74 |
+
- date: 日期
|
| 75 |
+
- location: 地点
|
| 76 |
+
- keyword: 关键词
|
| 77 |
+
- organization: 组织
|
| 78 |
+
- person: 人物
|
| 79 |
+
- event: 事件
|
| 80 |
+
- work: 作品/工作
|
| 81 |
+
- nature: 自然
|
| 82 |
+
- artificial: 人工
|
| 83 |
+
- science: 科学
|
| 84 |
+
- technology: 技术
|
| 85 |
+
- mission: 任务
|
| 86 |
+
- gene: 基因
|
| 87 |
+
|
| 88 |
+
如果无法确定类型,请使用 "concept" 作为默认值。
|
| 89 |
+
|
| 90 |
+
2. description: 实体描述(简要描述该实体在文本中的作用和特征)
|
| 91 |
+
|
| 92 |
+
请以 JSON 格式返回:
|
| 93 |
+
{{
|
| 94 |
+
"entity_type": "<实体类型(必须是上述预设类型之一)>",
|
| 95 |
+
"description": "<实体描述>"
|
| 96 |
+
}}
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
CONSISTENCY_EVALUATION_PROMPT = {
|
| 100 |
+
"en": "",
|
| 101 |
+
"zh": ""
|
| 102 |
+
}
|
graphgen/utils/help_nltk.py
CHANGED
|
@@ -1,39 +1,61 @@
|
|
|
|
|
| 1 |
import os
|
| 2 |
-
from typing import Dict, List, Optional
|
|
|
|
| 3 |
import nltk
|
| 4 |
import jieba
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
class NLTKHelper:
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
}
|
| 14 |
|
| 15 |
-
def __init__(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
jieba.initialize()
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
def get_stopwords(self, lang: str) -> List[str]:
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
self._stopwords[lang] = nltk.corpus.stopwords.words(lang)
|
| 27 |
-
return self._stopwords[lang]
|
| 28 |
-
|
| 29 |
-
@staticmethod
|
| 30 |
-
def word_tokenize(text: str, lang: str) -> List[str]:
|
| 31 |
if lang == "zh":
|
| 32 |
return jieba.lcut(text)
|
| 33 |
-
nltk.data.path.append(os.path.join(resource_path, "nltk_data"))
|
| 34 |
-
try:
|
| 35 |
-
nltk.data.find("tokenizers/punkt_tab")
|
| 36 |
-
except LookupError:
|
| 37 |
-
nltk.download("punkt_tab", download_dir=os.path.join(resource_path, "nltk_data"))
|
| 38 |
|
| 39 |
return nltk.word_tokenize(text)
|
|
|
|
| 1 |
+
from functools import lru_cache
|
| 2 |
import os
|
| 3 |
+
from typing import Dict, List, Final, Optional
|
| 4 |
+
import warnings
|
| 5 |
import nltk
|
| 6 |
import jieba
|
| 7 |
|
| 8 |
+
warnings.filterwarnings(
|
| 9 |
+
"ignore",
|
| 10 |
+
category=UserWarning,
|
| 11 |
+
module=r"jieba\._compat"
|
| 12 |
+
)
|
| 13 |
|
| 14 |
class NLTKHelper:
|
| 15 |
+
"""
|
| 16 |
+
NLTK helper class
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
SUPPORTED_LANGUAGES: Final[Dict[str, str]] = {
|
| 20 |
+
"en": "english",
|
| 21 |
+
"zh": "chinese"
|
| 22 |
+
}
|
| 23 |
+
_NLTK_PACKAGES: Final[Dict[str, str]] = {
|
| 24 |
+
"stopwords": "corpora",
|
| 25 |
+
"punkt_tab": "tokenizers"
|
| 26 |
}
|
| 27 |
|
| 28 |
+
def __init__(self, nltk_data_path: Optional[str] = None):
|
| 29 |
+
self._nltk_path = nltk_data_path or os.path.join(
|
| 30 |
+
os.path.dirname(os.path.dirname(__file__)),
|
| 31 |
+
"resources",
|
| 32 |
+
"nltk_data"
|
| 33 |
+
)
|
| 34 |
+
nltk.data.path.append(self._nltk_path)
|
| 35 |
jieba.initialize()
|
| 36 |
|
| 37 |
+
self._ensure_nltk_data("stopwords")
|
| 38 |
+
self._ensure_nltk_data("punkt_tab")
|
| 39 |
+
|
| 40 |
+
def _ensure_nltk_data(self, package_name: str) -> None:
|
| 41 |
+
"""
|
| 42 |
+
ensure nltk data is downloaded
|
| 43 |
+
"""
|
| 44 |
+
try:
|
| 45 |
+
nltk.data.find(f"{self._NLTK_PACKAGES[package_name]}/{package_name}")
|
| 46 |
+
except LookupError:
|
| 47 |
+
nltk.download(package_name, download_dir=self._nltk_path, quiet=True)
|
| 48 |
+
|
| 49 |
+
@lru_cache(maxsize=2)
|
| 50 |
def get_stopwords(self, lang: str) -> List[str]:
|
| 51 |
+
if lang not in self.SUPPORTED_LANGUAGES:
|
| 52 |
+
raise ValueError(f"Language {lang} is not supported.")
|
| 53 |
+
return nltk.corpus.stopwords.words(self.SUPPORTED_LANGUAGES[lang])
|
| 54 |
+
|
| 55 |
+
def word_tokenize(self, text: str, lang: str) -> List[str]:
|
| 56 |
+
if lang not in self.SUPPORTED_LANGUAGES:
|
| 57 |
+
raise ValueError(f"Language {lang} is not supported.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
if lang == "zh":
|
| 59 |
return jieba.lcut(text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
return nltk.word_tokenize(text)
|
requirements.txt
CHANGED
|
@@ -23,7 +23,6 @@ aiohttp
|
|
| 23 |
socksio
|
| 24 |
pydantic
|
| 25 |
ray==2.52.1
|
| 26 |
-
kuzu
|
| 27 |
pyarrow
|
| 28 |
|
| 29 |
leidenalg
|
|
@@ -32,9 +31,11 @@ python-louvain
|
|
| 32 |
|
| 33 |
# storage
|
| 34 |
rocksdict
|
|
|
|
| 35 |
|
| 36 |
# KG
|
| 37 |
rdflib
|
|
|
|
| 38 |
|
| 39 |
# Bioinformatics
|
| 40 |
biopython
|
|
|
|
| 23 |
socksio
|
| 24 |
pydantic
|
| 25 |
ray==2.52.1
|
|
|
|
| 26 |
pyarrow
|
| 27 |
|
| 28 |
leidenalg
|
|
|
|
| 31 |
|
| 32 |
# storage
|
| 33 |
rocksdict
|
| 34 |
+
kuzu
|
| 35 |
|
| 36 |
# KG
|
| 37 |
rdflib
|
| 38 |
+
scipy
|
| 39 |
|
| 40 |
# Bioinformatics
|
| 41 |
biopython
|
webui/utils/count_tokens.py
CHANGED
|
@@ -7,10 +7,12 @@ import pandas as pd
|
|
| 7 |
# pylint: disable=wrong-import-position
|
| 8 |
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 9 |
sys.path.append(root_dir)
|
| 10 |
-
from graphgen.models import Tokenizer
|
| 11 |
|
| 12 |
|
| 13 |
def count_tokens(file, tokenizer_name, data_frame):
|
|
|
|
|
|
|
|
|
|
| 14 |
if not file or not os.path.exists(file):
|
| 15 |
return data_frame
|
| 16 |
|
|
|
|
| 7 |
# pylint: disable=wrong-import-position
|
| 8 |
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 9 |
sys.path.append(root_dir)
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
def count_tokens(file, tokenizer_name, data_frame):
|
| 13 |
+
# Lazy import to avoid circular dependency
|
| 14 |
+
from graphgen.models import Tokenizer # pylint: disable=import-outside-toplevel
|
| 15 |
+
|
| 16 |
if not file or not os.path.exists(file):
|
| 17 |
return data_frame
|
| 18 |
|