from __future__ import annotations import json import math import re from collections import Counter from dataclasses import asdict, dataclass from osint_env.domain.models import CanonicalGraph, Edge, TaskInstance @dataclass(slots=True) class RewardModel: relation_idf: dict[str, float] max_relation_idf: float hub_penalty: dict[str, float] max_hub_penalty: float type_priors: dict[tuple[str, str, str], float] @dataclass(slots=True) class EdgeRewardBreakdown: total: float global_accuracy: float soft_shaping: float efficiency: float diversity: float relation_informativeness: float entity_informativeness: float connectivity_gain: float def to_dict(self) -> dict[str, float]: return asdict(self) @dataclass(slots=True) class AnswerRewardBreakdown: total: float format_reward: float correctness: float knowledge_carrier: float knowledge_indexing: float connectivity: float graph_f1: float efficiency: float compactness: float relation_informativeness: float entity_informativeness: float repetition_penalty: float def to_dict(self) -> dict[str, float]: return asdict(self) def _normalize_difficulty(value: str) -> str: token = str(value or "").strip().lower() if token in {"easy", "e"}: return "easy" if token in {"mid", "medium", "m"}: return "medium" if token in {"high", "hard", "h"}: return "hard" return "hard" def build_reward_model(graph: CanonicalGraph) -> RewardModel: relation_freq: Counter[str] = Counter(e.rel for e in graph.edges) total_edges = max(1, len(graph.edges)) relation_idf = { rel: math.log((1.0 + total_edges) / (1.0 + freq)) + 1.0 for rel, freq in relation_freq.items() } max_relation_idf = max(relation_idf.values()) if relation_idf else 1.0 degree: Counter[str] = Counter() for edge in graph.edges: degree[edge.src] += 1 degree[edge.dst] += 1 hub_penalty = {node_id: math.log(1.0 + deg) for node_id, deg in degree.items()} max_hub_penalty = max(hub_penalty.values()) if hub_penalty else 1.0 type_counts: Counter[tuple[str, str, str]] = Counter() rel_counts: Counter[str] = Counter() for edge in graph.edges: src = graph.nodes.get(edge.src) dst = graph.nodes.get(edge.dst) if src is None or dst is None: continue key = (str(src.node_type.value), edge.rel, str(dst.node_type.value)) type_counts[key] += 1 rel_counts[edge.rel] += 1 type_priors = { key: count / max(1, rel_counts[key[1]]) for key, count in type_counts.items() } return RewardModel( relation_idf=relation_idf, max_relation_idf=max_relation_idf, hub_penalty=hub_penalty, max_hub_penalty=max_hub_penalty, type_priors=type_priors, ) def edge_in_truth(edge: Edge, task: TaskInstance) -> bool: return any(e.src == edge.src and e.rel == edge.rel and e.dst == edge.dst for e in task.supporting_edges) def _cosine(a: Counter[str], b: Counter[str]) -> float: common = set(a) & set(b) num = sum(a[t] * b[t] for t in common) den = math.sqrt(sum(v * v for v in a.values())) * math.sqrt(sum(v * v for v in b.values())) return (num / den) if den else 0.0 def _edge_signature(edge: Edge) -> Counter[str]: # Approximate path/edge embedding using relation and endpoint prefixes. src_prefix = edge.src.split("_", 1)[0] dst_prefix = edge.dst.split("_", 1)[0] return Counter({f"rel:{edge.rel}": 2, f"src:{src_prefix}": 1, f"dst:{dst_prefix}": 1}) def _soft_fact_score(edge: Edge, model: RewardModel, graph: CanonicalGraph) -> float: if any(e.src == edge.src and e.rel == edge.rel and e.dst == edge.dst for e in graph.edges): return 1.0 src = graph.nodes.get(edge.src) dst = graph.nodes.get(edge.dst) if src is None or dst is None: return 0.0 type_key = (str(src.node_type.value), edge.rel, str(dst.node_type.value)) prior = model.type_priors.get(type_key, 0.0) # A tiny domain heuristic: alias links are common and worth soft credit even without exact support edge. alias_bias = 0.2 if (edge.rel == "alias_of" and edge.src.startswith("alias_") and edge.dst.startswith("user_")) else 0.0 relation_exists = any(e.rel == edge.rel for e in graph.edges) relation_bonus = 0.1 if relation_exists else 0.0 return max(0.0, min(1.0, 0.1 + (0.65 * prior) + alias_bias + relation_bonus)) def _normalized_relation_info(rel: str, model: RewardModel) -> float: idf = model.relation_idf.get(rel, 1.0) return idf / max(1e-6, model.max_relation_idf) def _normalized_entity_info(src: str, dst: str, model: RewardModel) -> float: src_h = model.hub_penalty.get(src, 0.0) dst_h = model.hub_penalty.get(dst, 0.0) mean_hub = (src_h + dst_h) / 2.0 # UniRel-style preference for low-degree intermediates: lower hub penalty -> higher informativeness. return 1.0 - (mean_hub / max(1e-6, model.max_hub_penalty)) def _is_reachable_undirected(edges: list[Edge], src: str, dst: str) -> bool: if src == dst: return True adj: dict[str, set[str]] = {} for edge in edges: adj.setdefault(edge.src, set()).add(edge.dst) adj.setdefault(edge.dst, set()).add(edge.src) seen = {src} stack = [src] while stack: node = stack.pop() for nxt in adj.get(node, set()): if nxt == dst: return True if nxt not in seen: seen.add(nxt) stack.append(nxt) return False def _connectivity_gain(edge: Edge, existing_edges: list[Edge]) -> float: # Reward edges that bridge disconnected regions and penalize already-connected shortcuts. if edge.src == edge.dst: return -0.06 already_connected = _is_reachable_undirected(existing_edges, edge.src, edge.dst) if already_connected: return -0.03 return 0.10 def _sigmoid_temperature(value: float, temperature: float = 2.0) -> float: scaled = float(value) / max(1e-6, float(temperature)) if scaled >= 0: z = math.exp(-scaled) return 1.0 / (1.0 + z) z = math.exp(scaled) return z / (1.0 + z) def compute_edge_reward( edge: Edge, task: TaskInstance, existing_edges: list[Edge], step_count: int, model: RewardModel, graph: CanonicalGraph, difficulty: str = "hard", ) -> EdgeRewardBreakdown: in_truth = edge_in_truth(edge, task) difficulty_level = _normalize_difficulty(difficulty) # DeepPath-inspired global accuracy term. global_accuracy = 0.85 if in_truth else -0.55 # D18 reward shaping: R = Rb + (1 - Rb) * f, where f is a soft fact plausibility score. base_reward = 1.0 if in_truth else 0.0 shaped = base_reward + ((1.0 - base_reward) * _soft_fact_score(edge, model, graph)) soft_shaping = 0.30 * (shaped - 0.5) # DeepPath-inspired efficiency term: earlier useful edges are better. efficiency = 0.10 * (1.0 / max(1, step_count)) # DeepPath-inspired diversity term: discourage repeated edge patterns. if not existing_edges: diversity = 0.08 else: new_sig = _edge_signature(edge) avg_similarity = sum(_cosine(new_sig, _edge_signature(e)) for e in existing_edges) / len(existing_edges) novelty = 1.0 - avg_similarity diversity = 0.14 * (novelty - 0.5) # UniRel-style informativeness terms. relation_informativeness = 0.12 * (_normalized_relation_info(edge.rel, model) - 0.5) entity_informativeness = 0.12 * (_normalized_entity_info(edge.src, edge.dst, model) - 0.5) # Additional structural utility shaping for KG construction. connectivity_gain = _connectivity_gain(edge, existing_edges) if difficulty_level == "easy": global_accuracy = 0.75 if in_truth else -0.45 soft_shaping = 0.0 diversity = 0.0 relation_informativeness = 0.0 entity_informativeness = 0.0 connectivity_gain = 0.0 efficiency = 0.15 * (1.0 / max(1, step_count)) elif difficulty_level == "medium": diversity = 0.0 relation_informativeness = 0.0 entity_informativeness = 0.0 raw_total = ( global_accuracy + soft_shaping + efficiency + diversity + relation_informativeness + entity_informativeness + connectivity_gain ) total = _sigmoid_temperature(raw_total, temperature=2.0) return EdgeRewardBreakdown( total=total, global_accuracy=global_accuracy, soft_shaping=soft_shaping, efficiency=efficiency, diversity=diversity, relation_informativeness=relation_informativeness, entity_informativeness=entity_informativeness, connectivity_gain=connectivity_gain, ) def _connectivity_ratio(pred_edges: list[Edge], task: TaskInstance) -> float: nodes = {e.src for e in task.supporting_edges} | {e.dst for e in task.supporting_edges} if len(nodes) <= 1: return 1.0 adj: dict[str, set[str]] = {} for edge in pred_edges: adj.setdefault(edge.src, set()).add(edge.dst) adj.setdefault(edge.dst, set()).add(edge.src) start = next(iter(nodes)) seen = {start} stack = [start] while stack: cur = stack.pop() for nxt in adj.get(cur, set()): if nxt not in seen: seen.add(nxt) stack.append(nxt) return len(seen & nodes) / max(1, len(nodes)) def _knowledge_indexing_recall(task: TaskInstance, tool_outputs: list[dict[str, object]]) -> float: gold_terms = {task.answer.lower()} for edge in task.supporting_edges: gold_terms.add(edge.src.lower()) gold_terms.add(edge.dst.lower()) gold_terms.add(edge.rel.lower()) serialized = json.dumps(tool_outputs).lower() covered = sum(1 for term in gold_terms if term and term in serialized) return covered / max(1, len(gold_terms)) def _knowledge_carrier_reward(pred_edges: list[Edge], task: TaskInstance) -> float: pred = {(e.src, e.rel, e.dst) for e in pred_edges} truth = {(e.src, e.rel, e.dst) for e in task.supporting_edges} deducible = bool(truth & pred) return 0.4 if deducible else -0.2 def _extract_query_entities(question: str) -> set[str]: pattern = r"\b(?:alias|user|org|loc|post|thr|thread|event)_[a-zA-Z0-9_]+\b" return set(re.findall(pattern, question)) def _max_connected_seed_count(pred_edges: list[Edge], seeds: set[str]) -> int: if not seeds: return 0 adj: dict[str, set[str]] = {} for edge in pred_edges: adj.setdefault(edge.src, set()).add(edge.dst) adj.setdefault(edge.dst, set()).add(edge.src) best = 1 for seed in seeds: seen = {seed} stack = [seed] while stack: cur = stack.pop() for nxt in adj.get(cur, set()): if nxt not in seen: seen.add(nxt) stack.append(nxt) connected_seed_count = len(seeds & seen) best = max(best, connected_seed_count) return best def _unirel_connectivity_score(pred_edges: list[Edge], seeds: set[str]) -> float: # UniRel-style discrete connectivity range projected to [-1, 1] for stable weighting. n = len(seeds) if n <= 1: return 0.0 connected = _max_connected_seed_count(pred_edges, seeds) raw = -math.floor(n / 2) + max(0, connected - 1) lo = -math.floor(n / 2) hi = math.ceil(n / 2) - 1 if hi <= lo: return 0.0 return ((raw - lo) / (hi - lo)) * 2.0 - 1.0 def _subgraph_relation_informativeness(pred_edges: list[Edge], model: RewardModel | None) -> float: if not pred_edges or model is None: return 0.0 avg = sum(_normalized_relation_info(edge.rel, model) for edge in pred_edges) / len(pred_edges) return avg - 0.5 def _subgraph_entity_informativeness(pred_edges: list[Edge], model: RewardModel | None) -> float: if not pred_edges or model is None: return 0.0 avg = sum(_normalized_entity_info(edge.src, edge.dst, model) for edge in pred_edges) / len(pred_edges) return avg - 0.5 def _relation_repetition_ratio(pred_edges: list[Edge]) -> float: if len(pred_edges) <= 1: return 0.0 rels = [edge.rel for edge in pred_edges] unique = len(set(rels)) return 1.0 - (unique / len(rels)) def _deducible_answer(proposed_answer: str, task: TaskInstance, pred_edges: list[Edge]) -> bool: if proposed_answer != task.answer: return False truth = {(edge.src, edge.rel, edge.dst) for edge in task.supporting_edges} pred = {(edge.src, edge.rel, edge.dst) for edge in pred_edges} if truth & pred: return True seeds = _extract_query_entities(task.question) if not seeds: return False for seed in seeds: if _is_reachable_undirected(pred_edges, seed, proposed_answer): return True return False def compute_answer_reward( proposed_answer: str, task: TaskInstance, pred_edges: list[Edge], tool_outputs: list[dict[str, object]], step_count: int, model: RewardModel | None = None, difficulty: str = "hard", ) -> AnswerRewardBreakdown: difficulty_level = _normalize_difficulty(difficulty) format_reward = 0.15 if proposed_answer else -0.55 correctness = 1.15 if proposed_answer == task.answer else -1.0 # AutoGraph-R1 style task utility decomposition. knowledge_carrier = 0.50 if _deducible_answer(proposed_answer, task, pred_edges) else -0.25 knowledge_indexing = 0.45 * _knowledge_indexing_recall(task, tool_outputs) # UniRel-style connectivity over seed entities. seed_entities = _extract_query_entities(task.question) seed_entities.add(task.answer) connectivity = 0.30 * _unirel_connectivity_score(pred_edges, seed_entities) graph_f1 = 0.55 * compute_graph_f1(pred_edges, task.supporting_edges) efficiency = 0.12 * (1.0 / max(1, step_count)) extra_edges = max(0, len(pred_edges) - len(task.supporting_edges)) compactness = -0.05 * extra_edges relation_informativeness = 0.12 * _subgraph_relation_informativeness(pred_edges, model) entity_informativeness = 0.12 * _subgraph_entity_informativeness(pred_edges, model) # AutoGraph-R1 repetition control variant used in larger models. repetition_penalty = -0.10 * _relation_repetition_ratio(pred_edges) if difficulty_level == "easy": knowledge_carrier = 0.0 knowledge_indexing = 0.25 * _knowledge_indexing_recall(task, tool_outputs) connectivity = 0.0 graph_f1 = 0.0 efficiency = 0.18 * (1.0 / max(1, step_count)) compactness = 0.0 relation_informativeness = 0.0 entity_informativeness = 0.0 repetition_penalty = 0.0 elif difficulty_level == "medium": connectivity = 0.18 * _unirel_connectivity_score(pred_edges, seed_entities) graph_f1 = 0.35 * compute_graph_f1(pred_edges, task.supporting_edges) compactness = -0.04 * extra_edges relation_informativeness = 0.0 entity_informativeness = 0.0 repetition_penalty = 0.0 raw_total = ( format_reward + correctness + knowledge_carrier + knowledge_indexing + connectivity + graph_f1 + efficiency + compactness + relation_informativeness + entity_informativeness + repetition_penalty ) total = _sigmoid_temperature(raw_total, temperature=2.0) return AnswerRewardBreakdown( total=total, format_reward=format_reward, correctness=correctness, knowledge_carrier=knowledge_carrier, knowledge_indexing=knowledge_indexing, connectivity=connectivity, graph_f1=graph_f1, efficiency=efficiency, compactness=compactness, relation_informativeness=relation_informativeness, entity_informativeness=entity_informativeness, repetition_penalty=repetition_penalty, ) def compute_graph_f1(pred_edges: list[Edge], truth_edges: list[Edge]) -> float: pred = {(e.src, e.rel, e.dst) for e in pred_edges} truth = {(e.src, e.rel, e.dst) for e in truth_edges} if not pred and not truth: return 1.0 if not pred or not truth: return 0.0 tp = len(pred & truth) p = tp / len(pred) if pred else 0.0 r = tp / len(truth) if truth else 0.0 return (2 * p * r / (p + r)) if (p + r) else 0.0