""" experiments/run_all.py ====================== Leakage-free experiment runner for the UPI-Sim temporal fraud benchmark. Key protocol changes -------------------- - Strict prefix evaluation: models only see events up to cutoff t. - Horizon-specific retraining: each horizon uses fresh model instances. - Causal ablation trains/evaluates on globally shuffled chronology. - XGBoost uses the real xgboost library with aligned node-level labels. - All experiments support multi-seed aggregation with mean ± std outputs. """ from __future__ import annotations import argparse import hashlib import os import random import sys import time from typing import Dict, Iterable, List, Sequence os.environ.setdefault("OMP_NUM_THREADS", "1") os.environ.setdefault("MKL_NUM_THREADS", "1") os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8") import numpy as np import pandas as pd import torch from sklearn.linear_model import LogisticRegression from sklearn.metrics import average_precision_score, brier_score_loss, roc_auc_score from xgboost import XGBClassifier _ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) if _ROOT not in sys.path: sys.path.insert(0, _ROOT) from src.core.config_loader import load_config from src.generators.user_generator import generate_users from src.generators.transaction_generator import generate_transactions from src.fraud.fraud_engine import FraudEngine, ORACLE_ONLY_COLS from src.graph.graph_builder import build_edge_features from src.risk.risk_engine import apply_risk_engine from models.base import TemporalModel from models.dyrep import DyRepWrapper from models.jodie import JODIEWrapper from models.audit_oracle import AuditOracleWrapper, RawMotifOracleWrapper from models.oracle_motif import OracleMotifWrapper from models.sequence_gru import SequenceGRUWrapper from models.static_gnn import StaticGNNWrapper from models.tgat import TGATWrapper from models.tgn_wrapper import TGNWrapper from models.xgboost_model import XGBoostWrapper torch.set_num_threads(1) if hasattr(torch, "set_num_interop_threads"): try: torch.set_num_interop_threads(1) except RuntimeError: pass # Oracle models that are allowed to receive unstripped audit columns _ORACLE_MODEL_NAMES: frozenset = frozenset({"OracleMotif", "AuditOracle", "RawMotifOracle"}) # --------------------------------------------------------------------------- # Oracle / audit column stripping # --------------------------------------------------------------------------- def strip_oracle_cols(df: pd.DataFrame) -> pd.DataFrame: """Remove audit/oracle columns before passing data to learned baselines.""" cols_to_drop = [c for c in df.columns if c in ORACLE_ONLY_COLS] if cols_to_drop: return df.drop(columns=cols_to_drop) return df DEFAULT_HORIZONS = [0.01, 0.05, 0.10, 0.20] DEFAULT_SEEDS = [0, 1, 2, 3, 4] _TWIN_DIFFICULTY_USER_SEED_OFFSETS = {"easy": 11, "medium": 23, "hard": 37} MODEL_ORDER = [ "OracleMotif", "SeqGRU", "TGN", "TGAT", "DyRep", "JODIE", "StaticGNN", "XGBoost", ] def stable_int_hash(*parts: object, modulo: int = 2**32) -> int: """Deterministic integer hash for seed derivation across Python processes.""" seed_material = "::".join(map(str, parts)) digest = hashlib.sha256(seed_material.encode("utf-8")).hexdigest() return int(digest[:16], 16) % modulo def seed_python_numpy(seed: int) -> None: random.seed(seed) np.random.seed(seed) def set_global_determinism(seed: int) -> None: seed_python_numpy(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) if hasattr(torch.backends, "cudnn"): torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False if hasattr(torch.backends.cudnn, "allow_tf32"): torch.backends.cudnn.allow_tf32 = False if hasattr(torch.backends, "cuda") and hasattr(torch.backends.cuda, "matmul"): torch.backends.cuda.matmul.allow_tf32 = False try: torch.use_deterministic_algorithms(True) except Exception: pass def derived_seed(base_seed: int, *parts: object, modulo: int = 2**31 - 1) -> int: return int((int(base_seed) + stable_int_hash(*parts, modulo=modulo)) % modulo) def _is_oracle_calib_mode(benchmark_mode: str) -> bool: return benchmark_mode == "temporal_twins_oracle_calib" def _oracle_metric_labels(benchmark_mode: str) -> dict[str, str]: if _is_oracle_calib_mode(benchmark_mode): return { "audit": "AuditOracle", "raw": "RawMotifOracle", "table": "Oracle Debug Table", } return { "audit": "MotifProbe", "raw": "RawMotifProbe", "table": "Probe Debug Table", } def _attach_probe_aliases(report: dict, benchmark_mode: str) -> None: """Expose standard-mode probe names without breaking old oracle-key consumers.""" labels = _oracle_metric_labels(benchmark_mode) report["audit_metric_label"] = labels["audit"] report["raw_metric_label"] = labels["raw"] if _is_oracle_calib_mode(benchmark_mode): return alias_map = { "motif_probe_roc_auc": "audit_roc_auc", "motif_probe_pair_sep": "audit_pair_sep", "motif_probe_n_examples": "audit_n_examples", "motif_probe_auc_bootstrap_std": "audit_auc_bootstrap_std", "motif_probe_auc_ci_lo": "audit_auc_ci_lo", "motif_probe_auc_ci_hi": "audit_auc_ci_hi", "raw_motif_probe_roc_auc": "raw_roc_auc", "raw_motif_probe_pair_sep": "raw_pair_sep", "raw_motif_probe_n_examples": "raw_n_examples", "raw_motif_probe_auc_bootstrap_std": "raw_auc_bootstrap_std", "raw_motif_probe_auc_ci_lo": "raw_auc_ci_lo", "raw_motif_probe_auc_ci_hi": "raw_auc_ci_hi", } for alias_key, source_key in alias_map.items(): if source_key in report: report[alias_key] = report[source_key] # --------------------------------------------------------------------------- # Data generation # --------------------------------------------------------------------------- def generate_difficulty( config, users: pd.DataFrame, difficulty: str, seed: int, time_offset: float = 0.0, benchmark_mode: str = "standard", ) -> pd.DataFrame: """Generate one difficulty slice with a global timestamp offset.""" df = generate_transactions(users, config) df = apply_risk_engine(df, users, config) engine_seed = seed + stable_int_hash("FraudEngine", difficulty, benchmark_mode, modulo=10_000) engine = FraudEngine( seed=engine_seed, difficulty=difficulty, benchmark_mode=benchmark_mode, ) df = engine.apply(df) df = df.sort_values("timestamp").reset_index(drop=True) if benchmark_mode in ("temporal_twins", "temporal_twins_oracle_calib"): diff_offset = {"easy": 0, "medium": 1_000_000, "hard": 2_000_000}[difficulty] df["sender_id"] = df["sender_id"].astype(np.int64) + diff_offset df["receiver_id"] = df["receiver_id"].astype(np.int64) + diff_offset if "twin_pair_id" in df.columns: df["twin_pair_id"] = df["twin_pair_id"].astype(np.int64) valid = df["twin_pair_id"] >= 0 df.loc[valid, "twin_pair_id"] = df.loc[valid, "twin_pair_id"] + diff_offset if "template_id" in df.columns: df["template_id"] = df["template_id"].astype(np.int64) valid = df["template_id"] >= 0 df.loc[valid, "template_id"] = df.loc[valid, "template_id"] + diff_offset df["timestamp"] = df["timestamp"] + time_offset return df def generate_all(config, seed: int = 42, benchmark_mode: str = "standard"): """Generate Easy/Medium/Hard datasets.""" seed_python_numpy(seed) gap = 1_000.0 if benchmark_mode in ("temporal_twins", "temporal_twins_oracle_calib"): seed_python_numpy(seed + 11) users_easy = generate_users(config) seed_python_numpy(seed + 23) users_medium = generate_users(config) seed_python_numpy(seed + 37) users_hard = generate_users(config) else: shared_users = generate_users(config) users_easy = shared_users users_medium = shared_users users_hard = shared_users df_easy = generate_difficulty( config, users_easy, "easy", seed, time_offset=0.0, benchmark_mode=benchmark_mode, ) t_after_easy = float(df_easy["timestamp"].max()) + gap df_medium = generate_difficulty( config, users_medium, "medium", seed, time_offset=t_after_easy, benchmark_mode=benchmark_mode, ) t_after_medium = float(df_medium["timestamp"].max()) + gap df_hard = generate_difficulty( config, users_hard, "hard", seed, time_offset=t_after_medium, benchmark_mode=benchmark_mode, ) return df_easy, df_medium, df_hard def generate_single_difficulty( config, difficulty: str, seed: int = 42, benchmark_mode: str = "standard", ) -> pd.DataFrame: """Generate one difficulty slice using the same user-seed scheme as generate_all().""" seed_python_numpy(seed) if benchmark_mode in ("temporal_twins", "temporal_twins_oracle_calib"): user_seed = seed + _TWIN_DIFFICULTY_USER_SEED_OFFSETS[difficulty] seed_python_numpy(user_seed) users = generate_users(config) else: users = generate_users(config) return generate_difficulty( config, users, difficulty, seed, time_offset=0.0, benchmark_mode=benchmark_mode, ) # --------------------------------------------------------------------------- # Metrics # --------------------------------------------------------------------------- def compute_ece(y_true: np.ndarray, y_prob: np.ndarray, n_bins: int = 10) -> float: bins = np.linspace(0.0, 1.0, n_bins + 1) ece = 0.0 for lo, hi in zip(bins[:-1], bins[1:]): mask = (y_prob >= lo) & (y_prob < hi if hi < 1.0 else y_prob <= hi) if not mask.any(): continue frac = float(mask.mean()) avg_conf = float(y_prob[mask].mean()) avg_acc = float(y_true[mask].mean()) ece += frac * abs(avg_conf - avg_acc) return float(ece) def safe_roc_auc(y_true: np.ndarray, y_prob: np.ndarray) -> float: if len(np.unique(y_true)) < 2: return 0.5 return float(roc_auc_score(y_true, y_prob)) def safe_pr_auc(y_true: np.ndarray, y_prob: np.ndarray) -> float: positives = float(np.sum(y_true == 1)) negatives = float(np.sum(y_true == 0)) if positives == 0.0: return 0.0 if negatives == 0.0: return 1.0 return float(average_precision_score(y_true, y_prob)) def compute_metrics(y_true: np.ndarray, y_prob: np.ndarray) -> dict: y_true = np.asarray(y_true, dtype=np.float32) y_prob = np.nan_to_num(np.asarray(y_prob, dtype=np.float32), nan=0.5, posinf=1.0, neginf=0.0) y_prob = np.clip(y_prob, 0.0, 1.0) return { "roc_auc": safe_roc_auc(y_true, y_prob), "pr_auc": safe_pr_auc(y_true, y_prob), "brier": float(brier_score_loss(y_true, y_prob)), "ece": compute_ece(y_true, y_prob), } def safe_pearson(x: np.ndarray, y: np.ndarray) -> float: x = np.asarray(x, dtype=np.float32) y = np.asarray(y, dtype=np.float32) if len(x) == 0 or len(y) == 0: return 0.0 if np.std(x) < 1e-8 or np.std(y) < 1e-8: return 0.0 return float(np.corrcoef(x, y)[0, 1]) def build_node_audit_table(df: pd.DataFrame) -> pd.DataFrame: df = df.sort_values("timestamp").reset_index(drop=True).copy() df["_dt"] = df.groupby("sender_id")["timestamp"].diff().fillna(0.0) df["_phase"] = df["timestamp"] % 86400.0 df["_burst"] = (df["_dt"] > 0.0) & (df["_dt"] < 600.0) df["_quiet"] = df["_dt"] > 3600.0 grp = df.groupby("sender_id", sort=False) node_df = pd.DataFrame({ "txn_count": grp["sender_id"].count(), "receiver_count": grp["receiver_id"].nunique(), "retry_count": grp["is_retry"].sum() if "is_retry" in df.columns else 0.0, "failed_count": grp["failed"].sum() if "failed" in df.columns else 0.0, "burst_count": grp["_burst"].sum(), "quiet_count": grp["_quiet"].sum(), "dt_mean": grp["_dt"].mean(), "dt_std": grp["_dt"].std().fillna(0.0), "amount_mean": grp["amount"].mean(), "amount_std": grp["amount"].std().fillna(0.0), "phase_std": grp["_phase"].std().fillna(0.0), }) recv_counts = ( df.groupby(["sender_id", "receiver_id"]) .size() .reset_index(name="_n") ) recv_counts["_tot"] = recv_counts.groupby("sender_id")["_n"].transform("sum") recv_counts["_p"] = recv_counts["_n"] / recv_counts["_tot"] recv_counts["_h"] = -recv_counts["_p"] * np.log2(recv_counts["_p"] + 1e-9) node_df["recv_entropy"] = recv_counts.groupby("sender_id")["_h"].sum() if "twin_pair_id" in df.columns: node_df["twin_pair_id"] = grp["twin_pair_id"].first().astype(np.int32) else: node_df["twin_pair_id"] = -1 if "twin_label" in df.columns: node_df["label"] = grp["twin_label"].max().astype(np.int32) else: node_df["label"] = grp["is_fraud"].max().astype(np.int32) return node_df.fillna(0.0).reset_index() def with_local_event_idx(df: pd.DataFrame) -> pd.DataFrame: out = df.sort_values("timestamp").reset_index(drop=True).copy() out["local_event_idx"] = ( out.groupby("sender_id").cumcount().astype(np.int32) ) return out def build_matched_control_tables( df: pd.DataFrame, ) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: """Build matched fraud/benign evaluation examples at the same local index k.""" required = {"twin_pair_id", "twin_role", "twin_label", "label_event_idx"} if not required.issubset(df.columns): empty = pd.DataFrame() return empty, empty, empty twin_df = with_local_event_idx(df[df["twin_pair_id"] >= 0].copy()) if twin_df.empty: empty = pd.DataFrame() return empty, empty, empty sender_meta = ( twin_df.groupby("sender_id") .agg( twin_pair_id=("twin_pair_id", "first"), twin_role=("twin_role", "first"), twin_label=("twin_label", "max"), template_id=("template_id", "first") if "template_id" in twin_df.columns else ("twin_pair_id", "first"), total_txn_count=("sender_id", "size"), sender_start_time=("timestamp", "min"), motif_hit_count=("motif_hit_count", "max") if "motif_hit_count" in twin_df.columns else ("sender_id", "size"), ) .reset_index() ) if "motif_hit_count" not in twin_df.columns: sender_meta["motif_hit_count"] = 0 pair_rows: list[dict] = [] example_rows: list[dict] = [] pair_count_rows: list[dict] = [] pair_event_id = 0 sender_groups = { int(sender_id): group.reset_index(drop=True).copy() for sender_id, group in twin_df.groupby("sender_id", sort=False) } for pair_id, pair_meta in sender_meta.groupby("twin_pair_id", sort=True): if len(pair_meta) != 2 or set(pair_meta["twin_role"]) != {"fraud", "benign"}: continue fraud_meta = pair_meta[pair_meta["twin_role"] == "fraud"].iloc[0] benign_meta = pair_meta[pair_meta["twin_role"] == "benign"].iloc[0] fraud_sender = int(fraud_meta["sender_id"]) benign_sender = int(benign_meta["sender_id"]) template_id = int(fraud_meta["template_id"]) fraud_total = int(fraud_meta["total_txn_count"]) benign_total = int(benign_meta["total_txn_count"]) pair_count_rows.append({ "twin_pair_id": int(pair_id), "template_id": template_id, "fraud_sender_id": fraud_sender, "benign_sender_id": benign_sender, "fraud_total_txn_count": fraud_total, "benign_total_txn_count": benign_total, "pair_total_txn_count_diff": abs(fraud_total - benign_total), "fraud_motif_hit_count": int(fraud_meta["motif_hit_count"]), "benign_motif_hit_count": int(benign_meta["motif_hit_count"]), }) fraud_group = sender_groups[fraud_sender] benign_group = sender_groups[benign_sender] benign_by_idx = benign_group.set_index("local_event_idx", drop=False) fraud_positives = fraud_group[ (fraud_group["is_fraud"] == 1) & (fraud_group["label_event_idx"] >= 0) ].copy() for row in fraud_positives.itertuples(index=False): k = int(row.label_event_idx) if k not in benign_by_idx.index: continue benign_row = benign_by_idx.loc[k] if isinstance(benign_row, pd.DataFrame): benign_row = benign_row.iloc[0] fraud_age = float(row.timestamp - fraud_meta["sender_start_time"]) benign_age = float(benign_row["timestamp"] - benign_meta["sender_start_time"]) prefix_txn_count = k + 1 pair_rows.append({ "pair_event_id": pair_event_id, "twin_pair_id": int(pair_id), "template_id": template_id, "fraud_sender_id": fraud_sender, "benign_sender_id": benign_sender, "fraud_label_event_idx": k, "benign_eval_event_idx": int(benign_row["local_event_idx"]), "fraud_eval_timestamp": float(row.timestamp), "benign_eval_timestamp": float(benign_row["timestamp"]), "fraud_active_age": fraud_age, "benign_active_age": benign_age, "active_age_diff": abs(fraud_age - benign_age), "timestamp_diff": abs(float(row.timestamp) - float(benign_row["timestamp"])), "prefix_txn_count": prefix_txn_count, "fraud_total_txn_count": fraud_total, "benign_total_txn_count": benign_total, "pair_total_txn_count_diff": abs(fraud_total - benign_total), "fraud_motif_hit_count": int(fraud_meta["motif_hit_count"]), "benign_motif_hit_count": int(benign_meta["motif_hit_count"]), "label_delay": int(row.label_delay) if hasattr(row, "label_delay") else -1, }) common = { "pair_event_id": pair_event_id, "twin_pair_id": int(pair_id), "template_id": template_id, "eval_local_event_idx": k, "prefix_txn_count": prefix_txn_count, } example_rows.append({ **common, "sender_id": fraud_sender, "label": 1, "twin_role": "fraud", "matched_sender_id": benign_sender, "total_txn_count": fraud_total, "eval_timestamp": float(row.timestamp), # The simulator has no separate account-creation time, so # account_age equals active_age for twin-control audits. "account_age": fraud_age, "active_age": fraud_age, }) example_rows.append({ **common, "sender_id": benign_sender, "label": 0, "twin_role": "benign", "matched_sender_id": fraud_sender, "total_txn_count": benign_total, "eval_timestamp": float(benign_row["timestamp"]), "account_age": benign_age, "active_age": benign_age, }) pair_event_id += 1 return ( pd.DataFrame(example_rows), pd.DataFrame(pair_rows), pd.DataFrame(pair_count_rows), ) def _sender_prefix_feature_row(prefix: pd.DataFrame) -> dict: prefix = prefix.sort_values("timestamp").reset_index(drop=True) timestamps = prefix["timestamp"].to_numpy(dtype=np.float64) dts = np.diff(timestamps, prepend=timestamps[0]) if len(prefix) else np.zeros(0, dtype=np.float64) dts = np.maximum(dts, 0.0) phase = timestamps % 86400.0 if len(prefix) else np.zeros(0, dtype=np.float64) burst = ((dts > 0.0) & (dts < 600.0)).astype(np.float32) quiet = (dts > 3600.0).astype(np.float32) recv_counts = prefix["receiver_id"].value_counts().to_numpy(dtype=np.float64) recv_p = recv_counts / max(float(recv_counts.sum()), 1.0) recv_entropy = float(-np.sum(recv_p * np.log2(recv_p + 1e-9))) if len(recv_counts) else 0.0 return { "txn_count": float(len(prefix)), "txn_cnt10_last": float(min(len(prefix), 10)), "receiver_count": float(prefix["receiver_id"].nunique()) if len(prefix) else 0.0, "retry_count": float(prefix["is_retry"].sum()) if "is_retry" in prefix.columns else 0.0, "failed_count": float(prefix["failed"].sum()) if "failed" in prefix.columns else 0.0, "burst_count": float(burst.sum()), "quiet_count": float(quiet.sum()), "amount_mean": float(prefix["amount"].mean()) if len(prefix) else 0.0, "amount_std": float(prefix["amount"].std(ddof=1)) if len(prefix) > 1 else 0.0, "amount_max": float(prefix["amount"].max()) if len(prefix) else 0.0, "td_mean": float(dts.mean()) if len(dts) else 0.0, "td_std": float(dts.std(ddof=1)) if len(dts) > 1 else 0.0, "dt_mean": float(dts.mean()) if len(dts) else 0.0, "dt_std": float(dts.std(ddof=1)) if len(dts) > 1 else 0.0, "phase_std": float(np.std(phase, ddof=1)) if len(phase) > 1 else 0.0, "recv_entropy": recv_entropy, "fail_rate": float(prefix["failed"].mean()) if "failed" in prefix.columns and len(prefix) else 0.0, "retry_rate": float(prefix["is_retry"].mean()) if "is_retry" in prefix.columns and len(prefix) else 0.0, "pair_freq_mean": float(prefix["pair_freq"].mean()) if "pair_freq" in prefix.columns and len(prefix) else 0.0, } def build_matched_prefix_feature_table( df: pd.DataFrame, examples: pd.DataFrame, ) -> pd.DataFrame: if examples.empty: return pd.DataFrame() indexed_df = with_local_event_idx(df) sender_groups = { int(sender_id): group.reset_index(drop=True).copy() for sender_id, group in indexed_df.groupby("sender_id", sort=False) } rows: list[dict] = [] for example in examples.itertuples(index=False): sender_id = int(example.sender_id) end_idx = int(example.eval_local_event_idx) sender_prefix = sender_groups[sender_id] prefix = sender_prefix.iloc[: end_idx + 1].copy() rows.append({ "pair_event_id": int(example.pair_event_id), "twin_pair_id": int(example.twin_pair_id), "template_id": int(example.template_id), "sender_id": sender_id, "label": int(example.label), "eval_local_event_idx": int(example.eval_local_event_idx), "prefix_txn_count": int(example.prefix_txn_count), "total_txn_count": int(example.total_txn_count), "eval_timestamp": float(example.eval_timestamp), "account_age": float(example.account_age), "active_age": float(example.active_age), **_sender_prefix_feature_row(prefix), }) return pd.DataFrame(rows).fillna(0.0) def report_matched_control_audits( test_examples: pd.DataFrame, test_pair_rows: pd.DataFrame, test_pair_counts: pd.DataFrame, ) -> dict: if test_examples.empty: return {} y = test_examples["label"].to_numpy(dtype=np.float32) audit = { "pair_total_txn_count_diff_mean": float(test_pair_counts["pair_total_txn_count_diff"].mean()) if not test_pair_counts.empty else 0.0, "pair_total_txn_count_diff_max": float(test_pair_counts["pair_total_txn_count_diff"].max()) if not test_pair_counts.empty else 0.0, "auc_total_txn_count": safe_roc_auc(y, test_examples["total_txn_count"].to_numpy(dtype=np.float32)), "auc_local_event_idx": safe_roc_auc(y, test_examples["eval_local_event_idx"].to_numpy(dtype=np.float32)), "auc_prefix_txn_count": safe_roc_auc(y, test_examples["prefix_txn_count"].to_numpy(dtype=np.float32)), "auc_timestamp": safe_roc_auc(y, test_examples["eval_timestamp"].to_numpy(dtype=np.float32)), "auc_account_age": safe_roc_auc(y, test_examples["account_age"].to_numpy(dtype=np.float32)), "auc_active_age": safe_roc_auc(y, test_examples["active_age"].to_numpy(dtype=np.float32)), "fraud_label_event_idx_mean": float(test_pair_rows["fraud_label_event_idx"].mean()) if not test_pair_rows.empty else 0.0, "fraud_label_event_idx_max": float(test_pair_rows["fraud_label_event_idx"].max()) if not test_pair_rows.empty else 0.0, "benign_eval_event_idx_mean": float(test_pair_rows["benign_eval_event_idx"].mean()) if not test_pair_rows.empty else 0.0, "benign_eval_event_idx_max": float(test_pair_rows["benign_eval_event_idx"].max()) if not test_pair_rows.empty else 0.0, "pair_event_idx_diff_mean": float((test_pair_rows["fraud_label_event_idx"] - test_pair_rows["benign_eval_event_idx"]).abs().mean()) if not test_pair_rows.empty else 0.0, "pair_event_idx_diff_max": float((test_pair_rows["fraud_label_event_idx"] - test_pair_rows["benign_eval_event_idx"]).abs().max()) if not test_pair_rows.empty else 0.0, "pair_active_age_diff_mean": float(test_pair_rows["active_age_diff"].mean()) if not test_pair_rows.empty else 0.0, "pair_active_age_diff_max": float(test_pair_rows["active_age_diff"].max()) if not test_pair_rows.empty else 0.0, "pair_timestamp_diff_mean": float(test_pair_rows["timestamp_diff"].mean()) if not test_pair_rows.empty else 0.0, "pair_timestamp_diff_max": float(test_pair_rows["timestamp_diff"].max()) if not test_pair_rows.empty else 0.0, "benign_motif_hit_rate": float((test_pair_counts["benign_motif_hit_count"] > 0).mean()) if not test_pair_counts.empty else 0.0, "benign_motif_hit_pairs": int((test_pair_counts["benign_motif_hit_count"] > 0).sum()) if not test_pair_counts.empty else 0, "matched_control_examples": int(len(test_examples)), "matched_control_pair_events": int(len(test_pair_rows)), } print("\n--- Matched-Control Shortcut Audit ---") for key in ( "pair_total_txn_count_diff_mean", "pair_total_txn_count_diff_max", "auc_total_txn_count", "auc_local_event_idx", "auc_prefix_txn_count", "auc_timestamp", "auc_account_age", "auc_active_age", "benign_motif_hit_rate", "benign_motif_hit_pairs", ): print(f" {key:<30}: {audit[key]}") if not test_pair_rows.empty: print("\n label_event_idx distribution (fraud twins):") print(test_pair_rows["fraud_label_event_idx"].describe().to_string()) print("\n pseudo-label idx distribution (benign twins):") print(test_pair_rows["benign_eval_event_idx"].describe().to_string()) print("\n per-pair fraud-vs-benign evaluation indices:") cols = [ "twin_pair_id", "fraud_label_event_idx", "benign_eval_event_idx", "active_age_diff", "timestamp_diff", ] print(test_pair_rows[cols].head(20).to_string(index=False)) return audit def bootstrap_auc_summary( y_true: np.ndarray, y_score: np.ndarray, seed: int, n_bootstrap: int = 200, ) -> dict: y_true = np.asarray(y_true, dtype=np.float32) y_score = np.asarray(y_score, dtype=np.float32) if len(y_true) == 0 or len(np.unique(y_true)) < 2: return { "bootstrap_std": float("nan"), "ci_lo": float("nan"), "ci_hi": float("nan"), "n_bootstrap": 0, } rng = np.random.default_rng(seed) aucs: list[float] = [] n = len(y_true) for _ in range(n_bootstrap): idx = rng.integers(0, n, size=n) sample_y = y_true[idx] if len(np.unique(sample_y)) < 2: continue aucs.append(safe_roc_auc(sample_y, y_score[idx])) if not aucs: return { "bootstrap_std": float("nan"), "ci_lo": float("nan"), "ci_hi": float("nan"), "n_bootstrap": 0, } auc_arr = np.asarray(aucs, dtype=np.float32) return { "bootstrap_std": float(np.std(auc_arr, ddof=1)) if len(auc_arr) > 1 else 0.0, "ci_lo": float(np.quantile(auc_arr, 0.025)), "ci_hi": float(np.quantile(auc_arr, 0.975)), "n_bootstrap": int(len(auc_arr)), } def make_auc_result( y_true: np.ndarray, y_score: np.ndarray, seed: int, extra: dict | None = None, ) -> dict: y_true = np.asarray(y_true, dtype=np.float32) y_score = np.asarray(y_score, dtype=np.float32) result = { "auc": safe_roc_auc(y_true, y_score), "y_true": y_true, "y_score": y_score, "n_examples": int(len(y_true)), "n_pos": int(np.sum(y_true == 1)), "n_neg": int(np.sum(y_true == 0)), } result.update(bootstrap_auc_summary(y_true, y_score, seed=seed)) if extra: result.update(extra) return result def attach_auc_result(report: dict, prefix: str, result: dict) -> None: report[f"{prefix}_roc_auc"] = float(result["auc"]) report[f"{prefix}_n_examples"] = int(result["n_examples"]) report[f"{prefix}_n_pos"] = int(result["n_pos"]) report[f"{prefix}_n_neg"] = int(result["n_neg"]) report[f"{prefix}_auc_bootstrap_std"] = float(result["bootstrap_std"]) report[f"{prefix}_auc_ci_lo"] = float(result["ci_lo"]) report[f"{prefix}_auc_ci_hi"] = float(result["ci_hi"]) def _standardize_train_test( train_df: pd.DataFrame, test_df: pd.DataFrame, feature_cols: Sequence[str], ) -> tuple[np.ndarray, np.ndarray]: x_train = train_df[list(feature_cols)].to_numpy(dtype=np.float32) x_test = test_df[list(feature_cols)].to_numpy(dtype=np.float32) mean = x_train.mean(axis=0, keepdims=True) std = x_train.std(axis=0, keepdims=True) + 1e-6 return (x_train - mean) / std, (x_test - mean) / std def compute_matched_static_aggregate_auc( train_features: pd.DataFrame, test_features: pd.DataFrame, seed: int, verbose: bool = True, ) -> dict: feature_cols = [ "txn_count", "receiver_count", "retry_count", "failed_count", "burst_count", "quiet_count", "dt_mean", "dt_std", "amount_mean", "amount_std", "phase_std", "recv_entropy", ] if train_features.empty or test_features.empty: return make_auc_result(np.zeros(0, dtype=np.float32), np.zeros(0, dtype=np.float32), seed=seed) if train_features["label"].nunique() < 2 or test_features["label"].nunique() < 2: y_test = test_features["label"].to_numpy(dtype=np.float32) probs = np.full(len(y_test), 0.5, dtype=np.float32) return make_auc_result(y_test, probs, seed=seed) x_train, x_test = _standardize_train_test(train_features, test_features, feature_cols) clf = LogisticRegression( max_iter=2000, class_weight="balanced", random_state=42, solver="liblinear", ) clf.fit(x_train, train_features["label"].to_numpy(dtype=np.int32)) probs = clf.predict_proba(x_test)[:, 1] y_test = test_features["label"].to_numpy(dtype=np.float32) if verbose: coefs = np.abs(clf.coef_[0]) ranked = np.argsort(coefs)[::-1] print("\n Top matched static aggregate predictors:") for rank_i in ranked[:5]: print(f" {feature_cols[rank_i]:<20}: |coef|={coefs[rank_i]:.4f}") return make_auc_result( y_test, probs.astype(np.float32), seed=seed, ) def compute_matched_xgboost_auc( train_features: pd.DataFrame, test_features: pd.DataFrame, seed: int, ) -> dict: feature_cols = [ "txn_count", "txn_cnt10_last", "amount_mean", "amount_std", "amount_max", "td_mean", "td_std", "fail_rate", "retry_rate", "recv_entropy", "pair_freq_mean", ] if train_features.empty or test_features.empty: return make_auc_result(np.zeros(0, dtype=np.float32), np.zeros(0, dtype=np.float32), seed=seed) y_train = train_features["label"].to_numpy(dtype=np.int32) y_test = test_features["label"].to_numpy(dtype=np.int32) if len(np.unique(y_train)) < 2 or len(np.unique(y_test)) < 2: probs = np.full(len(y_test), 0.5, dtype=np.float32) return make_auc_result(y_test.astype(np.float32), probs, seed=seed) x_train = train_features[feature_cols].to_numpy(dtype=np.float32) x_test = test_features[feature_cols].to_numpy(dtype=np.float32) scale_pos_weight = max(1.0, float((y_train == 0).sum()) / max(float((y_train == 1).sum()), 1.0)) model = XGBClassifier( n_estimators=200, max_depth=6, learning_rate=0.05, objective="binary:logistic", eval_metric="logloss", scale_pos_weight=scale_pos_weight, random_state=42, verbosity=0, n_jobs=1, tree_method="exact", ) model.fit(x_train, y_train) probs = model.predict_proba(x_test)[:, 1] importances = model.feature_importances_ ranked = np.argsort(importances)[::-1] print(" [Matched XGBoost] Top-5 feature importances:") for idx in ranked[:5]: print(f" {feature_cols[idx]:<20}: {importances[idx]:.4f}") return make_auc_result( y_test.astype(np.float32), probs.astype(np.float32), seed=seed, ) def _build_example_prefix( df_full: pd.DataFrame, sender_id: int, eval_local_event_idx: int, eval_timestamp: float, ) -> pd.DataFrame: prefix = df_full[df_full["timestamp"] <= eval_timestamp].copy() if "local_event_idx" not in prefix.columns: prefix = with_local_event_idx(prefix) sender_mask = prefix["sender_id"] == sender_id if sender_mask.any(): prefix = prefix[(~sender_mask) | (prefix["local_event_idx"] <= eval_local_event_idx)].copy() return prefix.sort_values("timestamp").reset_index(drop=True) def build_static_gnn_example_embeddings( model: StaticGNNWrapper, df_full: pd.DataFrame, examples: pd.DataFrame, ) -> tuple[np.ndarray, dict]: if examples.empty: return np.zeros((0, model.hidden_dim), dtype=np.float32), { "matched_examples": 0, "unique_prefix_cutoffs": 0, "graph_builds": 0, "cache_hit_rate": float("nan"), "eval_time_sec": 0.0, } clean_full = strip_oracle_cols( df_full.sort_values("timestamp").reset_index(drop=True) ) start = time.perf_counter() sender_ids = clean_full["sender_id"].to_numpy(dtype=np.int64) receiver_ids = clean_full["receiver_id"].to_numpy(dtype=np.int64) timestamps = clean_full["timestamp"].to_numpy(dtype=np.float64) edge_feats = build_edge_features(clean_full).astype(np.float32) ns = model._norm_stats edge_feats = (edge_feats - ns["ea_mean"]) / ns["ea_std"] max_sender = int(sender_ids.max()) if len(sender_ids) else 0 max_receiver = int(receiver_ids.max()) if len(receiver_ids) else 0 n_nodes = max(max(max_sender, max_receiver) + 1, model._n_nodes) feat_sum = np.zeros((n_nodes, edge_feats.shape[1]), dtype=np.float32) feat_count = np.zeros(n_nodes, dtype=np.float32) node_feat = np.zeros((n_nodes, edge_feats.shape[1]), dtype=np.float32) device = model.device x_t = torch.zeros((n_nodes, edge_feats.shape[1]), dtype=torch.float32, device=device) edge_index_full = torch.tensor( np.vstack([sender_ids, receiver_ids]), dtype=torch.long, device=device, ) examples_reset = examples.reset_index(drop=True).copy() grouped = examples_reset.groupby("eval_timestamp", sort=True).indices grouped_items = sorted( [(float(ts), idxs) for ts, idxs in grouped.items()], key=lambda item: item[0], ) ordered_cutoffs = [item[0] for item in grouped_items] cutoff_ends = np.searchsorted(timestamps, np.asarray(ordered_cutoffs, dtype=np.float64), side="right") out = np.zeros((len(examples_reset), model.hidden_dim), dtype=np.float32) prev_end = 0 graph_builds = 0 for (cutoff, row_indices), end_idx in zip(grouped_items, cutoff_ends.tolist()): if end_idx > prev_end: batch_senders = sender_ids[prev_end:end_idx] batch_feats = edge_feats[prev_end:end_idx] np.add.at(feat_sum, batch_senders, batch_feats) np.add.at(feat_count, batch_senders, 1.0) changed_nodes = np.unique(batch_senders) node_feat[changed_nodes] = feat_sum[changed_nodes] / feat_count[changed_nodes, None] changed_t = torch.tensor(changed_nodes, dtype=torch.long, device=device) x_t[changed_t] = torch.tensor(node_feat[changed_nodes], dtype=torch.float32, device=device) prev_end = end_idx edge_index = edge_index_full[:, :end_idx] model._encoder.eval() with torch.no_grad(): prefix_emb = model._encoder(x_t, edge_index) graph_builds += 1 sender_batch = examples_reset.loc[row_indices, "sender_id"].to_numpy(dtype=np.int64) sender_t = torch.tensor(sender_batch, dtype=torch.long, device=device) out[row_indices] = prefix_emb[sender_t].detach().cpu().numpy().astype(np.float32) matched_examples = int(len(examples_reset)) unique_cutoffs = int(len(ordered_cutoffs)) hits = max(0, matched_examples - graph_builds) diagnostics = { "matched_examples": matched_examples, "unique_prefix_cutoffs": unique_cutoffs, "graph_builds": int(graph_builds), "cache_hit_rate": float(hits / matched_examples) if matched_examples > 0 else float("nan"), "eval_time_sec": float(time.perf_counter() - start), } return out.astype(np.float32), diagnostics def compute_matched_static_gnn_auc( df_train: pd.DataFrame, df_test: pd.DataFrame, train_examples: pd.DataFrame, test_examples: pd.DataFrame, device: str, num_epochs: int, seed: int, ) -> dict: if train_examples.empty or test_examples.empty: return make_auc_result(np.zeros(0, dtype=np.float32), np.zeros(0, dtype=np.float32), seed=seed) if train_examples["label"].nunique() < 2 or test_examples["label"].nunique() < 2: y_test = test_examples["label"].to_numpy(dtype=np.float32) probs = np.full(len(y_test), 0.5, dtype=np.float32) return make_auc_result(y_test, probs, seed=seed) static_seed = derived_seed(seed, "StaticGNN", "matched_prefix") set_global_determinism(static_seed) model = StaticGNNWrapper(hidden_dim=64, n_snapshots=10, device=device) model.fit(strip_oracle_cols(df_train), num_epochs=num_epochs) eval_start = time.perf_counter() train_emb, train_diag = build_static_gnn_example_embeddings(model, df_train, train_examples) full_test_df = ( pd.concat([df_train, df_test], ignore_index=True) .sort_values("timestamp") .reset_index(drop=True) ) test_emb, test_diag = build_static_gnn_example_embeddings(model, full_test_df, test_examples) y_train = train_examples["label"].to_numpy(dtype=np.int32) y_test = test_examples["label"].to_numpy(dtype=np.int32) mean = train_emb.mean(axis=0, keepdims=True) std = train_emb.std(axis=0, keepdims=True) + 1e-6 train_emb = (train_emb - mean) / std test_emb = (test_emb - mean) / std clf = LogisticRegression( max_iter=2000, class_weight="balanced", random_state=42, solver="liblinear", ) clf.fit(train_emb, y_train) probs = clf.predict_proba(test_emb)[:, 1] return make_auc_result( y_test.astype(np.float32), probs.astype(np.float32), seed=seed, extra={ "auc_flipped": safe_roc_auc(y_test.astype(np.float32), (1.0 - probs).astype(np.float32)), "score_mean_pos": float(probs[y_test == 1].mean()) if np.any(y_test == 1) else float("nan"), "score_mean_neg": float(probs[y_test == 0].mean()) if np.any(y_test == 0) else float("nan"), "score_std": float(np.std(probs)), "zero_emb_frac": float(np.mean(np.linalg.norm(test_emb, axis=1) < 1e-8)), "train_examples": int(len(train_examples)), "test_examples": int(len(test_examples)), "matched_examples": int(train_diag["matched_examples"] + test_diag["matched_examples"]), "unique_prefix_cutoffs": int(train_diag["unique_prefix_cutoffs"] + test_diag["unique_prefix_cutoffs"]), "graph_builds": int(train_diag["graph_builds"] + test_diag["graph_builds"]), "cache_hit_rate": float( ( max(0, train_diag["matched_examples"] - train_diag["graph_builds"]) + max(0, test_diag["matched_examples"] - test_diag["graph_builds"]) ) / max(1, train_diag["matched_examples"] + test_diag["matched_examples"]) ), "eval_time_sec": float(time.perf_counter() - eval_start), "train_unique_prefix_cutoffs": int(train_diag["unique_prefix_cutoffs"]), "test_unique_prefix_cutoffs": int(test_diag["unique_prefix_cutoffs"]), "train_graph_builds": int(train_diag["graph_builds"]), "test_graph_builds": int(test_diag["graph_builds"]), "train_eval_time_sec": float(train_diag["eval_time_sec"]), "test_eval_time_sec": float(test_diag["eval_time_sec"]), }, ) def compute_matched_seqgru_metrics( df_train: pd.DataFrame, df_test: pd.DataFrame, train_examples: pd.DataFrame, test_examples: pd.DataFrame, device: str, seed: int, max_epochs: int, hidden_dim: int = 96, receiver_buckets: int = 512, ) -> dict: if train_examples.empty or test_examples.empty: empty = make_auc_result(np.zeros(0, dtype=np.float32), np.zeros(0, dtype=np.float32), seed=seed) return { "clean": empty, "shuffled": empty, "delta": float("nan"), "clean_fit": {}, "shuffled_fit": {}, } clean_train_df = strip_oracle_cols(df_train) clean_test_df = strip_oracle_cols(df_test) y_test = test_examples["label"].to_numpy(dtype=np.float32) if train_examples["label"].nunique() < 2 or test_examples["label"].nunique() < 2: flat_probs = np.full(len(y_test), 0.5, dtype=np.float32) flat = make_auc_result(y_test, flat_probs, seed=seed) flat["pr_auc"] = compute_metrics(y_test, flat_probs)["pr_auc"] return { "clean": flat, "shuffled": flat, "delta": 0.0, "clean_fit": {}, "shuffled_fit": {}, } def build_model() -> SequenceGRUWrapper: return SequenceGRUWrapper( hidden_dim=hidden_dim, receiver_buckets=receiver_buckets, device=device, ) clean_seed = derived_seed(seed, "SeqGRU", "clean") shuffled_seed = derived_seed(seed, "SeqGRU", "shuffled") set_global_determinism(clean_seed) clean_model = build_model() clean_model.fit(clean_train_df, num_epochs=1) clean_fit = clean_model.fit_matched_prefix_examples( clean_train_df, train_examples, seed=clean_seed, max_epochs=max_epochs, patience=6, valid_frac=0.20, pair_batch_size=64, learning_rate=2e-3, weight_decay=1e-4, shuffle_within_sequence=False, ) clean_probs = clean_model.predict_matched_prefix_examples( clean_test_df, test_examples, seed=clean_seed, shuffle_within_sequence=False, ) clean_metrics = compute_metrics(y_test, clean_probs) clean_result = make_auc_result( y_test, clean_probs.astype(np.float32), seed=seed, extra={ "pr_auc": float(clean_metrics["pr_auc"]), "brier": float(clean_metrics["brier"]), "ece": float(clean_metrics["ece"]), }, ) set_global_determinism(shuffled_seed) shuffled_model = build_model() shuffled_model.fit(clean_train_df, num_epochs=1) shuffled_fit = shuffled_model.fit_matched_prefix_examples( clean_train_df, train_examples, seed=shuffled_seed, max_epochs=max_epochs, patience=6, valid_frac=0.20, pair_batch_size=64, learning_rate=2e-3, weight_decay=1e-4, shuffle_within_sequence=True, ) shuffled_probs = shuffled_model.predict_matched_prefix_examples( clean_test_df, test_examples, seed=shuffled_seed, shuffle_within_sequence=True, ) shuffled_metrics = compute_metrics(y_test, shuffled_probs) shuffled_result = make_auc_result( y_test, shuffled_probs.astype(np.float32), seed=seed, extra={ "pr_auc": float(shuffled_metrics["pr_auc"]), "brier": float(shuffled_metrics["brier"]), "ece": float(shuffled_metrics["ece"]), }, ) return { "clean": clean_result, "shuffled": shuffled_result, "delta": float(shuffled_result["auc"] - clean_result["auc"]), "clean_fit": clean_fit, "shuffled_fit": shuffled_fit, } def _combine_matched_examples( train_examples: pd.DataFrame, test_examples: pd.DataFrame, ) -> pd.DataFrame: tagged_train = train_examples.copy() tagged_train["example_split"] = "train" tagged_test = test_examples.copy() tagged_test["example_split"] = "test" return pd.concat([tagged_train, tagged_test], ignore_index=True) def _fit_embedding_probe( train_emb: np.ndarray, test_emb: np.ndarray, y_train: np.ndarray, y_test: np.ndarray, seed: int, ) -> dict: if len(y_train) == 0 or len(y_test) == 0: return make_auc_result(np.zeros(0, dtype=np.float32), np.zeros(0, dtype=np.float32), seed=seed) if len(np.unique(y_train)) < 2 or len(np.unique(y_test)) < 2: probs = np.full(len(y_test), 0.5, dtype=np.float32) metrics = compute_metrics(y_test, probs) return make_auc_result( y_test.astype(np.float32), probs, seed=seed, extra={ "pr_auc": float(metrics["pr_auc"]), "brier": float(metrics["brier"]), "ece": float(metrics["ece"]), }, ) mean = train_emb.mean(axis=0, keepdims=True) std = train_emb.std(axis=0, keepdims=True) + 1e-6 train_emb = (train_emb - mean) / std test_emb = (test_emb - mean) / std clf = LogisticRegression( max_iter=2000, class_weight="balanced", random_state=seed, solver="liblinear", ) clf.fit(train_emb, y_train.astype(np.int32)) probs = clf.predict_proba(test_emb)[:, 1].astype(np.float32) metrics = compute_metrics(y_test.astype(np.float32), probs) return make_auc_result( y_test.astype(np.float32), probs, seed=seed, extra={ "pr_auc": float(metrics["pr_auc"]), "brier": float(metrics["brier"]), "ece": float(metrics["ece"]), }, ) def compute_matched_temporal_gnn_metrics( model_name: str, model_builder, df_train: pd.DataFrame, df_test: pd.DataFrame, train_examples: pd.DataFrame, test_examples: pd.DataFrame, seed: int, num_epochs: int, ) -> dict: if train_examples.empty or test_examples.empty: empty = make_auc_result(np.zeros(0, dtype=np.float32), np.zeros(0, dtype=np.float32), seed=seed) return { "clean": empty, "shuffled": empty, "delta": float("nan"), } clean_train = strip_oracle_cols(df_train) clean_test = strip_oracle_cols(df_test) all_examples = _combine_matched_examples(train_examples, test_examples) train_mask = all_examples["example_split"].to_numpy() == "train" test_mask = ~train_mask y_train = all_examples.loc[train_mask, "label"].to_numpy(dtype=np.float32) y_test = all_examples.loc[test_mask, "label"].to_numpy(dtype=np.float32) clean_model_seed = derived_seed(seed, model_name, "clean_model") shuffled_model_seed = derived_seed(seed, model_name, "shuffled_model") set_global_determinism(clean_model_seed) clean_model = model_builder() clean_model.fit(clean_train, num_epochs=num_epochs) clean_full = ( pd.concat([clean_train, clean_test], ignore_index=True) .sort_values("timestamp") .reset_index(drop=True) ) clean_emb = clean_model.extract_prefix_embeddings(clean_full, all_examples) clean_result = _fit_embedding_probe( clean_emb[train_mask], clean_emb[test_mask], y_train, y_test, seed=seed, ) shuffled_train = shuffle_chronology(clean_train, seed=seed + 101) shuffled_test = shuffle_chronology(clean_test, seed=seed + 211) set_global_determinism(shuffled_model_seed) shuffled_model = model_builder() shuffled_model.fit(shuffled_train, num_epochs=num_epochs) shuffled_full = ( pd.concat([shuffled_train, shuffled_test], ignore_index=True) .sort_values("timestamp") .reset_index(drop=True) ) shuffled_emb = shuffled_model.extract_prefix_embeddings(shuffled_full, all_examples) shuffled_result = _fit_embedding_probe( shuffled_emb[train_mask], shuffled_emb[test_mask], y_train, y_test, seed=seed, ) return { "clean": clean_result, "shuffled": shuffled_result, "delta": float(shuffled_result["auc"] - clean_result["auc"]), "train_examples": int(train_mask.sum()), "test_examples": int(test_mask.sum()), "model_name": model_name, } def ks_distance(x: np.ndarray, y: np.ndarray) -> float: x = np.sort(np.asarray(x, dtype=np.float64)) y = np.sort(np.asarray(y, dtype=np.float64)) if len(x) == 0 or len(y) == 0: return 0.0 values = np.sort(np.concatenate([x, y])) cdf_x = np.searchsorted(x, values, side="right") / len(x) cdf_y = np.searchsorted(y, values, side="right") / len(y) return float(np.max(np.abs(cdf_x - cdf_y))) def compute_static_aggregate_auc(node_df: pd.DataFrame, seed: int, verbose: bool = True) -> float: feature_cols = [ "txn_count", "receiver_count", "retry_count", "failed_count", "burst_count", "quiet_count", "dt_mean", "dt_std", "amount_mean", "amount_std", "phase_std", "recv_entropy", ] audit_df = node_df[node_df["twin_pair_id"] >= 0].copy() if audit_df.empty or audit_df["label"].nunique() < 2: return 0.5 pair_ids = audit_df["twin_pair_id"].unique() if len(pair_ids) < 4: return 0.5 rng = np.random.default_rng(seed) pair_ids = rng.permutation(pair_ids) split = max(1, int(0.7 * len(pair_ids))) train_ids = set(pair_ids[:split]) test_ids = set(pair_ids[split:]) if not test_ids: test_ids = set(pair_ids[-1:]) train_ids = set(pair_ids[:-1]) train_df = audit_df[audit_df["twin_pair_id"].isin(train_ids)] test_df = audit_df[audit_df["twin_pair_id"].isin(test_ids)] if train_df["label"].nunique() < 2 or test_df["label"].nunique() < 2: return 0.5 x_train = train_df[feature_cols].to_numpy(dtype=np.float32) x_test = test_df[feature_cols].to_numpy(dtype=np.float32) mean = x_train.mean(axis=0, keepdims=True) std = x_train.std(axis=0, keepdims=True) + 1e-6 x_train = (x_train - mean) / std x_test = (x_test - mean) / std clf = LogisticRegression( max_iter=2000, class_weight="balanced", random_state=seed, solver="liblinear", ) clf.fit(x_train, train_df["label"].to_numpy(dtype=np.int32)) probs = clf.predict_proba(x_test)[:, 1] auc = safe_roc_auc(test_df["label"].to_numpy(dtype=np.float32), probs.astype(np.float32)) if verbose: # Top predictors by absolute coefficient coefs = np.abs(clf.coef_[0]) ranked = np.argsort(coefs)[::-1] print("\n Top static aggregate predictors:") for rank_i in ranked[:5]: print(f" {feature_cols[rank_i]:<20}: |coef|={coefs[rank_i]:.4f}") return auc def compute_aggregate_ks(node_df: pd.DataFrame) -> tuple[float, float]: fraud_df = node_df[(node_df["twin_pair_id"] >= 0) & (node_df["label"] == 1)] benign_df = node_df[(node_df["twin_pair_id"] >= 0) & (node_df["label"] == 0)] if fraud_df.empty or benign_df.empty: return 0.0, 0.0 feature_cols = [ "txn_count", "receiver_count", "retry_count", "burst_count", "dt_mean", "dt_std", "recv_entropy", ] distances = [ ks_distance(fraud_df[col].to_numpy(), benign_df[col].to_numpy()) for col in feature_cols ] if not distances: return 0.0, 0.0 return float(np.mean(distances)), float(np.max(distances)) def evaluate_matched_pair_separability( model: TemporalModel, df_train: pd.DataFrame, df_test: pd.DataFrame, delta_time: float, n_checkpoints: int, ) -> tuple[float, int]: if "twin_pair_id" not in df_test.columns or "twin_label" not in df_test.columns: return 0.0, 0 checkpoints = make_checkpoints(df_test, delta_time, n_checkpoints=n_checkpoints) if not checkpoints: return 0.0, 0 cutoff_time = checkpoints[-1] df_full = ( pd.concat([df_train, df_test], ignore_index=True) .sort_values("timestamp") .reset_index(drop=True) ) prefix_df = df_full[df_full["timestamp"] <= cutoff_time].copy() active_nodes = sorted(df_test[df_test["timestamp"] <= cutoff_time]["sender_id"].unique()) if not active_nodes: return 0.0, 0 if model.is_temporal: model.reset_memory() probs = model.predict(prefix_df, active_nodes) score_map = {int(node_id): float(prob) for node_id, prob in zip(active_nodes, probs)} meta = ( df_full.groupby("sender_id")[["twin_pair_id", "twin_label"]] .first() .reset_index() ) meta = meta[(meta["sender_id"].isin(active_nodes)) & (meta["twin_pair_id"] >= 0)] pair_scores = [] for _, pair_df in meta.groupby("twin_pair_id"): if len(pair_df) != 2 or set(pair_df["twin_label"]) != {0, 1}: continue fraud_node = int(pair_df.loc[pair_df["twin_label"] == 1, "sender_id"].iloc[0]) benign_node = int(pair_df.loc[pair_df["twin_label"] == 0, "sender_id"].iloc[0]) if fraud_node not in score_map or benign_node not in score_map: continue pair_scores.append(float(score_map[fraud_node] > score_map[benign_node])) if not pair_scores: return 0.0, 0 return float(np.mean(pair_scores)), int(len(pair_scores)) def compute_split_leakage(df_train: pd.DataFrame, df_test: pd.DataFrame) -> dict: train_users = set(df_train["sender_id"].unique().tolist()) test_users = set(df_test["sender_id"].unique().tolist()) leakage = { "sender_overlap_count": int(len(train_users & test_users)), "pair_overlap_count": 0, "template_overlap_count": 0, "receiver_pair_overlap_count": 0, } if "twin_pair_id" in df_train.columns and "twin_pair_id" in df_test.columns: train_pairs = set(df_train.loc[df_train["twin_pair_id"] >= 0, "twin_pair_id"].unique().tolist()) test_pairs = set(df_test.loc[df_test["twin_pair_id"] >= 0, "twin_pair_id"].unique().tolist()) leakage["pair_overlap_count"] = int(len(train_pairs & test_pairs)) if "template_id" in df_train.columns and "template_id" in df_test.columns: train_templates = set(df_train.loc[df_train["template_id"] >= 0, "template_id"].unique().tolist()) test_templates = set(df_test.loc[df_test["template_id"] >= 0, "template_id"].unique().tolist()) leakage["template_overlap_count"] = int(len(train_templates & test_templates)) # Receiver-pair overlap: distinct (sender_id, receiver_id) tuples train_rpairs = set(zip( df_train["sender_id"].tolist(), df_train["receiver_id"].tolist() )) test_rpairs = set(zip( df_test["sender_id"].tolist(), df_test["receiver_id"].tolist() )) leakage["receiver_pair_overlap_count"] = int(len(train_rpairs & test_rpairs)) return leakage # --------------------------------------------------------------------------- # Prefix-only evaluation guard # --------------------------------------------------------------------------- def assert_prefix_only(df_prefix: pd.DataFrame, cutoff_time: float) -> None: """Warn if any future event slipped into the prefix. Uses 1.0s tolerance to absorb float32-vs-float64 precision gaps. """ if df_prefix.empty: return actual_max = float(df_prefix["timestamp"].max()) if actual_max > cutoff_time + 1.0: print( f"[PREFIX LEAK] df_prefix max timestamp {actual_max:.2f} > cutoff {cutoff_time:.2f}!" ) # --------------------------------------------------------------------------- # Label-source audit # --------------------------------------------------------------------------- def build_label_source_audit_table(df: pd.DataFrame) -> pd.DataFrame: """Return per-positive-event audit table. Required audit columns (populated by FraudEngine): fraud_source, motif_source, motif_hit_count, trigger_event_idx, label_event_idx, label_delay, is_fallback_label """ fraud_rows = df[df["is_fraud"] == 1].copy() if fraud_rows.empty: return pd.DataFrame() audit_cols = [ "sender_id", "twin_pair_id", "twin_role", "fraud_source", "motif_source", "motif_hit_count", "trigger_event_idx", "label_event_idx", "label_delay", "is_fallback_label", ] available = [c for c in audit_cols if c in fraud_rows.columns] return fraud_rows[available].reset_index(drop=True) def compute_motif_label_consistency(df: pd.DataFrame, calib_mode: bool = False) -> dict: """Compute and print motif/label consistency statistics.""" has_motif = "motif_hit_count" in df.columns has_fraud = "is_fraud" in df.columns if not (has_motif and has_fraud): return {} # Restrict to twin users only if "twin_pair_id" in df.columns: twin_df = df[df["twin_pair_id"] >= 0].copy() else: twin_df = df.copy() if twin_df.empty: return {} # Node-level aggregation node_grp = twin_df.groupby("sender_id") node_label = node_grp["is_fraud"].max() node_hit = node_grp["motif_hit_count"].max() node_role = node_grp["twin_role"].first() if "twin_role" in twin_df.columns else None has_hit = (node_hit >= 1) label_pos = (node_label == 1) p_label_given_hit = float(label_pos[has_hit].mean()) if has_hit.any() else float("nan") p_label_given_nohit = float(label_pos[~has_hit].mean()) if (~has_hit).any() else float("nan") p_hit_given_label = float(has_hit[label_pos].mean()) if label_pos.any() else float("nan") if node_role is not None: benign_mask = (node_role == "benign") accidental_motif_rate = float(has_hit[benign_mask].mean()) if benign_mask.any() else float("nan") avg_hits_fraud = float(node_hit[~benign_mask & label_pos].mean()) if (label_pos & ~benign_mask).any() else float("nan") avg_hits_benign = float(node_hit[benign_mask].mean()) if benign_mask.any() else float("nan") else: accidental_motif_rate = float("nan") avg_hits_fraud = float("nan") avg_hits_benign = float("nan") result = { "p_label_given_hit": p_label_given_hit, "p_label_given_nohit": p_label_given_nohit, "p_hit_given_label": p_hit_given_label, "accidental_benign_motif": accidental_motif_rate, "avg_hits_fraud_twin": avg_hits_fraud, "avg_hits_benign_twin": avg_hits_benign, } print("\n--- Motif-Label Consistency ---") for k, v in result.items(): print(f" {k:<30}: {v:.4f}" if not (isinstance(v, float) and v != v) else f" {k:<30}: N/A") if calib_mode: # In calib mode, verify no fallback positives exist if "is_fallback_label" in df.columns: fallback_pos = int(df.loc[df["is_fraud"] == 1, "is_fallback_label"].sum()) print(f" {'fallback_positives':<30}: {fallback_pos}") if fallback_pos > 0: print(" [CALIB VIOLATION] Fallback positives found! is_fallback_label.sum() must be 0.") result["fallback_positives"] = fallback_pos return result def compute_label_delay_stats(df: pd.DataFrame) -> dict: """Print and return min/mean/max label_delay for positive events.""" if "label_delay" not in df.columns: return {} delays = df.loc[(df["is_fraud"] == 1) & (df["label_delay"] >= 0), "label_delay"] if delays.empty: print(" label_delay: no valid delay data.") return {"delay_min": float("nan"), "delay_mean": float("nan"), "delay_max": float("nan")} result = { "delay_min": float(delays.min()), "delay_mean": float(delays.mean()), "delay_max": float(delays.max()), } print(f" label_delay min={result['delay_min']:.1f} mean={result['delay_mean']:.1f} max={result['delay_max']:.1f}") return result # --------------------------------------------------------------------------- # Prefix-task helpers # --------------------------------------------------------------------------- def uses_twin_pairs(df: pd.DataFrame) -> bool: return "twin_pair_id" in df.columns and bool((df["twin_pair_id"] >= 0).any()) def get_eval_nodes(df: pd.DataFrame) -> List[int]: if uses_twin_pairs(df): pair_df = df[df["twin_pair_id"] >= 0] return sorted(pair_df["sender_id"].unique().tolist()) return sorted(df["sender_id"].unique().tolist()) def remap_node_ids(*dfs: pd.DataFrame) -> list[pd.DataFrame]: non_empty = [df for df in dfs if df is not None and not df.empty] if not non_empty: return [df.copy() for df in dfs] all_ids = np.unique( np.concatenate( [ np.concatenate( [ df["sender_id"].to_numpy(dtype=np.int64), df["receiver_id"].to_numpy(dtype=np.int64), ] ) for df in non_empty ] ) ) id_map = {int(node_id): idx for idx, node_id in enumerate(all_ids.tolist())} remapped = [] for df in dfs: if df is None: remapped.append(df) continue out = df.copy() out["sender_id"] = out["sender_id"].map(id_map).astype(np.int64) out["receiver_id"] = out["receiver_id"].map(id_map).astype(np.int64) remapped.append(out) return remapped def augment_with_placeholder_nodes(df_train: pd.DataFrame, df_test: pd.DataFrame) -> pd.DataFrame: train_nodes = set( np.concatenate( [ df_train["sender_id"].to_numpy(dtype=np.int64), df_train["receiver_id"].to_numpy(dtype=np.int64), ] ).tolist() ) test_nodes = set( np.concatenate( [ df_test["sender_id"].to_numpy(dtype=np.int64), df_test["receiver_id"].to_numpy(dtype=np.int64), ] ).tolist() ) unseen_nodes = sorted(test_nodes - train_nodes) if not unseen_nodes: return df_train base_time = float(min(df_train["timestamp"].min(), df_test["timestamp"].min())) - 1.0 rows = [] for offset, node_id in enumerate(unseen_nodes): row = {} for col in df_train.columns: if col in {"sender_id", "receiver_id"}: row[col] = int(node_id) elif col == "timestamp": row[col] = base_time - offset elif col in {"fraud_type", "twin_role"}: row[col] = "placeholder" elif col in {"txn_id", "twin_pair_id", "template_id"}: row[col] = -1 elif col in {"twin_label", "is_fraud", "is_retry", "failed"}: row[col] = 0 else: row[col] = 0.0 rows.append(row) placeholder_df = pd.DataFrame(rows, columns=df_train.columns) out = pd.concat([placeholder_df, df_train], ignore_index=True) return out.sort_values("timestamp").reset_index(drop=True) def split_temporally(df: pd.DataFrame, train_ratio: float = 0.7) -> tuple[pd.DataFrame, pd.DataFrame, float]: df = df.sort_values("timestamp").reset_index(drop=True) if uses_twin_pairs(df): pair_meta = ( df[df["twin_pair_id"] >= 0] .groupby("twin_pair_id")["timestamp"] .min() .sort_values() ) if len(pair_meta) >= 2: split_idx = max(1, min(len(pair_meta) - 1, int(train_ratio * len(pair_meta)))) train_ids = set(pair_meta.index[:split_idx].tolist()) test_ids = set(pair_meta.index[split_idx:].tolist()) df_train = df[(df["twin_pair_id"] < 0) | (df["twin_pair_id"].isin(train_ids))].copy() df_test = df[df["twin_pair_id"].isin(test_ids)].copy() split_time = float(df_test["timestamp"].min()) if not df_test.empty else float(df["timestamp"].quantile(train_ratio)) return df_train.sort_values("timestamp").reset_index(drop=True), df_test.sort_values("timestamp").reset_index(drop=True), split_time split_time = float(df["timestamp"].quantile(train_ratio)) df_train = df[df["timestamp"] <= split_time].copy() df_test = df[df["timestamp"] > split_time].copy() return df_train, df_test, split_time def horizon_to_delta(df_test: pd.DataFrame, horizon: float) -> float: if df_test.empty: return 1e-6 t_min = float(df_test["timestamp"].min()) t_max = float(df_test["timestamp"].max()) return max(1e-6, horizon * max(t_max - t_min, 1e-6)) def build_window_labels( df: pd.DataFrame, cutoff_time: float, eval_nodes: Sequence[int], delta_time: float, ) -> np.ndarray: future = df[(df["timestamp"] > cutoff_time) & (df["timestamp"] <= cutoff_time + delta_time)] fraud_map = future.groupby("sender_id")["is_fraud"].max() return np.array([int(fraud_map.get(node_id, 0)) for node_id in eval_nodes], dtype=np.float32) def build_window_state( df: pd.DataFrame, cutoff_time: float, eval_nodes: Sequence[int], delta_time: float, ) -> np.ndarray: future = df[(df["timestamp"] > cutoff_time) & (df["timestamp"] <= cutoff_time + delta_time)] if "dynamic_fraud_state" in future.columns: state_map = future.groupby("sender_id")["dynamic_fraud_state"].mean() else: state_map = future.groupby("sender_id")["is_fraud"].mean() return np.array([float(state_map.get(node_id, 0.0)) for node_id in eval_nodes], dtype=np.float32) def choose_anchor_time(df_train: pd.DataFrame, delta_time: float) -> float: t_min = float(df_train["timestamp"].min()) max_anchor = float(df_train["timestamp"].max()) - delta_time if max_anchor <= t_min: return t_min candidate_quantiles = [0.80, 0.75, 0.70, 0.65, 0.60, 0.55] for quantile in candidate_quantiles: anchor_time = min(float(df_train["timestamp"].quantile(quantile)), max_anchor) prefix_nodes = get_eval_nodes(df_train[df_train["timestamp"] <= anchor_time]) if not prefix_nodes: continue y_anchor = build_window_labels(df_train, anchor_time, prefix_nodes, delta_time) if len(np.unique(y_anchor)) >= 2: return anchor_time return min(float(df_train["timestamp"].quantile(0.80)), max_anchor) def make_checkpoints(df_test: pd.DataFrame, delta_time: float, n_checkpoints: int) -> List[float]: if df_test.empty: return [] t_max = float(df_test["timestamp"].max()) valid = df_test[df_test["timestamp"] <= t_max - delta_time].sort_values("timestamp") if valid.empty: return [] timestamps = valid["timestamp"].to_numpy(dtype=np.float64) idx = np.unique( np.linspace(0, len(timestamps) - 1, num=min(n_checkpoints, len(timestamps)), dtype=int) ) checkpoints = [float(timestamps[i]) for i in idx] return sorted(set(checkpoints)) def train_node_head( model: TemporalModel, df_anchor_prefix: pd.DataFrame, eval_nodes: List[int], y_labels: np.ndarray, num_epochs: int = 150, ) -> None: if hasattr(model, "train_node_classifier_on_prefix"): model.train_node_classifier_on_prefix( df_anchor_prefix, eval_nodes, y_labels, num_epochs=num_epochs ) return if model.is_temporal: model.reset_memory() if len(df_anchor_prefix) > 0 and len(eval_nodes) > 0: model.predict(df_anchor_prefix, eval_nodes) if hasattr(model, "train_node_classifier"): model.train_node_classifier(eval_nodes, y_labels, num_epochs=num_epochs) if isinstance(model, TGNWrapper): assert model._node_head_fitted, "TGN node classifier was not fitted." return raise ValueError(f"Model {model.name} does not expose node-head training.") def fit_model_for_horizon( model: TemporalModel, df_train: pd.DataFrame, delta_time: float, num_epochs: int, node_epochs: int, ) -> dict: # Strip oracle columns from all non-oracle models train_input = df_train if model.name in _ORACLE_MODEL_NAMES else strip_oracle_cols(df_train) model.fit(train_input, num_epochs=num_epochs) anchor_time = choose_anchor_time(train_input, delta_time) df_anchor_prefix = train_input[train_input["timestamp"] <= anchor_time].copy() assert_prefix_only(df_anchor_prefix, anchor_time) anchor_nodes = get_eval_nodes(df_anchor_prefix) y_anchor = build_window_labels(train_input, anchor_time, anchor_nodes, delta_time) train_node_head( model, df_anchor_prefix=df_anchor_prefix, eval_nodes=anchor_nodes, y_labels=y_anchor, num_epochs=node_epochs, ) return { "anchor_time": anchor_time, "anchor_nodes": len(anchor_nodes), "anchor_fraud_rate": float(y_anchor.mean()) if len(y_anchor) else 0.0, } def collect_prefix_predictions( model: TemporalModel, df_train: pd.DataFrame, df_test: pd.DataFrame, delta_time: float, n_checkpoints: int, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: checkpoints = make_checkpoints(df_test, delta_time, n_checkpoints=n_checkpoints) if not checkpoints: return ( np.zeros(0, dtype=np.float32), np.zeros(0, dtype=np.float32), np.zeros(0, dtype=np.float32), ) df_full = ( pd.concat([df_train, df_test], ignore_index=True) .sort_values("timestamp") .reset_index(drop=True) ) y_chunks: List[np.ndarray] = [] p_chunks: List[np.ndarray] = [] s_chunks: List[np.ndarray] = [] is_oracle = model.name in _ORACLE_MODEL_NAMES for cutoff_time in checkpoints: active_nodes = get_eval_nodes(df_test[df_test["timestamp"] <= cutoff_time]) if not active_nodes: continue prefix_df = df_full[df_full["timestamp"] <= cutoff_time].copy() assert_prefix_only(prefix_df, cutoff_time) eval_df = prefix_df if is_oracle else strip_oracle_cols(prefix_df) if model.is_temporal: model.reset_memory() probs = model.predict(eval_df, active_nodes) y_true = build_window_labels(df_full, cutoff_time, active_nodes, delta_time) state = build_window_state(df_full, cutoff_time, active_nodes, delta_time) y_chunks.append(y_true) p_chunks.append(np.asarray(probs, dtype=np.float32)) s_chunks.append(state) if not y_chunks: return ( np.zeros(0, dtype=np.float32), np.zeros(0, dtype=np.float32), np.zeros(0, dtype=np.float32), ) return ( np.concatenate(y_chunks).astype(np.float32), np.concatenate(p_chunks).astype(np.float32), np.concatenate(s_chunks).astype(np.float32), ) def evaluate_model( model: TemporalModel, df_train: pd.DataFrame, df_test: pd.DataFrame, delta_time: float, n_checkpoints: int, ) -> tuple[dict, np.ndarray, np.ndarray, np.ndarray]: y_true, probs, states = collect_prefix_predictions( model=model, df_train=df_train, df_test=df_test, delta_time=delta_time, n_checkpoints=n_checkpoints, ) metrics = compute_metrics(y_true, probs) if len(y_true) else compute_metrics(np.array([0.0]), np.array([0.5])) metrics["n_predictions"] = int(len(y_true)) return metrics, y_true, probs, states def shuffle_chronology(df: pd.DataFrame, seed: int) -> pd.DataFrame: """Break temporal order while preserving the event table.""" rng = np.random.default_rng(seed) shuffled = df.copy() shuffled["timestamp"] = rng.permutation(shuffled["timestamp"].to_numpy(dtype=np.float64)) return shuffled.sort_values("timestamp").reset_index(drop=True) # --------------------------------------------------------------------------- # Experiments (single seed) # --------------------------------------------------------------------------- def run_ood_single( df_easy: pd.DataFrame, df_medium: pd.DataFrame, df_hard: pd.DataFrame, device: str, num_epochs: int, node_epochs: int, n_checkpoints: int, horizon: float = 0.10, ) -> pd.DataFrame: df_train = ( pd.concat([df_easy, df_medium], ignore_index=True) .sort_values("timestamp") .reset_index(drop=True) ) df_test = df_hard.sort_values("timestamp").reset_index(drop=True) df_train, df_test = remap_node_ids(df_train, df_test) df_train = augment_with_placeholder_nodes(df_train, df_test) delta_time = horizon_to_delta(df_test, horizon) rows = [] models = build_models(device=device) for model_name in MODEL_ORDER: model = models[model_name] fit_info = fit_model_for_horizon(model, df_train, delta_time, num_epochs, node_epochs) metrics, _, _, _ = evaluate_model(model, df_train, df_test, delta_time, n_checkpoints) rows.append({ "model": model_name, **metrics, **fit_info, }) df_out = pd.DataFrame(rows) xgb_roc = float(df_out.loc[df_out["model"] == "XGBoost", "roc_auc"].iloc[0]) df_out["gap_vs_xgb"] = df_out["roc_auc"] - xgb_roc return df_out def run_causal_single( df_hard: pd.DataFrame, device: str, num_epochs: int, node_epochs: int, n_checkpoints: int, seed: int, horizon: float = 0.10, ) -> pd.DataFrame: df_clean = df_hard.sort_values("timestamp").reset_index(drop=True) df_shuffled = shuffle_chronology(df_clean, seed=seed + 17) df_train_clean, df_test_clean, _ = split_temporally(df_clean) df_train_shuf, df_test_shuf, _ = split_temporally(df_shuffled) df_train_clean, df_test_clean = remap_node_ids(df_train_clean, df_test_clean) df_train_shuf, df_test_shuf = remap_node_ids(df_train_shuf, df_test_shuf) df_train_clean = augment_with_placeholder_nodes(df_train_clean, df_test_clean) df_train_shuf = augment_with_placeholder_nodes(df_train_shuf, df_test_shuf) delta_time_clean = horizon_to_delta(df_test_clean, horizon) delta_time_shuf = horizon_to_delta(df_test_shuf, horizon) rows = [] clean_models = build_models(device=device) shuffled_models = build_models(device=device) for model_name in MODEL_ORDER: clean_model = clean_models[model_name] shuffled_model = shuffled_models[model_name] fit_model_for_horizon(clean_model, df_train_clean, delta_time_clean, num_epochs, node_epochs) clean_metrics, _, _, _ = evaluate_model( clean_model, df_train_clean, df_test_clean, delta_time_clean, n_checkpoints ) fit_model_for_horizon(shuffled_model, df_train_shuf, delta_time_shuf, num_epochs, node_epochs) shuffled_metrics, _, _, _ = evaluate_model( shuffled_model, df_train_shuf, df_test_shuf, delta_time_shuf, n_checkpoints ) rows.append({ "model": model_name, "roc_auc_clean": clean_metrics["roc_auc"], "pr_auc_clean": clean_metrics["pr_auc"], "brier_clean": clean_metrics["brier"], "ece_clean": clean_metrics["ece"], "roc_auc_shuffled": shuffled_metrics["roc_auc"], "pr_auc_shuffled": shuffled_metrics["pr_auc"], "brier_shuffled": shuffled_metrics["brier"], "ece_shuffled": shuffled_metrics["ece"], "delta": shuffled_metrics["roc_auc"] - clean_metrics["roc_auc"], }) return pd.DataFrame(rows) def run_horizon_single( df_medium: pd.DataFrame, device: str, num_epochs: int, node_epochs: int, n_checkpoints: int, horizons: Sequence[float], ) -> pd.DataFrame: df_train, df_test, _ = split_temporally(df_medium) df_train, df_test = remap_node_ids(df_train, df_test) df_train = augment_with_placeholder_nodes(df_train, df_test) rows = [] for horizon in horizons: delta_time = horizon_to_delta(df_test, horizon) models = build_models(device=device) for model_name in MODEL_ORDER: model = models[model_name] fit_model_for_horizon(model, df_train, delta_time, num_epochs, node_epochs) metrics, _, _, _ = evaluate_model(model, df_train, df_test, delta_time, n_checkpoints) rows.append({ "horizon": float(horizon), "model": model_name, **metrics, }) return pd.DataFrame(rows) def run_mechanistic_single( df_hard: pd.DataFrame, device: str, num_epochs: int, node_epochs: int, n_checkpoints: int, horizon: float = 0.10, ) -> pd.DataFrame: df_train, df_test, _ = split_temporally(df_hard) df_train, df_test = remap_node_ids(df_train, df_test) df_train = augment_with_placeholder_nodes(df_train, df_test) delta_time = horizon_to_delta(df_test, horizon) rows = [] models = build_models(device=device) for model_name in MODEL_ORDER: model = models[model_name] fit_model_for_horizon(model, df_train, delta_time, num_epochs, node_epochs) _, _, probs, states = evaluate_model(model, df_train, df_test, delta_time, n_checkpoints) rows.append({ "model": model_name, "pearson_r": safe_pearson(states, probs), }) return pd.DataFrame(rows) def run_audit_single( df_hard: pd.DataFrame, device: str, num_epochs: int, node_epochs: int, n_checkpoints: int, seed: int, horizon: float = 0.10, benchmark_mode: str = "temporal_twins", ) -> pd.DataFrame: node_audit = build_node_audit_table(df_hard) ks_mean, ks_max = compute_aggregate_ks(node_audit) paired_pairs = int(node_audit.loc[node_audit["twin_pair_id"] >= 0, "twin_pair_id"].nunique()) paired_nodes = int((node_audit["twin_pair_id"] >= 0).sum()) # --- Label-source audit --- calib_mode = benchmark_mode == "temporal_twins_oracle_calib" print("\n--- Label-Source Audit ---") audit_tbl = build_label_source_audit_table(df_hard) if not audit_tbl.empty: print(audit_tbl.to_string(index=False, max_rows=20)) consistency = compute_motif_label_consistency(df_hard, calib_mode=calib_mode) compute_label_delay_stats(df_hard) df_train, df_test, _ = split_temporally(df_hard) leakage = compute_split_leakage(df_train, df_test) # --- Split integrity report --- print("\n--- Split Integrity ---") for k, v in leakage.items(): status = "[OK]" if v == 0 else "[WARN]" print(f" {status} {k}: {v}") df_train, df_test = remap_node_ids(df_train, df_test) train_examples, train_pair_rows, train_pair_counts = build_matched_control_tables(df_train) test_examples, test_pair_rows, test_pair_counts = build_matched_control_tables(df_test) matched_train_features = build_matched_prefix_feature_table(df_train, train_examples) matched_test_features = build_matched_prefix_feature_table(df_test, test_examples) matched_audit = report_matched_control_audits( test_examples=test_examples, test_pair_rows=test_pair_rows, test_pair_counts=test_pair_counts, ) static_agg_result = compute_matched_static_aggregate_auc( matched_train_features, matched_test_features, seed=seed, verbose=True, ) xgb_result = compute_matched_xgboost_auc( matched_train_features, matched_test_features, seed=seed, ) static_gnn_result = compute_matched_static_gnn_auc( df_train=df_train, df_test=df_test, train_examples=train_examples, test_examples=test_examples, device=device, num_epochs=num_epochs, seed=seed, ) df_train_eval = augment_with_placeholder_nodes(df_train, df_test) delta_time = horizon_to_delta(df_test, horizon) models = build_models(device=device) rows = [] for model_name in MODEL_ORDER: model = models[model_name] fit_model_for_horizon(model, df_train_eval, delta_time, num_epochs, node_epochs) metrics, _, probs, states = evaluate_model(model, df_train_eval, df_test, delta_time, n_checkpoints) pair_sep, eval_pairs = evaluate_matched_pair_separability( model, df_train=df_train_eval, df_test=df_test, delta_time=delta_time, n_checkpoints=n_checkpoints, ) matched_control_roc_auc = float("nan") if model_name == "XGBoost": matched_control_roc_auc = float(xgb_result["auc"]) elif model_name == "StaticGNN": matched_control_roc_auc = float(static_gnn_result["auc"]) rows.append({ "model": model_name, **metrics, "pearson_r": safe_pearson(states, probs), "matched_pair_sep": pair_sep, "matched_pair_eval_pairs": eval_pairs, "matched_control_roc_auc": matched_control_roc_auc, "static_agg_auc": float(static_agg_result["auc"]), "static_agg_auc_bootstrap_std": float(static_agg_result["bootstrap_std"]), "xgb_auc_bootstrap_std": float(xgb_result["bootstrap_std"]), "static_gnn_auc_bootstrap_std": float(static_gnn_result["bootstrap_std"]), "ks_mean": ks_mean, "ks_max": ks_max, "paired_pairs": paired_pairs, "paired_nodes": paired_nodes, **leakage, **matched_audit, }) return pd.DataFrame(rows) # --------------------------------------------------------------------------- # Aggregation / plotting outputs # --------------------------------------------------------------------------- def summarise_mean_std(df: pd.DataFrame, group_cols: Sequence[str], value_cols: Sequence[str]) -> pd.DataFrame: summary = df.groupby(list(group_cols)).agg({ value_col: ["mean", "std"] for value_col in value_cols }) summary.columns = [ f"{value_col}_{stat}" for value_col, stat in summary.columns.to_flat_index() ] summary = summary.reset_index() return summary.fillna(0.0) def save_experiment_outputs( raw_frames: Dict[str, List[pd.DataFrame]], results_dir: str, ) -> None: os.makedirs(results_dir, exist_ok=True) raw_causal = pd.concat(raw_frames["causal"], ignore_index=True) if raw_frames["causal"] else None if raw_frames["ood"]: raw_ood = pd.concat(raw_frames["ood"], ignore_index=True) raw_ood.to_csv(os.path.join(results_dir, "ood_raw.csv"), index=False) ood_summary = summarise_mean_std( raw_ood, group_cols=["model"], value_cols=["roc_auc", "pr_auc", "brier", "ece", "gap_vs_xgb"], ) ood_summary.to_csv(os.path.join(results_dir, "ood.csv"), index=False) if raw_frames["causal"]: assert raw_causal is not None raw_causal.to_csv(os.path.join(results_dir, "causal_raw.csv"), index=False) causal_summary = summarise_mean_std( raw_causal, group_cols=["model"], value_cols=[ "roc_auc_clean", "pr_auc_clean", "brier_clean", "ece_clean", "roc_auc_shuffled", "pr_auc_shuffled", "brier_shuffled", "ece_shuffled", "delta", ], ) causal_summary.to_csv(os.path.join(results_dir, "causal.csv"), index=False) if raw_frames["horizon"]: raw_horizon = pd.concat(raw_frames["horizon"], ignore_index=True) raw_horizon.to_csv(os.path.join(results_dir, "horizon_raw.csv"), index=False) horizon_summary = summarise_mean_std( raw_horizon, group_cols=["horizon", "model"], value_cols=["roc_auc", "pr_auc", "brier", "ece"], ) horizon_summary.to_csv(os.path.join(results_dir, "horizon.csv"), index=False) if raw_frames["mechanistic"]: raw_mech = pd.concat(raw_frames["mechanistic"], ignore_index=True) raw_mech.to_csv(os.path.join(results_dir, "mechanistic_raw.csv"), index=False) mech_summary = summarise_mean_std( raw_mech, group_cols=["model"], value_cols=["pearson_r"], ) mech_summary.to_csv(os.path.join(results_dir, "mechanistic.csv"), index=False) if raw_frames.get("audit"): raw_audit = pd.concat(raw_frames["audit"], ignore_index=True) raw_audit.to_csv(os.path.join(results_dir, "audit_raw.csv"), index=False) audit_summary = summarise_mean_std( raw_audit, group_cols=["model"], value_cols=[ "roc_auc", "pr_auc", "brier", "ece", "pearson_r", "matched_pair_sep", "matched_pair_eval_pairs", "matched_control_roc_auc", "static_agg_auc", "ks_mean", "ks_max", "paired_pairs", "paired_nodes", "sender_overlap_count", "pair_overlap_count", "template_overlap_count", "pair_total_txn_count_diff_mean", "pair_total_txn_count_diff_max", "auc_total_txn_count", "auc_local_event_idx", "auc_prefix_txn_count", "auc_timestamp", "auc_account_age", "auc_active_age", "fraud_label_event_idx_mean", "fraud_label_event_idx_max", "benign_eval_event_idx_mean", "benign_eval_event_idx_max", "pair_event_idx_diff_mean", "pair_event_idx_diff_max", "pair_active_age_diff_mean", "pair_active_age_diff_max", "pair_timestamp_diff_mean", "pair_timestamp_diff_max", "benign_motif_hit_rate", "benign_motif_hit_pairs", "matched_control_examples", "matched_control_pair_events", ], ) if raw_causal is not None: causal_delta = summarise_mean_std( raw_causal, group_cols=["model"], value_cols=["delta"], )[["model", "delta_mean", "delta_std"]] audit_summary = audit_summary.merge(causal_delta, on="model", how="left") audit_summary[["delta_mean", "delta_std"]] = audit_summary[ ["delta_mean", "delta_std"] ].fillna(0.0) audit_summary.to_csv(os.path.join(results_dir, "audit.csv"), index=False) # --------------------------------------------------------------------------- # Node-level oracle evaluation helpers (twin_label, not window label) # --------------------------------------------------------------------------- def _twin_labels_for_nodes(df_full: pd.DataFrame, nodes: List[int]) -> np.ndarray: """Return twin_label (1=fraud twin, 0=benign) per node. Falls back to is_fraud if twin_label is absent.""" col = "twin_label" if "twin_label" in df_full.columns else "is_fraud" label_series = df_full.groupby("sender_id")[col].max() return np.array([float(label_series.get(n, 0.0)) for n in nodes], dtype=np.float32) def evaluate_oracle_node_level( model: TemporalModel, df_full: pd.DataFrame, eval_nodes: List[int], ) -> float: """ROC-AUC of oracle scored against twin_label (user-level, not window-level). For oracle-type models we pass the FULL df (with audit columns). For AuditOracle, predict() directly reads motif_hit_count — no training. For RawMotifOracle, train_node_classifier_on_prefix must be called first with twin_labels so it learns the node-level task. """ if not eval_nodes: return float("nan") y_true = _twin_labels_for_nodes(df_full, eval_nodes) probs = model.predict(df_full, eval_nodes) return safe_roc_auc(y_true, probs.astype(np.float32)) def evaluate_oracle_pair_sep_node_level( model: TemporalModel, df_full: pd.DataFrame, eval_nodes: List[int], ) -> float: """Matched-pair separability: P(score_fraud > score_benign) using twin_label.""" if not eval_nodes or "twin_pair_id" not in df_full.columns: return float("nan") probs = model.predict(df_full, eval_nodes) score_map = {n: float(p) for n, p in zip(eval_nodes, probs)} meta = ( df_full[df_full["sender_id"].isin(eval_nodes) & (df_full["twin_pair_id"] >= 0)] .groupby("sender_id") .agg(twin_pair_id=("twin_pair_id", "first"), twin_label=("twin_label", "max")) .reset_index() ) pair_results: List[float] = [] for _, grp in meta.groupby("twin_pair_id"): if len(grp) != 2 or set(grp["twin_label"]) != {0, 1}: continue fraud_node = int(grp.loc[grp["twin_label"] == 1, "sender_id"].iloc[0]) benign_node = int(grp.loc[grp["twin_label"] == 0, "sender_id"].iloc[0]) if fraud_node in score_map and benign_node in score_map: pair_results.append(float(score_map[fraud_node] > score_map[benign_node])) return float(np.mean(pair_results)) if pair_results else float("nan") def build_oracle_debug_table( df_full: pd.DataFrame, eval_nodes: List[int], oracle_scores: dict[str, np.ndarray], y_twin: np.ndarray, n_sample: int = 20, primary_score_name: str = "AuditOracle", table_title: str = "Oracle Debug Table", ) -> pd.DataFrame: """Print a per-node debug table for oracle/probe scores vs ground-truth.""" audit_cols = [ "twin_pair_id", "twin_role", "motif_hit_count", "trigger_event_idx", "label_event_idx", "label_delay", "is_fallback_label", ] available = [c for c in audit_cols if c in df_full.columns] meta = ( df_full[df_full["sender_id"].isin(eval_nodes)] .groupby("sender_id")[available] .first() .reset_index() # sender_id becomes a column here ) meta["twin_label"] = y_twin meta["_idx"] = meta["sender_id"].map({n: i for i, n in enumerate(eval_nodes)}) for name, scores in oracle_scores.items(): meta[f"score_{name}"] = meta["_idx"].map( {i: float(scores[i]) for i in range(len(scores))} ) meta = meta.drop(columns=["_idx"]) # Sample: top n_sample/2 by the primary motif score + bottom n_sample/2 sort_col = f"score_{primary_score_name}" if sort_col not in meta.columns: sort_col = meta.columns[-1] meta = meta.sort_values(sort_col, ascending=False) sample = pd.concat([meta.head(n_sample // 2), meta.tail(n_sample // 2)]).drop_duplicates() print(f"\n--- {table_title} (top & bottom by {primary_score_name} score) ---") print(sample.to_string(index=False)) return sample # Gate volume targets / budgets _FAST_GATE_MIN_MATCHED_PAIRS = 500 _FULL_GATE_MIN_MATCHED_PAIRS = 2000 _GATE_MIN_CLASS_EXAMPLES = 500 _GATE_MIN_UNIQUE_USERS = 50 _GATE_POS_RATE_RANGE = (0.35, 0.65) _GATE_BOOTSTRAP_ROUNDS = 200 _GATE_PACK_NAMESPACE = 10_000_000 _GATE_MAX_EXTRA_PACKS = 6 def _subsample_for_gate( df: pd.DataFrame, rng: np.random.Generator, max_pairs: int | None = None, ) -> pd.DataFrame: """Keep at most max_pairs twin pairs for the gate.""" if "twin_pair_id" not in df.columns: return df pair_ids = df.loc[df["twin_pair_id"] >= 0, "twin_pair_id"].unique() if max_pairs is None or max_pairs <= 0 or len(pair_ids) <= max_pairs: return df[df["twin_pair_id"] >= 0].copy() chosen = set(rng.choice(pair_ids, size=max_pairs, replace=False).tolist()) return df[df["twin_pair_id"].isin(chosen)].copy() def gate_volume_thresholds(fast_mode: bool) -> dict: return { "matched_eval_pairs_min": _FAST_GATE_MIN_MATCHED_PAIRS if fast_mode else _FULL_GATE_MIN_MATCHED_PAIRS, "positives_min": _GATE_MIN_CLASS_EXAMPLES, "negatives_min": _GATE_MIN_CLASS_EXAMPLES, "unique_fraud_users_min": _GATE_MIN_UNIQUE_USERS, "unique_benign_users_min": _GATE_MIN_UNIQUE_USERS, "positive_rate_lo": _GATE_POS_RATE_RANGE[0], "positive_rate_hi": _GATE_POS_RATE_RANGE[1], } def summarize_gate_volume( test_examples: pd.DataFrame, test_pair_rows: pd.DataFrame, eval_nodes: Sequence[int], ) -> dict: positives = int(test_examples["label"].sum()) if not test_examples.empty else 0 total_examples = int(len(test_examples)) negatives = int(total_examples - positives) fraud_users = int(test_examples.loc[test_examples["label"] == 1, "sender_id"].nunique()) if not test_examples.empty else 0 benign_users = int(test_examples.loc[test_examples["label"] == 0, "sender_id"].nunique()) if not test_examples.empty else 0 unique_templates = int(test_examples["template_id"].nunique()) if ("template_id" in test_examples.columns and not test_examples.empty) else 0 positive_rate = float(positives / max(total_examples, 1)) return { "matched_eval_pairs": int(len(test_pair_rows)), "positives": positives, "negatives": negatives, "unique_fraud_users": fraud_users, "unique_benign_users": benign_users, "unique_templates": unique_templates, "positive_rate": positive_rate, "audit_n_examples": int(len(eval_nodes)), "raw_n_examples": int(len(eval_nodes)), "xgb_n_examples": total_examples, "static_gnn_n_examples": total_examples, } def gate_volume_violations(volume: dict, fast_mode: bool) -> list[str]: thresholds = gate_volume_thresholds(fast_mode) violations: list[str] = [] if volume.get("matched_eval_pairs", 0) < thresholds["matched_eval_pairs_min"]: violations.append( f"matched_eval_pairs {volume.get('matched_eval_pairs', 0)} < {thresholds['matched_eval_pairs_min']}" ) if volume.get("positives", 0) < thresholds["positives_min"]: violations.append(f"positives {volume.get('positives', 0)} < {thresholds['positives_min']}") if volume.get("negatives", 0) < thresholds["negatives_min"]: violations.append(f"negatives {volume.get('negatives', 0)} < {thresholds['negatives_min']}") if volume.get("unique_fraud_users", 0) < thresholds["unique_fraud_users_min"]: violations.append( f"unique_fraud_users {volume.get('unique_fraud_users', 0)} < {thresholds['unique_fraud_users_min']}" ) if volume.get("unique_benign_users", 0) < thresholds["unique_benign_users_min"]: violations.append( f"unique_benign_users {volume.get('unique_benign_users', 0)} < {thresholds['unique_benign_users_min']}" ) pos_rate = float(volume.get("positive_rate", 0.0)) if pos_rate < thresholds["positive_rate_lo"] or pos_rate > thresholds["positive_rate_hi"]: violations.append( f"positive_rate {pos_rate:.4f} outside [{thresholds['positive_rate_lo']:.2f}, {thresholds['positive_rate_hi']:.2f}]" ) return violations def gate_volume_is_sufficient(volume: dict, fast_mode: bool) -> bool: return len(gate_volume_violations(volume, fast_mode)) == 0 def offset_gate_namespace(df: pd.DataFrame, pack_idx: int) -> pd.DataFrame: if pack_idx == 0: return df.copy() out = df.copy() offset = pack_idx * _GATE_PACK_NAMESPACE out["sender_id"] = out["sender_id"].astype(np.int64) + offset out["receiver_id"] = out["receiver_id"].astype(np.int64) + offset for col in ("twin_pair_id", "template_id"): if col in out.columns: valid = out[col].astype(np.int64) >= 0 out.loc[valid, col] = out.loc[valid, col].astype(np.int64) + offset return out def build_gate_pool_from_frames(frames: Sequence[pd.DataFrame]) -> pd.DataFrame: non_empty = [frame for frame in frames if frame is not None and not frame.empty] if not non_empty: return pd.DataFrame() return ( pd.concat(non_empty, ignore_index=True) .sort_values("timestamp") .reset_index(drop=True) ) def gate_pair_budget_candidates(total_pairs: int, fast_mode: bool) -> list[int | None]: if total_pairs <= 0: return [0] target_budget = 900 if fast_mode else 3500 budgets = [min(total_pairs, target_budget)] if total_pairs > budgets[0]: budgets.append(total_pairs) return [int(budget) for budget in dict.fromkeys(budgets)] def prepare_gate_subset( df_pool: pd.DataFrame, seed: int, fast_mode: bool, ) -> dict: total_pairs = int(df_pool.loc[df_pool["twin_pair_id"] >= 0, "twin_pair_id"].nunique()) if "twin_pair_id" in df_pool.columns else 0 if total_pairs == 0: empty = pd.DataFrame() return { "pair_budget": 0, "df_gate": empty, "df_train": empty, "df_test": empty, "df_train_eval": empty, "df_full": empty, "eval_nodes": [], "train_examples": empty, "train_pair_rows": empty, "train_pair_counts": empty, "test_examples": empty, "test_pair_rows": empty, "test_pair_counts": empty, "volume": summarize_gate_volume(empty, empty, []), } best: dict | None = None for pair_budget in gate_pair_budget_candidates(total_pairs, fast_mode): gate_rng = np.random.default_rng(seed + int(pair_budget)) df_gate = _subsample_for_gate(df_pool, gate_rng, max_pairs=pair_budget) df_train, df_test, _ = split_temporally(df_gate) df_train, df_test = remap_node_ids(df_train, df_test) train_examples, train_pair_rows, train_pair_counts = build_matched_control_tables(df_train) test_examples, test_pair_rows, test_pair_counts = build_matched_control_tables(df_test) df_train_eval = augment_with_placeholder_nodes(df_train, df_test) df_full = ( pd.concat([df_train_eval, df_test], ignore_index=True) .sort_values("timestamp") .reset_index(drop=True) ) eval_nodes = get_eval_nodes(df_full) volume = summarize_gate_volume(test_examples, test_pair_rows, eval_nodes) candidate = { "pair_budget": int(pair_budget) if pair_budget is not None else total_pairs, "df_gate": df_gate, "df_train": df_train, "df_test": df_test, "df_train_eval": df_train_eval, "df_full": df_full, "eval_nodes": eval_nodes, "train_examples": train_examples, "train_pair_rows": train_pair_rows, "train_pair_counts": train_pair_counts, "test_examples": test_examples, "test_pair_rows": test_pair_rows, "test_pair_counts": test_pair_counts, "volume": volume, } best = candidate if gate_volume_is_sufficient(volume, fast_mode): return candidate assert best is not None return best def ensure_gate_volume( df_pool: pd.DataFrame, config, seed: int, benchmark_mode: str, fast_mode: bool, ) -> dict: pool = df_pool.copy() gate = prepare_gate_subset(pool, seed=seed, fast_mode=fast_mode) pack_count = 1 while (not gate_volume_is_sufficient(gate["volume"], fast_mode)) and pack_count <= _GATE_MAX_EXTRA_PACKS: extra_seed = seed + pack_count * 10_007 extra_easy, extra_medium, extra_hard = generate_all( config, seed=extra_seed, benchmark_mode=benchmark_mode, ) extra_pack = build_gate_pool_from_frames([ offset_gate_namespace(extra_easy, pack_count), offset_gate_namespace(extra_medium, pack_count), offset_gate_namespace(extra_hard, pack_count), ]) pool = build_gate_pool_from_frames([pool, extra_pack]) gate = prepare_gate_subset(pool, seed=seed, fast_mode=fast_mode) pack_count += 1 gate["source_pool_events"] = int(len(pool)) gate["source_pool_pairs"] = int(pool.loc[pool["twin_pair_id"] >= 0, "twin_pair_id"].nunique()) if "twin_pair_id" in pool.columns else 0 gate["source_pool_packs"] = int(pack_count) return gate def ensure_gate_volume_for_difficulty( config, difficulty: str, seed: int, benchmark_mode: str, fast_mode: bool, initial_pool: pd.DataFrame | None = None, ) -> dict: """Build a reliable-volume gate pool using repeated packs of a single difficulty.""" if initial_pool is None: pool = generate_single_difficulty( config, difficulty=difficulty, seed=seed, benchmark_mode=benchmark_mode, ) else: pool = initial_pool.copy() gate = prepare_gate_subset(pool, seed=seed, fast_mode=fast_mode) pack_count = 1 while (not gate_volume_is_sufficient(gate["volume"], fast_mode)) and pack_count <= _GATE_MAX_EXTRA_PACKS: extra_seed = seed + pack_count * 10_007 extra_pack = generate_single_difficulty( config, difficulty=difficulty, seed=extra_seed, benchmark_mode=benchmark_mode, ) extra_pack = offset_gate_namespace(extra_pack, pack_count) pool = build_gate_pool_from_frames([pool, extra_pack]) gate = prepare_gate_subset(pool, seed=seed, fast_mode=fast_mode) pack_count += 1 gate["source_pool_events"] = int(len(pool)) gate["source_pool_pairs"] = int(pool.loc[pool["twin_pair_id"] >= 0, "twin_pair_id"].nunique()) if "twin_pair_id" in pool.columns else 0 gate["source_pool_packs"] = int(pack_count) return gate # --------------------------------------------------------------------------- # Motif Validity Check (req #11) # --------------------------------------------------------------------------- def run_motif_validity_check( df: pd.DataFrame, config, seed: int, device: str, num_epochs: int, node_epochs: int, n_checkpoints: int, hard_abort: bool = True, horizon: float = 0.10, benchmark_mode: str = "temporal_twins_oracle_calib", fast_mode: bool = False, force_temporal_models: bool = False, prebuilt_gate: dict | None = None, ) -> tuple[bool, dict]: """Run the staged MOTIF VALIDITY CHECK gate. Stage 1 — AuditOracle: reads audit cols directly. >= 0.99 required. Stage 2 — RawMotifOracle: reconstructs motif. >= 0.95 required. Stage 3 — Static ceilings: XGB <= 0.65, StaticGNN <= 0.70. Stage 4 — SeqGRU: >= 0.80 (calib mode only). Oracles are evaluated against twin_label (NOT window label) to avoid the target-alignment bug where late windows have no upcoming fraud events. """ calib_mode = _is_oracle_calib_mode(benchmark_mode) metric_labels = _oracle_metric_labels(benchmark_mode) # Dataset-wide stats computed on the FULL df before subsampling consistency = compute_motif_label_consistency(df, calib_mode=calib_mode) delay_stats = compute_label_delay_stats(df) node_audit = build_node_audit_table(df) ks_mean, ks_max = compute_aggregate_ks(node_audit) gate = prebuilt_gate if gate is None: gate = ensure_gate_volume( df_pool=df, config=config, seed=seed, benchmark_mode=benchmark_mode, fast_mode=fast_mode, ) df_gate = gate["df_gate"] df_train = gate["df_train"] df_test = gate["df_test"] df_train_eval = gate["df_train_eval"] df_full = gate["df_full"] eval_nodes = gate["eval_nodes"] train_examples = gate["train_examples"] train_pair_rows = gate["train_pair_rows"] train_pair_counts = gate["train_pair_counts"] test_examples = gate["test_examples"] test_pair_rows = gate["test_pair_rows"] test_pair_counts = gate["test_pair_counts"] gate_volume = gate["volume"] print( f" [gate] Using {df_gate['twin_pair_id'].nunique()} pairs " f"({len(df_gate):,} events) for model stages from " f"{gate['source_pool_packs']} pack(s), source pairs={gate['source_pool_pairs']:,}." ) leakage = compute_split_leakage(df_train, df_test) matched_train_features = build_matched_prefix_feature_table(df_train, train_examples) matched_test_features = build_matched_prefix_feature_table(df_test, test_examples) matched_audit = report_matched_control_audits( test_examples=test_examples, test_pair_rows=test_pair_rows, test_pair_counts=test_pair_counts, ) static_agg_result = compute_matched_static_aggregate_auc( matched_train_features, matched_test_features, seed=seed, verbose=False, ) delta_time = horizon_to_delta(df_test, horizon) y_twin = _twin_labels_for_nodes(df_full, eval_nodes) report: dict = { "ks_mean": ks_mean, "ks_max": ks_max, "static_agg_auc": float(static_agg_result["auc"]), **delay_stats, **{k: v for k, v in consistency.items()}, **matched_audit, **gate_volume, "gate_pair_budget": int(gate["pair_budget"]), "gate_source_pool_events": int(gate["source_pool_events"]), "gate_source_pool_pairs": int(gate["source_pool_pairs"]), "gate_source_pool_packs": int(gate["source_pool_packs"]), } attach_auc_result(report, "static_agg", static_agg_result) oracle_scores: dict[str, np.ndarray] = {} # Stage 1 — AuditOracle / MotifProbe (no training; reads motif_hit_count directly) audit_oracle = AuditOracleWrapper() audit_probs = audit_oracle.predict(df_full, eval_nodes) oracle_scores[metric_labels["audit"]] = audit_probs audit_result = make_auc_result(y_twin, audit_probs.astype(np.float32), seed=seed) attach_auc_result(report, "audit", audit_result) report["audit_pair_sep"] = evaluate_oracle_pair_sep_node_level( audit_oracle, df_full, eval_nodes ) # Stage 2 — RawMotifOracle / RawMotifProbe (trained on twin_label, not window label) raw_oracle = RawMotifOracleWrapper() raw_oracle.fit(df_train_eval, num_epochs=num_epochs) train_nodes_raw = get_eval_nodes(df_train_eval) y_train_twin_raw = _twin_labels_for_nodes(df_train_eval, train_nodes_raw) train_node_head( raw_oracle, df_anchor_prefix=df_train_eval, eval_nodes=train_nodes_raw, y_labels=y_train_twin_raw, num_epochs=node_epochs, ) raw_probs = raw_oracle.predict(df_full, eval_nodes) oracle_scores[metric_labels["raw"]] = raw_probs raw_result = make_auc_result(y_twin, raw_probs.astype(np.float32), seed=seed) attach_auc_result(report, "raw", raw_result) report["raw_pair_sep"] = evaluate_oracle_pair_sep_node_level( raw_oracle, df_full, eval_nodes ) _attach_probe_aliases(report, benchmark_mode) # Oracle/probe debug table build_oracle_debug_table( df_full, eval_nodes, oracle_scores, y_twin, primary_score_name=metric_labels["audit"], table_title=metric_labels["table"], ) # Stage 3 — Static baselines (window-label eval, as in main benchmark) xgb_result = compute_matched_xgboost_auc( matched_train_features, matched_test_features, seed=seed, ) attach_auc_result(report, "xgb", xgb_result) static_gnn_result = compute_matched_static_gnn_auc( df_train=df_train, df_test=df_test, train_examples=train_examples, test_examples=test_examples, device=device, num_epochs=num_epochs, seed=seed, ) attach_auc_result(report, "static_gnn", static_gnn_result) report["xgb_roc_auc"] = float(xgb_result["auc"]) report["static_gnn_roc"] = float(static_gnn_result["auc"]) report["static_gnn_auc_flipped"] = float(static_gnn_result.get("auc_flipped", float("nan"))) report["static_gnn_score_mean_pos"] = float(static_gnn_result.get("score_mean_pos", float("nan"))) report["static_gnn_score_mean_neg"] = float(static_gnn_result.get("score_mean_neg", float("nan"))) report["static_gnn_score_std"] = float(static_gnn_result.get("score_std", float("nan"))) report["static_gnn_zero_emb_frac"] = float(static_gnn_result.get("zero_emb_frac", float("nan"))) report["static_gnn_matched_examples"] = int(static_gnn_result.get("matched_examples", 0)) report["static_gnn_unique_prefix_cutoffs"] = int(static_gnn_result.get("unique_prefix_cutoffs", 0)) report["static_gnn_graph_builds"] = int(static_gnn_result.get("graph_builds", 0)) report["static_gnn_cache_hit_rate"] = float(static_gnn_result.get("cache_hit_rate", float("nan"))) report["static_gnn_eval_time_sec"] = float(static_gnn_result.get("eval_time_sec", float("nan"))) report["static_gnn_train_unique_prefix_cutoffs"] = int(static_gnn_result.get("train_unique_prefix_cutoffs", 0)) report["static_gnn_test_unique_prefix_cutoffs"] = int(static_gnn_result.get("test_unique_prefix_cutoffs", 0)) report["static_gnn_train_graph_builds"] = int(static_gnn_result.get("train_graph_builds", 0)) report["static_gnn_test_graph_builds"] = int(static_gnn_result.get("test_graph_builds", 0)) report["static_gnn_train_eval_time_sec"] = float(static_gnn_result.get("train_eval_time_sec", float("nan"))) report["static_gnn_test_eval_time_sec"] = float(static_gnn_result.get("test_eval_time_sec", float("nan"))) # Stage 4 — SeqGRU (calib mode only) run_temporal_models = calib_mode or force_temporal_models if run_temporal_models: seqgru_result = compute_matched_seqgru_metrics( df_train=df_train, df_test=df_test, train_examples=train_examples, test_examples=test_examples, device=device, seed=seed, max_epochs=max(24, min(72, node_epochs)), ) seqgru_clean = seqgru_result["clean"] seqgru_shuffled = seqgru_result["shuffled"] report["seqgru_roc_auc"] = float(seqgru_clean["auc"]) report["seqgru_pr_auc"] = float(seqgru_clean.get("pr_auc", float("nan"))) report["seqgru_brier"] = float(seqgru_clean.get("brier", float("nan"))) report["seqgru_ece"] = float(seqgru_clean.get("ece", float("nan"))) report["seqgru_n_examples"] = int(seqgru_clean.get("n_examples", 0)) report["seqgru_shuffle_roc_auc"] = float(seqgru_shuffled["auc"]) report["seqgru_shuffle_pr_auc"] = float(seqgru_shuffled.get("pr_auc", float("nan"))) report["seqgru_shuffle_delta"] = float(seqgru_result["delta"]) report["seqgru_best_epoch"] = int(seqgru_result["clean_fit"].get("best_epoch", 0)) report["seqgru_best_valid_roc_auc"] = float(seqgru_result["clean_fit"].get("best_valid_roc_auc", float("nan"))) report["seqgru_best_valid_pr_auc"] = float(seqgru_result["clean_fit"].get("best_valid_pr_auc", float("nan"))) report["seqgru_shuffle_best_epoch"] = int(seqgru_result["shuffled_fit"].get("best_epoch", 0)) report["seqgru_shuffle_best_valid_roc_auc"] = float(seqgru_result["shuffled_fit"].get("best_valid_roc_auc", float("nan"))) temporal_gnn_specs = [ ("TGN", "tgn", lambda: TGNWrapper(device=device)), ("TGAT", "tgat", lambda: TGATWrapper(device=device)), ("DyRep", "dyrep", lambda: DyRepWrapper(device=device)), ("JODIE", "jodie", lambda: JODIEWrapper(device=device)), ] temporal_num_epochs = max(2, num_epochs) for model_label, key_prefix, builder in temporal_gnn_specs: temporal_result = compute_matched_temporal_gnn_metrics( model_name=model_label, model_builder=builder, df_train=df_train, df_test=df_test, train_examples=train_examples, test_examples=test_examples, seed=seed, num_epochs=temporal_num_epochs, ) clean_result = temporal_result["clean"] shuffled_result = temporal_result["shuffled"] report[f"{key_prefix}_roc_auc"] = float(clean_result["auc"]) report[f"{key_prefix}_pr_auc"] = float(clean_result.get("pr_auc", float("nan"))) report[f"{key_prefix}_n_examples"] = int(clean_result.get("n_examples", 0)) report[f"{key_prefix}_shuffle_roc_auc"] = float(shuffled_result["auc"]) report[f"{key_prefix}_shuffle_pr_auc"] = float(shuffled_result.get("pr_auc", float("nan"))) report[f"{key_prefix}_shuffle_delta"] = float(temporal_result["delta"]) else: report["seqgru_roc_auc"] = float("nan") report["seqgru_pr_auc"] = float("nan") report["seqgru_n_examples"] = 0 report["seqgru_shuffle_roc_auc"] = float("nan") report["seqgru_shuffle_pr_auc"] = float("nan") report["seqgru_shuffle_delta"] = float("nan") report["seqgru_best_epoch"] = 0 report["seqgru_best_valid_roc_auc"] = float("nan") report["seqgru_best_valid_pr_auc"] = float("nan") report["seqgru_shuffle_best_epoch"] = 0 report["seqgru_shuffle_best_valid_roc_auc"] = float("nan") for key_prefix in ("tgn", "tgat", "dyrep", "jodie"): report[f"{key_prefix}_roc_auc"] = float("nan") report[f"{key_prefix}_pr_auc"] = float("nan") report[f"{key_prefix}_n_examples"] = 0 report[f"{key_prefix}_shuffle_roc_auc"] = float("nan") report[f"{key_prefix}_shuffle_pr_auc"] = float("nan") report[f"{key_prefix}_shuffle_delta"] = float("nan") # Gate table gate_items = [ (f"{metric_labels['audit']} ROC-AUC", "audit_roc_auc", "ge", 0.99, "label-alignment bug"), (f"{metric_labels['audit']} pair-sep", "audit_pair_sep", "ge", 0.99, "pair construction bug"), (f"{metric_labels['raw']} ROC-AUC", "raw_roc_auc", "ge", 0.95, "motif reconstruction bug"), (f"{metric_labels['raw']} pair-sep", "raw_pair_sep", "ge", 0.90, "motif reconstruction bug"), ("static_agg_auc", "static_agg_auc", "le", 0.60, "static leakage"), ("XGBoost ROC-AUC", "xgb_roc_auc", "le", 0.65, "static leakage"), ("StaticGNN ROC-AUC", "static_gnn_roc", "le", 0.70, "static leakage"), ] if run_temporal_models: gate_items.extend([ ("SeqGRU ROC-AUC", "seqgru_roc_auc", "ge", 0.80, "learning/input bug"), ("SeqGRU shuffle delta", "seqgru_shuffle_delta", "le", -0.10, "order signal missing"), ]) temporal_gnn_items = [ ("TGN ROC-AUC", "tgn_roc_auc", "ge", 0.75, "temporal learnability"), ("TGN shuffle delta", "tgn_shuffle_delta", "le", -0.10, "order signal missing"), ("TGAT ROC-AUC", "tgat_roc_auc", "ge", 0.75, "temporal learnability"), ("TGAT shuffle delta", "tgat_shuffle_delta", "le", -0.10, "order signal missing"), ("DyRep ROC-AUC", "dyrep_roc_auc", "ge", 0.75, "temporal learnability"), ("DyRep shuffle delta", "dyrep_shuffle_delta", "le", -0.10, "order signal missing"), ("JODIE ROC-AUC", "jodie_roc_auc", "ge", 0.75, "temporal learnability"), ("JODIE shuffle delta", "jodie_shuffle_delta", "le", -0.10, "order signal missing"), ] print("\n" + "=" * 72) print(" MOTIF VALIDITY CHECK") print("=" * 72) print(" Gate Volume") print(f" matched_eval_pairs : {report['matched_eval_pairs']}") print(f" positives / negatives : {report['positives']} / {report['negatives']}") print(f" unique fraud / benign : {report['unique_fraud_users']} / {report['unique_benign_users']}") print(f" unique templates : {report['unique_templates']}") print(f" positive rate : {report['positive_rate']:.4f}") print( " model examples : " f"{metric_labels['audit']}={report['audit_n_examples']} " f"{metric_labels['raw']}={report['raw_n_examples']} " f"XGB={report['xgb_n_examples']} StaticGNN={report['static_gnn_n_examples']} " f"SeqGRU={report['seqgru_n_examples']} TGN={report['tgn_n_examples']} " f"TGAT={report['tgat_n_examples']} DyRep={report['dyrep_n_examples']} " f"JODIE={report['jodie_n_examples']}" ) print(f" gate source packs/pairs : {report['gate_source_pool_packs']} / {report['gate_source_pool_pairs']}") print(f" gate pair budget : {report['gate_pair_budget']}") volume_violations = gate_volume_violations(report, fast_mode) if volume_violations: print(" INSUFFICIENT_GATE_VOLUME") for violation in volume_violations: print(f" - {violation}") all_pass = True if volume_violations: all_pass = False for label, key, op, thresh, hint in gate_items: val = report.get(key, float("nan")) is_nan = val != val ok = (not is_nan) and ((val >= thresh) if op == "ge" else (val <= thresh)) status = "N/A " if is_nan else ("PASS" if ok else "FAIL") if not ok: all_pass = False tstr = f"{'>='+str(thresh) if op=='ge' else '<='+str(thresh)}" print(f" {label:<28}: {val:>7.4f} [{status} {tstr}] {'<-- '+hint if not ok else ''}") for label, key, op, thresh, hint in temporal_gnn_items: val = report.get(key, float("nan")) is_nan = val != val ok = (not is_nan) and ((val >= thresh) if op == "ge" else (val <= thresh)) status = "N/A " if is_nan else ("PASS" if ok else "FAIL") tstr = f"{'>='+str(thresh) if op=='ge' else '<='+str(thresh)}" suffix = "" if ok else f" [advisory: {hint}]" print(f" {label:<28}: {val:>7.4f} [{status} {tstr}]{suffix}") audit_ci_label = f"{metric_labels['audit']} AUC std/CI" raw_ci_label = f"{metric_labels['raw']} std/CI" print(f" {audit_ci_label:<28}: {report['audit_auc_bootstrap_std']:.4f} [{report['audit_auc_ci_lo']:.4f}, {report['audit_auc_ci_hi']:.4f}]") print(f" {raw_ci_label:<28}: {report['raw_auc_bootstrap_std']:.4f} [{report['raw_auc_ci_lo']:.4f}, {report['raw_auc_ci_hi']:.4f}]") print(f" {'XGBoost AUC std/CI':<28}: {report['xgb_auc_bootstrap_std']:.4f} [{report['xgb_auc_ci_lo']:.4f}, {report['xgb_auc_ci_hi']:.4f}]") print(f" {'StaticGNN AUC std/CI':<28}: {report['static_gnn_auc_bootstrap_std']:.4f} [{report['static_gnn_auc_ci_lo']:.4f}, {report['static_gnn_auc_ci_hi']:.4f}]") print(f" {'static_agg_auc std/CI':<28}: {report['static_agg_auc_bootstrap_std']:.4f} [{report['static_agg_auc_ci_lo']:.4f}, {report['static_agg_auc_ci_hi']:.4f}]") print(f" {'StaticGNN flip check':<28}: auc={report['static_gnn_roc']:.4f} flipped={report['static_gnn_auc_flipped']:.4f} zero_emb={report['static_gnn_zero_emb_frac']:.4f}") print(f" {'StaticGNN score means':<28}: pos={report['static_gnn_score_mean_pos']:.4f} neg={report['static_gnn_score_mean_neg']:.4f} std={report['static_gnn_score_std']:.4f}") print( f" {'StaticGNN runtime':<28}: " f"examples={report['static_gnn_matched_examples']} " f"cutoffs={report['static_gnn_unique_prefix_cutoffs']} " f"builds={report['static_gnn_graph_builds']} " f"hit_rate={report['static_gnn_cache_hit_rate']:.4f} " f"time={report['static_gnn_eval_time_sec']:.2f}s" ) print( f" {'StaticGNN train/test rt':<28}: " f"train_cutoffs={report['static_gnn_train_unique_prefix_cutoffs']} " f"test_cutoffs={report['static_gnn_test_unique_prefix_cutoffs']} " f"train_builds={report['static_gnn_train_graph_builds']} " f"test_builds={report['static_gnn_test_graph_builds']} " f"train={report['static_gnn_train_eval_time_sec']:.2f}s " f"test={report['static_gnn_test_eval_time_sec']:.2f}s" ) print(f" {'SeqGRU PR-AUC':<28}: {report['seqgru_pr_auc']:.4f} [informational]") print(f" {'SeqGRU shuffled ROC-AUC':<28}: {report['seqgru_shuffle_roc_auc']:.4f} [informational]") print(f" {'SeqGRU shuffled PR-AUC':<28}: {report['seqgru_shuffle_pr_auc']:.4f} [informational]") print(f" {'SeqGRU early stop':<28}: epoch={report['seqgru_best_epoch']} valid_roc={report['seqgru_best_valid_roc_auc']:.4f} valid_pr={report['seqgru_best_valid_pr_auc']:.4f}") print(f" {'SeqGRU shuffled stop':<28}: epoch={report['seqgru_shuffle_best_epoch']} valid_roc={report['seqgru_shuffle_best_valid_roc_auc']:.4f}") print(f" {'TGN PR/shuffled ROC':<28}: pr={report['tgn_pr_auc']:.4f} shuffled={report['tgn_shuffle_roc_auc']:.4f}") print(f" {'TGAT PR/shuffled ROC':<28}: pr={report['tgat_pr_auc']:.4f} shuffled={report['tgat_shuffle_roc_auc']:.4f}") print(f" {'DyRep PR/shuffled ROC':<28}: pr={report['dyrep_pr_auc']:.4f} shuffled={report['dyrep_shuffle_roc_auc']:.4f}") print(f" {'JODIE PR/shuffled ROC':<28}: pr={report['jodie_pr_auc']:.4f} shuffled={report['jodie_shuffle_roc_auc']:.4f}") print(f" {'P(label|hit>=1)':<28}: {report.get('p_label_given_hit', float('nan')):>7.4f} [informational]") print(f" {'P(label|hit=0)':<28}: {report.get('p_label_given_nohit', float('nan')):>7.4f} [informational]") print(f" {'accidental_benign_motif':<28}: {report.get('accidental_benign_motif', float('nan')):>7.4f} [informational]") print(f" {'KS mean/max':<28}: {ks_mean:>7.4f} / {ks_max:.4f}") print(f" {'delay min/mean/max':<28}: " f"{report.get('delay_min',float('nan')):.1f} / " f"{report.get('delay_mean',float('nan')):.1f} / " f"{report.get('delay_max',float('nan')):.1f}") print("=" * 72) if all_pass: print(" [GATE] All thresholds met. Proceeding to full run.") else: msg = "[GATE] One or more thresholds FAILED." if hard_abort: print(f" {msg} Aborting (hard gate).") sys.exit(1) else: print(f" {msg} Continuing as soft diagnostic.") return all_pass, report # --------------------------------------------------------------------------- # Model factory # --------------------------------------------------------------------------- def build_models(device: str = "cpu") -> Dict[str, TemporalModel]: return { "OracleMotif": OracleMotifWrapper(), "SeqGRU": SequenceGRUWrapper(hidden_dim=64, receiver_buckets=256, device=device), "TGN": TGNWrapper(memory_dim=64, time_dim=16, device=device), "TGAT": TGATWrapper(memory_dim=64, time_dim=8, num_heads=4, n_neighbors=10, device=device), "DyRep": DyRepWrapper(memory_dim=64, time_dim=8, device=device), "JODIE": JODIEWrapper(memory_dim=64, time_emb_dim=16, device=device), "StaticGNN": StaticGNNWrapper(hidden_dim=64, n_snapshots=10, device=device), "XGBoost": XGBoostWrapper(n_estimators=200, max_depth=6), } # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- def parse_seed_list(seed_string: str) -> List[int]: return [int(token.strip()) for token in seed_string.split(",") if token.strip()] def parse_args(): parser = argparse.ArgumentParser(description="Leakage-free UPI-Sim benchmark runner") parser.add_argument("--fast", action="store_true", help="Fast mode: 1 epoch and fewer checkpoints.") parser.add_argument("--seed", type=int, default=None, help="Run a single seed.") parser.add_argument( "--seeds", nargs="+", type=int, default=None, help="Space-separated seed list, e.g. --seeds 0 1 2 3 4", ) parser.add_argument( "--config", type=str, default="config/default.yaml", help="Path to config YAML.", ) parser.add_argument( "--device", type=str, default="cpu", help='Torch device ("cpu" or "cuda").', ) parser.add_argument( "--benchmark-mode", type=str, default=None, help='Benchmark mode override, e.g. "standard" or "temporal_twins".', ) parser.add_argument( "--experiments", nargs="+", type=str, default=None, help="Space-separated list of experiments to run, e.g. --experiments ood causal horizon mechanistic audit", ) return parser.parse_args() def main(): args = parse_args() # Support both space-separated (nargs=+) and comma-separated experiment lists if args.experiments is None: experiments_to_run = {"ood", "causal", "horizon", "mechanistic", "audit"} elif isinstance(args.experiments, list): experiments_to_run = set(args.experiments) else: experiments_to_run = {exp.strip() for exp in args.experiments.split(",") if exp.strip()} num_epochs = 1 if args.fast else 3 node_epochs = 60 if args.fast else 150 n_checkpoints = 4 if args.fast else 8 if args.seed is not None: seeds = [args.seed] elif args.seeds is not None: # Already parsed as List[int] via nargs="+" seeds = args.seeds else: seeds = [0, 1, 2, 3, 4] config = load_config(args.config) benchmark_mode = args.benchmark_mode or getattr(config, "benchmark_mode", "standard") print("=" * 60) print(" UPI-Sim Multi-Model Benchmark (Leakage-Free)") print(f" epochs={num_epochs} node_epochs={node_epochs} checkpoints={n_checkpoints}") print(f" seeds={seeds} device={args.device} mode={benchmark_mode}") print("=" * 60) raw_frames: Dict[str, List[pd.DataFrame]] = { "ood": [], "causal": [], "horizon": [], "mechanistic": [], "audit": [], } import torch is_twin_mode = benchmark_mode in ("temporal_twins", "temporal_twins_oracle_calib") calib_mode = benchmark_mode == "temporal_twins_oracle_calib" for seed in seeds: set_global_determinism(seed) print(f"\n[data] Generating datasets for seed={seed}...") df_easy, df_medium, df_hard = generate_all( config, seed=seed, benchmark_mode=benchmark_mode, ) print(f" Easy : {len(df_easy):,} events | fraud={df_easy['is_fraud'].mean():.3f}") print(f" Medium: {len(df_medium):,} events | fraud={df_medium['is_fraud'].mean():.3f}") print(f" Hard : {len(df_hard):,} events | fraud={df_hard['is_fraud'].mean():.3f}") if is_twin_mode: # Run validity check: hard-abort in calib mode, soft diagnostic otherwise. gate_df = build_gate_pool_from_frames([df_easy, df_medium, df_hard]) run_motif_validity_check( df=gate_df, config=config, seed=seed, device=args.device, num_epochs=num_epochs, node_epochs=node_epochs, n_checkpoints=n_checkpoints, hard_abort=calib_mode, benchmark_mode=benchmark_mode, fast_mode=args.fast, ) if "ood" in experiments_to_run: print(f"\n[seed={seed}] OOD generalisation") df_ood = run_ood_single( df_easy=df_easy, df_medium=df_medium, df_hard=df_hard, device=args.device, num_epochs=num_epochs, node_epochs=node_epochs, n_checkpoints=n_checkpoints, ) df_ood["seed"] = seed raw_frames["ood"].append(df_ood) if "causal" in experiments_to_run: print(f"\n[seed={seed}] Causal chronology shuffle") df_causal = run_causal_single( df_hard=df_hard, device=args.device, num_epochs=num_epochs, node_epochs=node_epochs, n_checkpoints=n_checkpoints, seed=seed, ) df_causal["seed"] = seed raw_frames["causal"].append(df_causal) if "horizon" in experiments_to_run: print(f"\n[seed={seed}] Horizon sweep") df_horizon = run_horizon_single( df_medium=df_medium, device=args.device, num_epochs=num_epochs, node_epochs=node_epochs, n_checkpoints=n_checkpoints, horizons=DEFAULT_HORIZONS, ) df_horizon["seed"] = seed raw_frames["horizon"].append(df_horizon) if "mechanistic" in experiments_to_run: print(f"\n[seed={seed}] Mechanistic correlation") df_mech = run_mechanistic_single( df_hard=df_hard, device=args.device, num_epochs=num_epochs, node_epochs=node_epochs, n_checkpoints=n_checkpoints, ) df_mech["seed"] = seed raw_frames["mechanistic"].append(df_mech) if "audit" in experiments_to_run: print(f"\n[seed={seed}] Temporal twins audit") df_audit = run_audit_single( df_hard=df_hard, device=args.device, num_epochs=num_epochs, node_epochs=node_epochs, n_checkpoints=n_checkpoints, seed=seed, benchmark_mode=benchmark_mode, ) df_audit["seed"] = seed raw_frames["audit"].append(df_audit) save_experiment_outputs(raw_frames, results_dir="results") print("\n" + "=" * 60) print(" All requested experiments completed.") print(" Saved raw + summary CSVs in results/") print(" Run: python -m plots.plot_results") print("=" * 60) if __name__ == "__main__": main()