Spaces:
Paused
Paused
| from __future__ import annotations | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| from osint_env.agents.single_agent import SingleAgentRunner | |
| from osint_env.agents.swarm_agent import SwarmAgentRunner | |
| from osint_env.config import clone_environment_config, load_seeding_config, load_shared_config | |
| from osint_env.domain.models import EnvironmentConfig | |
| from osint_env.env.environment import OSINTEnvironment | |
| from osint_env.env.reward import compute_graph_f1 | |
| from osint_env.eval.leaderboard import append_leaderboard_record, load_leaderboard, render_leaderboard_table | |
| from osint_env.eval.runner import run_evaluation | |
| from osint_env.llm import build_llm_client | |
| from osint_env.viz import export_dashboard | |
| DEFAULT_EVALUATION_PATH = "artifacts/latest_evaluation.json" | |
| def _save_evaluation(path: str, payload: dict) -> None: | |
| out = Path(path) | |
| out.parent.mkdir(parents=True, exist_ok=True) | |
| out.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8") | |
| def _load_evaluation(path: str) -> dict | None: | |
| file_path = Path(path) | |
| if not file_path.exists(): | |
| return None | |
| try: | |
| data = json.loads(file_path.read_text(encoding="utf-8")) | |
| except json.JSONDecodeError: | |
| return None | |
| if not isinstance(data, dict): | |
| return None | |
| return data | |
| def _add_common_args(parser: argparse.ArgumentParser) -> None: | |
| parser.add_argument("--config", type=str, default="config/shared_config.json") | |
| parser.add_argument("--seed-file", type=str, default="") | |
| parser.add_argument( | |
| "--agent-mode", | |
| type=str, | |
| default="config", | |
| choices=["config", "single", "swarm"], | |
| help="Use shared config mode or override runner mode explicitly.", | |
| ) | |
| parser.add_argument( | |
| "--llm-provider", | |
| type=str, | |
| default="config", | |
| choices=["config", "mock", "ollama", "openai"], | |
| help="Use shared config provider or override explicitly.", | |
| ) | |
| parser.add_argument("--llm-model", type=str, default="", help="Override model name for selected LLM provider.") | |
| parser.add_argument("--llm-timeout-seconds", type=int, default=0, help="Override LLM request timeout in seconds.") | |
| parser.add_argument("--ollama-base-url", type=str, default="", help="Override Ollama base URL.") | |
| parser.add_argument("--openai-base-url", type=str, default="", help="Override OpenAI base URL.") | |
| parser.add_argument("--openai-api-key", type=str, default="", help="OpenAI API key override.") | |
| parser.add_argument( | |
| "--openai-api-key-env", | |
| type=str, | |
| default="", | |
| help="Environment variable name for OpenAI API key.", | |
| ) | |
| parser.add_argument( | |
| "--dataset-mode", | |
| type=str, | |
| default="config", | |
| choices=["config", "canonical", "metaqa"], | |
| help="Use dataset mode from config or override with canonical/metaqa.", | |
| ) | |
| parser.add_argument("--metaqa-root", type=str, default="", help="Override MetaQA dataset root directory.") | |
| parser.add_argument( | |
| "--metaqa-kb-path", | |
| type=str, | |
| default="", | |
| help="Override MetaQA KB triples file path. Defaults to <metaqa-root>/kb.txt.", | |
| ) | |
| parser.add_argument( | |
| "--metaqa-variant", | |
| type=str, | |
| default="", | |
| choices=["", "vanilla", "ntm"], | |
| help="Override MetaQA QA variant.", | |
| ) | |
| parser.add_argument( | |
| "--metaqa-hops", | |
| type=str, | |
| default="", | |
| help="Comma-separated hop buckets for MetaQA mode (example: 1-hop,2-hop,3-hop).", | |
| ) | |
| parser.add_argument( | |
| "--metaqa-splits", | |
| type=str, | |
| default="", | |
| help="Comma-separated splits for MetaQA mode (example: train,dev,test).", | |
| ) | |
| def build_parser() -> argparse.ArgumentParser: | |
| parser = argparse.ArgumentParser(prog="osint-env") | |
| sub = parser.add_subparsers(dest="cmd", required=True) | |
| d = sub.add_parser("demo", help="Run one episode and print debug info.") | |
| _add_common_args(d) | |
| e = sub.add_parser("eval", help="Run multiple episodes and show aggregate metrics.") | |
| _add_common_args(e) | |
| e.add_argument("--episodes", type=int, default=0) | |
| e.add_argument("--dashboard", type=str, default="") | |
| b = sub.add_parser("benchmark", help="Run eval, update leaderboard, and export interactive dashboard.") | |
| _add_common_args(b) | |
| b.add_argument("--episodes", type=int, default=0) | |
| b.add_argument("--name", type=str, default="") | |
| b.add_argument("--leaderboard", type=str, default="") | |
| b.add_argument("--dashboard", type=str, default="") | |
| l = sub.add_parser("leaderboard", help="Print ranked benchmark leaderboard.") | |
| _add_common_args(l) | |
| l.add_argument("--leaderboard", type=str, default="") | |
| l.add_argument("--top", type=int, default=20) | |
| l.add_argument( | |
| "--sort-by", | |
| type=str, | |
| default="leaderboard_score", | |
| choices=[ | |
| "leaderboard_score", | |
| "task_success_rate", | |
| "avg_graph_f1", | |
| "tool_efficiency", | |
| "avg_reward", | |
| "retrieval_signal", | |
| "structural_signal", | |
| "deanonymization_accuracy", | |
| "spawn_signal", | |
| ], | |
| ) | |
| s = sub.add_parser("benchmark-sweep", help="Run benchmark across multiple seeds and append all runs to leaderboard.") | |
| _add_common_args(s) | |
| s.add_argument("--episodes", type=int, default=0) | |
| s.add_argument("--seeds", type=str, default="7,11,17,23,31") | |
| s.add_argument("--name-prefix", type=str, default="sweep") | |
| s.add_argument("--leaderboard", type=str, default="") | |
| s.add_argument("--dashboard-dir", type=str, default="") | |
| v = sub.add_parser("viz", help="Export an interactive graph/database explorer.") | |
| _add_common_args(v) | |
| v.add_argument("--output", type=str, default="artifacts/osint_explorer.html") | |
| v.add_argument("--with-demo", action="store_true") | |
| v.add_argument("--leaderboard", type=str, default="") | |
| v.add_argument( | |
| "--evaluation", | |
| type=str, | |
| default=DEFAULT_EVALUATION_PATH, | |
| help="Path to a saved evaluation payload with episode details.", | |
| ) | |
| t = sub.add_parser( | |
| "train-self-play", | |
| help="Run adversarial self-play fine-tuning scaffold with Hugging Face TRL (Kimi-style alternating phases).", | |
| ) | |
| _add_common_args(t) | |
| t.add_argument( | |
| "--train-config", | |
| type=str, | |
| default="config/self_play_training_example.json", | |
| help="Path to self-play training JSON config.", | |
| ) | |
| t.add_argument( | |
| "--train-output-dir", | |
| type=str, | |
| default="", | |
| help="Optional output dir override for self-play artifacts and checkpoints.", | |
| ) | |
| t.add_argument( | |
| "--train-rounds", | |
| type=int, | |
| default=0, | |
| help="Optional override for the number of self-play rounds.", | |
| ) | |
| t.add_argument( | |
| "--dry-run", | |
| action="store_true", | |
| help="Skip actual GRPO updates and only materialize datasets/round artifacts.", | |
| ) | |
| return parser | |
| def _resolve_environment_config(args: argparse.Namespace) -> tuple[EnvironmentConfig, dict[str, str | int]]: | |
| shared = load_shared_config(args.config) | |
| env_cfg = clone_environment_config(shared.environment) | |
| if args.seed_file: | |
| env_cfg.seeding = load_seeding_config(args.seed_file) | |
| if args.llm_provider != "config": | |
| env_cfg.llm.provider = args.llm_provider | |
| if args.llm_model: | |
| env_cfg.llm.model = args.llm_model | |
| if int(args.llm_timeout_seconds) > 0: | |
| env_cfg.llm.timeout_seconds = int(args.llm_timeout_seconds) | |
| if args.ollama_base_url: | |
| env_cfg.llm.ollama_base_url = args.ollama_base_url | |
| if args.openai_base_url: | |
| env_cfg.llm.openai_base_url = args.openai_base_url | |
| if args.openai_api_key: | |
| env_cfg.llm.openai_api_key = args.openai_api_key | |
| if args.openai_api_key_env: | |
| env_cfg.llm.openai_api_key_env = args.openai_api_key_env | |
| if args.dataset_mode != "config": | |
| env_cfg.dataset_mode = args.dataset_mode | |
| if args.metaqa_root: | |
| env_cfg.metaqa_root = args.metaqa_root | |
| if args.metaqa_kb_path: | |
| env_cfg.metaqa_kb_path = args.metaqa_kb_path | |
| if args.metaqa_variant: | |
| env_cfg.metaqa_variant = args.metaqa_variant | |
| if args.metaqa_hops: | |
| env_cfg.metaqa_hops = [item.strip() for item in str(args.metaqa_hops).split(",") if item.strip()] | |
| if args.metaqa_splits: | |
| env_cfg.metaqa_splits = [item.strip() for item in str(args.metaqa_splits).split(",") if item.strip()] | |
| if args.agent_mode == "single": | |
| env_cfg.swarm.enabled = False | |
| elif args.agent_mode == "swarm": | |
| env_cfg.swarm.enabled = True | |
| runtime = { | |
| "default_episodes": shared.runtime.default_episodes, | |
| "leaderboard_path": shared.runtime.leaderboard_path, | |
| "dashboard_path": shared.runtime.dashboard_path, | |
| "sweep_dashboard_dir": shared.runtime.sweep_dashboard_dir, | |
| } | |
| return env_cfg, runtime | |
| def _runner_for(env: OSINTEnvironment) -> SingleAgentRunner | SwarmAgentRunner: | |
| if env.config.swarm.enabled: | |
| return SwarmAgentRunner(env, llm=build_llm_client(env.config.llm)) | |
| return SingleAgentRunner(env, llm=build_llm_client(env.config.llm)) | |
| def main() -> None: | |
| args = build_parser().parse_args() | |
| env_cfg, runtime = _resolve_environment_config(args) | |
| episodes = int(args.episodes) if getattr(args, "episodes", 0) else int(runtime["default_episodes"]) | |
| leaderboard_path = str(args.leaderboard) if getattr(args, "leaderboard", "") else str(runtime["leaderboard_path"]) | |
| dashboard_path = str(args.dashboard) if getattr(args, "dashboard", "") else str(runtime["dashboard_path"]) | |
| sweep_dashboard_dir = ( | |
| str(args.dashboard_dir) if getattr(args, "dashboard_dir", "") else str(runtime["sweep_dashboard_dir"]) | |
| ) | |
| evaluation_path = str(getattr(args, "evaluation", "") or DEFAULT_EVALUATION_PATH) | |
| if args.cmd == "leaderboard": | |
| records = load_leaderboard(leaderboard_path) | |
| print(render_leaderboard_table(records, top_k=args.top, sort_by=args.sort_by)) | |
| return | |
| if args.cmd == "benchmark-sweep": | |
| seed_values = [int(x.strip()) for x in args.seeds.split(",") if x.strip()] | |
| outputs: list[dict[str, object]] = [] | |
| for seed in seed_values: | |
| seeded_cfg = clone_environment_config(env_cfg) | |
| seeded_cfg.seed = seed | |
| env = OSINTEnvironment(seeded_cfg, llm=build_llm_client(seeded_cfg.llm)) | |
| evaluation = run_evaluation(env, episodes=episodes, return_details=True, llm=build_llm_client(seeded_cfg.llm)) | |
| summary = evaluation["summary"] | |
| run_name = f"{args.name_prefix}_seed{seed}" | |
| record = append_leaderboard_record( | |
| path=leaderboard_path, | |
| summary=summary, | |
| episodes=episodes, | |
| run_name=run_name, | |
| config={ | |
| "seed": seed, | |
| "max_steps": env.config.max_steps, | |
| "swarm_enabled": env.config.swarm.enabled, | |
| "max_agents": env.config.swarm.max_agents, | |
| "max_breadth": env.config.swarm.max_breadth, | |
| "max_width": env.config.swarm.max_width, | |
| "max_depth": env.config.swarm.max_depth, | |
| "seeded_questions": len(env.config.seeding.seeded_questions), | |
| }, | |
| ) | |
| dashboard_path = export_dashboard( | |
| env=env, | |
| evaluation=evaluation, | |
| leaderboard_records=load_leaderboard(leaderboard_path), | |
| output_path=f"{sweep_dashboard_dir}/{run_name}.html", | |
| ) | |
| _save_evaluation(DEFAULT_EVALUATION_PATH, evaluation) | |
| outputs.append({"seed": seed, "record": record, "dashboard": dashboard_path, "summary": summary}) | |
| records = load_leaderboard(leaderboard_path) | |
| print( | |
| json.dumps( | |
| { | |
| "runs": outputs, | |
| "leaderboard_preview": render_leaderboard_table(records, top_k=min(10, len(records))), | |
| }, | |
| indent=2, | |
| sort_keys=True, | |
| ) | |
| ) | |
| return | |
| if args.cmd == "train-self-play": | |
| from osint_env.training import load_self_play_config, run_adversarial_self_play | |
| train_cfg = load_self_play_config(args.train_config) | |
| if str(args.train_output_dir).strip(): | |
| train_cfg.output_dir = str(args.train_output_dir).strip() | |
| if int(args.train_rounds) > 0: | |
| train_cfg.rounds = int(args.train_rounds) | |
| payload = run_adversarial_self_play( | |
| env_config=env_cfg, | |
| training_config=train_cfg, | |
| dry_run=bool(args.dry_run), | |
| ) | |
| print(json.dumps(payload, indent=2, sort_keys=True)) | |
| return | |
| llm_client = build_llm_client(env_cfg.llm) | |
| env = OSINTEnvironment(env_cfg, llm=llm_client) | |
| if args.cmd == "demo": | |
| info = _runner_for(env).run_episode() | |
| print(json.dumps(info, indent=2, sort_keys=True)) | |
| elif args.cmd == "eval": | |
| evaluation = run_evaluation(env, episodes=episodes, return_details=True, llm=llm_client) | |
| _save_evaluation(DEFAULT_EVALUATION_PATH, evaluation) | |
| leaderboard = load_leaderboard(leaderboard_path) | |
| export_dashboard( | |
| env=env, | |
| evaluation=evaluation, | |
| leaderboard_records=leaderboard, | |
| output_path=dashboard_path, | |
| ) | |
| print(json.dumps(evaluation["summary"], indent=2, sort_keys=True)) | |
| elif args.cmd == "benchmark": | |
| evaluation = run_evaluation(env, episodes=episodes, return_details=True, llm=llm_client) | |
| summary = evaluation["summary"] | |
| record = append_leaderboard_record( | |
| path=leaderboard_path, | |
| summary=summary, | |
| episodes=episodes, | |
| run_name=args.name or None, | |
| config={ | |
| "seed": env.config.seed, | |
| "max_steps": env.config.max_steps, | |
| "swarm_enabled": env.config.swarm.enabled, | |
| "max_agents": env.config.swarm.max_agents, | |
| "max_breadth": env.config.swarm.max_breadth, | |
| "max_width": env.config.swarm.max_width, | |
| "max_depth": env.config.swarm.max_depth, | |
| "seeded_questions": len(env.config.seeding.seeded_questions), | |
| }, | |
| ) | |
| leaderboard = load_leaderboard(leaderboard_path) | |
| dashboard_path = export_dashboard( | |
| env=env, | |
| evaluation=evaluation, | |
| leaderboard_records=leaderboard, | |
| output_path=dashboard_path, | |
| ) | |
| _save_evaluation(DEFAULT_EVALUATION_PATH, evaluation) | |
| payload = { | |
| "record": record, | |
| "summary": summary, | |
| "dashboard": dashboard_path, | |
| } | |
| print(json.dumps(payload, indent=2, sort_keys=True)) | |
| elif args.cmd == "viz": | |
| evaluation: dict | None = _load_evaluation(evaluation_path) | |
| if args.with_demo: | |
| _runner_for(env).run_episode() | |
| info = { | |
| "agent_answer": env.state.answer if env.state else "", | |
| "task_answer": env.state.task.answer if env.state else "", | |
| "total_reward": env.state.total_reward if env.state else 0.0, | |
| "step_count": env.state.step_count if env.state else 0, | |
| "tool_calls": env.state.tool_calls if env.state else 0, | |
| } | |
| evaluation = { | |
| "summary": { | |
| "task_success_rate": float(info["agent_answer"] == info["task_answer"]), | |
| "tool_efficiency": 0.0, | |
| "avg_graph_f1": 0.0, | |
| "avg_steps_to_solution": float(info["step_count"]), | |
| "deanonymization_accuracy": 0.0, | |
| "avg_reward": float(info["total_reward"]), | |
| "leaderboard_score": 0.0, | |
| }, | |
| "episodes": [ | |
| { | |
| "task_id": env.state.task.task_id if env.state else "n/a", | |
| "task_type": env.state.task.task_type if env.state else "n/a", | |
| "question": env.state.task.question if env.state else "n/a", | |
| "task_answer": str(info["task_answer"]), | |
| "agent_answer": str(info["agent_answer"]), | |
| "graph_f1": 0.0, | |
| "reward": float(info["total_reward"]), | |
| "steps": int(info["step_count"]), | |
| "tool_calls": int(info["tool_calls"]), | |
| "success": int(info["agent_answer"] == info["task_answer"]), | |
| } | |
| ], | |
| } | |
| graph_f1 = 0.0 | |
| if env.state is not None: | |
| graph_f1 = compute_graph_f1(env.memory_graph.edges, env.state.task.supporting_edges) | |
| if evaluation is None: | |
| summary = { | |
| "task_success_rate": 0.0, | |
| "tool_efficiency": 0.0, | |
| "avg_graph_f1": graph_f1, | |
| "avg_steps_to_solution": float(env.state.step_count) if env.state else 0.0, | |
| "deanonymization_accuracy": 0.0, | |
| "avg_reward": float(env.state.total_reward) if env.state else 0.0, | |
| "leaderboard_score": 0.0, | |
| } | |
| evaluation = {"summary": summary, "episodes": []} | |
| leaderboard = load_leaderboard(leaderboard_path) | |
| out = export_dashboard(env=env, evaluation=evaluation, leaderboard_records=leaderboard, output_path=args.output) | |
| print(json.dumps({"dashboard": out, "evaluation": evaluation_path}, indent=2, sort_keys=True)) | |
| if __name__ == "__main__": | |
| main() | |