Spaces:
Paused
Paused
| from __future__ import annotations | |
| from collections import deque | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| import re | |
| from typing import Iterable | |
| from osint_env.domain.models import CanonicalGraph, Edge, Node, NodeType | |
| _TOPIC_PATTERN = re.compile(r"\[(.*?)\]") | |
| class MetaQATaskRecord: | |
| question: str | |
| answers: list[str] | |
| primary_answer: str | |
| hop_label: str | |
| hop_count: int | |
| split: str | |
| qtype: str | |
| topic_entity: str | |
| supporting_edges: list[Edge] | |
| def _normalize_hop_label(value: str) -> str: | |
| token = str(value or "").strip().lower().replace(" ", "") | |
| if token in {"1", "1hop", "1-hop"}: | |
| return "1-hop" | |
| if token in {"2", "2hop", "2-hop"}: | |
| return "2-hop" | |
| if token in {"3", "3hop", "3-hop"}: | |
| return "3-hop" | |
| return "" | |
| def _normalize_split(value: str) -> str: | |
| token = str(value or "").strip().lower() | |
| if token in {"train", "dev", "test"}: | |
| return token | |
| return "" | |
| def _hop_count(label: str) -> int: | |
| return int(label.split("-", 1)[0]) | |
| def _extract_topic_entity(question: str) -> str: | |
| match = _TOPIC_PATTERN.search(str(question)) | |
| return match.group(1).strip() if match else "" | |
| def _node_types_for_relation(rel: str) -> tuple[NodeType, NodeType]: | |
| relation = str(rel or "").strip().lower() | |
| src_type = NodeType.POST | |
| if relation in {"directed_by", "written_by", "starred_actors"}: | |
| return src_type, NodeType.USER | |
| if relation == "release_year": | |
| return src_type, NodeType.EVENT | |
| if relation == "in_language": | |
| return src_type, NodeType.LOCATION | |
| if relation in {"has_genre", "has_tags", "has_imdb_votes"}: | |
| return src_type, NodeType.ORG | |
| return src_type, NodeType.USER | |
| def _ensure_node(graph: CanonicalGraph, node_id: str, node_type: NodeType) -> None: | |
| existing = graph.nodes.get(node_id) | |
| if existing is not None: | |
| return | |
| graph.nodes[node_id] = Node(node_id=node_id, node_type=node_type, attrs={"name": node_id}) | |
| def _read_non_empty_lines(path: Path) -> list[str]: | |
| return [line.strip() for line in path.read_text(encoding="utf-8").splitlines() if line.strip()] | |
| def _parse_kb_line(line: str) -> tuple[str, str, str] | None: | |
| parts = [part.strip() for part in str(line).split("|", 2)] | |
| if len(parts) != 3: | |
| return None | |
| src, rel, dst = parts | |
| if not src or not rel or not dst: | |
| return None | |
| return src, rel, dst | |
| def _undirected_adjacency(edges: Iterable[Edge]) -> dict[str, list[tuple[str, Edge]]]: | |
| adj: dict[str, list[tuple[str, Edge]]] = {} | |
| for edge in edges: | |
| adj.setdefault(edge.src, []).append((edge.dst, edge)) | |
| adj.setdefault(edge.dst, []).append((edge.src, edge)) | |
| return adj | |
| def _bfs_support_path( | |
| topic_entity: str, | |
| answer_candidates: list[str], | |
| adjacency: dict[str, list[tuple[str, Edge]]], | |
| max_depth: int, | |
| ) -> list[Edge]: | |
| topic = str(topic_entity or "").strip() | |
| if not topic or topic not in adjacency: | |
| return [] | |
| answers = {item.strip() for item in answer_candidates if item.strip()} | |
| if not answers: | |
| return [] | |
| queue: deque[tuple[str, list[Edge]]] = deque([(topic, [])]) | |
| visited_depth: dict[str, int] = {topic: 0} | |
| while queue: | |
| node, path = queue.popleft() | |
| depth = len(path) | |
| if depth > max_depth: | |
| continue | |
| if node in answers and path: | |
| return path | |
| if depth == max_depth: | |
| continue | |
| for neighbor, edge in adjacency.get(node, []): | |
| next_depth = depth + 1 | |
| best = visited_depth.get(neighbor) | |
| if best is not None and best <= next_depth: | |
| continue | |
| visited_depth[neighbor] = next_depth | |
| queue.append((neighbor, path + [edge])) | |
| return [] | |
| def _infer_support_edges( | |
| topic_entity: str, | |
| answer_candidates: list[str], | |
| adjacency: dict[str, list[tuple[str, Edge]]], | |
| hop_count: int, | |
| ) -> list[Edge]: | |
| for limit in (hop_count, hop_count + 1, hop_count + 2, max(4, hop_count + 3)): | |
| path = _bfs_support_path(topic_entity, answer_candidates, adjacency, max_depth=max(1, limit)) | |
| if path: | |
| return path | |
| return [] | |
| def infer_metaqa_support_edges( | |
| graph: CanonicalGraph, | |
| topic_entity: str, | |
| answer_candidates: list[str], | |
| hop_count: int, | |
| ) -> list[Edge]: | |
| adjacency = _undirected_adjacency(graph.edges) | |
| return _infer_support_edges( | |
| topic_entity=topic_entity, | |
| answer_candidates=answer_candidates, | |
| adjacency=adjacency, | |
| hop_count=hop_count, | |
| ) | |
| def load_metaqa_dataset( | |
| root: str | Path, | |
| kb_path: str | Path | None, | |
| variant: str, | |
| hops: list[str], | |
| splits: list[str], | |
| ) -> tuple[CanonicalGraph, list[MetaQATaskRecord]]: | |
| root_path = Path(root) | |
| if not root_path.exists(): | |
| raise FileNotFoundError(f"MetaQA root not found: {root_path}") | |
| kb_file = Path(kb_path) if kb_path else root_path / "kb.txt" | |
| if not kb_file.exists(): | |
| raise FileNotFoundError(f"MetaQA KB file not found: {kb_file}") | |
| graph = CanonicalGraph() | |
| seen_edges: set[tuple[str, str, str]] = set() | |
| for raw_line in _read_non_empty_lines(kb_file): | |
| row = _parse_kb_line(raw_line) | |
| if row is None: | |
| continue | |
| src, rel, dst = row | |
| edge_key = (src, rel, dst) | |
| if edge_key in seen_edges: | |
| continue | |
| seen_edges.add(edge_key) | |
| src_type, dst_type = _node_types_for_relation(rel) | |
| _ensure_node(graph, src, src_type) | |
| _ensure_node(graph, dst, dst_type) | |
| graph.edges.append(Edge(src=src, rel=rel, dst=dst, confidence=1.0)) | |
| hop_labels = [_normalize_hop_label(hop) for hop in hops] | |
| hop_labels = [hop for hop in hop_labels if hop] | |
| if not hop_labels: | |
| hop_labels = ["1-hop", "2-hop", "3-hop"] | |
| split_labels = [_normalize_split(split) for split in splits] | |
| split_labels = [split for split in split_labels if split] | |
| if not split_labels: | |
| split_labels = ["train", "dev", "test"] | |
| variant_token = str(variant or "vanilla").strip().lower() | |
| if variant_token not in {"vanilla", "ntm"}: | |
| variant_token = "vanilla" | |
| records: list[MetaQATaskRecord] = [] | |
| for hop in hop_labels: | |
| hop_dir = root_path / hop | |
| for split in split_labels: | |
| qa_path = hop_dir / variant_token / f"qa_{split}.txt" | |
| if not qa_path.exists(): | |
| continue | |
| qa_lines = _read_non_empty_lines(qa_path) | |
| qtype_path = hop_dir / f"qa_{split}_qtype.txt" | |
| qtypes = _read_non_empty_lines(qtype_path) if qtype_path.exists() else [] | |
| for idx, row in enumerate(qa_lines): | |
| parts = row.split("\t") | |
| if len(parts) < 2: | |
| continue | |
| question = parts[0].strip() | |
| answer_blob = parts[1].strip() | |
| answers = [item.strip() for item in answer_blob.split("|") if item.strip()] | |
| if not question or not answers: | |
| continue | |
| topic_entity = _extract_topic_entity(question) | |
| qtype = qtypes[idx] if idx < len(qtypes) else "" | |
| records.append( | |
| MetaQATaskRecord( | |
| question=question, | |
| answers=answers, | |
| primary_answer=answers[0], | |
| hop_label=hop, | |
| hop_count=_hop_count(hop), | |
| split=split, | |
| qtype=qtype, | |
| topic_entity=topic_entity, | |
| supporting_edges=[], | |
| ) | |
| ) | |
| return graph, records | |