OSINT / src /osint_env /training /self_play.py
ritishshrirao's picture
Add evaluation, minor updates to HF space
8ad6382
from __future__ import annotations
import inspect
import json
import os
import re
import time
from dataclasses import dataclass
from pathlib import Path
import random
from typing import Any
from osint_env.data.generator import (
build_swarm_v2_canonical_subgraph,
build_swarm_v2_path_candidates,
build_swarm_v2_tool_trace,
emit_swarm_v2_question,
select_swarm_v2_answer,
trace_swarm_v2_path,
)
from osint_env.domain.models import Edge, EnvironmentConfig, TaskInstance
from osint_env.env.environment import OSINTEnvironment
from osint_env.env.reward import compute_graph_f1
from osint_env.llm import build_llm_client
from osint_env.training.config import (
KimiGRPOPhaseConfig,
LoraTuningConfig,
SelfPlayTrainingConfig,
SwarmV2SwarmConfig,
)
from osint_env.training.rewards import (
AnswererJudge,
AnswererRewardFunction,
GeneratorRewardFunction,
SwarmV2ReplayValidator,
decode_completion_text,
extract_answer_from_completion,
normalize_answer,
parse_generated_task_completion,
)
@dataclass(slots=True)
class _RoundArtifacts:
round_index: int
generator_dataset_path: str
answerer_dataset_path: str
generated_tasks_path: str
def _is_true_env(value: str | None) -> bool:
token = str(value or "").strip().lower()
return token in {"1", "true", "yes", "y", "on"}
def _resolve_hf_upload_token() -> str:
for env_name in ("HF_TOKEN", "HUGGINGFACE_HUB_TOKEN", "HUGGING_FACE_HUB_TOKEN"):
token = str(os.getenv(env_name, "")).strip()
if token:
return token
return ""
def _slugify_hf_repo_name(value: str) -> str:
token = re.sub(r"[^a-zA-Z0-9._-]+", "-", str(value).strip().lower())
token = re.sub(r"-{2,}", "-", token).strip("-.")
return token
def _default_hf_checkpoint_repo_id(run_dir: Path) -> str:
explicit = str(os.getenv("OSINT_HF_CHECKPOINT_REPO_ID", "")).strip()
if explicit:
return explicit
space_id = str(os.getenv("SPACE_ID") or os.getenv("HF_SPACE_ID") or "").strip()
if "/" not in space_id:
return ""
owner, _, space_name = space_id.partition("/")
suffix = str(os.getenv("OSINT_HF_CHECKPOINT_REPO_SUFFIX", "-checkpoints")).strip() or "-checkpoints"
repo_name = _slugify_hf_repo_name(f"{space_name}{suffix}") or "osint-self-play-checkpoints"
return f"{owner}/{repo_name}"
def _hf_checkpoint_repo_prefix(run_dir: Path) -> str:
explicit = str(os.getenv("OSINT_HF_CHECKPOINT_PATH_PREFIX", "")).strip().strip("/")
if explicit:
return explicit
return _slugify_hf_repo_name(run_dir.name) or "self-play"
def _hf_relative_repo_path(local_path: Path, run_dir: Path) -> str:
prefix = _hf_checkpoint_repo_prefix(run_dir)
try:
relative = local_path.relative_to(run_dir).as_posix()
except ValueError:
relative = local_path.name
return f"{prefix}/{relative}".strip("/")
def _maybe_upload_folder_to_hf(local_dir: Path, run_dir: Path, commit_message: str) -> None:
repo_id = _default_hf_checkpoint_repo_id(run_dir)
token = _resolve_hf_upload_token()
if not repo_id or not token or not local_dir.exists():
return
try:
from huggingface_hub import HfApi
except ImportError:
print("[self_play][hf_upload] huggingface_hub missing; skipping checkpoint upload.")
return
repo_type = str(os.getenv("OSINT_HF_CHECKPOINT_REPO_TYPE", "model")).strip() or "model"
private = _is_true_env(os.getenv("OSINT_HF_CHECKPOINT_REPO_PRIVATE", "1"))
path_in_repo = _hf_relative_repo_path(local_dir, run_dir)
api = HfApi(token=token)
api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private, exist_ok=True)
# Upload only inference-relevant artifacts. Resume-only state such as
# optimizer/scheduler RNG snapshots makes uploads much larger and is not
# needed for sharing or post-phase evaluation.
ignore_patterns = [
"*.pyc",
"__pycache__",
".DS_Store",
"**/optimizer.pt",
"**/scheduler.pt",
"**/rng_state.pth",
"**/trainer_state.json",
"**/training_args.bin",
]
api.upload_folder(
folder_path=str(local_dir),
repo_id=repo_id,
repo_type=repo_type,
path_in_repo=path_in_repo,
commit_message=commit_message,
ignore_patterns=ignore_patterns,
)
print(f"[self_play][hf_upload] uploaded {local_dir} -> {repo_type}:{repo_id}/{path_in_repo}")
def _maybe_upload_file_to_hf(local_file: Path, run_dir: Path, commit_message: str) -> None:
repo_id = _default_hf_checkpoint_repo_id(run_dir)
token = _resolve_hf_upload_token()
if not repo_id or not token or not local_file.exists():
return
try:
from huggingface_hub import HfApi
except ImportError:
print("[self_play][hf_upload] huggingface_hub missing; skipping artifact upload.")
return
repo_type = str(os.getenv("OSINT_HF_CHECKPOINT_REPO_TYPE", "model")).strip() or "model"
private = _is_true_env(os.getenv("OSINT_HF_CHECKPOINT_REPO_PRIVATE", "1"))
path_in_repo = _hf_relative_repo_path(local_file, run_dir)
api = HfApi(token=token)
api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private, exist_ok=True)
api.upload_file(
path_or_fileobj=str(local_file),
repo_id=repo_id,
repo_type=repo_type,
path_in_repo=path_in_repo,
commit_message=commit_message,
)
print(f"[self_play][hf_upload] uploaded {local_file} -> {repo_type}:{repo_id}/{path_in_repo}")
def _require_training_stack() -> tuple[Any, Any, Any]:
try:
from datasets import Dataset
from trl import GRPOConfig, GRPOTrainer
except ImportError as exc:
raise RuntimeError(
"Training stack is missing. Install train dependencies first: "
"python -m pip install -e .[train]"
) from exc
return Dataset, GRPOConfig, GRPOTrainer
def _build_hf_checkpoint_upload_callback(output_dir: Path, run_dir: Path) -> Any:
"""Return a Trainer callback that uploads each fresh ``checkpoint-*`` to
HF Hub the moment Transformers' Trainer writes it to disk. Returns None
if uploads are disabled or transformers is unavailable.
This pairs with ``_maybe_download_phase_checkpoints_from_hf`` so a Space
that gets restarted mid-phase can pull the most recent checkpoint and
warm-start instead of starting from step 0. Honors the same env vars as
the post-phase upload helper:
- ``OSINT_HF_CHECKPOINT_REPO_ID`` (or auto-derived from ``SPACE_ID``)
- ``OSINT_HF_CHECKPOINT_REPO_TYPE`` (default ``model``)
- ``OSINT_HF_UPLOAD_ON_SAVE`` (default ``1``; set to 0 to disable)
"""
if not _is_true_env(os.getenv("OSINT_HF_UPLOAD_ON_SAVE", "1")):
return None
repo_id = _default_hf_checkpoint_repo_id(run_dir)
token = _resolve_hf_upload_token()
if not repo_id or not token:
return None
try:
from transformers import TrainerCallback
except ImportError:
print(
"[self_play][hf_upload] transformers.TrainerCallback unavailable; "
"intermediate checkpoint uploads disabled.",
flush=True,
)
return None
captured_output_dir = output_dir
captured_run_dir = run_dir
class _HfHubCheckpointUploadCallback(TrainerCallback): # type: ignore[misc]
"""Upload the latest local ``checkpoint-*`` directory after each save."""
def __init__(self) -> None:
self._last_uploaded_step: int | None = None
self._failures = 0
def on_save(self, args: Any, state: Any, control: Any, **kwargs: Any) -> Any: # noqa: D401
try:
latest = _latest_local_checkpoint(captured_output_dir)
if latest is None:
return control
step = int(latest.name.split("-", 1)[1]) if "-" in latest.name else 0
if self._last_uploaded_step is not None and step <= self._last_uploaded_step:
return control
print(
f"[self_play][hf_upload] on_save uploading {latest.name} "
f"to HF Hub (phase_dir={captured_output_dir.name}, step={step}).",
flush=True,
)
_maybe_upload_folder_to_hf(
latest,
captured_run_dir,
f"Intermediate checkpoint upload step={step} ({captured_output_dir.name})",
)
self._last_uploaded_step = step
self._failures = 0
except Exception as exc: # noqa: BLE001
self._failures += 1
print(
f"[self_play][hf_upload] on_save upload failed "
f"({type(exc).__name__}: {exc}). failures={self._failures}. "
"Continuing training; next save will retry.",
flush=True,
)
return control
return _HfHubCheckpointUploadCallback()
def _latest_local_checkpoint(output_dir: Path) -> Path | None:
if not output_dir.exists():
return None
candidates: list[tuple[int, Path]] = []
for path in output_dir.glob("checkpoint-*"):
if not path.is_dir():
continue
suffix = path.name.split("-", 1)[-1]
try:
step = int(suffix)
except ValueError:
continue
candidates.append((step, path))
if not candidates:
return None
candidates.sort(key=lambda item: item[0])
return candidates[-1][1]
def _final_model_already_present(output_dir: Path) -> bool:
final_dir = output_dir / "final_model"
if not final_dir.is_dir():
return False
safetensors = list(final_dir.glob("*.safetensors"))
legacy_bin = list(final_dir.glob("pytorch_model*.bin"))
return bool(safetensors or legacy_bin)
def _maybe_download_phase_checkpoints_from_hf(output_dir: Path, run_dir: Path) -> Path | None:
"""If no local checkpoint exists, try to recover the latest checkpoint
for this phase from the HF Hub repo we already upload to. Returns the
local path of the restored ``checkpoint-*`` directory, or None.
Designed to make Space restarts non-destructive: training state is
pushed to ``OSINT_HF_CHECKPOINT_REPO_ID`` after every phase, so on a
fresh container we can pull it back and resume.
"""
if _latest_local_checkpoint(output_dir) is not None:
return _latest_local_checkpoint(output_dir)
repo_id = _default_hf_checkpoint_repo_id(run_dir)
token = _resolve_hf_upload_token()
if not repo_id or not token:
return None
try:
from huggingface_hub import HfApi, snapshot_download
except ImportError:
return None
repo_type = str(os.getenv("OSINT_HF_CHECKPOINT_REPO_TYPE", "model")).strip() or "model"
api = HfApi(token=token)
try:
files = list(api.list_repo_files(repo_id=repo_id, repo_type=repo_type))
except Exception as exc: # noqa: BLE001
print(
f"[self_play][resume] could not list files in {repo_id}: "
f"{type(exc).__name__}: {exc}",
flush=True,
)
return None
phase_prefix = _hf_relative_repo_path(output_dir, run_dir)
phase_prefix_clean = phase_prefix.strip("/") + "/"
candidate_steps: dict[int, list[str]] = {}
for remote_path in files:
if not remote_path.startswith(phase_prefix_clean):
continue
relative = remote_path[len(phase_prefix_clean) :]
parts = relative.split("/", 1)
if len(parts) < 2 or not parts[0].startswith("checkpoint-"):
continue
try:
step = int(parts[0].split("-", 1)[1])
except ValueError:
continue
candidate_steps.setdefault(step, []).append(remote_path)
if not candidate_steps:
return None
best_step = max(candidate_steps.keys())
target_local_dir = output_dir / f"checkpoint-{best_step}"
target_local_dir.mkdir(parents=True, exist_ok=True)
print(
f"[self_play][resume] downloading phase checkpoint from HF Hub: "
f"repo={repo_id} prefix={phase_prefix_clean}checkpoint-{best_step} "
f"-> {target_local_dir}",
flush=True,
)
try:
snapshot_download(
repo_id=repo_id,
repo_type=repo_type,
local_dir=str(output_dir),
allow_patterns=[f"{phase_prefix_clean}checkpoint-{best_step}/*"],
token=token,
)
downloaded_root = output_dir / phase_prefix_clean.rstrip("/") / f"checkpoint-{best_step}"
if downloaded_root.exists() and downloaded_root != target_local_dir:
for item in downloaded_root.iterdir():
dest = target_local_dir / item.name
if dest.exists():
continue
item.replace(dest)
return target_local_dir if target_local_dir.exists() else None
except Exception as exc: # noqa: BLE001
print(
f"[self_play][resume] failed to download checkpoint from HF Hub: "
f"{type(exc).__name__}: {exc}",
flush=True,
)
return None
def _task_to_edge_json(task: TaskInstance) -> str:
payload = [
{
"src": edge.src,
"rel": edge.rel,
"dst": edge.dst,
"confidence": float(edge.confidence),
}
for edge in task.supporting_edges
]
return json.dumps(payload, sort_keys=True)
def _edge_payload(edge: Edge) -> dict[str, Any]:
return {
"src": edge.src,
"rel": edge.rel,
"dst": edge.dst,
"confidence": float(edge.confidence),
}
def _edges_from_payload(rows: Any, max_edges: int) -> list[Edge]:
if not isinstance(rows, list):
return []
edges: list[Edge] = []
for row in rows[:max_edges]:
if not isinstance(row, dict):
continue
src = str(row.get("src", "")).strip()
rel = str(row.get("rel", "")).strip()
dst = str(row.get("dst", "")).strip()
if not src or not rel or not dst:
continue
try:
confidence = float(row.get("confidence", 1.0))
except (TypeError, ValueError):
confidence = 1.0
edges.append(Edge(src=src, rel=rel, dst=dst, confidence=confidence))
return edges
def _compact_shared_context(
shared_context: dict[str, Any],
max_nodes: int = 8,
max_edges: int = 6,
) -> dict[str, Any]:
return {
"nodes": list(shared_context.get("nodes", []))[:max_nodes],
"edges": list(shared_context.get("edges", []))[:max_edges],
}
def _task_shared_context(
env: OSINTEnvironment,
task: TaskInstance,
cfg: SelfPlayTrainingConfig,
) -> dict[str, Any]:
metadata = dict(task.metadata or {})
canonical_graph = metadata.get("canonical_graph")
if isinstance(canonical_graph, dict):
return {
"nodes": list(canonical_graph.get("nodes", []))[: cfg.swarm_v2.shared_context.max_nodes],
"edges": list(canonical_graph.get("edges", []))[: cfg.swarm_v2.shared_context.max_edges],
}
deterministic_seed = sum(ord(ch) for ch in task.task_id)
return _graph_context_for_prompt(
env=env,
max_nodes=cfg.swarm_v2.shared_context.max_nodes,
max_edges=cfg.swarm_v2.shared_context.max_edges,
rng=random.Random(deterministic_seed),
)
def _swarm_v2_worker_packets(
canonical_candidate: dict[str, Any],
shared_context: dict[str, Any],
swarm_cfg: SwarmV2SwarmConfig,
) -> dict[str, Any]:
path_edges = _edges_from_payload(
canonical_candidate.get("path", canonical_candidate.get("edges", [])),
max_edges=max(1, swarm_cfg.max_depth * 2),
)
if not path_edges:
path_edges = _edges_from_payload(canonical_candidate.get("edges", []), max_edges=2)
relation_path = [edge.rel for edge in path_edges]
start_node = path_edges[0].src if path_edges else ""
return {
"path_agent": {
"path_edges": [_edge_payload(edge) for edge in path_edges],
"goal": "Choose one contiguous replayable path from the canonical candidate.",
},
"question_agent": {
"start_node": start_node,
"relation_path": relation_path,
"goal": "Write a compact question that describes the path without leaking the answer.",
},
"context_agent": {
"shared_context": _compact_shared_context(shared_context),
"goal": "Keep support/context usage compact and graph-grounded.",
},
"planner": {
"max_agents": int(swarm_cfg.max_agents),
"max_breadth": int(swarm_cfg.max_breadth),
"max_depth": int(swarm_cfg.max_depth),
},
}
def _serialize_tool_trace(tool_trace: Any) -> list[dict[str, Any]]:
serialized: list[dict[str, Any]] = []
for call in tool_trace or []:
tool_name = getattr(call, "tool_name", "")
args = getattr(call, "args", {})
output = getattr(call, "output", {})
if not tool_name:
continue
serialized.append(
{
"tool_name": str(tool_name),
"args": dict(args) if isinstance(args, dict) else {},
"output": dict(output) if isinstance(output, dict) else {},
}
)
return serialized
def _canonical_example_payload(
graph: Any,
canonical_candidate: dict[str, Any],
swarm_cfg: SwarmV2SwarmConfig,
) -> dict[str, Any]:
candidate_edges = _edges_from_payload(canonical_candidate.get("edges", []), max_edges=2)
traced_edges = trace_swarm_v2_path(graph, candidate_edges) or candidate_edges
if not traced_edges:
return {
"question": "Which entity is reached by following the provided replayable relation path?",
"answer": "",
"task_type": "swarm_v2_trace",
"supporting_edges": [],
"subagent_outputs": ["path_agent: no replayable edge"],
"orchestrator": {
"spawn_count": 1,
"finished_subtasks": 1,
"critical_steps": 1,
"breadth": 1,
"depth": 1,
},
}
traced_edges = traced_edges[:2]
spawn_count = min(swarm_cfg.max_agents, max(1, len(traced_edges) + 1))
return {
"question": emit_swarm_v2_question(traced_edges),
"answer": select_swarm_v2_answer(traced_edges),
"task_type": f"swarm_v2_{len(traced_edges)}hop_trace",
"supporting_edges": [_edge_payload(edge) for edge in traced_edges],
"subagent_outputs": [
f"path_agent_{idx}: {edge.src} --{edge.rel}--> {edge.dst}"
for idx, edge in enumerate(traced_edges)
]
+ [
"question_agent: emitted compact relation-path question",
"context_agent: kept shared context focused on replayable edges",
],
"orchestrator": {
"spawn_count": spawn_count,
"finished_subtasks": spawn_count,
"critical_steps": max(1, len(traced_edges)),
"breadth": min(swarm_cfg.max_breadth, spawn_count),
"depth": min(swarm_cfg.max_depth, 1),
},
}
def _difficulty_for_task(task: TaskInstance) -> str:
metadata = dict(task.metadata or {})
token = str(metadata.get("difficulty", "")).strip().lower()
if token in {"easy", "medium", "hard"}:
return token
if task.task_type.startswith("metaqa_1-hop"):
return "easy"
if task.task_type.startswith("metaqa_2-hop"):
return "medium"
return "hard"
def _answer_prompt(question: str) -> str:
return (
"You are the answer-generation swarm for an OSINT graph task.\n"
"Return ONLY one compact JSON object. Do not use markdown. Do not add prose.\n"
"Required schema: {\"answer\": \"<entity_or_value>\"}\n"
"Valid example: {\"answer\": \"user_7\"}\n"
f"Question: {question}"
)
def _swarm_v2_answer_prompt(
question: str,
shared_context: dict[str, Any],
swarm_cfg: SwarmV2SwarmConfig,
) -> str:
del swarm_cfg # kept for signature compatibility
compact_context = _compact_shared_context(shared_context)
return (
"You answer one OSINT graph question using ONLY the shared context.\n"
"Output rules:\n"
"- Return ONLY one compact JSON object. No markdown. No prose. End with }.\n"
"- Required keys: answer, supporting_edges, orchestrator.\n"
"- supporting_edges: list of {src, rel, dst, confidence} taken from shared edges.\n"
"- orchestrator integer keys: spawn_count, finished_subtasks, critical_steps, breadth, depth.\n"
"Example schema:\n"
"{\"answer\":\"user_7\",\"supporting_edges\":[{\"src\":\"alias_7_123\",\"rel\":\"alias_of\","
"\"dst\":\"user_7\",\"confidence\":1.0}],\"orchestrator\":{\"spawn_count\":2,"
"\"finished_subtasks\":2,\"critical_steps\":1,\"breadth\":2,\"depth\":1}}\n"
f"Shared context: {json.dumps(compact_context, separators=(',', ':'), sort_keys=True)}\n"
f"Question: {question}\n"
"JSON:"
)
def _build_answerer_rows(tasks: list[TaskInstance]) -> list[dict[str, Any]]:
rows: list[dict[str, Any]] = []
for task in tasks:
rows.append(
{
"prompt": _answer_prompt(task.question),
"question": task.question,
"answer": str(task.answer),
"supporting_edges_json": _task_to_edge_json(task),
"difficulty": _difficulty_for_task(task),
"task_type": task.task_type,
"task_id": task.task_id,
}
)
return rows
def _build_swarm_v2_answerer_rows(
env: OSINTEnvironment,
tasks: list[TaskInstance],
cfg: SelfPlayTrainingConfig,
) -> list[dict[str, Any]]:
rows: list[dict[str, Any]] = []
for task in tasks:
shared_context = _task_shared_context(env=env, task=task, cfg=cfg)
rows.append(
{
"prompt": _swarm_v2_answer_prompt(
question=task.question,
shared_context=shared_context,
swarm_cfg=cfg.swarm_v2.answerer_swarm,
),
"question": task.question,
"answer": str(task.answer),
"supporting_edges_json": _task_to_edge_json(task),
"difficulty": _difficulty_for_task(task),
"task_type": task.task_type,
"task_id": task.task_id,
}
)
return rows
def _graph_context_for_prompt(
env: OSINTEnvironment,
max_nodes: int,
max_edges: int,
rng: random.Random,
) -> dict[str, Any]:
node_ids = sorted(env.graph.nodes.keys())
if len(node_ids) > max_nodes:
node_ids = rng.sample(node_ids, k=max_nodes)
edges = list(env.graph.edges)
if len(edges) > max_edges:
edges = rng.sample(edges, k=max_edges)
return {
"nodes": node_ids,
"edges": [
{
"src": edge.src,
"rel": edge.rel,
"dst": edge.dst,
}
for edge in edges
],
}
def _generator_prompt(context_blob: dict[str, Any], anchor_questions: list[str]) -> str:
anchors = "\n".join(f"- {question}" for question in anchor_questions)
return (
"You are the adversarial question-and-graph generation swarm in self-play.\n"
"Generate one challenging but answerable OSINT task that makes answering difficult.\n"
"Use only entities and relations from the provided graph context.\n"
"Prefer multi-hop traces and avoid duplicates of the anchor questions.\n"
"Return strict JSON with keys: question, answer, task_type, supporting_edges.\n"
"supporting_edges must be a list of objects with src, rel, dst, confidence.\n"
"Graph context:\n"
f"{json.dumps(context_blob, sort_keys=True)}\n"
"Anchor questions to avoid:\n"
f"{anchors}\n"
)
def _build_generator_rows(
env: OSINTEnvironment,
cfg: SelfPlayTrainingConfig,
rng: random.Random,
) -> list[dict[str, Any]]:
rows: list[dict[str, Any]] = []
existing_questions = [task.question for task in env.tasks]
for _ in range(max(1, cfg.generator_prompts_per_round)):
context_blob = _graph_context_for_prompt(
env=env,
max_nodes=cfg.max_graph_context_nodes,
max_edges=cfg.max_graph_context_edges,
rng=rng,
)
anchor_sample_size = min(5, len(existing_questions))
anchor_sample = rng.sample(existing_questions, k=anchor_sample_size) if anchor_sample_size > 0 else []
rows.append(
{
"prompt": _generator_prompt(context_blob, anchor_sample),
}
)
return rows
def _swarm_v2_generator_prompt(
graph: Any,
shared_context: dict[str, Any],
canonical_candidate: dict[str, Any],
anchor_questions: list[str],
swarm_cfg: SwarmV2SwarmConfig,
canonical_graph_mode: str,
) -> str:
anchors = "\n".join(f"- {question}" for question in anchor_questions)
canonical_mode = str(canonical_graph_mode).strip().lower() or "generate"
example_payload = _canonical_example_payload(graph, canonical_candidate, swarm_cfg)
worker_packets = _swarm_v2_worker_packets(
canonical_candidate=canonical_candidate,
shared_context=shared_context,
swarm_cfg=swarm_cfg,
)
canonical_instruction = (
"You may propose canonical_graph updates when they improve replayability and keep it graph-grounded."
if canonical_mode == "generate"
else "Reuse the provided canonical candidate as-is; do not add, remove, or modify canonical_graph nodes/edges."
)
canonical_compact = _compact_shared_context(canonical_candidate)
return (
"You coordinate a compact multi-agent OSINT task-generation swarm.\n"
"Output rules:\n"
"- Return ONLY one JSON object. No markdown. No prose. End with }.\n"
"- Required keys: question, answer, task_type, supporting_edges, subagent_outputs, orchestrator.\n"
"- Optional keys: canonical_graph, validation.\n"
"- supporting_edges: non-empty list of {src, rel, dst, confidence}, taken from canonical edges.\n"
"- supporting_edges must form one contiguous replayable path. Keep it compact.\n"
"- Do NOT emit verbose tool traces or neighbor dumps; replay tools are derived from supporting_edges.\n"
"- answer = final dst of the trace. question describes the path without leaking the answer.\n"
"- subagent_outputs: 2-4 terse strings summarizing path_agent/question_agent/context_agent work.\n"
"- orchestrator: integer keys spawn_count, finished_subtasks, critical_steps, breadth, depth.\n"
f"- canonical_graph_mode={canonical_mode}: {canonical_instruction}\n"
"- Favor minimal shared context per worker so question generation stays parallel-friendly.\n"
"Example (copy schema, not values):\n"
f"{json.dumps(example_payload, separators=(',', ':'), sort_keys=True)}\n"
"Worker packets:\n"
f"{json.dumps(worker_packets, separators=(',', ':'), sort_keys=True)}\n"
"Canonical candidate (use these edges):\n"
f"{json.dumps(canonical_compact, separators=(',', ':'), sort_keys=True)}\n"
"Shared context:\n"
f"{json.dumps(_compact_shared_context(shared_context), separators=(',', ':'), sort_keys=True)}\n"
f"Avoid these prior questions: {anchors}\n"
"JSON:"
)
def _build_swarm_v2_generator_rows(
env: OSINTEnvironment,
cfg: SelfPlayTrainingConfig,
rng: random.Random,
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
rows: list[dict[str, Any]] = []
canonical_candidates: list[dict[str, Any]] = []
existing_questions = [task.question for task in env.tasks]
path_candidates = build_swarm_v2_path_candidates(
env.graph,
rng=rng,
count=max(1, cfg.generator_prompts_per_round),
min_hops=2,
max_hops=cfg.swarm_v2.validation.max_path_hops,
)
for idx, path_edges in enumerate(path_candidates):
shared_context = _graph_context_for_prompt(
env=env,
max_nodes=cfg.swarm_v2.shared_context.max_nodes,
max_edges=cfg.swarm_v2.shared_context.max_edges,
rng=rng,
)
canonical_candidate = build_swarm_v2_canonical_subgraph(
env.graph,
path_edges=path_edges,
max_extra_edges=max(0, cfg.swarm_v2.shared_context.max_edges - len(path_edges)),
)
anchor_sample_size = min(5, len(existing_questions))
anchor_sample = rng.sample(existing_questions, k=anchor_sample_size) if anchor_sample_size > 0 else []
prompt = _swarm_v2_generator_prompt(
graph=env.graph,
shared_context=shared_context,
canonical_candidate=canonical_candidate,
anchor_questions=anchor_sample,
swarm_cfg=cfg.swarm_v2.generator_swarm,
canonical_graph_mode=cfg.canonical_graph_mode,
)
rows.append(
{
"prompt": prompt,
"candidate_id": f"candidate_{idx}",
"canonical_graph_json": json.dumps(canonical_candidate, sort_keys=True),
"shared_context_json": json.dumps(shared_context, sort_keys=True),
"worker_packets_json": json.dumps(
_swarm_v2_worker_packets(
canonical_candidate=canonical_candidate,
shared_context=shared_context,
swarm_cfg=cfg.swarm_v2.generator_swarm,
),
sort_keys=True,
),
}
)
canonical_candidates.append(canonical_candidate)
return rows, canonical_candidates
def _safe_build_grpo_config(
phase: KimiGRPOPhaseConfig,
output_dir: str,
grpo_config_cls: Any,
report_to: list[str] | None = None,
run_name: str = "",
) -> Any:
kwargs: dict[str, Any] = {
"output_dir": output_dir,
"learning_rate": float(phase.learning_rate),
"max_steps": int(phase.max_steps),
"per_device_train_batch_size": int(phase.per_device_train_batch_size),
"gradient_accumulation_steps": int(phase.gradient_accumulation_steps),
"num_generations": int(phase.num_generations),
"max_completion_length": int(phase.max_completion_length),
"temperature": float(phase.temperature),
"top_p": float(phase.top_p),
"repetition_penalty": float(phase.repetition_penalty),
"beta": float(phase.beta),
"epsilon": float(phase.epsilon),
"num_iterations": int(phase.num_iterations),
"loss_type": str(phase.loss_type),
"scale_rewards": str(phase.scale_rewards),
"logging_steps": int(phase.logging_steps),
"save_steps": int(phase.save_steps),
"save_total_limit": int(phase.save_total_limit),
"optim": str(phase.optim),
"bf16": bool(phase.bf16),
"tf32": bool(phase.tf32),
"gradient_checkpointing": bool(phase.gradient_checkpointing),
"dataloader_num_workers": int(phase.dataloader_num_workers),
"dataloader_persistent_workers": bool(phase.dataloader_persistent_workers),
"dataloader_prefetch_factor": int(phase.dataloader_prefetch_factor),
"generation_batch_size": int(phase.generation_batch_size),
"max_prompt_length": int(phase.max_prompt_length),
"remove_unused_columns": False,
"use_vllm": bool(phase.use_vllm),
"vllm_mode": str(phase.vllm_mode),
"report_to": list(report_to or []),
}
if str(run_name).strip():
kwargs["run_name"] = str(run_name).strip()
signature = inspect.signature(grpo_config_cls.__init__)
filtered = {key: value for key, value in kwargs.items() if key in signature.parameters}
return grpo_config_cls(**filtered)
def _build_lora_config(lora: LoraTuningConfig) -> Any:
try:
from peft import LoraConfig, TaskType
except ImportError as exc:
raise RuntimeError(
"LoRA tuning selected, but PEFT is not installed. "
"Install train dependencies first: python -m pip install -e .[train]"
) from exc
task_type_token = str(lora.task_type or "CAUSAL_LM").strip().upper()
task_type = getattr(TaskType, task_type_token, TaskType.CAUSAL_LM)
return LoraConfig(
r=max(1, int(lora.r)),
lora_alpha=max(1, int(lora.alpha)),
lora_dropout=float(lora.dropout),
target_modules=list(lora.target_modules),
bias=str(lora.bias),
task_type=task_type,
)
def _coerce_named_reward_func(reward_function: Any) -> Any:
"""Return a callable with a stable __name__ for TRL compatibility."""
if hasattr(reward_function, "__name__") and str(getattr(reward_function, "__name__", "")).strip():
return reward_function
# TRL versions that introspect reward_funcs[i].__name__ require this attribute.
if callable(reward_function):
name = reward_function.__class__.__name__ or "reward_func"
try:
setattr(reward_function, "__name__", name)
return reward_function
except Exception:
def _wrapped_reward(*args: Any, **kwargs: Any) -> Any:
return reward_function(*args, **kwargs)
_wrapped_reward.__name__ = name
return _wrapped_reward
return reward_function
def _train_grpo_phase(
model_name_or_path: str,
phase: KimiGRPOPhaseConfig,
rows: list[dict[str, Any]],
reward_function: Any,
output_dir: Path,
tuning_mode: str,
lora: LoraTuningConfig,
report_to: list[str] | None = None,
run_name: str = "",
) -> dict[str, Any]:
Dataset, GRPOConfig, GRPOTrainer = _require_training_stack()
output_dir.mkdir(parents=True, exist_ok=True)
phase_label = str(run_name).strip() or str(output_dir.name)
reward_class_name = type(reward_function).__name__
# Output layout: <run_dir>/round_NNN/<phase_subdir>. Match the run_dir
# used by the corresponding HF upload helpers so resume paths line up
# with where checkpoints were written.
run_dir_for_resume = output_dir.parents[1] if len(output_dir.parents) >= 2 else output_dir.parent
if _final_model_already_present(output_dir):
final_dir = output_dir / "final_model"
print(
f"[self_play][resume] phase={phase_label} already has final_model at {final_dir}. "
f"Skipping retrain on Space restart.",
flush=True,
)
checkpoint_dirs = [str(path) for path in sorted(output_dir.glob("checkpoint-*")) if path.is_dir()]
return {
"model_path": str(final_dir),
"final_model_path": str(final_dir),
"phase_output_dir": str(output_dir),
"checkpoint_dirs": checkpoint_dirs,
"global_step": int(getattr(phase, "max_steps", 0) or 0),
"training_loss": 0.0,
"train_rows": len(rows),
"tuning_mode": str(tuning_mode).strip().lower() or "full",
"is_full_finetune": str(tuning_mode).strip().lower() != "lora",
"resumed_skipped": True,
}
resume_checkpoint = _latest_local_checkpoint(output_dir)
resume_is_full_state = bool(
resume_checkpoint is not None
and (resume_checkpoint / "optimizer.pt").exists()
and (resume_checkpoint / "trainer_state.json").exists()
)
if resume_checkpoint is None:
downloaded = _maybe_download_phase_checkpoints_from_hf(output_dir, run_dir_for_resume)
if downloaded is not None:
resume_checkpoint = downloaded
resume_is_full_state = bool(
(downloaded / "optimizer.pt").exists()
and (downloaded / "trainer_state.json").exists()
)
# If the resume checkpoint is weights-only (e.g. recovered from HF Hub
# which intentionally drops optimizer state to keep uploads small),
# warm-start the model from those weights and start a fresh trainer
# state. Better than restarting from the base model.
warm_start_from: Path | None = None
if resume_checkpoint is not None and not resume_is_full_state:
warm_start_from = resume_checkpoint
resume_checkpoint = None
if warm_start_from is not None:
model_name_or_path = str(warm_start_from)
dataset = Dataset.from_list(rows)
args = _safe_build_grpo_config(
phase=phase,
output_dir=str(output_dir),
grpo_config_cls=GRPOConfig,
report_to=report_to,
run_name=run_name,
)
trainer_kwargs: dict[str, Any] = {
"model": model_name_or_path,
"args": args,
"reward_funcs": _coerce_named_reward_func(reward_function),
"train_dataset": dataset,
}
upload_callback = _build_hf_checkpoint_upload_callback(output_dir, run_dir_for_resume)
if upload_callback is not None:
trainer_kwargs["callbacks"] = [upload_callback]
print(
f"[self_play][hf_upload] phase={phase_label} intermediate checkpoint "
f"uploads enabled (every save_steps={phase.save_steps}).",
flush=True,
)
if str(tuning_mode).strip().lower() == "lora":
trainer_signature = inspect.signature(GRPOTrainer.__init__)
if "peft_config" not in trainer_signature.parameters:
raise RuntimeError("Installed TRL version does not expose peft_config in GRPOTrainer.")
trainer_kwargs["peft_config"] = _build_lora_config(lora)
print(
f"[self_play] Starting phase: {phase_label} rows={len(rows)} "
f"max_steps={phase.max_steps}",
flush=True,
)
print(
f"[self_play][reward_setup] phase={phase_label} "
f"reward_function={reward_class_name} "
f"wandb_metric=rewards/{reward_class_name}/mean "
f"logging_steps={phase.logging_steps} "
f"num_generations={phase.num_generations} "
f"per_device_train_batch_size={phase.per_device_train_batch_size}",
flush=True,
)
if resume_checkpoint is not None:
print(
f"[self_play][resume] phase={phase_label} resuming (full state) from checkpoint={resume_checkpoint}",
flush=True,
)
elif warm_start_from is not None:
print(
f"[self_play][resume] phase={phase_label} warm-starting from weights only at {warm_start_from} "
f"(no optimizer state available)",
flush=True,
)
strict_asserts = str(os.getenv("OSINT_TRAIN_STRICT_ASSERTS", "")).strip().lower() in {"1", "true", "yes", "on"}
trainer = GRPOTrainer(**trainer_kwargs)
tracked_params = [
(name, param)
for name, param in trainer.model.named_parameters()
if getattr(param, "requires_grad", False)
][:32]
pre_update_fingerprint = {
name: float(param.detach().float().abs().mean().item())
for name, param in tracked_params
}
if resume_checkpoint is not None:
train_output = trainer.train(resume_from_checkpoint=str(resume_checkpoint))
else:
train_output = trainer.train()
final_dir = output_dir / "final_model"
trainer.save_model(str(final_dir))
trainer_tokenizer = getattr(trainer, "processing_class", None) or getattr(trainer, "tokenizer", None)
if trainer_tokenizer is not None and hasattr(trainer_tokenizer, "save_pretrained"):
trainer_tokenizer.save_pretrained(str(final_dir))
checkpoint_dirs = [str(path) for path in sorted(output_dir.glob("checkpoint-*")) if path.is_dir()]
global_step = int(getattr(train_output, "global_step", 0))
training_loss = float(getattr(train_output, "training_loss", 0.0))
total_param_count = int(sum(param.numel() for param in trainer.model.parameters()))
result = {
"model_path": str(final_dir),
"final_model_path": str(final_dir),
"phase_output_dir": str(output_dir),
"checkpoint_dirs": checkpoint_dirs,
"global_step": global_step,
"training_loss": training_loss,
"train_rows": len(rows),
"tuning_mode": str(tuning_mode).strip().lower() or "full",
"is_full_finetune": str(tuning_mode).strip().lower() != "lora",
}
log_history = list(getattr(getattr(trainer, "state", None), "log_history", []) or [])
reward_values = [float(row.get("reward")) for row in log_history if isinstance(row, dict) and "reward" in row]
reward_std_values = [
float(row.get("reward_std"))
for row in log_history
if isinstance(row, dict) and "reward_std" in row
]
kl_values = [float(row.get("kl")) for row in log_history if isinstance(row, dict) and "kl" in row]
grad_norm_values = [
float(row.get("grad_norm"))
for row in log_history
if isinstance(row, dict) and "grad_norm" in row
]
loss_values = [float(row.get("loss")) for row in log_history if isinstance(row, dict) and "loss" in row]
entropy_values = [float(row.get("entropy")) for row in log_history if isinstance(row, dict) and "entropy" in row]
trainable_params = [param for param in trainer.model.parameters() if getattr(param, "requires_grad", False)]
grad_tensors = [param.grad for param in trainable_params if getattr(param, "grad", None) is not None]
trainable_param_count = int(sum(param.numel() for param in trainable_params))
params_with_grad = int(len(grad_tensors))
nonzero_grad_tensors = int(
sum(
1
for grad in grad_tensors
if float(grad.detach().abs().sum().item()) > 0.0
)
)
diagnostics = {
"reward_min": min(reward_values) if reward_values else 0.0,
"reward_max": max(reward_values) if reward_values else 0.0,
"reward_std_max": max(reward_std_values) if reward_std_values else 0.0,
"kl_max": max(kl_values) if kl_values else 0.0,
"loss_abs_max": max((abs(value) for value in loss_values), default=0.0),
"grad_norm_max": max(grad_norm_values) if grad_norm_values else 0.0,
"entropy_min": min(entropy_values) if entropy_values else 0.0,
"entropy_max": max(entropy_values) if entropy_values else 0.0,
"total_param_count": total_param_count,
"trainable_param_count": trainable_param_count,
"trainable_fraction": (
float(trainable_param_count / total_param_count)
if total_param_count > 0
else 0.0
),
"params_with_grad": params_with_grad,
"nonzero_grad_tensors": nonzero_grad_tensors,
"fingerprint_param_count": len(pre_update_fingerprint),
"fingerprint_changed_count": 0,
}
if pre_update_fingerprint:
changed_count = 0
for name, param in tracked_params:
after_value = float(param.detach().float().abs().mean().item())
before_value = pre_update_fingerprint.get(name, after_value)
if abs(after_value - before_value) > 1e-9:
changed_count += 1
diagnostics["fingerprint_changed_count"] = changed_count
result["diagnostics"] = diagnostics
print(
"[self_play][diagnostics] "
f"{phase_label} reward_range=({diagnostics['reward_min']:.4f},{diagnostics['reward_max']:.4f}) "
f"reward_std_max={diagnostics['reward_std_max']:.6f} "
f"kl_max={diagnostics['kl_max']:.6f} "
f"loss_abs_max={diagnostics['loss_abs_max']:.6f} "
f"grad_norm_max={diagnostics['grad_norm_max']:.6f} "
f"nonzero_grad_tensors={diagnostics['nonzero_grad_tensors']}/{max(1, diagnostics['params_with_grad'])} "
f"fingerprint_changed={diagnostics['fingerprint_changed_count']}/{max(1, diagnostics['fingerprint_param_count'])}"
)
if strict_asserts:
assert diagnostics["reward_max"] != diagnostics["reward_min"], (
f"Constant reward detected in {phase_label}: {diagnostics['reward_min']}"
)
assert diagnostics["reward_std_max"] > 0.0, f"reward_std stayed zero in {phase_label}"
assert diagnostics["kl_max"] > 0.0, f"KL stayed zero in {phase_label}"
assert diagnostics["loss_abs_max"] > 0.0, f"Loss stayed zero in {phase_label}"
assert diagnostics["grad_norm_max"] > 0.0, f"Grad norm stayed zero in {phase_label}"
assert diagnostics["nonzero_grad_tensors"] > 0, f"No non-zero grads in {phase_label}"
assert diagnostics["fingerprint_changed_count"] > 0, f"No parameter fingerprint change in {phase_label}"
reward_debug = getattr(reward_function, "_debug_last_batch", None)
if isinstance(reward_debug, dict):
print(
f"[reward_debug][last_batch] {phase_label} reward_function={reward_class_name} "
f"{json.dumps(reward_debug, sort_keys=True)}",
flush=True,
)
if reward_values:
print(
f"[self_play][reward_history] {phase_label} reward_function={reward_class_name} "
f"steps_logged={len(reward_values)} "
f"reward_first={reward_values[0]:.6f} "
f"reward_last={reward_values[-1]:.6f} "
f"reward_mean={(sum(reward_values) / len(reward_values)):.6f} "
f"reward_min={min(reward_values):.6f} "
f"reward_max={max(reward_values):.6f} "
f"wandb_metric=rewards/{reward_class_name}/mean",
flush=True,
)
else:
print(
f"[self_play][reward_history] {phase_label} reward_function={reward_class_name} "
"no_reward_logs_in_state (TRL never wrote a 'reward' field; check logging_steps / num_generations)",
flush=True,
)
print(
"[self_play] Finished phase: "
f"{phase_label} global_step={global_step} training_loss={training_loss} output={final_dir}",
flush=True,
)
return result
def _resolve_reporting(training_config: SelfPlayTrainingConfig, phase_name: str, round_index: int) -> tuple[list[str], str]:
if not training_config.wandb_enabled:
return [], ""
if training_config.wandb_project:
os.environ["WANDB_PROJECT"] = str(training_config.wandb_project)
if training_config.wandb_entity:
os.environ["WANDB_ENTITY"] = str(training_config.wandb_entity)
prefix = str(training_config.wandb_run_name_prefix).strip() or "self-play"
run_name = f"{prefix}-r{round_index:03d}-{phase_name}"
return ["wandb"], run_name
def _resolve_initial_models(cfg: SelfPlayTrainingConfig) -> tuple[str, str]:
topology = str(cfg.model_topology).strip().lower()
if topology == "shared":
shared = str(cfg.shared_model_name_or_path).strip()
if not shared:
shared = str(cfg.answerer_phase.model_name_or_path).strip() or str(cfg.generator_phase.model_name_or_path).strip()
return shared, shared
return str(cfg.generator_phase.model_name_or_path), str(cfg.answerer_phase.model_name_or_path)
def _fallback_generated_tasks(
base_tasks: list[TaskInstance],
round_index: int,
count: int,
rng: random.Random,
) -> list[TaskInstance]:
if not base_tasks:
return []
selected = list(base_tasks)
rng.shuffle(selected)
selected = selected[: max(1, count)]
out: list[TaskInstance] = []
for idx, task in enumerate(selected):
metadata = dict(task.metadata or {})
metadata.update(
{
"generated_by": "fallback_generator",
"difficulty": "hard",
"round": round_index,
"scenario": "adversarial_trace",
"grader": {
"type": "difficulty_exact_match",
"answer_type": "node_id",
"case_sensitive": True,
"reward_profile": "hard",
},
}
)
out.append(
TaskInstance(
task_id=f"adv_r{round_index}_{idx}",
task_type="adversarial_trace",
question=f"[Adversarial] {task.question}",
answer=task.answer,
supporting_edges=list(task.supporting_edges),
metadata=metadata,
)
)
return out
def _sample_generated_tasks_with_model(
model_name_or_path: str,
prompts: list[str],
round_index: int,
count: int,
max_support_edges: int,
max_new_tokens: int,
batch_size: int = 4,
) -> list[TaskInstance]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
if count <= 0 or not prompts:
return []
print(
f"[self_play][sample_generator_legacy] start model={model_name_or_path} "
f"prompts={len(prompts)} target_valid={count} max_new_tokens={max_new_tokens}",
flush=True,
)
load_start = time.monotonic()
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
tokenizer.pad_token = tokenizer.eos_token
if getattr(tokenizer, "padding_side", "right") != "left":
tokenizer.padding_side = "left"
model_kwargs: dict[str, Any] = {}
if torch.cuda.is_available():
model_kwargs["device_map"] = "auto"
model_kwargs["torch_dtype"] = torch.bfloat16
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **model_kwargs)
model.eval()
device = next(model.parameters()).device
print(
f"[self_play][sample_generator_legacy] model_loaded device={device} "
f"load_elapsed={time.monotonic() - load_start:.1f}s",
flush=True,
)
generated: list[TaskInstance] = []
overall_start = time.monotonic()
effective_batch = max(1, int(batch_size or 1))
processed = 0
for batch_start in range(0, len(prompts), effective_batch):
if len(generated) >= count:
break
batch_prompts = prompts[batch_start : batch_start + effective_batch]
encoded = tokenizer(
batch_prompts,
return_tensors="pt",
padding=True,
truncation=True,
)
encoded = {k: v.to(device) for k, v in encoded.items()}
print(
f"[self_play][sample_legacy] batch_start={batch_start}/{len(prompts)} "
f"batch_size={len(batch_prompts)} max_new_tokens={max(64, int(max_new_tokens))} "
f"input_len={encoded['input_ids'].shape[1]} generating...",
flush=True,
)
batch_t0 = time.monotonic()
with torch.no_grad():
output = model.generate(
**encoded,
max_new_tokens=max(64, int(max_new_tokens)),
do_sample=True,
top_p=0.95,
temperature=1.0,
num_return_sequences=1,
pad_token_id=tokenizer.eos_token_id,
)
print(
f"[self_play][sample_legacy] batch_start={batch_start} "
f"generate_elapsed={time.monotonic() - batch_t0:.1f}s",
flush=True,
)
input_len = encoded["input_ids"].shape[1]
for row_offset in range(len(batch_prompts)):
completion_ids = output[row_offset][input_len:]
completion = tokenizer.decode(completion_ids, skip_special_tokens=True)
candidate = parse_generated_task_completion(completion, max_support_edges=max_support_edges)
processed += 1
if not candidate.is_valid:
continue
metadata = {
"generated_by": "generator_model",
"round": round_index,
"difficulty": "hard",
"scenario": "adversarial_trace",
"grader": {
"type": "difficulty_exact_match",
"answer_type": "node_id",
"case_sensitive": True,
"reward_profile": "hard",
},
}
generated.append(
TaskInstance(
task_id=f"adv_r{round_index}_{len(generated)}",
task_type=candidate.task_type,
question=candidate.question,
answer=candidate.answer,
supporting_edges=list(candidate.supporting_edges),
metadata=metadata,
)
)
if len(generated) >= count:
break
print(
f"[self_play][sample_generator_legacy] processed={processed}/{len(prompts)} "
f"valid={len(generated)}/{count} "
f"elapsed={time.monotonic() - overall_start:.1f}s",
flush=True,
)
print(
f"[self_play][sample_generator_legacy] finished generated={len(generated)}/{count} "
f"total_elapsed={time.monotonic() - overall_start:.1f}s",
flush=True,
)
return generated
def _select_answerer_tasks(
seed_tasks: list[TaskInstance],
generated_tasks: list[TaskInstance],
cfg: SelfPlayTrainingConfig,
rng: random.Random,
) -> list[TaskInstance]:
seed_pick = list(seed_tasks)
gen_pick = list(generated_tasks)
rng.shuffle(seed_pick)
rng.shuffle(gen_pick)
chosen = seed_pick[: max(1, cfg.seed_tasks_per_round)]
chosen.extend(gen_pick[: max(1, cfg.generated_tasks_per_round)])
return chosen
def _save_rows(path: Path, rows: list[dict[str, Any]]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(rows, indent=2, sort_keys=True), encoding="utf-8")
def _save_tasks(path: Path, tasks: list[TaskInstance]) -> None:
payload = []
for task in tasks:
payload.append(
{
"task_id": task.task_id,
"task_type": task.task_type,
"question": task.question,
"answer": task.answer,
"supporting_edges": [
{
"src": edge.src,
"rel": edge.rel,
"dst": edge.dst,
"confidence": float(edge.confidence),
}
for edge in task.supporting_edges
],
"metadata": dict(task.metadata or {}),
}
)
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
def _save_payload(path: Path, payload: Any) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
def _generate_answerer_completion_texts_with_model(
model_name_or_path: str,
prompts: list[str],
max_new_tokens: int,
batch_size: int = 4,
) -> list[str]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
if not prompts:
return []
print(
f"[self_play][sample_answerer] start model={model_name_or_path} "
f"prompts={len(prompts)} max_new_tokens={max_new_tokens}",
flush=True,
)
load_start = time.monotonic()
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
tokenizer.pad_token = tokenizer.eos_token
if getattr(tokenizer, "padding_side", "right") != "left":
tokenizer.padding_side = "left"
model_kwargs: dict[str, Any] = {}
if torch.cuda.is_available():
model_kwargs["device_map"] = "auto"
model_kwargs["torch_dtype"] = torch.bfloat16
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **model_kwargs)
model.eval()
device = next(model.parameters()).device
print(
f"[self_play][sample_answerer] model_loaded device={device} "
f"load_elapsed={time.monotonic() - load_start:.1f}s",
flush=True,
)
completions: list[str] = []
overall_start = time.monotonic()
effective_batch = max(1, int(batch_size or 1))
processed = 0
for batch_start in range(0, len(prompts), effective_batch):
batch_prompts = prompts[batch_start : batch_start + effective_batch]
encoded = tokenizer(
batch_prompts,
return_tensors="pt",
padding=True,
truncation=True,
)
encoded = {key: value.to(device) for key, value in encoded.items()}
print(
f"[self_play][sample_answerer] batch_start={batch_start}/{len(prompts)} "
f"batch_size={len(batch_prompts)} max_new_tokens={max(16, int(max_new_tokens))} "
f"input_len={encoded['input_ids'].shape[1]} generating...",
flush=True,
)
batch_t0 = time.monotonic()
with torch.no_grad():
output = model.generate(
**encoded,
max_new_tokens=max(16, int(max_new_tokens)),
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
)
input_len = encoded["input_ids"].shape[1]
for row_offset in range(len(batch_prompts)):
completion_ids = output[row_offset][input_len:]
completions.append(tokenizer.decode(completion_ids, skip_special_tokens=True))
processed += len(batch_prompts)
print(
f"[self_play][sample_answerer] processed={processed}/{len(prompts)} "
f"generate_elapsed={time.monotonic() - batch_t0:.1f}s "
f"elapsed={time.monotonic() - overall_start:.1f}s",
flush=True,
)
print(
f"[self_play][sample_answerer] finished completions={len(completions)} "
f"total_elapsed={time.monotonic() - overall_start:.1f}s",
flush=True,
)
return completions
def _top_validation_reasons(validation_reports: list[dict[str, Any]]) -> list[tuple[str, int]]:
counts: dict[str, int] = {}
for report in validation_reports:
validation = report.get("validation", {}) if isinstance(report, dict) else {}
reasons = validation.get("reasons", []) if isinstance(validation, dict) else []
for reason in reasons:
token = str(reason).strip()
if not token:
continue
counts[token] = counts.get(token, 0) + 1
return sorted(counts.items(), key=lambda item: (-item[1], item[0]))
def _run_post_training_evaluation(
env_config: EnvironmentConfig,
training_config: SelfPlayTrainingConfig,
generator_model: str,
answerer_models: dict[str, str],
output_dir: Path,
pipeline_mode: str,
effective_dry_run: bool,
) -> dict[str, Any]:
tasks_path = output_dir / "post_training_eval_generated_tasks.json"
validation_path = output_dir / "post_training_eval_validation_reports.json"
payload_path = output_dir / "post_training_evaluation.json"
payload: dict[str, Any] = {
"pipeline_mode": pipeline_mode,
"generator_model": generator_model,
"answerer_models": dict(answerer_models),
"generated_tasks_path": str(tasks_path),
"validation_reports_path": str(validation_path),
"skipped": False,
}
if effective_dry_run:
payload.update({"skipped": True, "reason": "dry_run"})
_save_payload(validation_path, [])
_save_payload(tasks_path, [])
_save_payload(payload_path, payload)
payload["path"] = str(payload_path)
return payload
try:
env = OSINTEnvironment(env_config, llm=build_llm_client(env_config.llm))
rng = random.Random(env_config.seed + 9973)
validation_reports: list[dict[str, Any]] = []
if pipeline_mode == "swarm_v2":
generator_rows, prompt_canonical_candidates = _build_swarm_v2_generator_rows(env, training_config, rng)
completion_texts = _sample_swarm_v2_completion_texts_with_model(
env=env,
cfg=training_config,
model_name_or_path=generator_model,
prompts=[row["prompt"] for row in generator_rows],
count=max(1, training_config.post_training_eval_questions * 2),
seen_questions=[task.question for task in env.tasks],
)
generated_tasks, validation_reports, _, _ = _materialize_swarm_v2_completions(
env=env,
cfg=training_config,
completion_texts=completion_texts,
round_index=max(1, training_config.rounds) + 1,
seen_questions=[task.question for task in env.tasks],
prompt_canonical_candidates=prompt_canonical_candidates,
)
if not generated_tasks:
generated_tasks, validation_reports, _, _ = _materialize_swarm_v2_completions(
env=env,
cfg=training_config,
completion_texts=_fallback_swarm_v2_completion_texts(
env=env,
cfg=training_config,
round_index=max(1, training_config.rounds) + 1,
rng=rng,
),
round_index=max(1, training_config.rounds) + 1,
seen_questions=[task.question for task in env.tasks],
prompt_canonical_candidates=None,
)
generated_tasks = generated_tasks[: max(1, training_config.post_training_eval_questions)]
answer_rows = _build_swarm_v2_answerer_rows(env, generated_tasks, training_config)
reward_fn = AnswererRewardFunction(
graph=env.graph,
pipeline_mode="swarm_v2",
parl_max_parallel_hint=training_config.swarm_v2.answerer_swarm.max_agents,
)
else:
generator_rows = _build_generator_rows(env=env, cfg=training_config, rng=rng)
generated_tasks = _sample_generated_tasks_with_model(
model_name_or_path=generator_model,
prompts=[row["prompt"] for row in generator_rows],
round_index=max(1, training_config.rounds) + 1,
count=max(1, training_config.post_training_eval_questions),
max_support_edges=training_config.max_support_edges,
max_new_tokens=training_config.generated_task_max_new_tokens,
)
if not generated_tasks:
generated_tasks = _fallback_generated_tasks(
base_tasks=list(env.tasks),
round_index=max(1, training_config.rounds) + 1,
count=max(1, training_config.post_training_eval_questions),
rng=rng,
)
answer_rows = _build_answerer_rows(generated_tasks)
reward_fn = AnswererRewardFunction(graph=env.graph)
_save_tasks(tasks_path, generated_tasks)
_save_payload(validation_path, validation_reports)
model_evaluations: dict[str, dict[str, Any]] = {}
for model_label, answerer_model in answerer_models.items():
answerer_completions = _generate_answerer_completion_texts_with_model(
model_name_or_path=answerer_model,
prompts=[row["prompt"] for row in answer_rows],
max_new_tokens=training_config.post_training_eval_answer_max_new_tokens,
)
rewards = reward_fn(
prompts=[row["prompt"] for row in answer_rows],
completions=answerer_completions,
answer=[row["answer"] for row in answer_rows],
question=[row["question"] for row in answer_rows],
supporting_edges_json=[row["supporting_edges_json"] for row in answer_rows],
difficulty=[row["difficulty"] for row in answer_rows],
)
episodes: list[dict[str, Any]] = []
for task, row, completion_text, reward in zip(generated_tasks, answer_rows, answerer_completions, rewards):
support_edges = AnswererRewardFunction._parse_support_edges(row["supporting_edges_json"])
pred_edges = AnswererRewardFunction._extract_predicted_edges(completion_text, support_edges)
predicted_answer = normalize_answer(extract_answer_from_completion(completion_text))
target_answer = normalize_answer(task.answer)
graph_f1 = compute_graph_f1(pred_edges, support_edges)
episodes.append(
{
"task_id": task.task_id,
"task_type": task.task_type,
"question": task.question,
"task_answer": target_answer,
"agent_answer": predicted_answer,
"reward": float(reward),
"graph_f1": float(graph_f1),
"success": int(predicted_answer == target_answer),
"support_edge_count": len(support_edges),
"predicted_edge_count": len(pred_edges),
"completion_length": len(completion_text),
"pred_edges": [
{
"src": edge.src,
"rel": edge.rel,
"dst": edge.dst,
"confidence": float(edge.confidence),
}
for edge in pred_edges
],
"truth_edges": [
{
"src": edge.src,
"rel": edge.rel,
"dst": edge.dst,
"confidence": float(edge.confidence),
}
for edge in support_edges
],
}
)
episode_count = len(episodes)
model_evaluations[model_label] = {
"model_path": answerer_model,
"episodes": episodes,
"summary": {
"episodes": episode_count,
"task_success_rate": (
float(sum(row["success"] for row in episodes) / max(1, episode_count))
if episodes
else 0.0
),
"avg_reward": (
float(sum(float(row["reward"]) for row in episodes) / max(1, episode_count))
if episodes
else 0.0
),
"avg_graph_f1": (
float(sum(float(row["graph_f1"]) for row in episodes) / max(1, episode_count))
if episodes
else 0.0
),
"avg_completion_length": (
float(sum(int(row["completion_length"]) for row in episodes) / max(1, episode_count))
if episodes
else 0.0
),
},
}
final_summary = model_evaluations.get("finetuned_answerer", {}).get("summary", {})
baseline_summary = model_evaluations.get("original_answerer", {}).get("summary", {})
summary = {
"generated_task_count": len(generated_tasks),
"generator_valid_rate": (
float(len(generated_tasks) / max(1, len(validation_reports)))
if validation_reports
else 1.0
),
"compared_models": sorted(model_evaluations.keys()),
"finetuned_answerer": dict(final_summary),
"original_answerer": dict(baseline_summary),
"delta_vs_original": {
"task_success_rate": float(final_summary.get("task_success_rate", 0.0) - baseline_summary.get("task_success_rate", 0.0)),
"avg_reward": float(final_summary.get("avg_reward", 0.0) - baseline_summary.get("avg_reward", 0.0)),
"avg_graph_f1": float(final_summary.get("avg_graph_f1", 0.0) - baseline_summary.get("avg_graph_f1", 0.0)),
},
"top_generator_invalid_reasons": _top_validation_reasons(validation_reports)[:5],
}
payload.update(
{
"summary": summary,
"model_evaluations": model_evaluations,
}
)
except Exception as exc:
payload.update({"skipped": True, "reason": f"{type(exc).__name__}: {exc}"})
if not tasks_path.exists():
_save_payload(tasks_path, [])
if not validation_path.exists():
_save_payload(validation_path, [])
_save_payload(payload_path, payload)
payload["path"] = str(payload_path)
return payload
def _fallback_swarm_v2_completion_texts(
env: OSINTEnvironment,
cfg: SelfPlayTrainingConfig,
round_index: int,
rng: random.Random,
) -> list[str]:
completion_texts: list[str] = []
path_candidates = build_swarm_v2_path_candidates(
env.graph,
rng=rng,
count=max(1, cfg.generated_tasks_per_round * 2),
min_hops=2,
max_hops=cfg.swarm_v2.validation.max_path_hops,
)
for idx, path_edges in enumerate(path_candidates):
traced_edges = trace_swarm_v2_path(env.graph, path_edges)
if not traced_edges:
continue
question = emit_swarm_v2_question(traced_edges)
answer = select_swarm_v2_answer(traced_edges)
canonical_graph = build_swarm_v2_canonical_subgraph(
env.graph,
path_edges=traced_edges,
max_extra_edges=max(0, cfg.swarm_v2.shared_context.max_edges - len(traced_edges)),
)
spawn_count = min(
cfg.swarm_v2.generator_swarm.max_agents,
max(1, len(traced_edges) + 1),
)
payload = {
"canonical_graph": canonical_graph,
"question": question,
"answer": answer,
"task_type": f"swarm_v2_{len(traced_edges)}hop_trace",
"supporting_edges": [
{
"src": edge.src,
"rel": edge.rel,
"dst": edge.dst,
"confidence": float(edge.confidence),
}
for edge in traced_edges
],
"tool_trace": build_swarm_v2_tool_trace(env.graph, traced_edges),
"subagent_outputs": [
f"path_agent_{edge_idx}: {edge.src} --{edge.rel}--> {edge.dst}"
for edge_idx, edge in enumerate(traced_edges)
]
+ [
f"question_agent: emitted deterministic relation-path question for round {round_index}",
f"context_agent: shared context path_size={len(traced_edges)} candidate={idx}",
],
"orchestrator": {
"spawn_count": spawn_count,
"finished_subtasks": spawn_count,
"critical_steps": max(1, len(traced_edges)),
"breadth": min(cfg.swarm_v2.generator_swarm.max_breadth, spawn_count),
"depth": min(cfg.swarm_v2.generator_swarm.max_depth, 1 if len(traced_edges) <= 2 else 2),
},
}
completion_texts.append(json.dumps(payload, sort_keys=True))
return completion_texts
def _sample_swarm_v2_completion_texts_with_model(
env: OSINTEnvironment,
cfg: SelfPlayTrainingConfig,
model_name_or_path: str,
prompts: list[str],
count: int,
seen_questions: list[str],
) -> list[str]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
if count <= 0 or not prompts:
return []
print(
f"[self_play][sample_generator] start model={model_name_or_path} "
f"prompts={len(prompts)} target_valid={count} "
f"max_new_tokens={cfg.generated_task_max_new_tokens}",
flush=True,
)
load_start = time.monotonic()
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
tokenizer.pad_token = tokenizer.eos_token
if getattr(tokenizer, "padding_side", "right") != "left":
tokenizer.padding_side = "left"
model_kwargs: dict[str, Any] = {}
if torch.cuda.is_available():
model_kwargs["device_map"] = "auto"
model_kwargs["torch_dtype"] = torch.bfloat16
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **model_kwargs)
model.eval()
device = next(model.parameters()).device
print(
f"[self_play][sample_generator] model_loaded device={device} "
f"load_elapsed={time.monotonic() - load_start:.1f}s",
flush=True,
)
validator = SwarmV2ReplayValidator(
graph=env.graph,
validation=cfg.swarm_v2.validation,
shared_context=cfg.swarm_v2.shared_context,
seen_questions=seen_questions,
)
completions: list[str] = []
valid_count = 0
batch_size = max(1, int(getattr(cfg.generator_phase, "generation_batch_size", 4) or 4))
max_new_tokens = max(64, int(cfg.generated_task_max_new_tokens))
decode_schedule = [(0.7, 0.9), (0.5, 0.85), (0.3, 0.8)]
overall_start = time.monotonic()
pending_indices = list(range(len(prompts)))
best_completions: dict[int, str] = {}
best_scores: dict[int, int] = {}
valid_marks: dict[int, bool] = {}
for attempt_idx, (temperature, top_p) in enumerate(decode_schedule):
if not pending_indices:
break
attempt_start = time.monotonic()
next_pending: list[int] = []
processed = 0
for batch_start in range(0, len(pending_indices), batch_size):
batch_indices = pending_indices[batch_start : batch_start + batch_size]
batch_prompts = [prompts[i] for i in batch_indices]
encoded = tokenizer(
batch_prompts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=int(getattr(cfg.generator_phase, "max_prompt_length", 1024) or 1024),
)
encoded = {key: value.to(device) for key, value in encoded.items()}
print(
f"[self_play][sample_generator] attempt={attempt_idx + 1}/{len(decode_schedule)} "
f"batch_start={batch_start}/{len(pending_indices)} "
f"batch_size={len(batch_indices)} "
f"max_new_tokens={max_new_tokens} "
f"input_len={encoded['input_ids'].shape[1]} "
f"generating...",
flush=True,
)
batch_t0 = time.monotonic()
with torch.no_grad():
output = model.generate(
**encoded,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
temperature=temperature,
num_return_sequences=1,
pad_token_id=tokenizer.eos_token_id,
)
print(
f"[self_play][sample_generator] attempt={attempt_idx + 1}/{len(decode_schedule)} "
f"batch_start={batch_start} generate_elapsed={time.monotonic() - batch_t0:.1f}s",
flush=True,
)
input_len = encoded["input_ids"].shape[1]
for row_offset, prompt_idx in enumerate(batch_indices):
completion_ids = output[row_offset][input_len:]
completion = tokenizer.decode(completion_ids, skip_special_tokens=True)
candidate = parse_generated_task_completion(
completion,
max_support_edges=cfg.swarm_v2.validation.max_support_edges,
)
validation = validator.validate(candidate)
score = (
int(bool(candidate.question))
+ int(bool(candidate.answer))
+ len(candidate.supporting_edges)
)
if validation.is_valid:
if not valid_marks.get(prompt_idx):
valid_count += 1
valid_marks[prompt_idx] = True
best_completions[prompt_idx] = completion
best_scores[prompt_idx] = score
else:
if score > best_scores.get(prompt_idx, -999):
best_scores[prompt_idx] = score
best_completions[prompt_idx] = completion
if not valid_marks.get(prompt_idx):
next_pending.append(prompt_idx)
processed += len(batch_indices)
print(
f"[self_play][sample_generator] attempt={attempt_idx + 1}/{len(decode_schedule)} "
f"processed={processed}/{len(pending_indices)} "
f"valid_so_far={valid_count}/{len(prompts)} "
f"target_valid={count} "
f"elapsed={time.monotonic() - overall_start:.1f}s",
flush=True,
)
if valid_count >= count:
break
print(
f"[self_play][sample_generator] attempt={attempt_idx + 1} done "
f"valid={valid_count}/{len(prompts)} "
f"attempt_elapsed={time.monotonic() - attempt_start:.1f}s",
flush=True,
)
if valid_count >= count:
break
pending_indices = next_pending
for prompt_idx in range(len(prompts)):
completions.append(best_completions.get(prompt_idx, ""))
print(
f"[self_play][sample_generator] finished completions={len(completions)} "
f"valid={valid_count}/{len(prompts)} target_valid={count} "
f"total_elapsed={time.monotonic() - overall_start:.1f}s",
flush=True,
)
return completions
def _materialize_swarm_v2_completions(
env: OSINTEnvironment,
cfg: SelfPlayTrainingConfig,
completion_texts: list[str],
round_index: int,
seen_questions: list[str],
prompt_canonical_candidates: list[dict[str, Any]] | None = None,
) -> tuple[list[TaskInstance], list[dict[str, Any]], list[dict[str, Any]], list[dict[str, Any]]]:
validator = SwarmV2ReplayValidator(
graph=env.graph,
validation=cfg.swarm_v2.validation,
shared_context=cfg.swarm_v2.shared_context,
seen_questions=seen_questions,
)
tasks: list[TaskInstance] = []
validation_reports: list[dict[str, Any]] = []
canonical_graph_candidates: list[dict[str, Any]] = []
replay_traces: list[dict[str, Any]] = []
for completion_idx, completion_text in enumerate(completion_texts):
use_fixed_canonical = str(cfg.canonical_graph_mode).strip().lower() == "fixed"
if use_fixed_canonical and prompt_canonical_candidates and completion_idx >= len(prompt_canonical_candidates):
break
candidate = parse_generated_task_completion(
completion_text,
max_support_edges=cfg.swarm_v2.validation.max_support_edges,
)
validation = validator.validate(candidate)
replay_edges = list(validation.replayed_edges or candidate.supporting_edges)
materialized_tool_trace = _serialize_tool_trace(candidate.tool_trace)
if not materialized_tool_trace and replay_edges:
materialized_tool_trace = build_swarm_v2_tool_trace(env.graph, replay_edges)
if use_fixed_canonical and prompt_canonical_candidates and completion_idx < len(prompt_canonical_candidates):
canonical_graph = dict(prompt_canonical_candidates[completion_idx])
else:
if candidate.canonical_edges or candidate.canonical_nodes:
canonical_edges = list(candidate.canonical_edges or candidate.supporting_edges)
canonical_nodes = list(candidate.canonical_nodes)
if not canonical_nodes:
canonical_nodes = sorted(
{edge.src for edge in canonical_edges} | {edge.dst for edge in canonical_edges}
)
canonical_graph = {
"nodes": canonical_nodes,
"edges": [
{
"src": edge.src,
"rel": edge.rel,
"dst": edge.dst,
"confidence": float(edge.confidence),
}
for edge in canonical_edges
],
}
else:
canonical_graph = build_swarm_v2_canonical_subgraph(
env.graph,
candidate.supporting_edges,
max_extra_edges=max(0, cfg.swarm_v2.shared_context.max_edges - len(candidate.supporting_edges)),
)
canonical_graph_candidates.append(
{
"candidate_index": completion_idx,
"canonical_graph": canonical_graph,
"question": candidate.question,
"answer": candidate.answer,
}
)
replay_traces.append(
{
"candidate_index": completion_idx,
"question": candidate.question,
"tool_trace": materialized_tool_trace,
"replayed_edges": validation.to_dict()["replayed_edges"],
}
)
validation_reports.append(
{
"candidate_index": completion_idx,
"question": candidate.question,
"answer": candidate.answer,
"task_type": candidate.task_type,
"validation": validation.to_dict(),
"raw_completion": completion_text,
}
)
if not validation.is_valid:
continue
if len(tasks) >= max(1, cfg.generated_tasks_per_round):
continue
metadata = {
"generated_by": "swarm_v2_generator",
"round": round_index,
"difficulty": "hard",
"scenario": "swarm_v2_trace",
"canonical_graph": canonical_graph,
"tool_trace": materialized_tool_trace,
"subagent_outputs": list(candidate.subagent_outputs),
"validation": validation.to_dict(),
"shared_context_budget": {
"max_nodes": cfg.swarm_v2.shared_context.max_nodes,
"max_edges": cfg.swarm_v2.shared_context.max_edges,
"target_pressure": cfg.swarm_v2.shared_context.target_pressure,
},
"grader": {
"type": "difficulty_exact_match",
"answer_type": "node_id",
"case_sensitive": True,
"reward_profile": "hard",
},
}
tasks.append(
TaskInstance(
task_id=f"swarm_v2_r{round_index}_{len(tasks)}",
task_type=candidate.task_type or "swarm_v2_trace",
question=candidate.question,
answer=candidate.answer,
supporting_edges=replay_edges,
metadata=metadata,
)
)
validator.remember(candidate.question)
return tasks, validation_reports, canonical_graph_candidates, replay_traces
def _run_adversarial_self_play_swarm_v2(
env_config: EnvironmentConfig,
training_config: SelfPlayTrainingConfig,
dry_run: bool = False,
) -> dict[str, Any]:
effective_dry_run = bool(dry_run or training_config.dry_run)
topology = str(training_config.model_topology).strip().lower() or "dual"
phase_schedule = str(training_config.phase_schedule).strip().lower() or "generator_answerer"
tuning_mode = str(training_config.tuning_mode).strip().lower() or "full"
run_dir = Path(training_config.output_dir)
run_dir.mkdir(parents=True, exist_ok=True)
checkpoint_repo_id = _default_hf_checkpoint_repo_id(run_dir)
if checkpoint_repo_id and _resolve_hf_upload_token():
print(f"[self_play][hf_upload] checkpoint uploads enabled -> {checkpoint_repo_id}")
else:
print("[self_play][hf_upload] checkpoint uploads disabled; set HF token and/or OSINT_HF_CHECKPOINT_REPO_ID.")
env = OSINTEnvironment(env_config, llm=build_llm_client(env_config.llm))
seed_tasks = list(env.tasks)
seed_questions = [task.question for task in seed_tasks]
generator_model, answerer_model = _resolve_initial_models(training_config)
initial_generator_model = str(generator_model)
initial_answerer_model = str(answerer_model)
rng = random.Random(env_config.seed)
bootstrap_completions = _fallback_swarm_v2_completion_texts(
env=env,
cfg=training_config,
round_index=0,
rng=rng,
)
rolling_generated_tasks, _, _, _ = _materialize_swarm_v2_completions(
env=env,
cfg=training_config,
completion_texts=bootstrap_completions,
round_index=0,
seen_questions=seed_questions,
)
if not rolling_generated_tasks:
rolling_generated_tasks = list(seed_tasks[: max(1, training_config.generated_tasks_per_round)])
rounds_payload: list[dict[str, Any]] = []
for round_index in range(1, max(1, training_config.rounds) + 1):
round_dir = run_dir / f"round_{round_index:03d}"
round_dir.mkdir(parents=True, exist_ok=True)
answerer_pre_tasks: list[TaskInstance] = []
answerer_pre_dataset_path: Path | None = None
answerer_pre_train_result: dict[str, Any] | None = None
if phase_schedule == "answerer_generator_answerer":
answerer_pre_tasks = _select_answerer_tasks(
seed_tasks=seed_tasks,
generated_tasks=rolling_generated_tasks,
cfg=training_config,
rng=rng,
)
answerer_pre_rows = _build_swarm_v2_answerer_rows(env, answerer_pre_tasks, training_config)
answerer_pre_dataset_path = round_dir / "answerer_pre_dataset.json"
_save_rows(answerer_pre_dataset_path, answerer_pre_rows)
answerer_pre_train_result = {
"model_path": answerer_model,
"global_step": 0,
"training_loss": 0.0,
"train_rows": len(answerer_pre_rows),
"skipped": effective_dry_run,
"tuning_mode": tuning_mode,
}
if not effective_dry_run:
answerer_pre_report_to, answerer_pre_run_name = _resolve_reporting(
training_config=training_config,
phase_name="answerer-pre",
round_index=round_index,
)
answerer_pre_reward = AnswererRewardFunction(
graph=env.graph,
pipeline_mode="swarm_v2",
parl_max_parallel_hint=training_config.swarm_v2.answerer_swarm.max_agents,
)
answerer_pre_train_result = _train_grpo_phase(
model_name_or_path=answerer_model,
phase=training_config.answerer_phase,
rows=answerer_pre_rows,
reward_function=answerer_pre_reward,
output_dir=round_dir / f"{training_config.answerer_phase.output_subdir}_pre",
tuning_mode=tuning_mode,
lora=training_config.lora,
report_to=answerer_pre_report_to,
run_name=answerer_pre_run_name,
)
_maybe_upload_folder_to_hf(
round_dir / f"{training_config.answerer_phase.output_subdir}_pre",
run_dir,
f"Upload answerer-pre checkpoints for round {round_index:03d}",
)
answerer_model = str(answerer_pre_train_result["model_path"])
if topology == "shared":
generator_model = answerer_model
generator_rows, prompt_canonical_candidates = _build_swarm_v2_generator_rows(env, training_config, rng)
generator_dataset_path = round_dir / "generator_dataset.json"
_save_rows(generator_dataset_path, generator_rows)
generator_train_result: dict[str, Any] = {
"model_path": generator_model,
"global_step": 0,
"training_loss": 0.0,
"train_rows": len(generator_rows),
"skipped": effective_dry_run,
"tuning_mode": tuning_mode,
"frozen_answerer_model": answerer_model,
}
if not effective_dry_run:
generator_report_to, generator_run_name = _resolve_reporting(
training_config=training_config,
phase_name="generator",
round_index=round_index,
)
generator_reward = GeneratorRewardFunction(
graph=env.graph,
answerer_judge=AnswererJudge(
model_name_or_path=answerer_model,
max_new_tokens=training_config.answerer_judge_max_new_tokens,
),
weights=training_config.generator_reward_weights,
max_support_edges=training_config.swarm_v2.validation.max_support_edges,
pipeline_mode="swarm_v2",
swarm_v2_validation=training_config.swarm_v2.validation,
swarm_v2_shared_context=training_config.swarm_v2.shared_context,
parl_max_parallel_hint=training_config.swarm_v2.generator_swarm.max_agents,
)
generator_train_result = _train_grpo_phase(
model_name_or_path=generator_model,
phase=training_config.generator_phase,
rows=generator_rows,
reward_function=generator_reward,
output_dir=round_dir / training_config.generator_phase.output_subdir,
tuning_mode=tuning_mode,
lora=training_config.lora,
report_to=generator_report_to,
run_name=generator_run_name,
)
_maybe_upload_folder_to_hf(
round_dir / training_config.generator_phase.output_subdir,
run_dir,
f"Upload generator checkpoints for round {round_index:03d}",
)
generator_model = str(generator_train_result["model_path"])
if topology == "shared":
answerer_model = generator_model
if effective_dry_run:
completion_texts = _fallback_swarm_v2_completion_texts(
env=env,
cfg=training_config,
round_index=round_index,
rng=rng,
)
else:
completion_texts = _sample_swarm_v2_completion_texts_with_model(
env=env,
cfg=training_config,
model_name_or_path=generator_model,
prompts=[row["prompt"] for row in generator_rows],
count=max(1, training_config.generated_tasks_per_round * 2),
seen_questions=seed_questions + [task.question for task in rolling_generated_tasks],
)
if not completion_texts:
completion_texts = _fallback_swarm_v2_completion_texts(
env=env,
cfg=training_config,
round_index=round_index,
rng=rng,
)
generated_tasks, validation_reports, canonical_graph_candidates, replay_traces = _materialize_swarm_v2_completions(
env=env,
cfg=training_config,
completion_texts=completion_texts,
round_index=round_index,
seen_questions=seed_questions + [task.question for task in rolling_generated_tasks],
prompt_canonical_candidates=prompt_canonical_candidates,
)
if not generated_tasks:
fallback_completions = _fallback_swarm_v2_completion_texts(
env=env,
cfg=training_config,
round_index=round_index,
rng=rng,
)
generated_tasks, validation_reports, canonical_graph_candidates, replay_traces = _materialize_swarm_v2_completions(
env=env,
cfg=training_config,
completion_texts=fallback_completions,
round_index=round_index,
seen_questions=seed_questions + [task.question for task in rolling_generated_tasks],
prompt_canonical_candidates=None,
)
if generated_tasks:
rolling_generated_tasks = list(generated_tasks)
canonical_graph_candidates_path = round_dir / "canonical_graph_candidates.json"
replay_traces_path = round_dir / "replay_traces.json"
validation_reports_path = round_dir / "validation_reports.json"
generated_tasks_path = round_dir / "generated_tasks.json"
_save_payload(canonical_graph_candidates_path, prompt_canonical_candidates or canonical_graph_candidates)
_save_payload(replay_traces_path, replay_traces)
_save_payload(validation_reports_path, validation_reports)
_save_tasks(generated_tasks_path, generated_tasks)
answerer_tasks = _select_answerer_tasks(
seed_tasks=seed_tasks,
generated_tasks=generated_tasks,
cfg=training_config,
rng=rng,
)
answerer_rows = _build_swarm_v2_answerer_rows(env, answerer_tasks, training_config)
answerer_dataset_path = round_dir / "answerer_dataset.json"
_save_rows(answerer_dataset_path, answerer_rows)
answerer_train_result: dict[str, Any] = {
"model_path": answerer_model,
"global_step": 0,
"training_loss": 0.0,
"train_rows": len(answerer_rows),
"skipped": effective_dry_run,
"tuning_mode": tuning_mode,
}
if not effective_dry_run:
answerer_report_to, answerer_run_name = _resolve_reporting(
training_config=training_config,
phase_name="answerer",
round_index=round_index,
)
answerer_reward = AnswererRewardFunction(
graph=env.graph,
pipeline_mode="swarm_v2",
parl_max_parallel_hint=training_config.swarm_v2.answerer_swarm.max_agents,
)
answerer_train_result = _train_grpo_phase(
model_name_or_path=answerer_model,
phase=training_config.answerer_phase,
rows=answerer_rows,
reward_function=answerer_reward,
output_dir=round_dir / training_config.answerer_phase.output_subdir,
tuning_mode=tuning_mode,
lora=training_config.lora,
report_to=answerer_report_to,
run_name=answerer_run_name,
)
_maybe_upload_folder_to_hf(
round_dir / training_config.answerer_phase.output_subdir,
run_dir,
f"Upload answerer checkpoints for round {round_index:03d}",
)
answerer_model = str(answerer_train_result["model_path"])
if topology == "shared":
generator_model = answerer_model
rounds_payload.append(
{
"round": round_index,
"dry_run": effective_dry_run,
"pipeline_mode": "swarm_v2",
"phase_schedule": phase_schedule,
"generator": generator_train_result,
"answerer": answerer_train_result,
"answerer_pre": answerer_pre_train_result,
"generated_task_count": len(generated_tasks),
"answerer_task_count": len(answerer_tasks),
"answerer_pre_task_count": len(answerer_pre_tasks),
"artifacts": {
"generator_dataset": str(generator_dataset_path),
"answerer_dataset": str(answerer_dataset_path),
"generated_tasks": str(generated_tasks_path),
"canonical_graph_candidates": str(canonical_graph_candidates_path),
"replay_traces": str(replay_traces_path),
"validation_reports": str(validation_reports_path),
"answerer_pre_dataset": str(answerer_pre_dataset_path) if answerer_pre_dataset_path else "",
},
}
)
post_training_evaluation = _run_post_training_evaluation(
env_config=env_config,
training_config=training_config,
generator_model=generator_model,
answerer_models={
"finetuned_answerer": answerer_model,
"original_answerer": initial_answerer_model,
},
output_dir=run_dir,
pipeline_mode="swarm_v2",
effective_dry_run=effective_dry_run,
)
final_payload = {
"dry_run": effective_dry_run,
"pipeline_mode": "swarm_v2",
"output_dir": str(run_dir),
"model_topology": topology,
"phase_schedule": phase_schedule,
"tuning_mode": tuning_mode,
"canonical_graph_mode": str(training_config.canonical_graph_mode).strip().lower() or "generate",
"rounds": rounds_payload,
"final_models": {
"generator": generator_model,
"answerer": answerer_model,
},
"initial_models": {
"generator": initial_generator_model,
"answerer": initial_answerer_model,
},
"post_training_evaluation": post_training_evaluation,
"kimi_objective_mapping": {
"grouped_rollouts": "TRL GRPO num_generations",
"mean_centered_advantage": "GRPO relative reward baseline",
"token_level_clipping": "GRPO epsilon clipping over policy ratios",
"reference_regularization": "GRPO beta KL term",
"toggle_self_play": "Alternating generator and answerer rounds",
"parallel_orchestration": "PARL-inspired auxiliary reward over generator and answerer swarms",
},
}
summary_path = run_dir / "self_play_summary.json"
summary_path.write_text(json.dumps(final_payload, indent=2, sort_keys=True), encoding="utf-8")
final_payload["summary_path"] = str(summary_path)
_maybe_upload_file_to_hf(summary_path, run_dir, "Upload self-play summary")
_maybe_upload_file_to_hf(run_dir / "post_training_evaluation.json", run_dir, "Upload post-training evaluation")
return final_payload
def run_adversarial_self_play(
env_config: EnvironmentConfig,
training_config: SelfPlayTrainingConfig,
dry_run: bool = False,
) -> dict[str, Any]:
if str(training_config.pipeline_mode).strip().lower() == "swarm_v2":
return _run_adversarial_self_play_swarm_v2(
env_config=env_config,
training_config=training_config,
dry_run=dry_run,
)
effective_dry_run = bool(dry_run or training_config.dry_run)
topology = str(training_config.model_topology).strip().lower() or "dual"
phase_schedule = str(training_config.phase_schedule).strip().lower() or "generator_answerer"
tuning_mode = str(training_config.tuning_mode).strip().lower() or "full"
run_dir = Path(training_config.output_dir)
run_dir.mkdir(parents=True, exist_ok=True)
checkpoint_repo_id = _default_hf_checkpoint_repo_id(run_dir)
if checkpoint_repo_id and _resolve_hf_upload_token():
print(f"[self_play][hf_upload] checkpoint uploads enabled -> {checkpoint_repo_id}")
else:
print("[self_play][hf_upload] checkpoint uploads disabled; set HF token and/or OSINT_HF_CHECKPOINT_REPO_ID.")
env = OSINTEnvironment(env_config, llm=build_llm_client(env_config.llm))
seed_tasks = list(env.tasks)
generator_model, answerer_model = _resolve_initial_models(training_config)
initial_generator_model = str(generator_model)
initial_answerer_model = str(answerer_model)
rng = random.Random(env_config.seed)
rounds_payload: list[dict[str, Any]] = []
rolling_generated_tasks = _fallback_generated_tasks(
base_tasks=seed_tasks,
round_index=0,
count=training_config.generated_tasks_per_round,
rng=rng,
)
if not rolling_generated_tasks:
rolling_generated_tasks = list(seed_tasks[: max(1, training_config.generated_tasks_per_round)])
for round_index in range(1, max(1, training_config.rounds) + 1):
round_dir = run_dir / f"round_{round_index:03d}"
round_dir.mkdir(parents=True, exist_ok=True)
answerer_pre_tasks: list[TaskInstance] = []
answerer_pre_dataset_path: Path | None = None
answerer_pre_train_result: dict[str, Any] | None = None
if phase_schedule == "answerer_generator_answerer":
answerer_pre_tasks = _select_answerer_tasks(
seed_tasks=seed_tasks,
generated_tasks=rolling_generated_tasks,
cfg=training_config,
rng=rng,
)
answerer_pre_rows = _build_answerer_rows(answerer_pre_tasks)
answerer_pre_dataset_path = round_dir / "answerer_pre_dataset.json"
_save_rows(answerer_pre_dataset_path, answerer_pre_rows)
answerer_pre_train_result = {
"model_path": answerer_model,
"global_step": 0,
"training_loss": 0.0,
"train_rows": len(answerer_pre_rows),
"skipped": effective_dry_run,
"tuning_mode": tuning_mode,
}
if not effective_dry_run:
answerer_pre_report_to, answerer_pre_run_name = _resolve_reporting(
training_config=training_config,
phase_name="answerer-pre",
round_index=round_index,
)
answerer_pre_reward = AnswererRewardFunction(graph=env.graph)
answerer_pre_train_result = _train_grpo_phase(
model_name_or_path=answerer_model,
phase=training_config.answerer_phase,
rows=answerer_pre_rows,
reward_function=answerer_pre_reward,
output_dir=round_dir / f"{training_config.answerer_phase.output_subdir}_pre",
tuning_mode=tuning_mode,
lora=training_config.lora,
report_to=answerer_pre_report_to,
run_name=answerer_pre_run_name,
)
_maybe_upload_folder_to_hf(
round_dir / f"{training_config.answerer_phase.output_subdir}_pre",
run_dir,
f"Upload answerer-pre checkpoints for round {round_index:03d}",
)
answerer_model = str(answerer_pre_train_result["model_path"])
if topology == "shared":
generator_model = answerer_model
generator_rows = _build_generator_rows(env=env, cfg=training_config, rng=rng)
generator_dataset_path = round_dir / "generator_dataset.json"
_save_rows(generator_dataset_path, generator_rows)
generator_train_result: dict[str, Any] = {
"model_path": generator_model,
"global_step": 0,
"training_loss": 0.0,
"train_rows": len(generator_rows),
"skipped": effective_dry_run,
"tuning_mode": tuning_mode,
}
if not effective_dry_run:
generator_report_to, generator_run_name = _resolve_reporting(
training_config=training_config,
phase_name="generator",
round_index=round_index,
)
generator_reward = GeneratorRewardFunction(
graph=env.graph,
answerer_judge=AnswererJudge(
model_name_or_path=answerer_model,
max_new_tokens=training_config.answerer_judge_max_new_tokens,
),
weights=training_config.generator_reward_weights,
max_support_edges=training_config.max_support_edges,
)
generator_train_result = _train_grpo_phase(
model_name_or_path=generator_model,
phase=training_config.generator_phase,
rows=generator_rows,
reward_function=generator_reward,
output_dir=round_dir / training_config.generator_phase.output_subdir,
tuning_mode=tuning_mode,
lora=training_config.lora,
report_to=generator_report_to,
run_name=generator_run_name,
)
_maybe_upload_folder_to_hf(
round_dir / training_config.generator_phase.output_subdir,
run_dir,
f"Upload generator checkpoints for round {round_index:03d}",
)
generator_model = str(generator_train_result["model_path"])
if topology == "shared":
answerer_model = generator_model
generated_tasks: list[TaskInstance]
if effective_dry_run:
generated_tasks = _fallback_generated_tasks(
base_tasks=seed_tasks,
round_index=round_index,
count=training_config.generated_tasks_per_round,
rng=rng,
)
else:
generated_tasks = _sample_generated_tasks_with_model(
model_name_or_path=generator_model,
prompts=[row["prompt"] for row in generator_rows],
round_index=round_index,
count=training_config.generated_tasks_per_round,
max_support_edges=training_config.max_support_edges,
max_new_tokens=training_config.generated_task_max_new_tokens,
)
if not generated_tasks:
generated_tasks = _fallback_generated_tasks(
base_tasks=seed_tasks,
round_index=round_index,
count=training_config.generated_tasks_per_round,
rng=rng,
)
if generated_tasks:
rolling_generated_tasks = list(generated_tasks)
generated_tasks_path = round_dir / "generated_tasks.json"
_save_tasks(generated_tasks_path, generated_tasks)
answerer_tasks = _select_answerer_tasks(
seed_tasks=seed_tasks,
generated_tasks=generated_tasks,
cfg=training_config,
rng=rng,
)
answerer_rows = _build_answerer_rows(answerer_tasks)
answerer_dataset_path = round_dir / "answerer_dataset.json"
_save_rows(answerer_dataset_path, answerer_rows)
answerer_train_result: dict[str, Any] = {
"model_path": answerer_model,
"global_step": 0,
"training_loss": 0.0,
"train_rows": len(answerer_rows),
"skipped": effective_dry_run,
"tuning_mode": tuning_mode,
}
if not effective_dry_run:
answerer_report_to, answerer_run_name = _resolve_reporting(
training_config=training_config,
phase_name="answerer",
round_index=round_index,
)
answerer_reward = AnswererRewardFunction(graph=env.graph)
answerer_train_result = _train_grpo_phase(
model_name_or_path=answerer_model,
phase=training_config.answerer_phase,
rows=answerer_rows,
reward_function=answerer_reward,
output_dir=round_dir / training_config.answerer_phase.output_subdir,
tuning_mode=tuning_mode,
lora=training_config.lora,
report_to=answerer_report_to,
run_name=answerer_run_name,
)
_maybe_upload_folder_to_hf(
round_dir / training_config.answerer_phase.output_subdir,
run_dir,
f"Upload answerer checkpoints for round {round_index:03d}",
)
answerer_model = str(answerer_train_result["model_path"])
if topology == "shared":
generator_model = answerer_model
artifacts = _RoundArtifacts(
round_index=round_index,
generator_dataset_path=str(generator_dataset_path),
answerer_dataset_path=str(answerer_dataset_path),
generated_tasks_path=str(generated_tasks_path),
)
rounds_payload.append(
{
"round": round_index,
"dry_run": effective_dry_run,
"pipeline_mode": "legacy",
"phase_schedule": phase_schedule,
"generator": generator_train_result,
"answerer": answerer_train_result,
"answerer_pre": answerer_pre_train_result,
"generated_task_count": len(generated_tasks),
"answerer_task_count": len(answerer_tasks),
"answerer_pre_task_count": len(answerer_pre_tasks),
"artifacts": {
"generator_dataset": artifacts.generator_dataset_path,
"answerer_dataset": artifacts.answerer_dataset_path,
"generated_tasks": artifacts.generated_tasks_path,
"answerer_pre_dataset": str(answerer_pre_dataset_path) if answerer_pre_dataset_path else "",
},
}
)
post_training_evaluation = _run_post_training_evaluation(
env_config=env_config,
training_config=training_config,
generator_model=generator_model,
answerer_models={
"finetuned_answerer": answerer_model,
"original_answerer": initial_answerer_model,
},
output_dir=run_dir,
pipeline_mode="legacy",
effective_dry_run=effective_dry_run,
)
final_payload = {
"dry_run": effective_dry_run,
"pipeline_mode": "legacy",
"output_dir": str(run_dir),
"model_topology": topology,
"phase_schedule": phase_schedule,
"tuning_mode": tuning_mode,
"canonical_graph_mode": str(training_config.canonical_graph_mode).strip().lower() or "generate",
"rounds": rounds_payload,
"final_models": {
"generator": generator_model,
"answerer": answerer_model,
},
"initial_models": {
"generator": initial_generator_model,
"answerer": initial_answerer_model,
},
"post_training_evaluation": post_training_evaluation,
"kimi_objective_mapping": {
"grouped_rollouts": "TRL GRPO num_generations",
"mean_centered_advantage": "GRPO relative reward baseline",
"token_level_clipping": "GRPO epsilon clipping over policy ratios",
"reference_regularization": "GRPO beta KL term",
"toggle_self_play": "Alternating generator and answerer rounds",
},
}
summary_path = run_dir / "self_play_summary.json"
summary_path.write_text(json.dumps(final_payload, indent=2, sort_keys=True), encoding="utf-8")
final_payload["summary_path"] = str(summary_path)
_maybe_upload_file_to_hf(summary_path, run_dir, "Upload self-play summary")
_maybe_upload_file_to_hf(run_dir / "post_training_evaluation.json", run_dir, "Upload post-training evaluation")
return final_payload