chakravyuh / training /run_scripted_baseline.py
UjjwalPardeshi
deploy: latest main to HF Space
03815d6
"""Day-1 baseline runner: run N episodes with scripted agents and log to WandB.
This establishes the "no-LLM" baseline we beat on Day 2.
Usage:
python -m training.run_scripted_baseline --episodes 100
python -m training.run_scripted_baseline --episodes 100 --no-wandb
"""
from __future__ import annotations
import argparse
import json
import logging
import sys
from collections import Counter
from pathlib import Path
from chakravyuh_env import ChakravyuhEnv
from chakravyuh_env.schemas import VictimProfile
logger = logging.getLogger("chakravyuh.baseline")
def run_baseline(
episodes: int,
seed_base: int = 1000,
profile: VictimProfile = VictimProfile.SEMI_URBAN,
gullibility: float = 1.0,
wandb_project: str | None = "chakravyuh-run-1",
log_path: Path | None = None,
) -> dict[str, float | int]:
wandb_run = None
if wandb_project:
try:
import wandb # type: ignore
wandb_run = wandb.init(
project=wandb_project,
name=f"scripted-baseline-n{episodes}",
config={
"episodes": episodes,
"profile": profile.value,
"gullibility": gullibility,
"agents": "all-scripted",
},
)
except Exception as e:
logger.warning("WandB init failed (%s); continuing without WandB.", e)
env = ChakravyuhEnv(victim_profile=profile, gullibility=gullibility)
category_counts: Counter[str] = Counter()
outcomes_summary = {
"money_extracted": 0,
"victim_refused": 0,
"analyzer_flagged": 0,
"bank_flagged": 0,
"bank_froze": 0,
"sought_verification": 0,
}
detection_turns: list[int] = []
rewards_analyzer: list[float] = []
rewards_scammer: list[float] = []
log_rows: list[dict] = []
for i in range(episodes):
obs = env.reset(seed=seed_base + i)
done = False
reward = None
info: dict = {}
while not done:
obs, reward, done, info = env.step()
outcome = info["outcome"]
category_counts[outcome.scam_category.value] += 1
outcomes_summary["money_extracted"] += int(outcome.money_extracted)
outcomes_summary["victim_refused"] += int(outcome.victim_refused)
outcomes_summary["analyzer_flagged"] += int(outcome.analyzer_flagged)
outcomes_summary["bank_flagged"] += int(outcome.bank_flagged)
outcomes_summary["bank_froze"] += int(outcome.bank_froze)
outcomes_summary["sought_verification"] += int(outcome.victim_sought_verification)
if outcome.detected_by_turn is not None:
detection_turns.append(outcome.detected_by_turn)
if reward is not None:
rewards_analyzer.append(reward.analyzer)
rewards_scammer.append(reward.scammer)
if wandb_run is not None:
wandb_run.log(
{
"ep": i,
"reward/analyzer": reward.analyzer,
"reward/scammer": reward.scammer,
"detection/flagged": int(outcome.analyzer_flagged),
"detection/turn": outcome.detected_by_turn or -1,
"outcome/money_extracted": int(outcome.money_extracted),
}
)
log_rows.append(
{
"ep": i,
"seed": seed_base + i,
"category": outcome.scam_category.value,
"analyzer_flagged": outcome.analyzer_flagged,
"detected_by_turn": outcome.detected_by_turn,
"money_extracted": outcome.money_extracted,
"victim_refused": outcome.victim_refused,
"reward_analyzer": reward.analyzer if reward else None,
"reward_scammer": reward.scammer if reward else None,
}
)
detection_rate = outcomes_summary["analyzer_flagged"] / episodes
extraction_rate = outcomes_summary["money_extracted"] / episodes
avg_detection_turn = (
sum(detection_turns) / len(detection_turns) if detection_turns else -1
)
avg_reward_analyzer = (
sum(rewards_analyzer) / len(rewards_analyzer) if rewards_analyzer else 0.0
)
summary = {
"episodes": episodes,
"detection_rate": round(detection_rate, 4),
"extraction_rate": round(extraction_rate, 4),
"avg_detection_turn": round(avg_detection_turn, 2),
"avg_reward_analyzer": round(avg_reward_analyzer, 4),
**outcomes_summary,
**{f"category/{k}": v for k, v in category_counts.items()},
}
logger.info("=== Baseline Summary ===")
for k, v in summary.items():
logger.info(" %s: %s", k, v)
if log_path is not None:
log_path.parent.mkdir(parents=True, exist_ok=True)
log_path.write_text(json.dumps({"summary": summary, "rows": log_rows}, indent=2))
logger.info("Wrote log to %s", log_path)
if wandb_run is not None:
wandb_run.summary.update(summary)
wandb_run.finish()
return summary
def main(argv: list[str] | None = None) -> int:
parser = argparse.ArgumentParser(description="Scripted baseline runner")
parser.add_argument("--episodes", type=int, default=100)
parser.add_argument("--seed-base", type=int, default=1000)
parser.add_argument(
"--profile",
type=str,
default="semi_urban",
choices=["senior", "young_urban", "semi_urban"],
)
parser.add_argument("--gullibility", type=float, default=1.0)
parser.add_argument("--no-wandb", action="store_true")
parser.add_argument("--log-path", type=Path, default=Path("logs/baseline_day1.json"))
args = parser.parse_args(argv)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
profile = VictimProfile(args.profile)
run_baseline(
episodes=args.episodes,
seed_base=args.seed_base,
profile=profile,
gullibility=args.gullibility,
wandb_project=None if args.no_wandb else "chakravyuh-run-1",
log_path=args.log_path,
)
return 0
if __name__ == "__main__":
sys.exit(main())