Spaces:
Running
Running
| """Day-1 baseline runner: run N episodes with scripted agents and log to WandB. | |
| This establishes the "no-LLM" baseline we beat on Day 2. | |
| Usage: | |
| python -m training.run_scripted_baseline --episodes 100 | |
| python -m training.run_scripted_baseline --episodes 100 --no-wandb | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import logging | |
| import sys | |
| from collections import Counter | |
| from pathlib import Path | |
| from chakravyuh_env import ChakravyuhEnv | |
| from chakravyuh_env.schemas import VictimProfile | |
| logger = logging.getLogger("chakravyuh.baseline") | |
| def run_baseline( | |
| episodes: int, | |
| seed_base: int = 1000, | |
| profile: VictimProfile = VictimProfile.SEMI_URBAN, | |
| gullibility: float = 1.0, | |
| wandb_project: str | None = "chakravyuh-run-1", | |
| log_path: Path | None = None, | |
| ) -> dict[str, float | int]: | |
| wandb_run = None | |
| if wandb_project: | |
| try: | |
| import wandb # type: ignore | |
| wandb_run = wandb.init( | |
| project=wandb_project, | |
| name=f"scripted-baseline-n{episodes}", | |
| config={ | |
| "episodes": episodes, | |
| "profile": profile.value, | |
| "gullibility": gullibility, | |
| "agents": "all-scripted", | |
| }, | |
| ) | |
| except Exception as e: | |
| logger.warning("WandB init failed (%s); continuing without WandB.", e) | |
| env = ChakravyuhEnv(victim_profile=profile, gullibility=gullibility) | |
| category_counts: Counter[str] = Counter() | |
| outcomes_summary = { | |
| "money_extracted": 0, | |
| "victim_refused": 0, | |
| "analyzer_flagged": 0, | |
| "bank_flagged": 0, | |
| "bank_froze": 0, | |
| "sought_verification": 0, | |
| } | |
| detection_turns: list[int] = [] | |
| rewards_analyzer: list[float] = [] | |
| rewards_scammer: list[float] = [] | |
| log_rows: list[dict] = [] | |
| for i in range(episodes): | |
| obs = env.reset(seed=seed_base + i) | |
| done = False | |
| reward = None | |
| info: dict = {} | |
| while not done: | |
| obs, reward, done, info = env.step() | |
| outcome = info["outcome"] | |
| category_counts[outcome.scam_category.value] += 1 | |
| outcomes_summary["money_extracted"] += int(outcome.money_extracted) | |
| outcomes_summary["victim_refused"] += int(outcome.victim_refused) | |
| outcomes_summary["analyzer_flagged"] += int(outcome.analyzer_flagged) | |
| outcomes_summary["bank_flagged"] += int(outcome.bank_flagged) | |
| outcomes_summary["bank_froze"] += int(outcome.bank_froze) | |
| outcomes_summary["sought_verification"] += int(outcome.victim_sought_verification) | |
| if outcome.detected_by_turn is not None: | |
| detection_turns.append(outcome.detected_by_turn) | |
| if reward is not None: | |
| rewards_analyzer.append(reward.analyzer) | |
| rewards_scammer.append(reward.scammer) | |
| if wandb_run is not None: | |
| wandb_run.log( | |
| { | |
| "ep": i, | |
| "reward/analyzer": reward.analyzer, | |
| "reward/scammer": reward.scammer, | |
| "detection/flagged": int(outcome.analyzer_flagged), | |
| "detection/turn": outcome.detected_by_turn or -1, | |
| "outcome/money_extracted": int(outcome.money_extracted), | |
| } | |
| ) | |
| log_rows.append( | |
| { | |
| "ep": i, | |
| "seed": seed_base + i, | |
| "category": outcome.scam_category.value, | |
| "analyzer_flagged": outcome.analyzer_flagged, | |
| "detected_by_turn": outcome.detected_by_turn, | |
| "money_extracted": outcome.money_extracted, | |
| "victim_refused": outcome.victim_refused, | |
| "reward_analyzer": reward.analyzer if reward else None, | |
| "reward_scammer": reward.scammer if reward else None, | |
| } | |
| ) | |
| detection_rate = outcomes_summary["analyzer_flagged"] / episodes | |
| extraction_rate = outcomes_summary["money_extracted"] / episodes | |
| avg_detection_turn = ( | |
| sum(detection_turns) / len(detection_turns) if detection_turns else -1 | |
| ) | |
| avg_reward_analyzer = ( | |
| sum(rewards_analyzer) / len(rewards_analyzer) if rewards_analyzer else 0.0 | |
| ) | |
| summary = { | |
| "episodes": episodes, | |
| "detection_rate": round(detection_rate, 4), | |
| "extraction_rate": round(extraction_rate, 4), | |
| "avg_detection_turn": round(avg_detection_turn, 2), | |
| "avg_reward_analyzer": round(avg_reward_analyzer, 4), | |
| **outcomes_summary, | |
| **{f"category/{k}": v for k, v in category_counts.items()}, | |
| } | |
| logger.info("=== Baseline Summary ===") | |
| for k, v in summary.items(): | |
| logger.info(" %s: %s", k, v) | |
| if log_path is not None: | |
| log_path.parent.mkdir(parents=True, exist_ok=True) | |
| log_path.write_text(json.dumps({"summary": summary, "rows": log_rows}, indent=2)) | |
| logger.info("Wrote log to %s", log_path) | |
| if wandb_run is not None: | |
| wandb_run.summary.update(summary) | |
| wandb_run.finish() | |
| return summary | |
| def main(argv: list[str] | None = None) -> int: | |
| parser = argparse.ArgumentParser(description="Scripted baseline runner") | |
| parser.add_argument("--episodes", type=int, default=100) | |
| parser.add_argument("--seed-base", type=int, default=1000) | |
| parser.add_argument( | |
| "--profile", | |
| type=str, | |
| default="semi_urban", | |
| choices=["senior", "young_urban", "semi_urban"], | |
| ) | |
| parser.add_argument("--gullibility", type=float, default=1.0) | |
| parser.add_argument("--no-wandb", action="store_true") | |
| parser.add_argument("--log-path", type=Path, default=Path("logs/baseline_day1.json")) | |
| args = parser.parse_args(argv) | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", | |
| ) | |
| profile = VictimProfile(args.profile) | |
| run_baseline( | |
| episodes=args.episodes, | |
| seed_base=args.seed_base, | |
| profile=profile, | |
| gullibility=args.gullibility, | |
| wandb_project=None if args.no_wandb else "chakravyuh-run-1", | |
| log_path=args.log_path, | |
| ) | |
| return 0 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |