File size: 6,332 Bytes
03815d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
"""Day-1 baseline runner: run N episodes with scripted agents and log to WandB.

This establishes the "no-LLM" baseline we beat on Day 2.

Usage:
    python -m training.run_scripted_baseline --episodes 100
    python -m training.run_scripted_baseline --episodes 100 --no-wandb
"""

from __future__ import annotations

import argparse
import json
import logging
import sys
from collections import Counter
from pathlib import Path

from chakravyuh_env import ChakravyuhEnv
from chakravyuh_env.schemas import VictimProfile

logger = logging.getLogger("chakravyuh.baseline")


def run_baseline(
    episodes: int,
    seed_base: int = 1000,
    profile: VictimProfile = VictimProfile.SEMI_URBAN,
    gullibility: float = 1.0,
    wandb_project: str | None = "chakravyuh-run-1",
    log_path: Path | None = None,
) -> dict[str, float | int]:
    wandb_run = None
    if wandb_project:
        try:
            import wandb  # type: ignore

            wandb_run = wandb.init(
                project=wandb_project,
                name=f"scripted-baseline-n{episodes}",
                config={
                    "episodes": episodes,
                    "profile": profile.value,
                    "gullibility": gullibility,
                    "agents": "all-scripted",
                },
            )
        except Exception as e:
            logger.warning("WandB init failed (%s); continuing without WandB.", e)

    env = ChakravyuhEnv(victim_profile=profile, gullibility=gullibility)
    category_counts: Counter[str] = Counter()
    outcomes_summary = {
        "money_extracted": 0,
        "victim_refused": 0,
        "analyzer_flagged": 0,
        "bank_flagged": 0,
        "bank_froze": 0,
        "sought_verification": 0,
    }
    detection_turns: list[int] = []
    rewards_analyzer: list[float] = []
    rewards_scammer: list[float] = []

    log_rows: list[dict] = []
    for i in range(episodes):
        obs = env.reset(seed=seed_base + i)
        done = False
        reward = None
        info: dict = {}
        while not done:
            obs, reward, done, info = env.step()
        outcome = info["outcome"]
        category_counts[outcome.scam_category.value] += 1
        outcomes_summary["money_extracted"] += int(outcome.money_extracted)
        outcomes_summary["victim_refused"] += int(outcome.victim_refused)
        outcomes_summary["analyzer_flagged"] += int(outcome.analyzer_flagged)
        outcomes_summary["bank_flagged"] += int(outcome.bank_flagged)
        outcomes_summary["bank_froze"] += int(outcome.bank_froze)
        outcomes_summary["sought_verification"] += int(outcome.victim_sought_verification)
        if outcome.detected_by_turn is not None:
            detection_turns.append(outcome.detected_by_turn)
        if reward is not None:
            rewards_analyzer.append(reward.analyzer)
            rewards_scammer.append(reward.scammer)
            if wandb_run is not None:
                wandb_run.log(
                    {
                        "ep": i,
                        "reward/analyzer": reward.analyzer,
                        "reward/scammer": reward.scammer,
                        "detection/flagged": int(outcome.analyzer_flagged),
                        "detection/turn": outcome.detected_by_turn or -1,
                        "outcome/money_extracted": int(outcome.money_extracted),
                    }
                )
        log_rows.append(
            {
                "ep": i,
                "seed": seed_base + i,
                "category": outcome.scam_category.value,
                "analyzer_flagged": outcome.analyzer_flagged,
                "detected_by_turn": outcome.detected_by_turn,
                "money_extracted": outcome.money_extracted,
                "victim_refused": outcome.victim_refused,
                "reward_analyzer": reward.analyzer if reward else None,
                "reward_scammer": reward.scammer if reward else None,
            }
        )

    detection_rate = outcomes_summary["analyzer_flagged"] / episodes
    extraction_rate = outcomes_summary["money_extracted"] / episodes
    avg_detection_turn = (
        sum(detection_turns) / len(detection_turns) if detection_turns else -1
    )
    avg_reward_analyzer = (
        sum(rewards_analyzer) / len(rewards_analyzer) if rewards_analyzer else 0.0
    )

    summary = {
        "episodes": episodes,
        "detection_rate": round(detection_rate, 4),
        "extraction_rate": round(extraction_rate, 4),
        "avg_detection_turn": round(avg_detection_turn, 2),
        "avg_reward_analyzer": round(avg_reward_analyzer, 4),
        **outcomes_summary,
        **{f"category/{k}": v for k, v in category_counts.items()},
    }

    logger.info("=== Baseline Summary ===")
    for k, v in summary.items():
        logger.info("  %s: %s", k, v)

    if log_path is not None:
        log_path.parent.mkdir(parents=True, exist_ok=True)
        log_path.write_text(json.dumps({"summary": summary, "rows": log_rows}, indent=2))
        logger.info("Wrote log to %s", log_path)

    if wandb_run is not None:
        wandb_run.summary.update(summary)
        wandb_run.finish()

    return summary


def main(argv: list[str] | None = None) -> int:
    parser = argparse.ArgumentParser(description="Scripted baseline runner")
    parser.add_argument("--episodes", type=int, default=100)
    parser.add_argument("--seed-base", type=int, default=1000)
    parser.add_argument(
        "--profile",
        type=str,
        default="semi_urban",
        choices=["senior", "young_urban", "semi_urban"],
    )
    parser.add_argument("--gullibility", type=float, default=1.0)
    parser.add_argument("--no-wandb", action="store_true")
    parser.add_argument("--log-path", type=Path, default=Path("logs/baseline_day1.json"))
    args = parser.parse_args(argv)

    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
    )

    profile = VictimProfile(args.profile)
    run_baseline(
        episodes=args.episodes,
        seed_base=args.seed_base,
        profile=profile,
        gullibility=args.gullibility,
        wandb_project=None if args.no_wandb else "chakravyuh-run-1",
        log_path=args.log_path,
    )
    return 0


if __name__ == "__main__":
    sys.exit(main())