Spaces:
Paused
Paused
siddeshwar-kagatikar
feat(training): improve self-play progress visibility and reward diagnostics
4aca4f5 | from __future__ import annotations | |
| import json | |
| import re | |
| from collections import Counter | |
| from dataclasses import dataclass, field | |
| from functools import lru_cache | |
| from typing import Any | |
| from osint_env.data.generator import ( | |
| build_swarm_v2_tool_trace, | |
| emit_swarm_v2_question, | |
| enumerate_swarm_v2_neighbors, | |
| select_swarm_v2_answer, | |
| trace_swarm_v2_path, | |
| ) | |
| from osint_env.domain.models import CanonicalGraph, Edge, TaskInstance | |
| from osint_env.env.reward import build_reward_model, compute_answer_reward | |
| from osint_env.env.spawn_reward_hooks import parl_reward_breakdown | |
| from osint_env.training.config import ( | |
| GeneratorRewardWeights, | |
| SwarmV2SharedContextConfig, | |
| SwarmV2ValidationConfig, | |
| ) | |
| def _iter_text_fragments(value: Any) -> list[str]: | |
| if value is None: | |
| return [] | |
| if isinstance(value, str): | |
| token = value.strip() | |
| return [token] if token else [] | |
| if isinstance(value, list): | |
| out: list[str] = [] | |
| for item in value: | |
| out.extend(_iter_text_fragments(item)) | |
| return out | |
| if isinstance(value, dict): | |
| out: list[str] = [] | |
| for key in ("content", "text", "output_text", "message", "choices"): | |
| if key in value: | |
| out.extend(_iter_text_fragments(value.get(key))) | |
| return out | |
| return [str(value)] | |
| def decode_completion_text(completion: Any) -> str: | |
| parts = _iter_text_fragments(completion) | |
| return "\n".join(part for part in parts if part) | |
| def _extract_json_candidates(text: str) -> list[Any]: | |
| candidate = str(text or "").strip() | |
| if not candidate: | |
| return [] | |
| out: list[Any] = [] | |
| try: | |
| parsed = json.loads(candidate) | |
| except json.JSONDecodeError: | |
| parsed = None | |
| if isinstance(parsed, dict): | |
| out.append(parsed) | |
| for start_idx, ch in enumerate(candidate): | |
| if ch != "{": | |
| continue | |
| depth = 0 | |
| in_string = False | |
| escape = False | |
| for end_idx in range(start_idx, len(candidate)): | |
| current = candidate[end_idx] | |
| if escape: | |
| escape = False | |
| continue | |
| if current == "\\": | |
| escape = True | |
| continue | |
| if current == '"': | |
| in_string = not in_string | |
| continue | |
| if in_string: | |
| continue | |
| if current == "{": | |
| depth += 1 | |
| elif current == "}": | |
| depth -= 1 | |
| if depth < 0: | |
| break | |
| if depth != 0: | |
| continue | |
| snippet = candidate[start_idx : end_idx + 1] | |
| try: | |
| parsed = json.loads(snippet) | |
| except json.JSONDecodeError: | |
| break | |
| if isinstance(parsed, dict) and parsed not in out: | |
| out.append(parsed) | |
| break | |
| return out | |
| def _extract_json_blob(text: str, preferred_keys: tuple[str, ...] = ()) -> Any: | |
| blobs = _extract_json_candidates(text) | |
| if not blobs: | |
| return None | |
| if not preferred_keys: | |
| return blobs[0] | |
| best_blob = blobs[0] | |
| best_score = -1 | |
| preferred = set(preferred_keys) | |
| for blob in blobs: | |
| if not isinstance(blob, dict): | |
| continue | |
| score = sum(1 for key in preferred if key in blob) | |
| if score > best_score: | |
| best_blob = blob | |
| best_score = score | |
| return best_blob | |
| def normalize_answer(text: str) -> str: | |
| value = str(text or "").strip() | |
| value = value.strip('"').strip("'") | |
| value = re.sub(r"\s+", " ", value) | |
| value = value.rstrip(".\n ") | |
| return value | |
| def extract_answer_from_completion(completion_text: str) -> str: | |
| blob = _extract_json_blob(completion_text, preferred_keys=("answer",)) | |
| if isinstance(blob, dict): | |
| answer = str(blob.get("answer", "")).strip() | |
| if answer: | |
| return normalize_answer(answer) | |
| match = re.search(r"answer\s*[:=]\s*(.+)", completion_text, flags=re.IGNORECASE) | |
| if match: | |
| return normalize_answer(match.group(1)) | |
| lines = [line.strip() for line in completion_text.splitlines() if line.strip()] | |
| if not lines: | |
| return "" | |
| return normalize_answer(lines[-1]) | |
| class SwarmReplayToolCall: | |
| tool_name: str | |
| args: dict[str, Any] = field(default_factory=dict) | |
| output: dict[str, Any] = field(default_factory=dict) | |
| class SwarmOrchestratorTelemetry: | |
| spawn_count: int = 0 | |
| finished_subtasks: int = 0 | |
| critical_steps: int = 1 | |
| breadth: int = 0 | |
| depth: int = 0 | |
| class ReplayValidationResult: | |
| is_valid: bool | |
| reasons: list[str] = field(default_factory=list) | |
| duplicate_similarity: float = 0.0 | |
| context_nodes: int = 0 | |
| context_edges: int = 0 | |
| unique_path_count: int = 0 | |
| replayed_question: str = "" | |
| replayed_answer: str = "" | |
| replayed_edges: list[Edge] = field(default_factory=list) | |
| def to_dict(self) -> dict[str, Any]: | |
| return { | |
| "is_valid": self.is_valid, | |
| "reasons": list(self.reasons), | |
| "duplicate_similarity": float(self.duplicate_similarity), | |
| "context_nodes": int(self.context_nodes), | |
| "context_edges": int(self.context_edges), | |
| "unique_path_count": int(self.unique_path_count), | |
| "replayed_question": self.replayed_question, | |
| "replayed_answer": self.replayed_answer, | |
| "replayed_edges": [ | |
| { | |
| "src": edge.src, | |
| "rel": edge.rel, | |
| "dst": edge.dst, | |
| "confidence": float(edge.confidence), | |
| } | |
| for edge in self.replayed_edges | |
| ], | |
| } | |
| def _parse_edge_rows(value: Any, max_support_edges: int) -> list[Edge]: | |
| if not isinstance(value, list): | |
| return [] | |
| out: list[Edge] = [] | |
| for row in value[:max_support_edges]: | |
| if not isinstance(row, dict): | |
| continue | |
| src = str(row.get("src", "")).strip() | |
| rel = str(row.get("rel", "")).strip() | |
| dst = str(row.get("dst", "")).strip() | |
| if not src or not rel or not dst: | |
| continue | |
| try: | |
| confidence = float(row.get("confidence", 1.0)) | |
| except (TypeError, ValueError): | |
| confidence = 1.0 | |
| out.append(Edge(src=src, rel=rel, dst=dst, confidence=confidence)) | |
| return out | |
| def _parse_tool_trace(value: Any) -> list[SwarmReplayToolCall]: | |
| if not isinstance(value, list): | |
| return [] | |
| out: list[SwarmReplayToolCall] = [] | |
| for row in value: | |
| if not isinstance(row, dict): | |
| continue | |
| tool_name = str(row.get("tool_name", row.get("tool", ""))).strip() | |
| args = row.get("args", {}) | |
| output = row.get("output", row.get("result", {})) | |
| if not tool_name: | |
| continue | |
| out.append( | |
| SwarmReplayToolCall( | |
| tool_name=tool_name, | |
| args=dict(args) if isinstance(args, dict) else {}, | |
| output=dict(output) if isinstance(output, dict) else {}, | |
| ) | |
| ) | |
| return out | |
| def _parse_subagent_outputs(value: Any) -> list[str]: | |
| if not isinstance(value, list): | |
| return [] | |
| out: list[str] = [] | |
| for row in value: | |
| if isinstance(row, str): | |
| token = row.strip() | |
| elif isinstance(row, dict): | |
| token = str(row.get("content", row.get("summary", ""))).strip() | |
| else: | |
| token = str(row).strip() | |
| if token: | |
| out.append(token) | |
| return out | |
| def _coerce_int(value: Any, default: int) -> int: | |
| """Best-effort int coercion that NEVER raises. | |
| Models routinely emit garbage like ``"none"``, ``"N/A"``, ``true``, | |
| ``"2 agents"`` for fields the schema requires to be integers. The | |
| reward function runs inside the GRPO training loop, so a single | |
| ``ValueError`` here crashes the entire training job. Be defensive. | |
| """ | |
| if value is None: | |
| return default | |
| if isinstance(value, bool): | |
| return int(value) | |
| if isinstance(value, int): | |
| return value | |
| if isinstance(value, float): | |
| if value != value or value in (float("inf"), float("-inf")): | |
| return default | |
| return int(value) | |
| if isinstance(value, str): | |
| token = value.strip() | |
| if not token: | |
| return default | |
| try: | |
| return int(token) | |
| except ValueError: | |
| try: | |
| return int(float(token)) | |
| except ValueError: | |
| match = re.search(r"[-+]?\d+(?:\.\d+)?", token) | |
| if match: | |
| try: | |
| return int(float(match.group(0))) | |
| except ValueError: | |
| return default | |
| return default | |
| return default | |
| def _parse_orchestrator(value: Any) -> SwarmOrchestratorTelemetry: | |
| if not isinstance(value, dict): | |
| return SwarmOrchestratorTelemetry() | |
| return SwarmOrchestratorTelemetry( | |
| spawn_count=max(0, _coerce_int(value.get("spawn_count"), 0)), | |
| finished_subtasks=max(0, _coerce_int(value.get("finished_subtasks"), 0)), | |
| critical_steps=max(1, _coerce_int(value.get("critical_steps"), 1)), | |
| breadth=max(0, _coerce_int(value.get("breadth"), 0)), | |
| depth=max(0, _coerce_int(value.get("depth"), 0)), | |
| ) | |
| class GeneratedTaskCandidate: | |
| question: str | |
| answer: str | |
| supporting_edges: list[Edge] | |
| task_type: str | |
| is_valid: bool | |
| tool_trace: list[SwarmReplayToolCall] = field(default_factory=list) | |
| subagent_outputs: list[str] = field(default_factory=list) | |
| canonical_edges: list[Edge] = field(default_factory=list) | |
| canonical_nodes: list[str] = field(default_factory=list) | |
| orchestrator: SwarmOrchestratorTelemetry = field(default_factory=SwarmOrchestratorTelemetry) | |
| validation: dict[str, Any] = field(default_factory=dict) | |
| def parse_generated_task_completion(completion_text: str, max_support_edges: int = 8) -> GeneratedTaskCandidate: | |
| blob = _extract_json_blob( | |
| completion_text, | |
| preferred_keys=("question", "answer", "supporting_edges", "tool_trace", "canonical_graph"), | |
| ) | |
| question = "" | |
| answer = "" | |
| task_type = "adversarial_trace" | |
| supporting_edges: list[Edge] = [] | |
| tool_trace: list[SwarmReplayToolCall] = [] | |
| subagent_outputs: list[str] = [] | |
| canonical_edges: list[Edge] = [] | |
| canonical_nodes: list[str] = [] | |
| orchestrator = SwarmOrchestratorTelemetry() | |
| validation: dict[str, Any] = {} | |
| if isinstance(blob, dict): | |
| question = str(blob.get("question", "")).strip() | |
| answer = normalize_answer(str(blob.get("answer", "")).strip()) | |
| task_type = str(blob.get("task_type", "adversarial_trace")).strip() or "adversarial_trace" | |
| supporting_edges = _parse_edge_rows(blob.get("supporting_edges", []), max_support_edges=max_support_edges) | |
| tool_trace = _parse_tool_trace(blob.get("tool_trace", [])) | |
| subagent_outputs = _parse_subagent_outputs(blob.get("subagent_outputs", [])) | |
| orchestrator = _parse_orchestrator(blob.get("orchestrator")) | |
| validation = dict(blob.get("validation", {})) if isinstance(blob.get("validation"), dict) else {} | |
| canonical_graph = blob.get("canonical_graph", {}) | |
| if isinstance(canonical_graph, dict): | |
| canonical_nodes = [ | |
| str(node_id).strip() | |
| for node_id in canonical_graph.get("nodes", []) | |
| if str(node_id).strip() | |
| ] | |
| canonical_edges = _parse_edge_rows( | |
| canonical_graph.get("edges", []), | |
| max_support_edges=max(1, max_support_edges * 4), | |
| ) | |
| if not question: | |
| line_match = re.search(r"question\s*[:=]\s*(.+)", completion_text, flags=re.IGNORECASE) | |
| if line_match: | |
| question = line_match.group(1).strip() | |
| if not answer: | |
| answer = extract_answer_from_completion(completion_text) | |
| is_valid = bool(question and answer) | |
| return GeneratedTaskCandidate( | |
| question=question, | |
| answer=answer, | |
| supporting_edges=supporting_edges, | |
| task_type=task_type, | |
| is_valid=is_valid, | |
| tool_trace=tool_trace, | |
| subagent_outputs=subagent_outputs, | |
| canonical_edges=canonical_edges, | |
| canonical_nodes=canonical_nodes, | |
| orchestrator=orchestrator, | |
| validation=validation, | |
| ) | |
| def _token_set(text: str) -> set[str]: | |
| return set(re.findall(r"[a-zA-Z0-9_]+", str(text).lower())) | |
| def _jaccard_similarity(left: str, right: str) -> float: | |
| a = _token_set(left) | |
| b = _token_set(right) | |
| if not a and not b: | |
| return 1.0 | |
| if not a or not b: | |
| return 0.0 | |
| return len(a & b) / max(1, len(a | b)) | |
| _SWARM_V2_QUESTION_RE = re.compile( | |
| r"^If you start at (?P<start>.+?) and follow the relation path " | |
| r"(?P<relations>.+?), which entity do you reach after (?P<hops>\d+) hops\?$" | |
| ) | |
| def _swarm_v2_question_signature(question: str) -> tuple[str, tuple[str, ...], int] | None: | |
| match = _SWARM_V2_QUESTION_RE.match(str(question or "").strip()) | |
| if not match: | |
| return None | |
| start = normalize_answer(match.group("start")).lower() | |
| relations = tuple( | |
| token.strip().lower() | |
| for token in match.group("relations").split("->") | |
| if token.strip() | |
| ) | |
| if not start or not relations: | |
| return None | |
| return start, relations, int(match.group("hops")) | |
| def _swarm_v2_question_similarity(left: str, right: str) -> float: | |
| left_sig = _swarm_v2_question_signature(left) | |
| right_sig = _swarm_v2_question_signature(right) | |
| if left_sig is None or right_sig is None: | |
| return _jaccard_similarity(left, right) | |
| left_start, left_relations, left_hops = left_sig | |
| right_start, right_relations, right_hops = right_sig | |
| start_score = 1.0 if left_start == right_start else 0.0 | |
| path_score = 1.0 if left_relations == right_relations else _jaccard_similarity( | |
| " ".join(left_relations), | |
| " ".join(right_relations), | |
| ) | |
| hop_score = 1.0 if left_hops == right_hops else 0.0 | |
| return (0.55 * start_score) + (0.35 * path_score) + (0.10 * hop_score) | |
| def _distinct_ngram_ratio(texts: list[str], n: int = 2) -> float: | |
| tokens: list[str] = [] | |
| for text in texts: | |
| tokens.extend(re.findall(r"[a-zA-Z0-9_]+", text.lower())) | |
| if len(tokens) < n: | |
| return 0.0 if texts else 1.0 | |
| ngrams = [tuple(tokens[idx : idx + n]) for idx in range(0, len(tokens) - n + 1)] | |
| if not ngrams: | |
| return 0.0 | |
| return len(set(ngrams)) / max(1, len(ngrams)) | |
| class SwarmV2ReplayValidator: | |
| """Hard-gated replay validator for deterministic swarm_v2 generation.""" | |
| def __init__( | |
| self, | |
| graph: CanonicalGraph, | |
| validation: SwarmV2ValidationConfig, | |
| shared_context: SwarmV2SharedContextConfig, | |
| seen_questions: list[str] | None = None, | |
| ): | |
| self.graph = graph | |
| self.validation = validation | |
| self.shared_context = shared_context | |
| self.seen_questions = list(seen_questions or []) | |
| self.graph_nodes = set(graph.nodes.keys()) | |
| self.graph_edges = {(edge.src, edge.rel, edge.dst) for edge in graph.edges} | |
| self.outgoing: dict[str, list[Edge]] = {} | |
| for edge in graph.edges: | |
| self.outgoing.setdefault(edge.src, []).append(edge) | |
| def remember(self, question: str) -> None: | |
| token = str(question).strip() | |
| if not token: | |
| return | |
| self.seen_questions.append(token) | |
| if len(self.seen_questions) > 4096: | |
| self.seen_questions = self.seen_questions[-2048:] | |
| def _count_matching_paths(self, start: str, relations: list[str], answer: str, limit: int = 4) -> int: | |
| if not start or not relations: | |
| return 0 | |
| count = 0 | |
| stack: list[tuple[str, int, tuple[str, ...]]] = [(start, 0, (start,))] | |
| while stack: | |
| node_id, rel_idx, seen_nodes = stack.pop() | |
| if rel_idx >= len(relations): | |
| if node_id == answer: | |
| count += 1 | |
| if count >= limit: | |
| return count | |
| continue | |
| relation = relations[rel_idx] | |
| for edge in self.outgoing.get(node_id, []): | |
| if edge.rel != relation: | |
| continue | |
| if edge.dst in seen_nodes: | |
| continue | |
| stack.append((edge.dst, rel_idx + 1, seen_nodes + (edge.dst,))) | |
| return count | |
| def _replay_tool_trace(self, candidate: GeneratedTaskCandidate) -> tuple[list[str], list[Edge], str, str]: | |
| reasons: list[str] = [] | |
| replayed_edges: list[Edge] = [] | |
| replayed_answer = "" | |
| replayed_question = "" | |
| declared_answer = "" | |
| declared_question = "" | |
| tool_trace = list(candidate.tool_trace) | |
| trace_path_source: Any = candidate.supporting_edges | |
| if not tool_trace and candidate.supporting_edges: | |
| synthesized_trace = build_swarm_v2_tool_trace(self.graph, candidate.supporting_edges) | |
| tool_trace = _parse_tool_trace(synthesized_trace) | |
| for call in tool_trace: | |
| if call.tool_name == "enumerate_neighbors": | |
| node_id = str(call.args.get("node_id", "")).strip() | |
| expected_edge = call.args.get("expected_edge", {}) | |
| if not node_id: | |
| reasons.append("non_replayable_tool_calls") | |
| continue | |
| neighbors = enumerate_swarm_v2_neighbors(self.graph, node_id) | |
| if not neighbors: | |
| reasons.append("non_replayable_tool_calls") | |
| if isinstance(expected_edge, dict): | |
| expected_key = ( | |
| str(expected_edge.get("src", "")).strip(), | |
| str(expected_edge.get("rel", "")).strip(), | |
| str(expected_edge.get("dst", "")).strip(), | |
| ) | |
| if expected_key not in {(edge.src, edge.rel, edge.dst) for edge in neighbors}: | |
| reasons.append("non_replayable_tool_calls") | |
| elif call.tool_name == "trace_path": | |
| trace_path_source = call.args.get("path", trace_path_source) | |
| replayed_edges = trace_swarm_v2_path(self.graph, trace_path_source) | |
| if not replayed_edges: | |
| reasons.append("non_replayable_tool_calls") | |
| elif call.tool_name == "select_answer": | |
| declared_answer = normalize_answer(str(call.output.get("answer", "")).strip()) | |
| elif call.tool_name == "emit_question": | |
| declared_question = str(call.output.get("question", "")).strip() | |
| else: | |
| reasons.append("non_replayable_tool_calls") | |
| if not replayed_edges: | |
| replayed_edges = trace_swarm_v2_path(self.graph, trace_path_source) | |
| if not replayed_edges and candidate.supporting_edges: | |
| replayed_edges = trace_swarm_v2_path(self.graph, candidate.supporting_edges) | |
| if not replayed_edges: | |
| reasons.append("non_replayable_tool_calls") | |
| return reasons, replayed_edges, replayed_answer, replayed_question | |
| replayed_answer = select_swarm_v2_answer(replayed_edges) | |
| replayed_question = emit_swarm_v2_question(replayed_edges) | |
| if declared_answer and declared_answer != normalize_answer(replayed_answer): | |
| reasons.append("non_replayable_tool_calls") | |
| if declared_question and declared_question != replayed_question: | |
| reasons.append("non_replayable_tool_calls") | |
| return reasons, replayed_edges, replayed_answer, replayed_question | |
| def validate(self, candidate: GeneratedTaskCandidate) -> ReplayValidationResult: | |
| reasons: list[str] = [] | |
| if not candidate.question or not candidate.answer: | |
| reasons.append("missing_question_or_answer") | |
| if not candidate.supporting_edges: | |
| reasons.append("malformed_support_edges") | |
| if len(candidate.supporting_edges) > self.validation.max_support_edges: | |
| reasons.append("context_or_support_budget_overflow") | |
| edge_keys = [(edge.src, edge.rel, edge.dst) for edge in candidate.supporting_edges] | |
| if len(set(edge_keys)) != len(edge_keys): | |
| reasons.append("malformed_support_edges") | |
| for edge in candidate.supporting_edges: | |
| if edge.src not in self.graph_nodes or edge.dst not in self.graph_nodes: | |
| reasons.append("unseen_nodes_or_edges") | |
| break | |
| if (edge.src, edge.rel, edge.dst) not in self.graph_edges: | |
| reasons.append("unseen_nodes_or_edges") | |
| break | |
| replay_reasons, replayed_edges, replayed_answer, replayed_question = self._replay_tool_trace(candidate) | |
| reasons.extend(replay_reasons) | |
| if replayed_edges: | |
| expected_keys = [(edge.src, edge.rel, edge.dst) for edge in replayed_edges] | |
| if expected_keys != edge_keys: | |
| reasons.append("non_replayable_tool_calls") | |
| relations = [edge.rel for edge in replayed_edges] | |
| unique_path_count = self._count_matching_paths( | |
| start=replayed_edges[0].src, | |
| relations=relations, | |
| answer=replayed_answer or candidate.answer, | |
| ) | |
| else: | |
| unique_path_count = 0 | |
| if unique_path_count != 1: | |
| reasons.append("non_unique_derivation_path") | |
| if replayed_answer and normalize_answer(replayed_answer) != normalize_answer(candidate.answer): | |
| reasons.append("non_replayable_tool_calls") | |
| if replayed_question and replayed_question != candidate.question: | |
| reasons.append("non_replayable_tool_calls") | |
| if candidate.answer and normalize_answer(candidate.answer).lower() in candidate.question.lower(): | |
| reasons.append("answer_leakage") | |
| duplicate_similarity = 0.0 | |
| if candidate.question and self.seen_questions: | |
| duplicate_similarity = max( | |
| _swarm_v2_question_similarity(candidate.question, seen_question) | |
| for seen_question in self.seen_questions | |
| ) | |
| if duplicate_similarity >= self.validation.duplicate_similarity_threshold: | |
| reasons.append("duplicate_or_near_duplicate") | |
| context_nodes = len({edge.src for edge in candidate.supporting_edges} | {edge.dst for edge in candidate.supporting_edges}) | |
| context_edges = len(candidate.supporting_edges) | |
| max_context_nodes = min(self.validation.max_context_nodes, self.shared_context.max_nodes) | |
| max_context_edges = min(self.validation.max_context_edges, self.shared_context.max_edges) | |
| if context_nodes > max_context_nodes or context_edges > max_context_edges: | |
| reasons.append("context_or_support_budget_overflow") | |
| if len(candidate.supporting_edges) > self.validation.max_path_hops: | |
| reasons.append("context_or_support_budget_overflow") | |
| return ReplayValidationResult( | |
| is_valid=not reasons, | |
| reasons=sorted(set(reasons)), | |
| duplicate_similarity=duplicate_similarity, | |
| context_nodes=context_nodes, | |
| context_edges=context_edges, | |
| unique_path_count=unique_path_count, | |
| replayed_question=replayed_question, | |
| replayed_answer=replayed_answer, | |
| replayed_edges=replayed_edges, | |
| ) | |
| class AnswererJudge: | |
| """Lightweight frozen answerer used to score adversarial hardness.""" | |
| def __init__(self, model_name_or_path: str, max_new_tokens: int = 48): | |
| self.model_name_or_path = model_name_or_path | |
| self.max_new_tokens = max_new_tokens | |
| self._model = None | |
| self._tokenizer = None | |
| def _ensure_loaded(self) -> None: | |
| if self._model is not None and self._tokenizer is not None: | |
| return | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path) | |
| if tokenizer.pad_token is None and tokenizer.eos_token is not None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model_kwargs: dict[str, Any] = {} | |
| if torch.cuda.is_available(): | |
| model_kwargs["device_map"] = "auto" | |
| model_kwargs["torch_dtype"] = torch.bfloat16 | |
| model = AutoModelForCausalLM.from_pretrained(self.model_name_or_path, **model_kwargs) | |
| model.eval() | |
| self._model = model | |
| self._tokenizer = tokenizer | |
| def answer(self, question: str) -> str: | |
| self._ensure_loaded() | |
| assert self._model is not None | |
| assert self._tokenizer is not None | |
| import torch | |
| prompt = ( | |
| "You are an OSINT answering model. " | |
| "Answer with only the final entity string.\n" | |
| f"Question: {question}\n" | |
| "Answer:" | |
| ) | |
| tokenizer = self._tokenizer | |
| model = self._model | |
| encoded = tokenizer(prompt, return_tensors="pt") | |
| device = next(model.parameters()).device | |
| encoded = {k: v.to(device) for k, v in encoded.items()} | |
| with torch.no_grad(): | |
| output = model.generate( | |
| **encoded, | |
| max_new_tokens=max(1, int(self.max_new_tokens)), | |
| do_sample=False, | |
| temperature=0.0, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| generated = output[0][encoded["input_ids"].shape[1] :] | |
| completion = tokenizer.decode(generated, skip_special_tokens=True) | |
| return normalize_answer(extract_answer_from_completion(completion)) | |
| class GeneratorRewardFunction: | |
| """Reward for the graph/question generation swarm in adversarial self-play.""" | |
| def __init__( | |
| self, | |
| graph: CanonicalGraph, | |
| answerer_judge: AnswererJudge, | |
| weights: GeneratorRewardWeights, | |
| max_support_edges: int = 8, | |
| pipeline_mode: str = "legacy", | |
| swarm_v2_validation: SwarmV2ValidationConfig | None = None, | |
| swarm_v2_shared_context: SwarmV2SharedContextConfig | None = None, | |
| parl_max_parallel_hint: int = 0, | |
| ): | |
| self.graph = graph | |
| self.answerer_judge = answerer_judge | |
| self.weights = weights | |
| self.max_support_edges = max_support_edges | |
| self.pipeline_mode = str(pipeline_mode).strip().lower() or "legacy" | |
| self.graph_nodes = set(graph.nodes.keys()) | |
| self.graph_edges = {(edge.src, edge.rel, edge.dst) for edge in graph.edges} | |
| self._seen_questions: list[str] = [] | |
| self.swarm_v2_validation = swarm_v2_validation or SwarmV2ValidationConfig( | |
| max_support_edges=max_support_edges | |
| ) | |
| self.swarm_v2_shared_context = swarm_v2_shared_context or SwarmV2SharedContextConfig() | |
| self.parl_max_parallel_hint = max(0, int(parl_max_parallel_hint or 0)) | |
| self._swarm_v2_validator = SwarmV2ReplayValidator( | |
| graph=graph, | |
| validation=self.swarm_v2_validation, | |
| shared_context=self.swarm_v2_shared_context, | |
| seen_questions=self._seen_questions, | |
| ) | |
| self._debug_batches_seen = 0 | |
| self._debug_reason_counter: Counter[str] = Counter() | |
| self._debug_reward_window: list[float] = [] | |
| self._debug_last_batch: dict[str, Any] = {} | |
| def _std(values: list[float]) -> float: | |
| if len(values) <= 1: | |
| return 0.0 | |
| mean = sum(values) / len(values) | |
| variance = sum((value - mean) ** 2 for value in values) / len(values) | |
| return variance ** 0.5 | |
| def _invalid_swarm_v2_reward( | |
| self, | |
| candidate: GeneratedTaskCandidate, | |
| validation_result: ReplayValidationResult, | |
| completion_text: str = "", | |
| ) -> float: | |
| # Avoid a constant hard penalty. Keep invalid samples negative but | |
| # graded so GRPO still gets reward variance/advantages when quality | |
| # differs. Three layers of signal: | |
| # (1) per-reason penalty (caps how bad it can get) | |
| # (2) partial credit for any parseable structural element | |
| # (3) tiny text-level signal so completely-collapsed completions | |
| # differ from completions that at least *attempt* JSON. | |
| reason_penalty = { | |
| "missing_question_or_answer": 0.35, | |
| "malformed_support_edges": 0.25, | |
| "non_replayable_tool_calls": 0.25, | |
| "non_unique_derivation_path": 0.20, | |
| "unseen_nodes_or_edges": 0.25, | |
| "answer_leakage": 0.30, | |
| "duplicate_or_near_duplicate": 0.15, | |
| "context_or_support_budget_overflow": 0.15, | |
| } | |
| penalty = 0.10 | |
| for reason in validation_result.reasons: | |
| penalty += reason_penalty.get(reason, 0.10) | |
| partial_credit = 0.0 | |
| if candidate.question: | |
| partial_credit += 0.25 | |
| if candidate.answer: | |
| partial_credit += 0.25 | |
| if candidate.supporting_edges: | |
| partial_credit += min(0.36, 0.12 * len(candidate.supporting_edges)) | |
| if candidate.tool_trace: | |
| partial_credit += min(0.20, 0.05 * len(candidate.tool_trace)) | |
| if candidate.subagent_outputs: | |
| partial_credit += 0.10 | |
| if candidate.canonical_edges or candidate.canonical_nodes: | |
| partial_credit += 0.12 | |
| text_signal = self._completion_text_signal(completion_text) | |
| reward = partial_credit - penalty + text_signal | |
| return float(max(-1.25, min(-0.02, reward))) | |
| def _completion_text_signal(completion_text: str) -> float: | |
| """Small [0, 0.25] bonus that grades how 'JSON-like' a raw completion is. | |
| Important for GRPO: if every sample in a group is unparseable garbage | |
| but the *raw text* differs in JSON-likeness, we still produce non-zero | |
| advantages. Without this the reward collapses to a flat floor and | |
| ``frac_reward_zero_std`` stays at 1.0 forever. | |
| """ | |
| if not completion_text: | |
| return 0.0 | |
| text = completion_text.strip() | |
| if not text: | |
| return 0.0 | |
| signal = 0.0 | |
| # Brace cues (model is trying to emit JSON). | |
| signal += 0.03 * min(2, text.count("{")) | |
| signal += 0.03 * min(2, text.count("}")) | |
| signal += 0.01 * min(4, text.count("[")) | |
| signal += 0.01 * min(4, text.count("]")) | |
| # Schema-keyword cues. Each keyword bumps the signal a tiny amount. | |
| cues = ( | |
| "\"question\"", | |
| "\"answer\"", | |
| "\"supporting_edges\"", | |
| "\"tool_trace\"", | |
| "\"task_type\"", | |
| "\"orchestrator\"", | |
| ) | |
| signal += 0.015 * sum(1 for cue in cues if cue in text) | |
| # Diversity proxy: number of unique short tokens. Pure repetition | |
| # collapses this; varied output keeps it nonzero. Caps very fast. | |
| sample = text[:512] | |
| unique_words = len(set(sample.split())) if sample else 0 | |
| signal += min(0.04, unique_words / 800.0) | |
| # Length proxy capped — purely-empty vs anything-emitted differs. | |
| length_bump = min(0.03, len(text) / 8000.0) | |
| signal += length_bump | |
| return min(0.25, signal) | |
| def _validity_score(self, candidate: GeneratedTaskCandidate) -> float: | |
| score = 0.0 | |
| if candidate.question: | |
| score += 0.4 | |
| if candidate.answer: | |
| score += 0.4 | |
| if len(candidate.supporting_edges) <= self.max_support_edges: | |
| score += 0.2 | |
| return min(1.0, score) | |
| def _consistency_score(self, candidate: GeneratedTaskCandidate) -> float: | |
| if not candidate.question or not candidate.answer: | |
| return 0.0 | |
| edge_consistency = 0.0 | |
| if candidate.supporting_edges: | |
| matches = sum( | |
| 1 | |
| for edge in candidate.supporting_edges | |
| if (edge.src, edge.rel, edge.dst) in self.graph_edges | |
| ) | |
| edge_consistency = matches / max(1, len(candidate.supporting_edges)) | |
| answer_in_graph = 1.0 if candidate.answer in self.graph_nodes else 0.0 | |
| answer_in_edges = 1.0 if any( | |
| candidate.answer in {edge.src, edge.dst} for edge in candidate.supporting_edges | |
| ) else 0.0 | |
| question_mentions_graph_symbol = 1.0 if any( | |
| node_id in candidate.question for node_id in self.graph_nodes | |
| ) else 0.0 | |
| return ( | |
| 0.45 * edge_consistency | |
| + 0.30 * max(answer_in_graph, answer_in_edges) | |
| + 0.25 * question_mentions_graph_symbol | |
| ) | |
| def _diversity_score(self, question: str) -> float: | |
| if not self._seen_questions: | |
| return 1.0 | |
| max_similarity = max(_jaccard_similarity(question, prior) for prior in self._seen_questions) | |
| return max(0.0, 1.0 - max_similarity) | |
| def _hardness_score(self, candidate: GeneratedTaskCandidate) -> float: | |
| if not candidate.is_valid: | |
| return -1.0 | |
| predicted_answer = normalize_answer(self.answerer_judge.answer(candidate.question)) | |
| target_answer = normalize_answer(candidate.answer) | |
| return 1.0 if predicted_answer != target_answer else -0.4 | |
| def _support_path_coverage(candidate: GeneratedTaskCandidate) -> float: | |
| if not candidate.supporting_edges: | |
| return 0.0 | |
| keys = {(edge.src, edge.rel, edge.dst) for edge in candidate.supporting_edges} | |
| return len(keys) / max(1, len(candidate.supporting_edges)) | |
| def _swarm_diversity_score(self, candidate: GeneratedTaskCandidate) -> float: | |
| if not candidate.subagent_outputs: | |
| return 0.0 | |
| distinct_ratio = _distinct_ngram_ratio(candidate.subagent_outputs, n=2) | |
| path_coverage = self._support_path_coverage(candidate) | |
| return max(0.0, min(1.0, (0.7 * distinct_ratio) + (0.3 * path_coverage))) | |
| def _context_pressure_score(self, validation_result: ReplayValidationResult) -> float: | |
| if not validation_result.is_valid: | |
| return 0.0 | |
| node_util = validation_result.context_nodes / max(1, self.swarm_v2_shared_context.max_nodes) | |
| edge_util = validation_result.context_edges / max(1, self.swarm_v2_shared_context.max_edges) | |
| utilization = max(node_util, edge_util) | |
| target = max(0.05, float(self.swarm_v2_shared_context.target_pressure)) | |
| if utilization > 1.0: | |
| return 0.0 | |
| gap = abs(utilization - target) | |
| return max(0.0, 1.0 - (gap / max(target, 1.0 - target))) | |
| def _parl_scores(self, candidate: GeneratedTaskCandidate) -> tuple[float, float]: | |
| breakdown = parl_reward_breakdown( | |
| task_outcome_reward=0.0, | |
| spawn_count=candidate.orchestrator.spawn_count, | |
| finished_subtasks=candidate.orchestrator.finished_subtasks, | |
| critical_steps=candidate.orchestrator.critical_steps, | |
| lambda_parallel=0.15, | |
| lambda_finish=0.20, | |
| anneal=1.0, | |
| breadth=candidate.orchestrator.breadth, | |
| depth=candidate.orchestrator.depth, | |
| max_parallel_hint=self.parl_max_parallel_hint, | |
| ) | |
| return breakdown.parallel, breakdown.finish | |
| def _swarm_v2_reward( | |
| self, | |
| candidate: GeneratedTaskCandidate, | |
| completion_text: str = "", | |
| ) -> tuple[float, ReplayValidationResult]: | |
| validator = self._swarm_v2_validator | |
| validator.seen_questions = list(self._seen_questions) | |
| validation_result = validator.validate(candidate) | |
| if not validation_result.is_valid: | |
| return ( | |
| self._invalid_swarm_v2_reward(candidate, validation_result, completion_text), | |
| validation_result, | |
| ) | |
| hardness = self._hardness_score(candidate) | |
| swarm_diversity = self._swarm_diversity_score(candidate) | |
| context_pressure = self._context_pressure_score(validation_result) | |
| parl_parallel, parl_finish = self._parl_scores(candidate) | |
| hardness_component = max(0.0, min(1.0, (hardness + 0.4) / 1.4)) | |
| consistency_component = max( | |
| 0.0, | |
| min( | |
| 1.0, | |
| (0.55 * context_pressure) | |
| + (0.25 * parl_parallel) | |
| + (0.20 * parl_finish), | |
| ), | |
| ) | |
| completion_component = max(0.0, min(1.0, self._completion_text_signal(completion_text) / 0.25)) | |
| reward = ( | |
| self.weights.validity | |
| + (self.weights.hardness * hardness_component) | |
| + (self.weights.diversity * swarm_diversity) | |
| + (self.weights.consistency * consistency_component) | |
| + (0.05 * completion_component) | |
| ) | |
| return reward, validation_result | |
| def __call__( | |
| self, | |
| prompts: list[Any] | None = None, | |
| completions: list[Any] | None = None, | |
| **kwargs: Any, | |
| ) -> list[float]: | |
| del prompts | |
| if completions is None: | |
| completions = list(kwargs.get("completions", [])) | |
| rewards: list[float] = [] | |
| batch_reasons: Counter[str] = Counter() | |
| valid_count = 0 | |
| for completion in completions: | |
| try: | |
| text = decode_completion_text(completion) | |
| except Exception: | |
| text = "" | |
| try: | |
| candidate = parse_generated_task_completion( | |
| text, max_support_edges=self.max_support_edges | |
| ) | |
| except Exception as exc: | |
| # Hard guard: a single malformed completion must NEVER take | |
| # down GRPO. Fall through to the invalid-floor with a tiny | |
| # text signal so the group still has reward variance. | |
| print( | |
| f"[reward_debug][parse_error] {type(exc).__name__}: {exc}; " | |
| f"text_head={text[:120]!r}" | |
| ) | |
| rewards.append( | |
| float(max(-1.8, -0.6 + self._completion_text_signal(text))) | |
| ) | |
| batch_reasons["parse_error"] += 1 | |
| continue | |
| if self.pipeline_mode == "swarm_v2": | |
| try: | |
| reward, validation_result = self._swarm_v2_reward( | |
| candidate, completion_text=text | |
| ) | |
| except Exception as exc: | |
| print( | |
| f"[reward_debug][reward_error] {type(exc).__name__}: {exc}" | |
| ) | |
| rewards.append( | |
| float(max(-1.8, -0.6 + self._completion_text_signal(text))) | |
| ) | |
| batch_reasons["reward_error"] += 1 | |
| continue | |
| rewards.append(float(max(-1.8, min(1.2, reward)))) | |
| if validation_result.is_valid and candidate.question: | |
| valid_count += 1 | |
| self._seen_questions.append(candidate.question) | |
| if len(self._seen_questions) > 4096: | |
| self._seen_questions = self._seen_questions[-2048:] | |
| else: | |
| for reason in validation_result.reasons: | |
| batch_reasons[reason] += 1 | |
| else: | |
| validity = self._validity_score(candidate) | |
| consistency = self._consistency_score(candidate) | |
| diversity = self._diversity_score(candidate.question) if candidate.question else 0.0 | |
| hardness = self._hardness_score(candidate) | |
| reward = ( | |
| self.weights.validity * validity | |
| + self.weights.hardness * hardness | |
| + self.weights.diversity * diversity | |
| + self.weights.consistency * consistency | |
| ) | |
| rewards.append(float(max(-2.0, min(1.2, reward)))) | |
| if self.pipeline_mode != "swarm_v2" and candidate.question: | |
| self._seen_questions.append(candidate.question) | |
| if len(self._seen_questions) > 4096: | |
| self._seen_questions = self._seen_questions[-2048:] | |
| self._debug_batches_seen += 1 | |
| self._debug_reward_window.extend(rewards) | |
| self._debug_reward_window = self._debug_reward_window[-512:] | |
| self._debug_reason_counter.update(batch_reasons) | |
| batch_mean = float(sum(rewards) / max(1, len(rewards))) | |
| batch_std = float(self._std(rewards)) | |
| advantages = [float(value - batch_mean) for value in rewards] | |
| self._debug_last_batch = { | |
| "batch_rewards": list(rewards), | |
| "batch_reward_mean": batch_mean, | |
| "batch_reward_std": batch_std, | |
| "advantage_proxy_min": min(advantages) if advantages else 0.0, | |
| "advantage_proxy_max": max(advantages) if advantages else 0.0, | |
| "advantage_proxy_std": float(self._std(advantages)), | |
| "valid_count": int(valid_count), | |
| "invalid_count": int(max(0, len(rewards) - valid_count)), | |
| "valid_output_ratio": float(valid_count / max(1, len(rewards))), | |
| "top_invalid_reasons": batch_reasons.most_common(5), | |
| } | |
| if self.pipeline_mode == "swarm_v2" and (self._debug_batches_seen % 10 == 0): | |
| window_std = self._std(self._debug_reward_window) | |
| print( | |
| "[reward_debug][generator] " | |
| f"batches={self._debug_batches_seen} " | |
| f"window_reward_std={window_std:.6f} " | |
| f"last_batch_valid={valid_count}/{len(rewards)} " | |
| f"top_invalid_reasons={batch_reasons.most_common(3)}" | |
| ) | |
| return rewards | |
| class AnswererRewardFunction: | |
| """Answer-swarm reward wrapper that reuses the environment answer reward logic.""" | |
| def __init__( | |
| self, | |
| graph: CanonicalGraph, | |
| pipeline_mode: str = "legacy", | |
| parl_max_parallel_hint: int = 0, | |
| ): | |
| self.reward_model = build_reward_model(graph) | |
| self.pipeline_mode = str(pipeline_mode).strip().lower() or "legacy" | |
| self.parl_max_parallel_hint = max(0, int(parl_max_parallel_hint or 0)) | |
| # Mirror GeneratorRewardFunction observability: TRL's GRPOTrainer | |
| # already logs `rewards/AnswererRewardFunction/{mean,std}` to W&B | |
| # at every `logging_steps`, but we ALSO publish a per-batch debug | |
| # snapshot so the [reward_debug][last_batch] line appears in stdout | |
| # for the answerer phase, exactly like it does for the generator. | |
| self._debug_batches_seen = 0 | |
| self._debug_reward_window: list[float] = [] | |
| self._debug_last_batch: dict[str, Any] = {} | |
| def _std(values: list[float]) -> float: | |
| if len(values) <= 1: | |
| return 0.0 | |
| mean = sum(values) / len(values) | |
| variance = sum((value - mean) ** 2 for value in values) / len(values) | |
| return variance ** 0.5 | |
| def _parse_support_edges(value: Any) -> list[Edge]: | |
| payload = value | |
| if isinstance(value, str): | |
| try: | |
| payload = json.loads(value) | |
| except json.JSONDecodeError: | |
| payload = [] | |
| out: list[Edge] = [] | |
| if not isinstance(payload, list): | |
| return out | |
| for row in payload: | |
| if not isinstance(row, dict): | |
| continue | |
| src = str(row.get("src", "")).strip() | |
| rel = str(row.get("rel", "")).strip() | |
| dst = str(row.get("dst", "")).strip() | |
| if not src or not rel or not dst: | |
| continue | |
| try: | |
| confidence = float(row.get("confidence", 1.0)) | |
| except (TypeError, ValueError): | |
| confidence = 1.0 | |
| out.append(Edge(src=src, rel=rel, dst=dst, confidence=confidence)) | |
| return out | |
| def _value_at(column: Any, index: int, default: Any) -> Any: | |
| if isinstance(column, list) and index < len(column): | |
| return column[index] | |
| return default | |
| def _extract_predicted_edges(completion_text: str, support_edges: list[Edge]) -> list[Edge]: | |
| blob = _extract_json_blob(completion_text) | |
| if isinstance(blob, dict): | |
| structured_edges = _parse_edge_rows(blob.get("supporting_edges", []), max_support_edges=len(support_edges)) | |
| if structured_edges: | |
| return structured_edges | |
| text = completion_text.lower() | |
| matched: list[Edge] = [] | |
| for edge in support_edges: | |
| if edge.src.lower() in text and edge.rel.lower() in text and edge.dst.lower() in text: | |
| matched.append(edge) | |
| return matched | |
| def _extract_orchestrator_reward(self, completion_text: str, base_reward: float) -> float: | |
| if self.pipeline_mode != "swarm_v2": | |
| return float(base_reward) | |
| blob = _extract_json_blob(completion_text) | |
| orchestrator = _parse_orchestrator(blob.get("orchestrator")) if isinstance(blob, dict) else SwarmOrchestratorTelemetry() | |
| breakdown = parl_reward_breakdown( | |
| task_outcome_reward=base_reward, | |
| spawn_count=orchestrator.spawn_count, | |
| finished_subtasks=orchestrator.finished_subtasks, | |
| critical_steps=orchestrator.critical_steps, | |
| lambda_parallel=0.15, | |
| lambda_finish=0.20, | |
| anneal=1.0, | |
| breadth=orchestrator.breadth, | |
| depth=orchestrator.depth, | |
| max_parallel_hint=self.parl_max_parallel_hint, | |
| ) | |
| return float(breakdown.total) | |
| def __call__( | |
| self, | |
| prompts: list[Any], | |
| completions: list[Any], | |
| answer: list[Any] | None = None, | |
| question: list[Any] | None = None, | |
| supporting_edges_json: list[Any] | None = None, | |
| difficulty: list[Any] | None = None, | |
| **kwargs: Any, | |
| ) -> list[float]: | |
| rewards: list[float] = [] | |
| success_count = 0 | |
| graph_f1_sum = 0.0 | |
| for idx, completion in enumerate(completions): | |
| completion_text = decode_completion_text(completion) | |
| predicted_answer = extract_answer_from_completion(completion_text) | |
| target_answer = normalize_answer(str(self._value_at(answer, idx, ""))) | |
| question_text = str(self._value_at(question, idx, "")).strip() | |
| if not question_text: | |
| question_text = str(self._value_at(prompts, idx, "")).strip() | |
| support_payload = self._value_at(supporting_edges_json, idx, []) | |
| support_edges = self._parse_support_edges(support_payload) | |
| difficulty_level = str(self._value_at(difficulty, idx, "hard")).strip() or "hard" | |
| task = TaskInstance( | |
| task_id=f"train_task_{idx}", | |
| task_type="adversarial_trace", | |
| question=question_text, | |
| answer=target_answer, | |
| supporting_edges=support_edges, | |
| metadata={"difficulty": difficulty_level}, | |
| ) | |
| pred_edges = self._extract_predicted_edges(completion_text, support_edges) | |
| breakdown = compute_answer_reward( | |
| proposed_answer=predicted_answer, | |
| task=task, | |
| pred_edges=pred_edges, | |
| tool_outputs=[], | |
| step_count=1, | |
| model=self.reward_model, | |
| difficulty=difficulty_level, | |
| ) | |
| final_reward = self._extract_orchestrator_reward(completion_text, breakdown.total) | |
| rewards.append(final_reward) | |
| if predicted_answer and target_answer and normalize_answer(predicted_answer) == target_answer: | |
| success_count += 1 | |
| graph_f1_sum += float(getattr(breakdown, "graph_f1", 0.0) or 0.0) | |
| # Mirror GeneratorRewardFunction debug surface so the answerer reward | |
| # is visible to the same downstream tooling (printed by | |
| # `_train_grpo_phase` and forwarded to W&B by TRL). | |
| self._debug_batches_seen += 1 | |
| self._debug_reward_window.extend(rewards) | |
| self._debug_reward_window = self._debug_reward_window[-512:] | |
| batch_size = max(1, len(rewards)) | |
| batch_mean = float(sum(rewards) / batch_size) | |
| batch_std = float(self._std(rewards)) | |
| advantages = [float(value - batch_mean) for value in rewards] | |
| self._debug_last_batch = { | |
| "batch_rewards": list(rewards), | |
| "batch_reward_mean": batch_mean, | |
| "batch_reward_std": batch_std, | |
| "advantage_proxy_min": min(advantages) if advantages else 0.0, | |
| "advantage_proxy_max": max(advantages) if advantages else 0.0, | |
| "advantage_proxy_std": float(self._std(advantages)), | |
| "exact_match_count": int(success_count), | |
| "exact_match_ratio": float(success_count / batch_size), | |
| "avg_graph_f1": float(graph_f1_sum / batch_size), | |
| } | |
| if self._debug_batches_seen % 10 == 0: | |
| window_std = self._std(self._debug_reward_window) | |
| print( | |
| "[reward_debug][answerer] " | |
| f"batches={self._debug_batches_seen} " | |
| f"window_reward_std={window_std:.6f} " | |
| f"last_batch_mean={batch_mean:.6f} " | |
| f"last_batch_std={batch_std:.6f} " | |
| f"exact_match_ratio={self._debug_last_batch['exact_match_ratio']:.3f} " | |
| f"avg_graph_f1={self._debug_last_batch['avg_graph_f1']:.3f}", | |
| flush=True, | |
| ) | |
| return rewards | |