from __future__ import annotations from concurrent.futures import ThreadPoolExecutor, as_completed import json from pathlib import Path import random import re from dataclasses import dataclass from typing import TYPE_CHECKING, Any from osint_env.data.metaqa import MetaQATaskRecord, infer_metaqa_support_edges, load_metaqa_dataset from osint_env.domain.models import ( CanonicalGraph, Edge, EnvironmentConfig, Node, NodeType, SeedEdgeSpec, SeedQuestionSpec, TaskInstance, ) if TYPE_CHECKING: from osint_env.llm.interface import LLMClient @dataclass(slots=True) class PlatformViews: microblog_posts: list[dict] forum_threads: list[dict] profiles: list[dict] alias_lookup: dict[str, str] def _edge_payload(edge: Edge) -> dict[str, Any]: return { "src": edge.src, "rel": edge.rel, "dst": edge.dst, "confidence": float(edge.confidence), } def _normalize_swarm_v2_path_edges(path: list[Edge | dict[str, Any]]) -> list[Edge]: out: list[Edge] = [] for row in path: if isinstance(row, Edge): out.append(Edge(row.src, row.rel, row.dst, float(row.confidence))) continue if not isinstance(row, dict): return [] src = str(row.get("src", "")).strip() rel = str(row.get("rel", "")).strip() dst = str(row.get("dst", "")).strip() if not src or not rel or not dst: return [] try: confidence = float(row.get("confidence", 1.0)) except (TypeError, ValueError): confidence = 1.0 out.append(Edge(src=src, rel=rel, dst=dst, confidence=confidence)) return out def enumerate_swarm_v2_neighbors(graph: CanonicalGraph, node_id: str) -> list[Edge]: edges = [edge for edge in graph.edges if edge.src == node_id] edges.sort(key=lambda edge: (edge.src, edge.rel, edge.dst)) return [Edge(edge.src, edge.rel, edge.dst, float(edge.confidence)) for edge in edges] def trace_swarm_v2_path(graph: CanonicalGraph, path: list[Edge | dict[str, Any]]) -> list[Edge]: edges = _normalize_swarm_v2_path_edges(path) if not edges: return [] graph_edges = {(edge.src, edge.rel, edge.dst) for edge in graph.edges} for idx, edge in enumerate(edges): if (edge.src, edge.rel, edge.dst) not in graph_edges: return [] if idx > 0 and edges[idx - 1].dst != edge.src: return [] return edges def select_swarm_v2_answer(path_edges: list[Edge]) -> str: if not path_edges: return "" return path_edges[-1].dst def emit_swarm_v2_question(path_edges: list[Edge]) -> str: if not path_edges: return "" start = path_edges[0].src relation_path = " -> ".join(edge.rel for edge in path_edges) hops = len(path_edges) return ( f"If you start at {start} and follow the relation path {relation_path}, " f"which entity do you reach after {hops} hops?" ) def build_swarm_v2_tool_trace(graph: CanonicalGraph, path_edges: list[Edge]) -> list[dict[str, Any]]: traced = trace_swarm_v2_path(graph, path_edges) if not traced: return [] tool_trace: list[dict[str, Any]] = [] for idx, edge in enumerate(traced): neighbors = enumerate_swarm_v2_neighbors(graph, edge.src) tool_trace.append( { "tool_name": "enumerate_neighbors", "args": { "node_id": edge.src, "hop_index": idx, "expected_edge": _edge_payload(edge), }, "output": { "neighbors": [_edge_payload(candidate) for candidate in neighbors], }, } ) tool_trace.append( { "tool_name": "trace_path", "args": { "path": [_edge_payload(edge) for edge in traced], }, "output": { "path": [_edge_payload(edge) for edge in traced], }, } ) answer = select_swarm_v2_answer(traced) tool_trace.append( { "tool_name": "select_answer", "args": { "strategy": "path_dst", }, "output": { "answer": answer, }, } ) question = emit_swarm_v2_question(traced) tool_trace.append( { "tool_name": "emit_question", "args": { "style": "relation_path_v1", }, "output": { "question": question, }, } ) return tool_trace def build_swarm_v2_canonical_subgraph( graph: CanonicalGraph, path_edges: list[Edge], max_extra_edges: int = 4, ) -> dict[str, Any]: traced = trace_swarm_v2_path(graph, path_edges) if not traced: return {"nodes": [], "edges": [], "path": []} path_nodes = {traced[0].src} for edge in traced: path_nodes.add(edge.src) path_nodes.add(edge.dst) path_keys = {(edge.src, edge.rel, edge.dst) for edge in traced} extra_edges: list[Edge] = [] for edge in graph.edges: key = (edge.src, edge.rel, edge.dst) if key in path_keys: continue if edge.src in path_nodes or edge.dst in path_nodes: extra_edges.append(Edge(edge.src, edge.rel, edge.dst, float(edge.confidence))) if len(extra_edges) >= max(0, int(max_extra_edges)): break subgraph_edges = list(traced) + extra_edges subgraph_nodes = sorted({edge.src for edge in subgraph_edges} | {edge.dst for edge in subgraph_edges}) return { "nodes": subgraph_nodes, "edges": [_edge_payload(edge) for edge in subgraph_edges], "path": [_edge_payload(edge) for edge in traced], "answer": select_swarm_v2_answer(traced), } def build_swarm_v2_path_candidates( graph: CanonicalGraph, rng: random.Random, count: int, min_hops: int = 2, max_hops: int = 4, ) -> list[list[Edge]]: if count <= 0: return [] outgoing: dict[str, list[Edge]] = {} for edge in graph.edges: outgoing.setdefault(edge.src, []).append(edge) def path_match_count(path: list[Edge], limit: int = 4) -> int: if not path: return 0 relations = [edge.rel for edge in path] answer = path[-1].dst start = path[0].src match_count = 0 stack: list[tuple[str, int, tuple[str, ...]]] = [(start, 0, (start,))] while stack: node_id, rel_idx, seen_nodes = stack.pop() if rel_idx >= len(relations): if node_id == answer: match_count += 1 if match_count >= limit: return match_count continue relation = relations[rel_idx] for edge in outgoing.get(node_id, []): if edge.rel != relation: continue if edge.dst in seen_nodes: continue stack.append((edge.dst, rel_idx + 1, seen_nodes + (edge.dst,))) return match_count starts = [node_id for node_id, edges in outgoing.items() if edges] if not starts: return [] seen: set[tuple[tuple[str, str, str], ...]] = set() candidates: list[list[Edge]] = [] attempt_budget = max(16, count * 20) lower_hops = max(1, int(min_hops)) upper_hops = max(lower_hops, int(max_hops)) for _ in range(attempt_budget): if len(candidates) >= count: break current = rng.choice(starts) target_hops = rng.randint(lower_hops, upper_hops) path: list[Edge] = [] visited_nodes = {current} for _hop in range(target_hops): options = [edge for edge in outgoing.get(current, []) if edge.dst not in visited_nodes] if not options: break edge = rng.choice(options) path.append(Edge(edge.src, edge.rel, edge.dst, float(edge.confidence))) current = edge.dst visited_nodes.add(current) if len(path) < lower_hops: continue if path_match_count(path) != 1: continue key = tuple((edge.src, edge.rel, edge.dst) for edge in path) if key in seen: continue seen.add(key) candidates.append(path) if candidates: return candidates[:count] # Fall back to unique 1-hop paths only when the graph is too shallow for multi-hop traces. for edge in graph.edges: key = ((edge.src, edge.rel, edge.dst),) if key in seen: continue if path_match_count([edge]) != 1: continue seen.add(key) candidates.append([Edge(edge.src, edge.rel, edge.dst, float(edge.confidence))]) if len(candidates) >= count: break return candidates[:count] class DatasetGenerator: def __init__(self, config: EnvironmentConfig, llm: LLMClient | None = None): self.config = config self.rng = random.Random(config.seed) self.llm = llm self._metaqa_records: list[MetaQATaskRecord] = [] @staticmethod def _edge_key(edge: Edge) -> tuple[str, str, str]: return (edge.src, edge.rel, edge.dst) def _dataset_mode(self) -> str: token = str(getattr(self.config, "dataset_mode", "canonical") or "canonical").strip().lower() return "metaqa" if token == "metaqa" else "canonical" @staticmethod def _metaqa_difficulty(hop_label: str) -> str: hop = str(hop_label).strip().lower() if hop == "1-hop": return "easy" if hop == "2-hop": return "medium" return "hard" @staticmethod def _infer_node_type(node_id: str) -> NodeType: prefix = str(node_id).split("_", 1)[0].lower() mapping = { "user": NodeType.USER, "alias": NodeType.ALIAS, "org": NodeType.ORG, "loc": NodeType.LOCATION, "location": NodeType.LOCATION, "post": NodeType.POST, "thr": NodeType.THREAD, "thread": NodeType.THREAD, "event": NodeType.EVENT, } return mapping.get(prefix, NodeType.USER) def _ensure_node(self, graph: CanonicalGraph, node_id: str) -> None: if node_id in graph.nodes: return node_type = self._infer_node_type(node_id) attrs: dict[str, Any] = {} if node_type == NodeType.USER: attrs = {"name": node_id, "org": "Unknown", "location": "Unknown"} if node_type == NodeType.ALIAS: attrs = {"handle": f"@{node_id}"} graph.nodes[node_id] = Node(node_id=node_id, node_type=node_type, attrs=attrs) def _add_edge_if_missing(self, graph: CanonicalGraph, edge: Edge) -> None: key = self._edge_key(edge) if any(self._edge_key(existing) == key for existing in graph.edges): return self._ensure_node(graph, edge.src) self._ensure_node(graph, edge.dst) graph.edges.append(edge) @staticmethod def _extract_json_blob(text: str) -> Any: text = str(text).strip() if not text: return None for start, end in (("{", "}"), ("[", "]")): left = text.find(start) right = text.rfind(end) if left >= 0 and right > left: snippet = text[left : right + 1] try: return json.loads(snippet) except json.JSONDecodeError: continue return None def _apply_seed_nodes(self, graph: CanonicalGraph) -> None: for node_spec in self.config.seeding.seeded_nodes: node_type = ( node_spec.node_type if isinstance(node_spec.node_type, NodeType) else self._infer_node_type(node_spec.node_id) ) existing = graph.nodes.get(node_spec.node_id) attrs = dict(existing.attrs) if existing else {} attrs.update(node_spec.attrs) graph.nodes[node_spec.node_id] = Node(node_spec.node_id, node_type, attrs) def _apply_seed_edges(self, graph: CanonicalGraph) -> None: for edge_spec in self.config.seeding.seeded_edges: self._add_edge_if_missing( graph, Edge( src=edge_spec.src, rel=edge_spec.rel, dst=edge_spec.dst, confidence=float(edge_spec.confidence), ), ) @staticmethod def _normalize_edge_candidates(value: Any) -> list[SeedEdgeSpec]: items: list[SeedEdgeSpec] = [] if not isinstance(value, list): return items for row in value: if not isinstance(row, dict): continue src = str(row.get("src", "")).strip() rel = str(row.get("rel", "")).strip() dst = str(row.get("dst", "")).strip() if not src or not rel or not dst: continue try: confidence = float(row.get("confidence", 1.0)) except (TypeError, ValueError): confidence = 1.0 items.append(SeedEdgeSpec(src=src, rel=rel, dst=dst, confidence=confidence)) return items @staticmethod def _split_budget(total: int, parts: int) -> list[int]: if total <= 0: return [] slots = max(1, parts) base = total // slots remainder = total % slots chunks = [base + (1 if i < remainder else 0) for i in range(slots)] return [chunk for chunk in chunks if chunk > 0] @staticmethod def _shared_context_blob(graph: CanonicalGraph, node_limit: int = 100, edge_limit: int = 80) -> str: payload = { "known_nodes": sorted(graph.nodes.keys())[:node_limit], "known_edges": [ {"src": edge.src, "rel": edge.rel, "dst": edge.dst} for edge in graph.edges[: min(edge_limit, len(graph.edges))] ], } return json.dumps(payload) def _llm_generate_json_with_retry(self, prompt: str) -> Any: if self.llm is None: return None attempts = max(1, int(self.config.seeding.llm_generation_retries)) for _ in range(attempts): try: response = self.llm.generate([{"role": "system", "content": prompt}], tools=[]) except Exception: continue parsed = self._extract_json_blob(response.content) if parsed is not None: return parsed return None def _run_generation_workers(self, prompts: list[str]) -> list[Any]: if not prompts: return [] max_workers = max(1, min(self.config.seeding.llm_generation_workers, len(prompts))) if not self.config.seeding.llm_generation_parallel or max_workers == 1: output: list[Any] = [] for prompt in prompts: parsed = self._llm_generate_json_with_retry(prompt) if parsed is not None: output.append(parsed) return output output = [] with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [executor.submit(self._llm_generate_json_with_retry, prompt) for prompt in prompts] for future in as_completed(futures): try: parsed = future.result() except Exception: parsed = None if parsed is not None: output.append(parsed) return output def _template_fallback_allowed(self) -> bool: if self.llm is None: return True return bool(self.config.seeding.allow_template_fallback_on_llm_failure) def _template_generated_edges(self, graph: CanonicalGraph, budget: int) -> list[Edge]: if budget <= 0: return [] users = [n.node_id for n in graph.nodes.values() if n.node_type == NodeType.USER] aliases = [n.node_id for n in graph.nodes.values() if n.node_type == NodeType.ALIAS] if len(users) < 2: return [] generated: list[Edge] = [] rels = ["connected_to", "mentions", "co_occurs_with"] for _ in range(budget * 3): if len(generated) >= budget: break roll = self.rng.random() if aliases and roll < 0.2: src = self.rng.choice(aliases) dst = self.rng.choice(users) rel = "alias_of" elif roll < 0.75: src, dst = self.rng.sample(users, 2) rel = self.rng.choice(rels) else: src = self.rng.choice(users) dst = self.rng.choice([u for u in users if u != src]) rel = "connected_to" generated.append(Edge(src=src, rel=rel, dst=dst, confidence=0.7)) return generated[:budget] def _llm_expand_graph(self, graph: CanonicalGraph, budget: int) -> list[Edge]: if budget <= 0: return [] if self.llm is None: return self._template_generated_edges(graph, budget) shared_context = self._shared_context_blob(graph) workers = max(1, min(self.config.seeding.llm_generation_workers, budget)) chunks = self._split_budget(budget, workers) focus_tracks = ["entity_linking", "network_expansion", "org_location", "event_trace"] prompts: list[str] = [] for idx, chunk_budget in enumerate(chunks): focus = focus_tracks[idx % len(focus_tracks)] prompts.append( ( "SEED_GRAPH_EXPANSION_AGENT\n" "SHARED_CONTEXT\n" f"{shared_context}\n" f"worker_id: {idx}\n" f"focus: {focus}\n" f"budget: {chunk_budget}\n" "Generate plausible graph edges for OSINT retrieval.\n" "Return STRICT JSON object: {\"edges\": [{\"src\": str, \"rel\": str, \"dst\": str, \"confidence\": float}]}.\n" "Prefer known nodes from SHARED_CONTEXT and avoid duplicates." ) ) generated: list[Edge] = [] seen: set[tuple[str, str, str]] = set() for payload in self._run_generation_workers(prompts): raw_edges: Any = None if isinstance(payload, dict): raw_edges = payload.get("edges") elif isinstance(payload, list): raw_edges = payload for edge_spec in self._normalize_edge_candidates(raw_edges): key = (edge_spec.src, edge_spec.rel, edge_spec.dst) if key in seen: continue seen.add(key) generated.append(Edge(edge_spec.src, edge_spec.rel, edge_spec.dst, float(edge_spec.confidence))) if len(generated) >= budget: break if len(generated) >= budget: break if len(generated) < budget: residual = budget - len(generated) residual_prompt = ( "SEED_GRAPH_EXPANSION_AGENT\n" "SHARED_CONTEXT\n" f"{shared_context}\n" f"budget: {residual}\n" "Generate any remaining high-utility edges.\n" "Return STRICT JSON object: {\"edges\": [{\"src\": str, \"rel\": str, \"dst\": str, \"confidence\": float}]}." ) payload = self._llm_generate_json_with_retry(residual_prompt) raw_edges: Any = payload.get("edges") if isinstance(payload, dict) else payload for edge_spec in self._normalize_edge_candidates(raw_edges): key = (edge_spec.src, edge_spec.rel, edge_spec.dst) if key in seen: continue seen.add(key) generated.append(Edge(edge_spec.src, edge_spec.rel, edge_spec.dst, float(edge_spec.confidence))) if len(generated) >= budget: break if len(generated) < budget and self._template_fallback_allowed(): for edge in self._template_generated_edges(graph, budget - len(generated)): key = (edge.src, edge.rel, edge.dst) if key in seen: continue seen.add(key) generated.append(edge) if len(generated) >= budget: break return generated[:budget] @staticmethod def _extract_entity_tokens(question: str) -> list[str]: return re.findall(r"\b(?:alias|user|org|loc|post|thr|thread|event)_[a-zA-Z0-9_]+\b", question) @staticmethod def _normalize_difficulty(value: str, index: int) -> str: token = str(value or "").strip().lower() if token in {"easy", "e"}: return "easy" if token in {"mid", "medium", "m"}: return "medium" if token in {"high", "hard", "h"}: return "hard" if index < 10: return "easy" if index < 20: return "medium" return "hard" @staticmethod def _task_type_for_difficulty(base_task_type: str, difficulty: str) -> str: token = str(base_task_type or "").strip().lower() if token and token != "fixed_trace": return token if difficulty == "easy": return "easy_trace" if difficulty == "medium": return "medium_trace" return "hard_trace" @staticmethod def _grader_for_difficulty(difficulty: str) -> dict[str, Any]: return { "type": "difficulty_exact_match", "answer_type": "node_id", "case_sensitive": True, "reward_profile": difficulty, "logic": { "easy": "single_agent_simplified", "medium": "reduced_components", "hard": "full_reward", }.get(difficulty, "full_reward"), } def _task_metadata(self, index: int, base_task_type: str, metadata: dict[str, Any] | None = None) -> dict[str, Any]: out = dict(metadata or {}) difficulty = self._normalize_difficulty(out.get("difficulty", ""), index) out["difficulty"] = difficulty out.setdefault("grader", self._grader_for_difficulty(difficulty)) out.setdefault("scenario", self._task_type_for_difficulty(base_task_type, difficulty)) return out def _infer_answer_from_question(self, question: str, graph: CanonicalGraph) -> str: entities = self._extract_entity_tokens(question) question_l = question.lower() alias_tokens = [token for token in entities if token.startswith("alias_")] if alias_tokens: alias = alias_tokens[0] for edge in graph.edges: if edge.rel == "alias_of" and edge.src == alias: return edge.dst if "connected" in question_l: user_tokens = [token for token in entities if token.startswith("user_")] if user_tokens: source = user_tokens[0] for edge in graph.edges: if edge.rel == "connected_to" and edge.src == source: return edge.dst if "works at" in question_l: for edge in graph.edges: if edge.rel != "works_at": continue org = graph.nodes.get(edge.dst) org_name = str((org.attrs or {}).get("name", "")).lower() if org else "" if org_name and org_name in question_l: return edge.src return entities[0] if entities else "unknown" def _infer_support_edges(self, question: str, answer: str, graph: CanonicalGraph) -> list[Edge]: if answer: for edge in graph.edges: if edge.dst == answer or edge.src == answer: if edge.src in question or edge.dst in question or edge.rel in question.lower(): return [edge] entities = self._extract_entity_tokens(question) for edge in graph.edges: if edge.src in entities or edge.dst in entities: return [edge] return [] def _seeded_tasks(self, graph: CanonicalGraph) -> list[TaskInstance]: tasks: list[TaskInstance] = [] for idx, question_spec in enumerate(self.config.seeding.seeded_questions): answer = question_spec.answer or self._infer_answer_from_question(question_spec.question, graph) metadata = self._task_metadata(idx, question_spec.task_type, dict(question_spec.metadata)) difficulty = str(metadata.get("difficulty", "hard")) if question_spec.supporting_edges: support = [ Edge(src=e.src, rel=e.rel, dst=e.dst, confidence=float(e.confidence)) for e in question_spec.supporting_edges ] else: support = self._infer_support_edges(question_spec.question, answer, graph) tasks.append( TaskInstance( task_id=f"seed_task_{idx}", task_type=self._task_type_for_difficulty(question_spec.task_type, difficulty), question=question_spec.question, answer=answer, supporting_edges=support, metadata=metadata, ) ) return tasks def _template_tasks(self, graph: CanonicalGraph, count: int, start_idx: int = 0) -> list[TaskInstance]: alias_edges = [e for e in graph.edges if e.rel == "alias_of"] conn_edges = [e for e in graph.edges if e.rel == "connected_to"] work_edges = [e for e in graph.edges if e.rel == "works_at"] tasks: list[TaskInstance] = [] for i in range(count): mode = self.rng.choice(["identity_resolution", "network_discovery", "event_tracing"]) if mode == "identity_resolution" and alias_edges: edge = self.rng.choice(alias_edges) q = f"Which canonical user owns alias {edge.src}?" a = edge.dst support = [edge] elif mode == "network_discovery" and conn_edges: edge = self.rng.choice(conn_edges) q = f"Who is connected to {edge.src}?" a = edge.dst support = [edge] else: edge = self.rng.choice(work_edges) org_node = graph.nodes.get(edge.dst) org_name = (org_node.attrs or {}).get("name", edge.dst) if org_node else edge.dst q = f"Which user works at {org_name}?" a = edge.src support = [edge] tasks.append( TaskInstance( task_id=f"task_{start_idx + i}", task_type=mode, question=q, answer=a, supporting_edges=support, metadata=self._task_metadata(start_idx + i, mode), ) ) return tasks def _llm_generated_tasks(self, graph: CanonicalGraph, count: int, start_idx: int) -> list[TaskInstance]: if count <= 0: return [] if self.llm is None: return self._template_tasks(graph, count=count, start_idx=start_idx) candidate_edges = [ {"src": edge.src, "rel": edge.rel, "dst": edge.dst} for edge in graph.edges if edge.rel in {"alias_of", "connected_to", "works_at"} ][:60] shared_context = json.dumps( { "known_nodes": sorted(graph.nodes.keys())[:100], "edge_sample": candidate_edges, } ) workers = max(1, min(self.config.seeding.llm_generation_workers, count)) chunks = self._split_budget(count, workers) focus_tracks = ["identity_resolution", "network_discovery", "event_tracing", "deanonymization"] prompts: list[str] = [] for idx, chunk_budget in enumerate(chunks): focus = focus_tracks[idx % len(focus_tracks)] prompts.append( ( "SEED_TASK_EXPANSION_AGENT\n" "SHARED_CONTEXT\n" f"{shared_context}\n" f"worker_id: {idx}\n" f"focus: {focus}\n" f"task_budget: {chunk_budget}\n" "Generate OSINT QA tasks with answers and support edges.\n" "Return STRICT JSON object: {\"tasks\": [{\"task_type\": str, \"question\": str, \"answer\": str, \"supporting_edges\": [{\"src\": str, \"rel\": str, \"dst\": str, \"confidence\": float}]}]}." ) ) llm_tasks: list[TaskInstance] = [] seen_questions: set[str] = set() for payload in self._run_generation_workers(prompts): raw_tasks: Any = None if isinstance(payload, dict): raw_tasks = payload.get("tasks") elif isinstance(payload, list): raw_tasks = payload if not isinstance(raw_tasks, list): continue for row in raw_tasks: if not isinstance(row, dict): continue question = str(row.get("question", "")).strip() if not question: continue key = question.lower() if key in seen_questions: continue seen_questions.add(key) answer = str(row.get("answer", "")).strip() or self._infer_answer_from_question(question, graph) task_type = str(row.get("task_type", "llm_generated")).strip() or "llm_generated" support_specs = self._normalize_edge_candidates(row.get("supporting_edges")) if support_specs: support = [Edge(e.src, e.rel, e.dst, e.confidence) for e in support_specs] else: support = self._infer_support_edges(question, answer, graph) llm_tasks.append( TaskInstance( task_id=f"task_{start_idx + len(llm_tasks)}", task_type=task_type, question=question, answer=answer, supporting_edges=support, metadata=self._task_metadata( start_idx + len(llm_tasks), task_type, {"generated_by": "llm", "shared_context": True}, ), ) ) if len(llm_tasks) >= count: break if len(llm_tasks) >= count: break if len(llm_tasks) < count: residual = count - len(llm_tasks) residual_prompt = ( "SEED_TASK_EXPANSION_AGENT\n" "SHARED_CONTEXT\n" f"{shared_context}\n" f"task_budget: {residual}\n" "Generate additional tasks not already present in SHARED_CONTEXT.\n" "Return STRICT JSON object: {\"tasks\": [{\"task_type\": str, \"question\": str, \"answer\": str, \"supporting_edges\": [{\"src\": str, \"rel\": str, \"dst\": str, \"confidence\": float}]}]}." ) payload = self._llm_generate_json_with_retry(residual_prompt) raw_tasks: Any = payload.get("tasks") if isinstance(payload, dict) else payload if isinstance(raw_tasks, list): for row in raw_tasks: if not isinstance(row, dict): continue question = str(row.get("question", "")).strip() if not question: continue key = question.lower() if key in seen_questions: continue seen_questions.add(key) answer = str(row.get("answer", "")).strip() or self._infer_answer_from_question(question, graph) task_type = str(row.get("task_type", "llm_generated")).strip() or "llm_generated" support_specs = self._normalize_edge_candidates(row.get("supporting_edges")) if support_specs: support = [Edge(e.src, e.rel, e.dst, e.confidence) for e in support_specs] else: support = self._infer_support_edges(question, answer, graph) llm_tasks.append( TaskInstance( task_id=f"task_{start_idx + len(llm_tasks)}", task_type=task_type, question=question, answer=answer, supporting_edges=support, metadata=self._task_metadata( start_idx + len(llm_tasks), task_type, {"generated_by": "llm", "shared_context": True}, ), ) ) if len(llm_tasks) >= count: break if len(llm_tasks) < count and self._template_fallback_allowed(): llm_tasks.extend( self._template_tasks( graph, count=count - len(llm_tasks), start_idx=start_idx + len(llm_tasks), ) ) return llm_tasks[:count] def _metaqa_selected_records(self, count: int) -> list[MetaQATaskRecord]: records = list(self._metaqa_records) if not records: return [] if count <= 0 or len(records) <= count: return records grouped: dict[str, list[MetaQATaskRecord]] = {} for record in records: grouped.setdefault(record.hop_label, []).append(record) hop_keys = sorted(grouped.keys()) if not hop_keys: return records[:count] selected: list[MetaQATaskRecord] = [] leftovers: list[MetaQATaskRecord] = [] per_hop = max(1, count // len(hop_keys)) for hop in hop_keys: bucket = list(grouped[hop]) self.rng.shuffle(bucket) take = min(len(bucket), per_hop) selected.extend(bucket[:take]) leftovers.extend(bucket[take:]) if len(selected) < count: self.rng.shuffle(leftovers) selected.extend(leftovers[: count - len(selected)]) return selected[:count] def _metaqa_tasks(self, graph: CanonicalGraph, count: int) -> list[TaskInstance]: records = self._metaqa_selected_records(count) tasks: list[TaskInstance] = [] for idx, record in enumerate(records): difficulty = self._metaqa_difficulty(record.hop_label) support_edges = list(record.supporting_edges) if not support_edges: support_edges = infer_metaqa_support_edges( graph=graph, topic_entity=record.topic_entity, answer_candidates=record.answers, hop_count=record.hop_count, ) metadata = { "difficulty": difficulty, "hop": record.hop_label, "split": record.split, "source": "metaqa", "dataset_mode": "metaqa", "qtype": record.qtype, "topic_entity": record.topic_entity, "all_answers": list(record.answers), "grader": { "type": "metaqa_exact_match", "answer_type": "entity_name", "case_sensitive": True, "reward_profile": difficulty, "logic": "hop_trace", }, "scenario": f"metaqa_{record.hop_label}", } task_type = f"metaqa_{record.hop_label}" tasks.append( TaskInstance( task_id=f"metaqa_{record.hop_label}_{record.split}_{idx}", task_type=task_type, question=record.question, answer=record.primary_answer, supporting_edges=support_edges, metadata=metadata, ) ) return tasks def _build_platform_views_metaqa(self, graph: CanonicalGraph) -> PlatformViews: node_names = { node_id: str((node.attrs or {}).get("name") or node_id) for node_id, node in graph.nodes.items() } microblog_posts: list[dict] = [] for idx, edge in enumerate(graph.edges): microblog_posts.append( { "post_id": f"post_metaqa_{idx}", "user_id": edge.src, "canonical_user": edge.src, "text": f"{edge.src} {edge.rel} {edge.dst}", "references": [edge.src, edge.dst], "reference_names": [node_names.get(edge.src, edge.src), node_names.get(edge.dst, edge.dst)], "mentions": [edge.dst], "timestamp": 100000 + idx, } ) relation_groups: dict[str, list[Edge]] = {} for edge in graph.edges: relation_groups.setdefault(edge.rel, []).append(edge) forum_threads: list[dict] = [] for idx, rel in enumerate(sorted(relation_groups.keys())[:200]): group = relation_groups.get(rel, [])[:10] forum_threads.append( { "thread_id": f"thr_metaqa_{idx}", "topic": rel, "author_id": group[0].src if group else "metaqa", "comments": [ { "user_id": edge.src, "text": f"{edge.src} {edge.rel} {edge.dst}", } for edge in group ], "references": [edge.dst for edge in group], "discusses": [edge.dst for edge in group], } ) neighbors: dict[str, set[str]] = {} for edge in graph.edges: neighbors.setdefault(edge.src, set()).add(edge.dst) neighbors.setdefault(edge.dst, set()).add(edge.src) profiles: list[dict] = [] for node_id in sorted(graph.nodes.keys()): node = graph.nodes[node_id] profiles.append( { "user_id": node_id, "name": str((node.attrs or {}).get("name") or node_id), "org": str(node.node_type.value), "org_id": str(node.node_type.value), "location": "metaqa", "location_id": "metaqa", "alias_ids": [], "connections": sorted(neighbors.get(node_id, set()))[:8], "work_history": [str(node.node_type.value)], } ) return PlatformViews( microblog_posts=microblog_posts, forum_threads=forum_threads, profiles=profiles, alias_lookup={}, ) def build_canonical_graph(self) -> CanonicalGraph: if self._dataset_mode() == "metaqa": root = Path(self.config.metaqa_root) kb_path = Path(self.config.metaqa_kb_path) if str(self.config.metaqa_kb_path).strip() else None graph, records = load_metaqa_dataset( root=root, kb_path=kb_path, variant=self.config.metaqa_variant, hops=list(self.config.metaqa_hops), splits=list(self.config.metaqa_splits), ) self._metaqa_records = records self._apply_seed_nodes(graph) self._apply_seed_edges(graph) return graph graph = CanonicalGraph() orgs = ["Apex Dynamics", "Helios Labs", "Northbridge"] locations = ["Bengaluru", "Pune", "Hyderabad", "Delhi"] for i in range(self.config.n_users): uid = f"user_{i}" org = self.rng.choice(orgs) loc = self.rng.choice(locations) graph.nodes[uid] = Node(uid, NodeType.USER, {"name": f"Person {i}", "org": org, "location": loc}) org_id = f"org_{org.lower().replace(' ', '_')}" loc_id = f"loc_{loc.lower()}" graph.nodes.setdefault(org_id, Node(org_id, NodeType.ORG, {"name": org})) graph.nodes.setdefault(loc_id, Node(loc_id, NodeType.LOCATION, {"name": loc})) graph.edges.append(Edge(uid, "works_at", org_id)) graph.edges.append(Edge(uid, "located_in", loc_id)) if self.rng.random() < self.config.alias_density: alias = f"alias_{i}_{self.rng.randint(100,999)}" graph.nodes[alias] = Node(alias, NodeType.ALIAS, {"handle": f"@{alias}"}) graph.edges.append(Edge(alias, "alias_of", uid)) users = [n for n in graph.nodes.values() if n.node_type == NodeType.USER] for _ in range(max(1, self.config.n_users // 2)): a, b = self.rng.sample(users, 2) graph.edges.append(Edge(a.node_id, "connected_to", b.node_id, confidence=0.8)) self._apply_seed_nodes(graph) self._apply_seed_edges(graph) if self.config.seeding.llm_generate_remaining_graph: llm_edges = self._llm_expand_graph(graph, self.config.seeding.llm_generated_edge_budget) for edge in llm_edges: self._add_edge_if_missing(graph, edge) return graph def build_platform_views(self, graph: CanonicalGraph) -> PlatformViews: if self._dataset_mode() == "metaqa": return self._build_platform_views_metaqa(graph) users = [n for n in graph.nodes.values() if n.node_type == NodeType.USER] aliases = [n for n in graph.nodes.values() if n.node_type == NodeType.ALIAS] alias_owner = {e.src: e.dst for e in graph.edges if e.rel == "alias_of"} user_aliases: dict[str, list[str]] = {} for alias_id, user_id in alias_owner.items(): user_aliases.setdefault(user_id, []).append(alias_id) node_names = { node_id: str((node.attrs or {}).get("name") or (node.attrs or {}).get("handle") or node_id) for node_id, node in graph.nodes.items() } microblog_posts: list[dict] = [] for i, user in enumerate(users): poster = user.node_id if aliases and self.rng.random() < 0.45: candidate = self.rng.choice(aliases).node_id poster = candidate text = f"Update {i} from {user.attrs['org']} #{user.attrs['location'].lower()}" if self.rng.random() < self.config.noise_level: text = f"Rumor: {text} maybe fake" microblog_posts.append( { "post_id": f"post_{i}", "user_id": poster, "canonical_user": alias_owner.get(poster, user.node_id), "text": text, "references": [], "reference_names": [], "mentions": [f"user_{self.rng.randint(0, self.config.n_users - 1)}"], "timestamp": 1000 + i, } ) authored_posts: dict[str, str] = {} post_references: dict[str, list[str]] = {} for edge in graph.edges: if edge.rel == "authored_post": authored_posts[edge.dst] = edge.src elif edge.rel == "references" and edge.src.startswith("post_"): post_references.setdefault(edge.src, []).append(edge.dst) for post_id, author_id in authored_posts.items(): refs = post_references.get(post_id, []) ref_names = [node_names.get(ref, ref) for ref in refs] author_label = node_names.get(author_id, author_id) text_parts = [f"{post_id} update from {author_label}"] if ref_names: text_parts.append("references " + ", ".join(ref_names)) if refs: text_parts.append("ids " + ", ".join(refs)) post_payload = { "post_id": post_id, "user_id": author_id, "canonical_user": alias_owner.get(author_id, author_id), "text": ". ".join(text_parts), "references": refs, "reference_names": ref_names, "mentions": [], "timestamp": 5000 + len(microblog_posts), } existing_idx = next((idx for idx, row in enumerate(microblog_posts) if row["post_id"] == post_id), None) if existing_idx is None: microblog_posts.append(post_payload) else: microblog_posts[existing_idx] = post_payload forum_threads: list[dict] = [] for i in range(max(8, self.config.n_users // 3)): author = self.rng.choice(users).node_id forum_threads.append( { "thread_id": f"thr_{i}", "topic": self.rng.choice(["security", "startup", "ai", "infra"]), "author_id": author, "comments": [ {"user_id": self.rng.choice(users).node_id, "text": "Following this."}, {"user_id": self.rng.choice(users).node_id, "text": "Interesting link."}, ], "references": [], "discusses": [], } ) authored_threads: dict[str, str] = {} thread_refs: dict[str, list[str]] = {} thread_discusses: dict[str, list[str]] = {} for edge in graph.edges: if edge.rel == "authored_thread": authored_threads[edge.dst] = edge.src elif edge.rel == "references" and edge.src.startswith(("thr_", "thread_")): thread_refs.setdefault(edge.src, []).append(edge.dst) elif edge.rel == "discusses" and edge.src.startswith(("thr_", "thread_")): thread_discusses.setdefault(edge.src, []).append(edge.dst) for thread_id, author_id in authored_threads.items(): node = graph.nodes.get(thread_id) refs = thread_refs.get(thread_id, []) discussed = thread_discusses.get(thread_id, []) comments = [] for ref in refs: comments.append({"user_id": author_id, "text": f"Reference: {node_names.get(ref, ref)} ({ref})"}) for item in discussed: comments.append({"user_id": author_id, "text": f"Discusses: {node_names.get(item, item)} ({item})"}) thread_payload = { "thread_id": thread_id, "topic": str((node.attrs or {}).get("topic", "seeded")) if node else "seeded", "author_id": author_id, "title": node_names.get(thread_id, thread_id), "comments": comments, "references": refs, "discusses": discussed, } existing_idx = next((idx for idx, row in enumerate(forum_threads) if row["thread_id"] == thread_id), None) if existing_idx is None: forum_threads.append(thread_payload) else: forum_threads[existing_idx] = thread_payload profiles: list[dict] = [] for user in users: conns = [e.dst for e in graph.edges if e.src == user.node_id and e.rel == "connected_to"][:5] org_id = next((e.dst for e in graph.edges if e.src == user.node_id and e.rel == "works_at"), "") location_id = next((e.dst for e in graph.edges if e.src == user.node_id and e.rel == "located_in"), "") profiles.append( { "user_id": user.node_id, "name": user.attrs["name"], "org": user.attrs["org"], "org_id": org_id, "location": user.attrs["location"], "location_id": location_id, "alias_ids": sorted(user_aliases.get(user.node_id, [])), "connections": conns, "work_history": [user.attrs["org"]], } ) for i in range(int(len(users) * self.config.red_herring_rate)): profiles.append( { "user_id": f"noise_{i}", "name": f"P{self.rng.randint(100,999)}", "org": self.rng.choice(["Stealth Co", "Unknown Ventures"]), "org_id": "", "location": self.rng.choice(["Remote", "Unknown"]), "location_id": "", "alias_ids": [], "connections": [], "work_history": [], } ) return PlatformViews(microblog_posts, forum_threads, profiles, alias_lookup=alias_owner) def generate_tasks(self, graph: CanonicalGraph, views: PlatformViews, count: int = 12) -> list[TaskInstance]: if self._dataset_mode() == "metaqa": metaqa_tasks = self._metaqa_tasks(graph=graph, count=max(1, count)) if metaqa_tasks: return metaqa_tasks tasks = self._seeded_tasks(graph) target_count = max(1, count, len(tasks)) llm_budget = min( max(0, self.config.seeding.llm_generated_task_budget), max(0, target_count - len(tasks)), ) if self.config.seeding.llm_generate_remaining_tasks and llm_budget > 0: tasks.extend(self._llm_generated_tasks(graph, count=llm_budget, start_idx=len(tasks))) if len(tasks) < target_count and self._template_fallback_allowed(): tasks.extend(self._template_tasks(graph, count=target_count - len(tasks), start_idx=len(tasks))) if not tasks: tasks.extend(self._template_tasks(graph, count=target_count, start_idx=0)) return tasks[:target_count]