OSINT / src /osint_env /data /metaqa.py
siddeshwar-kagatikar
fix(rewards): never crash GRPO on malformed completions
d814291
from __future__ import annotations
from collections import deque
from dataclasses import dataclass
from pathlib import Path
import re
from typing import Iterable
from osint_env.domain.models import CanonicalGraph, Edge, Node, NodeType
_TOPIC_PATTERN = re.compile(r"\[(.*?)\]")
@dataclass(slots=True)
class MetaQATaskRecord:
question: str
answers: list[str]
primary_answer: str
hop_label: str
hop_count: int
split: str
qtype: str
topic_entity: str
supporting_edges: list[Edge]
def _normalize_hop_label(value: str) -> str:
token = str(value or "").strip().lower().replace(" ", "")
if token in {"1", "1hop", "1-hop"}:
return "1-hop"
if token in {"2", "2hop", "2-hop"}:
return "2-hop"
if token in {"3", "3hop", "3-hop"}:
return "3-hop"
return ""
def _normalize_split(value: str) -> str:
token = str(value or "").strip().lower()
if token in {"train", "dev", "test"}:
return token
return ""
def _hop_count(label: str) -> int:
return int(label.split("-", 1)[0])
def _extract_topic_entity(question: str) -> str:
match = _TOPIC_PATTERN.search(str(question))
return match.group(1).strip() if match else ""
def _node_types_for_relation(rel: str) -> tuple[NodeType, NodeType]:
relation = str(rel or "").strip().lower()
src_type = NodeType.POST
if relation in {"directed_by", "written_by", "starred_actors"}:
return src_type, NodeType.USER
if relation == "release_year":
return src_type, NodeType.EVENT
if relation == "in_language":
return src_type, NodeType.LOCATION
if relation in {"has_genre", "has_tags", "has_imdb_votes"}:
return src_type, NodeType.ORG
return src_type, NodeType.USER
def _ensure_node(graph: CanonicalGraph, node_id: str, node_type: NodeType) -> None:
existing = graph.nodes.get(node_id)
if existing is not None:
return
graph.nodes[node_id] = Node(node_id=node_id, node_type=node_type, attrs={"name": node_id})
def _read_non_empty_lines(path: Path) -> list[str]:
return [line.strip() for line in path.read_text(encoding="utf-8").splitlines() if line.strip()]
def _parse_kb_line(line: str) -> tuple[str, str, str] | None:
parts = [part.strip() for part in str(line).split("|", 2)]
if len(parts) != 3:
return None
src, rel, dst = parts
if not src or not rel or not dst:
return None
return src, rel, dst
def _undirected_adjacency(edges: Iterable[Edge]) -> dict[str, list[tuple[str, Edge]]]:
adj: dict[str, list[tuple[str, Edge]]] = {}
for edge in edges:
adj.setdefault(edge.src, []).append((edge.dst, edge))
adj.setdefault(edge.dst, []).append((edge.src, edge))
return adj
def _bfs_support_path(
topic_entity: str,
answer_candidates: list[str],
adjacency: dict[str, list[tuple[str, Edge]]],
max_depth: int,
) -> list[Edge]:
topic = str(topic_entity or "").strip()
if not topic or topic not in adjacency:
return []
answers = {item.strip() for item in answer_candidates if item.strip()}
if not answers:
return []
queue: deque[tuple[str, list[Edge]]] = deque([(topic, [])])
visited_depth: dict[str, int] = {topic: 0}
while queue:
node, path = queue.popleft()
depth = len(path)
if depth > max_depth:
continue
if node in answers and path:
return path
if depth == max_depth:
continue
for neighbor, edge in adjacency.get(node, []):
next_depth = depth + 1
best = visited_depth.get(neighbor)
if best is not None and best <= next_depth:
continue
visited_depth[neighbor] = next_depth
queue.append((neighbor, path + [edge]))
return []
def _infer_support_edges(
topic_entity: str,
answer_candidates: list[str],
adjacency: dict[str, list[tuple[str, Edge]]],
hop_count: int,
) -> list[Edge]:
for limit in (hop_count, hop_count + 1, hop_count + 2, max(4, hop_count + 3)):
path = _bfs_support_path(topic_entity, answer_candidates, adjacency, max_depth=max(1, limit))
if path:
return path
return []
def infer_metaqa_support_edges(
graph: CanonicalGraph,
topic_entity: str,
answer_candidates: list[str],
hop_count: int,
) -> list[Edge]:
adjacency = _undirected_adjacency(graph.edges)
return _infer_support_edges(
topic_entity=topic_entity,
answer_candidates=answer_candidates,
adjacency=adjacency,
hop_count=hop_count,
)
def load_metaqa_dataset(
root: str | Path,
kb_path: str | Path | None,
variant: str,
hops: list[str],
splits: list[str],
) -> tuple[CanonicalGraph, list[MetaQATaskRecord]]:
root_path = Path(root)
if not root_path.exists():
raise FileNotFoundError(f"MetaQA root not found: {root_path}")
kb_file = Path(kb_path) if kb_path else root_path / "kb.txt"
if not kb_file.exists():
raise FileNotFoundError(f"MetaQA KB file not found: {kb_file}")
graph = CanonicalGraph()
seen_edges: set[tuple[str, str, str]] = set()
for raw_line in _read_non_empty_lines(kb_file):
row = _parse_kb_line(raw_line)
if row is None:
continue
src, rel, dst = row
edge_key = (src, rel, dst)
if edge_key in seen_edges:
continue
seen_edges.add(edge_key)
src_type, dst_type = _node_types_for_relation(rel)
_ensure_node(graph, src, src_type)
_ensure_node(graph, dst, dst_type)
graph.edges.append(Edge(src=src, rel=rel, dst=dst, confidence=1.0))
hop_labels = [_normalize_hop_label(hop) for hop in hops]
hop_labels = [hop for hop in hop_labels if hop]
if not hop_labels:
hop_labels = ["1-hop", "2-hop", "3-hop"]
split_labels = [_normalize_split(split) for split in splits]
split_labels = [split for split in split_labels if split]
if not split_labels:
split_labels = ["train", "dev", "test"]
variant_token = str(variant or "vanilla").strip().lower()
if variant_token not in {"vanilla", "ntm"}:
variant_token = "vanilla"
records: list[MetaQATaskRecord] = []
for hop in hop_labels:
hop_dir = root_path / hop
for split in split_labels:
qa_path = hop_dir / variant_token / f"qa_{split}.txt"
if not qa_path.exists():
continue
qa_lines = _read_non_empty_lines(qa_path)
qtype_path = hop_dir / f"qa_{split}_qtype.txt"
qtypes = _read_non_empty_lines(qtype_path) if qtype_path.exists() else []
for idx, row in enumerate(qa_lines):
parts = row.split("\t")
if len(parts) < 2:
continue
question = parts[0].strip()
answer_blob = parts[1].strip()
answers = [item.strip() for item in answer_blob.split("|") if item.strip()]
if not question or not answers:
continue
topic_entity = _extract_topic_entity(question)
qtype = qtypes[idx] if idx < len(qtypes) else ""
records.append(
MetaQATaskRecord(
question=question,
answers=answers,
primary_answer=answers[0],
hop_label=hop,
hop_count=_hop_count(hop),
split=split,
qtype=qtype,
topic_entity=topic_entity,
supporting_edges=[],
)
)
return graph, records