Spaces:
Paused
Paused
| """ | |
| Comparative Benchmarking Suite | |
| =============================== | |
| Runs Random, Heuristic, and LLM agents across 100 episodes each | |
| and compares performance across difficulty levels. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import time | |
| from collections import defaultdict | |
| from immunoorg.environment import ImmunoOrgEnvironment | |
| from immunoorg.agents.llm_agent import ImmunoDefenderAgent | |
| from immunoorg.agents.baseline_agents import RandomAgent, HeuristicAgent | |
| class BenchmarkSuite: | |
| def __init__(self, num_episodes: int = 100): | |
| self.num_episodes = num_episodes | |
| self.results = defaultdict(list) | |
| def run_benchmark(self, model_path: str | None = None) -> dict: | |
| """Run benchmark across all agents and difficulty levels.""" | |
| agents = { | |
| "random": RandomAgent, | |
| "heuristic": HeuristicAgent, | |
| "llm": (lambda seed: ImmunoDefenderAgent(seed=seed, model_path=model_path)), | |
| } | |
| difficulties = [1, 2] # Keep it fast: levels 1-2 only | |
| print("Starting Comparative Benchmark...") | |
| print(f"Episodes per config: {self.num_episodes}") | |
| print(f"Agents: {list(agents.keys())}") | |
| print(f"Difficulties: {difficulties}\n") | |
| for difficulty in difficulties: | |
| print(f"\n=== Difficulty Level {difficulty} ===") | |
| for agent_name, agent_init in agents.items(): | |
| print(f" {agent_name.upper()}...", end="", flush=True) | |
| self.run_agent_episodes(agent_name, agent_init, difficulty) | |
| print(" DONE") | |
| return self._aggregate_results() | |
| def run_agent_episodes(self, agent_name: str, agent_init, difficulty: int): | |
| """Run n episodes for a given agent at a given difficulty.""" | |
| for ep in range(self.num_episodes): | |
| env = ImmunoOrgEnvironment(difficulty=difficulty, seed=ep) | |
| agent = agent_init(seed=ep) | |
| obs = env.reset() | |
| terminated = False | |
| steps = 0 | |
| max_steps = 200 | |
| while not terminated and steps < max_steps: | |
| action = agent.act(obs) | |
| obs, reward, terminated = env.step(action) | |
| steps += 1 | |
| # Collect metrics | |
| metrics = { | |
| "reward": env.state.cumulative_reward, | |
| "steps_to_termination": steps, | |
| "final_threat_level": env.state.threat_level, | |
| "total_downtime": env.state.total_downtime, | |
| "belief_accuracy": env.belief_map.calculate_belief_accuracy() if env.belief_map else 0.0, | |
| "org_chaos_score": env.state.org_chaos_score, | |
| "attacks_contained": len(env.state.contained_attacks), | |
| "attacks_active": len(env.state.active_attacks), | |
| } | |
| self.results[f"d{difficulty}_{agent_name}"].append(metrics) | |
| def _aggregate_results(self) -> dict: | |
| """Aggregate and compute statistics.""" | |
| aggregated = {} | |
| for key, metrics_list in self.results.items(): | |
| if not metrics_list: | |
| continue | |
| # Compute stats | |
| stats = { | |
| "count": len(metrics_list), | |
| "avg_reward": sum(m["reward"] for m in metrics_list) / len(metrics_list), | |
| "max_reward": max(m["reward"] for m in metrics_list), | |
| "min_reward": min(m["reward"] for m in metrics_list), | |
| "avg_steps": sum(m["steps_to_termination"] for m in metrics_list) / len(metrics_list), | |
| "avg_downtime": sum(m["total_downtime"] for m in metrics_list) / len(metrics_list), | |
| "avg_belief_accuracy": sum(m["belief_accuracy"] for m in metrics_list) / len(metrics_list), | |
| "avg_contained": sum(m["attacks_contained"] for m in metrics_list) / len(metrics_list), | |
| "win_rate": sum(1 for m in metrics_list if m["attacks_active"] == 0) / len(metrics_list), | |
| } | |
| aggregated[key] = stats | |
| return aggregated | |
| def save_results(self, filename: str = "benchmark_results.json"): | |
| """Save results to JSON.""" | |
| aggregated = self._aggregate_results() | |
| with open(filename, "w") as f: | |
| json.dump(aggregated, f, indent=2) | |
| print(f"\nResults saved to {filename}") | |
| def print_summary(self): | |
| """Print a human-readable summary.""" | |
| aggregated = self._aggregate_results() | |
| print("\n" + "=" * 80) | |
| print("BENCHMARK RESULTS SUMMARY") | |
| print("=" * 80) | |
| for difficulty in [1, 2]: | |
| print(f"\nDifficulty Level {difficulty}:") | |
| print("-" * 80) | |
| for agent_name in ["random", "heuristic", "llm"]: | |
| key = f"d{difficulty}_{agent_name}" | |
| if key not in aggregated: | |
| continue | |
| stats = aggregated[key] | |
| print(f"\n {agent_name.upper()}") | |
| print(f" Episodes: {stats['count']}") | |
| print(f" Avg Reward: {stats['avg_reward']:+.2f} (range: {stats['min_reward']:+.2f} to {stats['max_reward']:+.2f})") | |
| print(f" Avg Steps: {stats['avg_steps']:.1f}") | |
| print(f" Win Rate: {stats['win_rate']:.0%}") | |
| print(f" Avg Belief Accuracy: {stats['avg_belief_accuracy']:.1%}") | |
| print(f" Avg Attacks Contained: {stats['avg_contained']:.1f}") | |
| print(f" Avg Downtime: {stats['avg_downtime']:.1f}") | |
| if __name__ == "__main__": | |
| import sys | |
| num_episodes = int(sys.argv[1]) if len(sys.argv) > 1 else 100 | |
| model_path = sys.argv[2] if len(sys.argv) > 2 else None | |
| suite = BenchmarkSuite(num_episodes=num_episodes) | |
| start = time.time() | |
| suite.run_benchmark(model_path=model_path) | |
| elapsed = time.time() - start | |
| suite.print_summary() | |
| suite.save_results() | |
| print(f"\nBenchmark completed in {elapsed:.1f}s") | |