OSINT / src /osint_env /cli.py
siddeshwar-kagatikar
fix(rewards): never crash GRPO on malformed completions
d814291
from __future__ import annotations
import argparse
import json
from pathlib import Path
from osint_env.agents.single_agent import SingleAgentRunner
from osint_env.agents.swarm_agent import SwarmAgentRunner
from osint_env.config import clone_environment_config, load_seeding_config, load_shared_config
from osint_env.domain.models import EnvironmentConfig
from osint_env.env.environment import OSINTEnvironment
from osint_env.env.reward import compute_graph_f1
from osint_env.eval.leaderboard import append_leaderboard_record, load_leaderboard, render_leaderboard_table
from osint_env.eval.runner import run_evaluation
from osint_env.llm import build_llm_client
from osint_env.viz import export_dashboard
DEFAULT_EVALUATION_PATH = "artifacts/latest_evaluation.json"
def _save_evaluation(path: str, payload: dict) -> None:
out = Path(path)
out.parent.mkdir(parents=True, exist_ok=True)
out.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
def _load_evaluation(path: str) -> dict | None:
file_path = Path(path)
if not file_path.exists():
return None
try:
data = json.loads(file_path.read_text(encoding="utf-8"))
except json.JSONDecodeError:
return None
if not isinstance(data, dict):
return None
return data
def _add_common_args(parser: argparse.ArgumentParser) -> None:
parser.add_argument("--config", type=str, default="config/shared_config.json")
parser.add_argument("--seed-file", type=str, default="")
parser.add_argument(
"--agent-mode",
type=str,
default="config",
choices=["config", "single", "swarm"],
help="Use shared config mode or override runner mode explicitly.",
)
parser.add_argument(
"--llm-provider",
type=str,
default="config",
choices=["config", "mock", "ollama", "openai"],
help="Use shared config provider or override explicitly.",
)
parser.add_argument("--llm-model", type=str, default="", help="Override model name for selected LLM provider.")
parser.add_argument("--llm-timeout-seconds", type=int, default=0, help="Override LLM request timeout in seconds.")
parser.add_argument("--ollama-base-url", type=str, default="", help="Override Ollama base URL.")
parser.add_argument("--openai-base-url", type=str, default="", help="Override OpenAI base URL.")
parser.add_argument("--openai-api-key", type=str, default="", help="OpenAI API key override.")
parser.add_argument(
"--openai-api-key-env",
type=str,
default="",
help="Environment variable name for OpenAI API key.",
)
parser.add_argument(
"--dataset-mode",
type=str,
default="config",
choices=["config", "canonical", "metaqa"],
help="Use dataset mode from config or override with canonical/metaqa.",
)
parser.add_argument("--metaqa-root", type=str, default="", help="Override MetaQA dataset root directory.")
parser.add_argument(
"--metaqa-kb-path",
type=str,
default="",
help="Override MetaQA KB triples file path. Defaults to <metaqa-root>/kb.txt.",
)
parser.add_argument(
"--metaqa-variant",
type=str,
default="",
choices=["", "vanilla", "ntm"],
help="Override MetaQA QA variant.",
)
parser.add_argument(
"--metaqa-hops",
type=str,
default="",
help="Comma-separated hop buckets for MetaQA mode (example: 1-hop,2-hop,3-hop).",
)
parser.add_argument(
"--metaqa-splits",
type=str,
default="",
help="Comma-separated splits for MetaQA mode (example: train,dev,test).",
)
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(prog="osint-env")
sub = parser.add_subparsers(dest="cmd", required=True)
d = sub.add_parser("demo", help="Run one episode and print debug info.")
_add_common_args(d)
e = sub.add_parser("eval", help="Run multiple episodes and show aggregate metrics.")
_add_common_args(e)
e.add_argument("--episodes", type=int, default=0)
e.add_argument("--dashboard", type=str, default="")
b = sub.add_parser("benchmark", help="Run eval, update leaderboard, and export interactive dashboard.")
_add_common_args(b)
b.add_argument("--episodes", type=int, default=0)
b.add_argument("--name", type=str, default="")
b.add_argument("--leaderboard", type=str, default="")
b.add_argument("--dashboard", type=str, default="")
l = sub.add_parser("leaderboard", help="Print ranked benchmark leaderboard.")
_add_common_args(l)
l.add_argument("--leaderboard", type=str, default="")
l.add_argument("--top", type=int, default=20)
l.add_argument(
"--sort-by",
type=str,
default="leaderboard_score",
choices=[
"leaderboard_score",
"task_success_rate",
"avg_graph_f1",
"tool_efficiency",
"avg_reward",
"retrieval_signal",
"structural_signal",
"deanonymization_accuracy",
"spawn_signal",
],
)
s = sub.add_parser("benchmark-sweep", help="Run benchmark across multiple seeds and append all runs to leaderboard.")
_add_common_args(s)
s.add_argument("--episodes", type=int, default=0)
s.add_argument("--seeds", type=str, default="7,11,17,23,31")
s.add_argument("--name-prefix", type=str, default="sweep")
s.add_argument("--leaderboard", type=str, default="")
s.add_argument("--dashboard-dir", type=str, default="")
v = sub.add_parser("viz", help="Export an interactive graph/database explorer.")
_add_common_args(v)
v.add_argument("--output", type=str, default="artifacts/osint_explorer.html")
v.add_argument("--with-demo", action="store_true")
v.add_argument("--leaderboard", type=str, default="")
v.add_argument(
"--evaluation",
type=str,
default=DEFAULT_EVALUATION_PATH,
help="Path to a saved evaluation payload with episode details.",
)
t = sub.add_parser(
"train-self-play",
help="Run adversarial self-play fine-tuning scaffold with Hugging Face TRL (Kimi-style alternating phases).",
)
_add_common_args(t)
t.add_argument(
"--train-config",
type=str,
default="config/self_play_training_example.json",
help="Path to self-play training JSON config.",
)
t.add_argument(
"--train-output-dir",
type=str,
default="",
help="Optional output dir override for self-play artifacts and checkpoints.",
)
t.add_argument(
"--train-rounds",
type=int,
default=0,
help="Optional override for the number of self-play rounds.",
)
t.add_argument(
"--dry-run",
action="store_true",
help="Skip actual GRPO updates and only materialize datasets/round artifacts.",
)
return parser
def _resolve_environment_config(args: argparse.Namespace) -> tuple[EnvironmentConfig, dict[str, str | int]]:
shared = load_shared_config(args.config)
env_cfg = clone_environment_config(shared.environment)
if args.seed_file:
env_cfg.seeding = load_seeding_config(args.seed_file)
if args.llm_provider != "config":
env_cfg.llm.provider = args.llm_provider
if args.llm_model:
env_cfg.llm.model = args.llm_model
if int(args.llm_timeout_seconds) > 0:
env_cfg.llm.timeout_seconds = int(args.llm_timeout_seconds)
if args.ollama_base_url:
env_cfg.llm.ollama_base_url = args.ollama_base_url
if args.openai_base_url:
env_cfg.llm.openai_base_url = args.openai_base_url
if args.openai_api_key:
env_cfg.llm.openai_api_key = args.openai_api_key
if args.openai_api_key_env:
env_cfg.llm.openai_api_key_env = args.openai_api_key_env
if args.dataset_mode != "config":
env_cfg.dataset_mode = args.dataset_mode
if args.metaqa_root:
env_cfg.metaqa_root = args.metaqa_root
if args.metaqa_kb_path:
env_cfg.metaqa_kb_path = args.metaqa_kb_path
if args.metaqa_variant:
env_cfg.metaqa_variant = args.metaqa_variant
if args.metaqa_hops:
env_cfg.metaqa_hops = [item.strip() for item in str(args.metaqa_hops).split(",") if item.strip()]
if args.metaqa_splits:
env_cfg.metaqa_splits = [item.strip() for item in str(args.metaqa_splits).split(",") if item.strip()]
if args.agent_mode == "single":
env_cfg.swarm.enabled = False
elif args.agent_mode == "swarm":
env_cfg.swarm.enabled = True
runtime = {
"default_episodes": shared.runtime.default_episodes,
"leaderboard_path": shared.runtime.leaderboard_path,
"dashboard_path": shared.runtime.dashboard_path,
"sweep_dashboard_dir": shared.runtime.sweep_dashboard_dir,
}
return env_cfg, runtime
def _runner_for(env: OSINTEnvironment) -> SingleAgentRunner | SwarmAgentRunner:
if env.config.swarm.enabled:
return SwarmAgentRunner(env, llm=build_llm_client(env.config.llm))
return SingleAgentRunner(env, llm=build_llm_client(env.config.llm))
def main() -> None:
args = build_parser().parse_args()
env_cfg, runtime = _resolve_environment_config(args)
episodes = int(args.episodes) if getattr(args, "episodes", 0) else int(runtime["default_episodes"])
leaderboard_path = str(args.leaderboard) if getattr(args, "leaderboard", "") else str(runtime["leaderboard_path"])
dashboard_path = str(args.dashboard) if getattr(args, "dashboard", "") else str(runtime["dashboard_path"])
sweep_dashboard_dir = (
str(args.dashboard_dir) if getattr(args, "dashboard_dir", "") else str(runtime["sweep_dashboard_dir"])
)
evaluation_path = str(getattr(args, "evaluation", "") or DEFAULT_EVALUATION_PATH)
if args.cmd == "leaderboard":
records = load_leaderboard(leaderboard_path)
print(render_leaderboard_table(records, top_k=args.top, sort_by=args.sort_by))
return
if args.cmd == "benchmark-sweep":
seed_values = [int(x.strip()) for x in args.seeds.split(",") if x.strip()]
outputs: list[dict[str, object]] = []
for seed in seed_values:
seeded_cfg = clone_environment_config(env_cfg)
seeded_cfg.seed = seed
env = OSINTEnvironment(seeded_cfg, llm=build_llm_client(seeded_cfg.llm))
evaluation = run_evaluation(env, episodes=episodes, return_details=True, llm=build_llm_client(seeded_cfg.llm))
summary = evaluation["summary"]
run_name = f"{args.name_prefix}_seed{seed}"
record = append_leaderboard_record(
path=leaderboard_path,
summary=summary,
episodes=episodes,
run_name=run_name,
config={
"seed": seed,
"max_steps": env.config.max_steps,
"swarm_enabled": env.config.swarm.enabled,
"max_agents": env.config.swarm.max_agents,
"max_breadth": env.config.swarm.max_breadth,
"max_width": env.config.swarm.max_width,
"max_depth": env.config.swarm.max_depth,
"seeded_questions": len(env.config.seeding.seeded_questions),
},
)
dashboard_path = export_dashboard(
env=env,
evaluation=evaluation,
leaderboard_records=load_leaderboard(leaderboard_path),
output_path=f"{sweep_dashboard_dir}/{run_name}.html",
)
_save_evaluation(DEFAULT_EVALUATION_PATH, evaluation)
outputs.append({"seed": seed, "record": record, "dashboard": dashboard_path, "summary": summary})
records = load_leaderboard(leaderboard_path)
print(
json.dumps(
{
"runs": outputs,
"leaderboard_preview": render_leaderboard_table(records, top_k=min(10, len(records))),
},
indent=2,
sort_keys=True,
)
)
return
if args.cmd == "train-self-play":
from osint_env.training import load_self_play_config, run_adversarial_self_play
train_cfg = load_self_play_config(args.train_config)
if str(args.train_output_dir).strip():
train_cfg.output_dir = str(args.train_output_dir).strip()
if int(args.train_rounds) > 0:
train_cfg.rounds = int(args.train_rounds)
payload = run_adversarial_self_play(
env_config=env_cfg,
training_config=train_cfg,
dry_run=bool(args.dry_run),
)
print(json.dumps(payload, indent=2, sort_keys=True))
return
llm_client = build_llm_client(env_cfg.llm)
env = OSINTEnvironment(env_cfg, llm=llm_client)
if args.cmd == "demo":
info = _runner_for(env).run_episode()
print(json.dumps(info, indent=2, sort_keys=True))
elif args.cmd == "eval":
evaluation = run_evaluation(env, episodes=episodes, return_details=True, llm=llm_client)
_save_evaluation(DEFAULT_EVALUATION_PATH, evaluation)
leaderboard = load_leaderboard(leaderboard_path)
export_dashboard(
env=env,
evaluation=evaluation,
leaderboard_records=leaderboard,
output_path=dashboard_path,
)
print(json.dumps(evaluation["summary"], indent=2, sort_keys=True))
elif args.cmd == "benchmark":
evaluation = run_evaluation(env, episodes=episodes, return_details=True, llm=llm_client)
summary = evaluation["summary"]
record = append_leaderboard_record(
path=leaderboard_path,
summary=summary,
episodes=episodes,
run_name=args.name or None,
config={
"seed": env.config.seed,
"max_steps": env.config.max_steps,
"swarm_enabled": env.config.swarm.enabled,
"max_agents": env.config.swarm.max_agents,
"max_breadth": env.config.swarm.max_breadth,
"max_width": env.config.swarm.max_width,
"max_depth": env.config.swarm.max_depth,
"seeded_questions": len(env.config.seeding.seeded_questions),
},
)
leaderboard = load_leaderboard(leaderboard_path)
dashboard_path = export_dashboard(
env=env,
evaluation=evaluation,
leaderboard_records=leaderboard,
output_path=dashboard_path,
)
_save_evaluation(DEFAULT_EVALUATION_PATH, evaluation)
payload = {
"record": record,
"summary": summary,
"dashboard": dashboard_path,
}
print(json.dumps(payload, indent=2, sort_keys=True))
elif args.cmd == "viz":
evaluation: dict | None = _load_evaluation(evaluation_path)
if args.with_demo:
_runner_for(env).run_episode()
info = {
"agent_answer": env.state.answer if env.state else "",
"task_answer": env.state.task.answer if env.state else "",
"total_reward": env.state.total_reward if env.state else 0.0,
"step_count": env.state.step_count if env.state else 0,
"tool_calls": env.state.tool_calls if env.state else 0,
}
evaluation = {
"summary": {
"task_success_rate": float(info["agent_answer"] == info["task_answer"]),
"tool_efficiency": 0.0,
"avg_graph_f1": 0.0,
"avg_steps_to_solution": float(info["step_count"]),
"deanonymization_accuracy": 0.0,
"avg_reward": float(info["total_reward"]),
"leaderboard_score": 0.0,
},
"episodes": [
{
"task_id": env.state.task.task_id if env.state else "n/a",
"task_type": env.state.task.task_type if env.state else "n/a",
"question": env.state.task.question if env.state else "n/a",
"task_answer": str(info["task_answer"]),
"agent_answer": str(info["agent_answer"]),
"graph_f1": 0.0,
"reward": float(info["total_reward"]),
"steps": int(info["step_count"]),
"tool_calls": int(info["tool_calls"]),
"success": int(info["agent_answer"] == info["task_answer"]),
}
],
}
graph_f1 = 0.0
if env.state is not None:
graph_f1 = compute_graph_f1(env.memory_graph.edges, env.state.task.supporting_edges)
if evaluation is None:
summary = {
"task_success_rate": 0.0,
"tool_efficiency": 0.0,
"avg_graph_f1": graph_f1,
"avg_steps_to_solution": float(env.state.step_count) if env.state else 0.0,
"deanonymization_accuracy": 0.0,
"avg_reward": float(env.state.total_reward) if env.state else 0.0,
"leaderboard_score": 0.0,
}
evaluation = {"summary": summary, "episodes": []}
leaderboard = load_leaderboard(leaderboard_path)
out = export_dashboard(env=env, evaluation=evaluation, leaderboard_records=leaderboard, output_path=args.output)
print(json.dumps({"dashboard": out, "evaluation": evaluation_path}, indent=2, sort_keys=True))
if __name__ == "__main__":
main()