Spaces:
Paused
Paused
File size: 7,790 Bytes
db4fa53 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 | from __future__ import annotations
from collections import deque
from dataclasses import dataclass
from pathlib import Path
import re
from typing import Iterable
from osint_env.domain.models import CanonicalGraph, Edge, Node, NodeType
_TOPIC_PATTERN = re.compile(r"\[(.*?)\]")
@dataclass(slots=True)
class MetaQATaskRecord:
question: str
answers: list[str]
primary_answer: str
hop_label: str
hop_count: int
split: str
qtype: str
topic_entity: str
supporting_edges: list[Edge]
def _normalize_hop_label(value: str) -> str:
token = str(value or "").strip().lower().replace(" ", "")
if token in {"1", "1hop", "1-hop"}:
return "1-hop"
if token in {"2", "2hop", "2-hop"}:
return "2-hop"
if token in {"3", "3hop", "3-hop"}:
return "3-hop"
return ""
def _normalize_split(value: str) -> str:
token = str(value or "").strip().lower()
if token in {"train", "dev", "test"}:
return token
return ""
def _hop_count(label: str) -> int:
return int(label.split("-", 1)[0])
def _extract_topic_entity(question: str) -> str:
match = _TOPIC_PATTERN.search(str(question))
return match.group(1).strip() if match else ""
def _node_types_for_relation(rel: str) -> tuple[NodeType, NodeType]:
relation = str(rel or "").strip().lower()
src_type = NodeType.POST
if relation in {"directed_by", "written_by", "starred_actors"}:
return src_type, NodeType.USER
if relation == "release_year":
return src_type, NodeType.EVENT
if relation == "in_language":
return src_type, NodeType.LOCATION
if relation in {"has_genre", "has_tags", "has_imdb_votes"}:
return src_type, NodeType.ORG
return src_type, NodeType.USER
def _ensure_node(graph: CanonicalGraph, node_id: str, node_type: NodeType) -> None:
existing = graph.nodes.get(node_id)
if existing is not None:
return
graph.nodes[node_id] = Node(node_id=node_id, node_type=node_type, attrs={"name": node_id})
def _read_non_empty_lines(path: Path) -> list[str]:
return [line.strip() for line in path.read_text(encoding="utf-8").splitlines() if line.strip()]
def _parse_kb_line(line: str) -> tuple[str, str, str] | None:
parts = [part.strip() for part in str(line).split("|", 2)]
if len(parts) != 3:
return None
src, rel, dst = parts
if not src or not rel or not dst:
return None
return src, rel, dst
def _undirected_adjacency(edges: Iterable[Edge]) -> dict[str, list[tuple[str, Edge]]]:
adj: dict[str, list[tuple[str, Edge]]] = {}
for edge in edges:
adj.setdefault(edge.src, []).append((edge.dst, edge))
adj.setdefault(edge.dst, []).append((edge.src, edge))
return adj
def _bfs_support_path(
topic_entity: str,
answer_candidates: list[str],
adjacency: dict[str, list[tuple[str, Edge]]],
max_depth: int,
) -> list[Edge]:
topic = str(topic_entity or "").strip()
if not topic or topic not in adjacency:
return []
answers = {item.strip() for item in answer_candidates if item.strip()}
if not answers:
return []
queue: deque[tuple[str, list[Edge]]] = deque([(topic, [])])
visited_depth: dict[str, int] = {topic: 0}
while queue:
node, path = queue.popleft()
depth = len(path)
if depth > max_depth:
continue
if node in answers and path:
return path
if depth == max_depth:
continue
for neighbor, edge in adjacency.get(node, []):
next_depth = depth + 1
best = visited_depth.get(neighbor)
if best is not None and best <= next_depth:
continue
visited_depth[neighbor] = next_depth
queue.append((neighbor, path + [edge]))
return []
def _infer_support_edges(
topic_entity: str,
answer_candidates: list[str],
adjacency: dict[str, list[tuple[str, Edge]]],
hop_count: int,
) -> list[Edge]:
for limit in (hop_count, hop_count + 1, hop_count + 2, max(4, hop_count + 3)):
path = _bfs_support_path(topic_entity, answer_candidates, adjacency, max_depth=max(1, limit))
if path:
return path
return []
def infer_metaqa_support_edges(
graph: CanonicalGraph,
topic_entity: str,
answer_candidates: list[str],
hop_count: int,
) -> list[Edge]:
adjacency = _undirected_adjacency(graph.edges)
return _infer_support_edges(
topic_entity=topic_entity,
answer_candidates=answer_candidates,
adjacency=adjacency,
hop_count=hop_count,
)
def load_metaqa_dataset(
root: str | Path,
kb_path: str | Path | None,
variant: str,
hops: list[str],
splits: list[str],
) -> tuple[CanonicalGraph, list[MetaQATaskRecord]]:
root_path = Path(root)
if not root_path.exists():
raise FileNotFoundError(f"MetaQA root not found: {root_path}")
kb_file = Path(kb_path) if kb_path else root_path / "kb.txt"
if not kb_file.exists():
raise FileNotFoundError(f"MetaQA KB file not found: {kb_file}")
graph = CanonicalGraph()
seen_edges: set[tuple[str, str, str]] = set()
for raw_line in _read_non_empty_lines(kb_file):
row = _parse_kb_line(raw_line)
if row is None:
continue
src, rel, dst = row
edge_key = (src, rel, dst)
if edge_key in seen_edges:
continue
seen_edges.add(edge_key)
src_type, dst_type = _node_types_for_relation(rel)
_ensure_node(graph, src, src_type)
_ensure_node(graph, dst, dst_type)
graph.edges.append(Edge(src=src, rel=rel, dst=dst, confidence=1.0))
hop_labels = [_normalize_hop_label(hop) for hop in hops]
hop_labels = [hop for hop in hop_labels if hop]
if not hop_labels:
hop_labels = ["1-hop", "2-hop", "3-hop"]
split_labels = [_normalize_split(split) for split in splits]
split_labels = [split for split in split_labels if split]
if not split_labels:
split_labels = ["train", "dev", "test"]
variant_token = str(variant or "vanilla").strip().lower()
if variant_token not in {"vanilla", "ntm"}:
variant_token = "vanilla"
records: list[MetaQATaskRecord] = []
for hop in hop_labels:
hop_dir = root_path / hop
for split in split_labels:
qa_path = hop_dir / variant_token / f"qa_{split}.txt"
if not qa_path.exists():
continue
qa_lines = _read_non_empty_lines(qa_path)
qtype_path = hop_dir / f"qa_{split}_qtype.txt"
qtypes = _read_non_empty_lines(qtype_path) if qtype_path.exists() else []
for idx, row in enumerate(qa_lines):
parts = row.split("\t")
if len(parts) < 2:
continue
question = parts[0].strip()
answer_blob = parts[1].strip()
answers = [item.strip() for item in answer_blob.split("|") if item.strip()]
if not question or not answers:
continue
topic_entity = _extract_topic_entity(question)
qtype = qtypes[idx] if idx < len(qtypes) else ""
records.append(
MetaQATaskRecord(
question=question,
answers=answers,
primary_answer=answers[0],
hop_label=hop,
hop_count=_hop_count(hop),
split=split,
qtype=qtype,
topic_entity=topic_entity,
supporting_edges=[],
)
)
return graph, records
|