diff --git "a/experiments/run_all.py" "b/experiments/run_all.py" new file mode 100644--- /dev/null +++ "b/experiments/run_all.py" @@ -0,0 +1,3321 @@ +""" +experiments/run_all.py +====================== +Leakage-free experiment runner for the UPI-Sim temporal fraud benchmark. + +Key protocol changes +-------------------- +- Strict prefix evaluation: models only see events up to cutoff t. +- Horizon-specific retraining: each horizon uses fresh model instances. +- Causal ablation trains/evaluates on globally shuffled chronology. +- XGBoost uses the real xgboost library with aligned node-level labels. +- All experiments support multi-seed aggregation with mean ± std outputs. +""" + +from __future__ import annotations + +import argparse +import hashlib +import os +import random +import sys +import time +from typing import Dict, Iterable, List, Sequence + +os.environ.setdefault("OMP_NUM_THREADS", "1") +os.environ.setdefault("MKL_NUM_THREADS", "1") +os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8") + +import numpy as np +import pandas as pd +import torch +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import average_precision_score, brier_score_loss, roc_auc_score +from xgboost import XGBClassifier + +_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if _ROOT not in sys.path: + sys.path.insert(0, _ROOT) + +from src.core.config_loader import load_config +from src.generators.user_generator import generate_users +from src.generators.transaction_generator import generate_transactions +from src.fraud.fraud_engine import FraudEngine, ORACLE_ONLY_COLS +from src.graph.graph_builder import build_edge_features +from src.risk.risk_engine import apply_risk_engine + +from models.base import TemporalModel +from models.dyrep import DyRepWrapper +from models.jodie import JODIEWrapper +from models.audit_oracle import AuditOracleWrapper, RawMotifOracleWrapper +from models.oracle_motif import OracleMotifWrapper +from models.sequence_gru import SequenceGRUWrapper +from models.static_gnn import StaticGNNWrapper +from models.tgat import TGATWrapper +from models.tgn_wrapper import TGNWrapper +from models.xgboost_model import XGBoostWrapper + +torch.set_num_threads(1) +if hasattr(torch, "set_num_interop_threads"): + try: + torch.set_num_interop_threads(1) + except RuntimeError: + pass + +# Oracle models that are allowed to receive unstripped audit columns +_ORACLE_MODEL_NAMES: frozenset = frozenset({"OracleMotif", "AuditOracle", "RawMotifOracle"}) + + +# --------------------------------------------------------------------------- +# Oracle / audit column stripping +# --------------------------------------------------------------------------- + +def strip_oracle_cols(df: pd.DataFrame) -> pd.DataFrame: + """Remove audit/oracle columns before passing data to learned baselines.""" + cols_to_drop = [c for c in df.columns if c in ORACLE_ONLY_COLS] + if cols_to_drop: + return df.drop(columns=cols_to_drop) + return df + + + +DEFAULT_HORIZONS = [0.01, 0.05, 0.10, 0.20] +DEFAULT_SEEDS = [0, 1, 2, 3, 4] +_TWIN_DIFFICULTY_USER_SEED_OFFSETS = {"easy": 11, "medium": 23, "hard": 37} +MODEL_ORDER = [ + "OracleMotif", + "SeqGRU", + "TGN", + "TGAT", + "DyRep", + "JODIE", + "StaticGNN", + "XGBoost", +] + + +def stable_int_hash(*parts: object, modulo: int = 2**32) -> int: + """Deterministic integer hash for seed derivation across Python processes.""" + seed_material = "::".join(map(str, parts)) + digest = hashlib.sha256(seed_material.encode("utf-8")).hexdigest() + return int(digest[:16], 16) % modulo + + +def seed_python_numpy(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + + +def set_global_determinism(seed: int) -> None: + seed_python_numpy(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + if hasattr(torch.backends, "cudnn"): + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + if hasattr(torch.backends.cudnn, "allow_tf32"): + torch.backends.cudnn.allow_tf32 = False + if hasattr(torch.backends, "cuda") and hasattr(torch.backends.cuda, "matmul"): + torch.backends.cuda.matmul.allow_tf32 = False + try: + torch.use_deterministic_algorithms(True) + except Exception: + pass + + +def derived_seed(base_seed: int, *parts: object, modulo: int = 2**31 - 1) -> int: + return int((int(base_seed) + stable_int_hash(*parts, modulo=modulo)) % modulo) + + +def _is_oracle_calib_mode(benchmark_mode: str) -> bool: + return benchmark_mode == "temporal_twins_oracle_calib" + + +def _oracle_metric_labels(benchmark_mode: str) -> dict[str, str]: + if _is_oracle_calib_mode(benchmark_mode): + return { + "audit": "AuditOracle", + "raw": "RawMotifOracle", + "table": "Oracle Debug Table", + } + return { + "audit": "MotifProbe", + "raw": "RawMotifProbe", + "table": "Probe Debug Table", + } + + +def _attach_probe_aliases(report: dict, benchmark_mode: str) -> None: + """Expose standard-mode probe names without breaking old oracle-key consumers.""" + labels = _oracle_metric_labels(benchmark_mode) + report["audit_metric_label"] = labels["audit"] + report["raw_metric_label"] = labels["raw"] + if _is_oracle_calib_mode(benchmark_mode): + return + + alias_map = { + "motif_probe_roc_auc": "audit_roc_auc", + "motif_probe_pair_sep": "audit_pair_sep", + "motif_probe_n_examples": "audit_n_examples", + "motif_probe_auc_bootstrap_std": "audit_auc_bootstrap_std", + "motif_probe_auc_ci_lo": "audit_auc_ci_lo", + "motif_probe_auc_ci_hi": "audit_auc_ci_hi", + "raw_motif_probe_roc_auc": "raw_roc_auc", + "raw_motif_probe_pair_sep": "raw_pair_sep", + "raw_motif_probe_n_examples": "raw_n_examples", + "raw_motif_probe_auc_bootstrap_std": "raw_auc_bootstrap_std", + "raw_motif_probe_auc_ci_lo": "raw_auc_ci_lo", + "raw_motif_probe_auc_ci_hi": "raw_auc_ci_hi", + } + for alias_key, source_key in alias_map.items(): + if source_key in report: + report[alias_key] = report[source_key] + + +# --------------------------------------------------------------------------- +# Data generation +# --------------------------------------------------------------------------- + +def generate_difficulty( + config, + users: pd.DataFrame, + difficulty: str, + seed: int, + time_offset: float = 0.0, + benchmark_mode: str = "standard", +) -> pd.DataFrame: + """Generate one difficulty slice with a global timestamp offset.""" + df = generate_transactions(users, config) + df = apply_risk_engine(df, users, config) + engine_seed = seed + stable_int_hash("FraudEngine", difficulty, benchmark_mode, modulo=10_000) + engine = FraudEngine( + seed=engine_seed, + difficulty=difficulty, + benchmark_mode=benchmark_mode, + ) + df = engine.apply(df) + df = df.sort_values("timestamp").reset_index(drop=True) + if benchmark_mode in ("temporal_twins", "temporal_twins_oracle_calib"): + diff_offset = {"easy": 0, "medium": 1_000_000, "hard": 2_000_000}[difficulty] + df["sender_id"] = df["sender_id"].astype(np.int64) + diff_offset + df["receiver_id"] = df["receiver_id"].astype(np.int64) + diff_offset + if "twin_pair_id" in df.columns: + df["twin_pair_id"] = df["twin_pair_id"].astype(np.int64) + valid = df["twin_pair_id"] >= 0 + df.loc[valid, "twin_pair_id"] = df.loc[valid, "twin_pair_id"] + diff_offset + if "template_id" in df.columns: + df["template_id"] = df["template_id"].astype(np.int64) + valid = df["template_id"] >= 0 + df.loc[valid, "template_id"] = df.loc[valid, "template_id"] + diff_offset + df["timestamp"] = df["timestamp"] + time_offset + return df + + +def generate_all(config, seed: int = 42, benchmark_mode: str = "standard"): + """Generate Easy/Medium/Hard datasets.""" + seed_python_numpy(seed) + + gap = 1_000.0 + if benchmark_mode in ("temporal_twins", "temporal_twins_oracle_calib"): + seed_python_numpy(seed + 11) + users_easy = generate_users(config) + seed_python_numpy(seed + 23) + users_medium = generate_users(config) + seed_python_numpy(seed + 37) + users_hard = generate_users(config) + else: + shared_users = generate_users(config) + users_easy = shared_users + users_medium = shared_users + users_hard = shared_users + + df_easy = generate_difficulty( + config, + users_easy, + "easy", + seed, + time_offset=0.0, + benchmark_mode=benchmark_mode, + ) + t_after_easy = float(df_easy["timestamp"].max()) + gap + + df_medium = generate_difficulty( + config, + users_medium, + "medium", + seed, + time_offset=t_after_easy, + benchmark_mode=benchmark_mode, + ) + t_after_medium = float(df_medium["timestamp"].max()) + gap + + df_hard = generate_difficulty( + config, + users_hard, + "hard", + seed, + time_offset=t_after_medium, + benchmark_mode=benchmark_mode, + ) + return df_easy, df_medium, df_hard + + +def generate_single_difficulty( + config, + difficulty: str, + seed: int = 42, + benchmark_mode: str = "standard", +) -> pd.DataFrame: + """Generate one difficulty slice using the same user-seed scheme as generate_all().""" + seed_python_numpy(seed) + if benchmark_mode in ("temporal_twins", "temporal_twins_oracle_calib"): + user_seed = seed + _TWIN_DIFFICULTY_USER_SEED_OFFSETS[difficulty] + seed_python_numpy(user_seed) + users = generate_users(config) + else: + users = generate_users(config) + return generate_difficulty( + config, + users, + difficulty, + seed, + time_offset=0.0, + benchmark_mode=benchmark_mode, + ) + + +# --------------------------------------------------------------------------- +# Metrics +# --------------------------------------------------------------------------- + +def compute_ece(y_true: np.ndarray, y_prob: np.ndarray, n_bins: int = 10) -> float: + bins = np.linspace(0.0, 1.0, n_bins + 1) + ece = 0.0 + for lo, hi in zip(bins[:-1], bins[1:]): + mask = (y_prob >= lo) & (y_prob < hi if hi < 1.0 else y_prob <= hi) + if not mask.any(): + continue + frac = float(mask.mean()) + avg_conf = float(y_prob[mask].mean()) + avg_acc = float(y_true[mask].mean()) + ece += frac * abs(avg_conf - avg_acc) + return float(ece) + + +def safe_roc_auc(y_true: np.ndarray, y_prob: np.ndarray) -> float: + if len(np.unique(y_true)) < 2: + return 0.5 + return float(roc_auc_score(y_true, y_prob)) + + +def safe_pr_auc(y_true: np.ndarray, y_prob: np.ndarray) -> float: + positives = float(np.sum(y_true == 1)) + negatives = float(np.sum(y_true == 0)) + if positives == 0.0: + return 0.0 + if negatives == 0.0: + return 1.0 + return float(average_precision_score(y_true, y_prob)) + + +def compute_metrics(y_true: np.ndarray, y_prob: np.ndarray) -> dict: + y_true = np.asarray(y_true, dtype=np.float32) + y_prob = np.nan_to_num(np.asarray(y_prob, dtype=np.float32), nan=0.5, posinf=1.0, neginf=0.0) + y_prob = np.clip(y_prob, 0.0, 1.0) + + return { + "roc_auc": safe_roc_auc(y_true, y_prob), + "pr_auc": safe_pr_auc(y_true, y_prob), + "brier": float(brier_score_loss(y_true, y_prob)), + "ece": compute_ece(y_true, y_prob), + } + + +def safe_pearson(x: np.ndarray, y: np.ndarray) -> float: + x = np.asarray(x, dtype=np.float32) + y = np.asarray(y, dtype=np.float32) + if len(x) == 0 or len(y) == 0: + return 0.0 + if np.std(x) < 1e-8 or np.std(y) < 1e-8: + return 0.0 + return float(np.corrcoef(x, y)[0, 1]) + + +def build_node_audit_table(df: pd.DataFrame) -> pd.DataFrame: + df = df.sort_values("timestamp").reset_index(drop=True).copy() + df["_dt"] = df.groupby("sender_id")["timestamp"].diff().fillna(0.0) + df["_phase"] = df["timestamp"] % 86400.0 + df["_burst"] = (df["_dt"] > 0.0) & (df["_dt"] < 600.0) + df["_quiet"] = df["_dt"] > 3600.0 + + grp = df.groupby("sender_id", sort=False) + node_df = pd.DataFrame({ + "txn_count": grp["sender_id"].count(), + "receiver_count": grp["receiver_id"].nunique(), + "retry_count": grp["is_retry"].sum() if "is_retry" in df.columns else 0.0, + "failed_count": grp["failed"].sum() if "failed" in df.columns else 0.0, + "burst_count": grp["_burst"].sum(), + "quiet_count": grp["_quiet"].sum(), + "dt_mean": grp["_dt"].mean(), + "dt_std": grp["_dt"].std().fillna(0.0), + "amount_mean": grp["amount"].mean(), + "amount_std": grp["amount"].std().fillna(0.0), + "phase_std": grp["_phase"].std().fillna(0.0), + }) + + recv_counts = ( + df.groupby(["sender_id", "receiver_id"]) + .size() + .reset_index(name="_n") + ) + recv_counts["_tot"] = recv_counts.groupby("sender_id")["_n"].transform("sum") + recv_counts["_p"] = recv_counts["_n"] / recv_counts["_tot"] + recv_counts["_h"] = -recv_counts["_p"] * np.log2(recv_counts["_p"] + 1e-9) + node_df["recv_entropy"] = recv_counts.groupby("sender_id")["_h"].sum() + + if "twin_pair_id" in df.columns: + node_df["twin_pair_id"] = grp["twin_pair_id"].first().astype(np.int32) + else: + node_df["twin_pair_id"] = -1 + + if "twin_label" in df.columns: + node_df["label"] = grp["twin_label"].max().astype(np.int32) + else: + node_df["label"] = grp["is_fraud"].max().astype(np.int32) + + return node_df.fillna(0.0).reset_index() + + +def with_local_event_idx(df: pd.DataFrame) -> pd.DataFrame: + out = df.sort_values("timestamp").reset_index(drop=True).copy() + out["local_event_idx"] = ( + out.groupby("sender_id").cumcount().astype(np.int32) + ) + return out + + +def build_matched_control_tables( + df: pd.DataFrame, +) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: + """Build matched fraud/benign evaluation examples at the same local index k.""" + required = {"twin_pair_id", "twin_role", "twin_label", "label_event_idx"} + if not required.issubset(df.columns): + empty = pd.DataFrame() + return empty, empty, empty + + twin_df = with_local_event_idx(df[df["twin_pair_id"] >= 0].copy()) + if twin_df.empty: + empty = pd.DataFrame() + return empty, empty, empty + + sender_meta = ( + twin_df.groupby("sender_id") + .agg( + twin_pair_id=("twin_pair_id", "first"), + twin_role=("twin_role", "first"), + twin_label=("twin_label", "max"), + template_id=("template_id", "first") if "template_id" in twin_df.columns else ("twin_pair_id", "first"), + total_txn_count=("sender_id", "size"), + sender_start_time=("timestamp", "min"), + motif_hit_count=("motif_hit_count", "max") if "motif_hit_count" in twin_df.columns else ("sender_id", "size"), + ) + .reset_index() + ) + if "motif_hit_count" not in twin_df.columns: + sender_meta["motif_hit_count"] = 0 + + pair_rows: list[dict] = [] + example_rows: list[dict] = [] + pair_count_rows: list[dict] = [] + pair_event_id = 0 + + sender_groups = { + int(sender_id): group.reset_index(drop=True).copy() + for sender_id, group in twin_df.groupby("sender_id", sort=False) + } + + for pair_id, pair_meta in sender_meta.groupby("twin_pair_id", sort=True): + if len(pair_meta) != 2 or set(pair_meta["twin_role"]) != {"fraud", "benign"}: + continue + + fraud_meta = pair_meta[pair_meta["twin_role"] == "fraud"].iloc[0] + benign_meta = pair_meta[pair_meta["twin_role"] == "benign"].iloc[0] + fraud_sender = int(fraud_meta["sender_id"]) + benign_sender = int(benign_meta["sender_id"]) + template_id = int(fraud_meta["template_id"]) + fraud_total = int(fraud_meta["total_txn_count"]) + benign_total = int(benign_meta["total_txn_count"]) + + pair_count_rows.append({ + "twin_pair_id": int(pair_id), + "template_id": template_id, + "fraud_sender_id": fraud_sender, + "benign_sender_id": benign_sender, + "fraud_total_txn_count": fraud_total, + "benign_total_txn_count": benign_total, + "pair_total_txn_count_diff": abs(fraud_total - benign_total), + "fraud_motif_hit_count": int(fraud_meta["motif_hit_count"]), + "benign_motif_hit_count": int(benign_meta["motif_hit_count"]), + }) + + fraud_group = sender_groups[fraud_sender] + benign_group = sender_groups[benign_sender] + benign_by_idx = benign_group.set_index("local_event_idx", drop=False) + fraud_positives = fraud_group[ + (fraud_group["is_fraud"] == 1) & (fraud_group["label_event_idx"] >= 0) + ].copy() + + for row in fraud_positives.itertuples(index=False): + k = int(row.label_event_idx) + if k not in benign_by_idx.index: + continue + + benign_row = benign_by_idx.loc[k] + if isinstance(benign_row, pd.DataFrame): + benign_row = benign_row.iloc[0] + + fraud_age = float(row.timestamp - fraud_meta["sender_start_time"]) + benign_age = float(benign_row["timestamp"] - benign_meta["sender_start_time"]) + prefix_txn_count = k + 1 + + pair_rows.append({ + "pair_event_id": pair_event_id, + "twin_pair_id": int(pair_id), + "template_id": template_id, + "fraud_sender_id": fraud_sender, + "benign_sender_id": benign_sender, + "fraud_label_event_idx": k, + "benign_eval_event_idx": int(benign_row["local_event_idx"]), + "fraud_eval_timestamp": float(row.timestamp), + "benign_eval_timestamp": float(benign_row["timestamp"]), + "fraud_active_age": fraud_age, + "benign_active_age": benign_age, + "active_age_diff": abs(fraud_age - benign_age), + "timestamp_diff": abs(float(row.timestamp) - float(benign_row["timestamp"])), + "prefix_txn_count": prefix_txn_count, + "fraud_total_txn_count": fraud_total, + "benign_total_txn_count": benign_total, + "pair_total_txn_count_diff": abs(fraud_total - benign_total), + "fraud_motif_hit_count": int(fraud_meta["motif_hit_count"]), + "benign_motif_hit_count": int(benign_meta["motif_hit_count"]), + "label_delay": int(row.label_delay) if hasattr(row, "label_delay") else -1, + }) + + common = { + "pair_event_id": pair_event_id, + "twin_pair_id": int(pair_id), + "template_id": template_id, + "eval_local_event_idx": k, + "prefix_txn_count": prefix_txn_count, + } + example_rows.append({ + **common, + "sender_id": fraud_sender, + "label": 1, + "twin_role": "fraud", + "matched_sender_id": benign_sender, + "total_txn_count": fraud_total, + "eval_timestamp": float(row.timestamp), + # The simulator has no separate account-creation time, so + # account_age equals active_age for twin-control audits. + "account_age": fraud_age, + "active_age": fraud_age, + }) + example_rows.append({ + **common, + "sender_id": benign_sender, + "label": 0, + "twin_role": "benign", + "matched_sender_id": fraud_sender, + "total_txn_count": benign_total, + "eval_timestamp": float(benign_row["timestamp"]), + "account_age": benign_age, + "active_age": benign_age, + }) + pair_event_id += 1 + + return ( + pd.DataFrame(example_rows), + pd.DataFrame(pair_rows), + pd.DataFrame(pair_count_rows), + ) + + +def _sender_prefix_feature_row(prefix: pd.DataFrame) -> dict: + prefix = prefix.sort_values("timestamp").reset_index(drop=True) + timestamps = prefix["timestamp"].to_numpy(dtype=np.float64) + dts = np.diff(timestamps, prepend=timestamps[0]) if len(prefix) else np.zeros(0, dtype=np.float64) + dts = np.maximum(dts, 0.0) + phase = timestamps % 86400.0 if len(prefix) else np.zeros(0, dtype=np.float64) + burst = ((dts > 0.0) & (dts < 600.0)).astype(np.float32) + quiet = (dts > 3600.0).astype(np.float32) + + recv_counts = prefix["receiver_id"].value_counts().to_numpy(dtype=np.float64) + recv_p = recv_counts / max(float(recv_counts.sum()), 1.0) + recv_entropy = float(-np.sum(recv_p * np.log2(recv_p + 1e-9))) if len(recv_counts) else 0.0 + + return { + "txn_count": float(len(prefix)), + "txn_cnt10_last": float(min(len(prefix), 10)), + "receiver_count": float(prefix["receiver_id"].nunique()) if len(prefix) else 0.0, + "retry_count": float(prefix["is_retry"].sum()) if "is_retry" in prefix.columns else 0.0, + "failed_count": float(prefix["failed"].sum()) if "failed" in prefix.columns else 0.0, + "burst_count": float(burst.sum()), + "quiet_count": float(quiet.sum()), + "amount_mean": float(prefix["amount"].mean()) if len(prefix) else 0.0, + "amount_std": float(prefix["amount"].std(ddof=1)) if len(prefix) > 1 else 0.0, + "amount_max": float(prefix["amount"].max()) if len(prefix) else 0.0, + "td_mean": float(dts.mean()) if len(dts) else 0.0, + "td_std": float(dts.std(ddof=1)) if len(dts) > 1 else 0.0, + "dt_mean": float(dts.mean()) if len(dts) else 0.0, + "dt_std": float(dts.std(ddof=1)) if len(dts) > 1 else 0.0, + "phase_std": float(np.std(phase, ddof=1)) if len(phase) > 1 else 0.0, + "recv_entropy": recv_entropy, + "fail_rate": float(prefix["failed"].mean()) if "failed" in prefix.columns and len(prefix) else 0.0, + "retry_rate": float(prefix["is_retry"].mean()) if "is_retry" in prefix.columns and len(prefix) else 0.0, + "pair_freq_mean": float(prefix["pair_freq"].mean()) if "pair_freq" in prefix.columns and len(prefix) else 0.0, + } + + +def build_matched_prefix_feature_table( + df: pd.DataFrame, + examples: pd.DataFrame, +) -> pd.DataFrame: + if examples.empty: + return pd.DataFrame() + + indexed_df = with_local_event_idx(df) + sender_groups = { + int(sender_id): group.reset_index(drop=True).copy() + for sender_id, group in indexed_df.groupby("sender_id", sort=False) + } + + rows: list[dict] = [] + for example in examples.itertuples(index=False): + sender_id = int(example.sender_id) + end_idx = int(example.eval_local_event_idx) + sender_prefix = sender_groups[sender_id] + prefix = sender_prefix.iloc[: end_idx + 1].copy() + rows.append({ + "pair_event_id": int(example.pair_event_id), + "twin_pair_id": int(example.twin_pair_id), + "template_id": int(example.template_id), + "sender_id": sender_id, + "label": int(example.label), + "eval_local_event_idx": int(example.eval_local_event_idx), + "prefix_txn_count": int(example.prefix_txn_count), + "total_txn_count": int(example.total_txn_count), + "eval_timestamp": float(example.eval_timestamp), + "account_age": float(example.account_age), + "active_age": float(example.active_age), + **_sender_prefix_feature_row(prefix), + }) + + return pd.DataFrame(rows).fillna(0.0) + + +def report_matched_control_audits( + test_examples: pd.DataFrame, + test_pair_rows: pd.DataFrame, + test_pair_counts: pd.DataFrame, +) -> dict: + if test_examples.empty: + return {} + + y = test_examples["label"].to_numpy(dtype=np.float32) + audit = { + "pair_total_txn_count_diff_mean": float(test_pair_counts["pair_total_txn_count_diff"].mean()) if not test_pair_counts.empty else 0.0, + "pair_total_txn_count_diff_max": float(test_pair_counts["pair_total_txn_count_diff"].max()) if not test_pair_counts.empty else 0.0, + "auc_total_txn_count": safe_roc_auc(y, test_examples["total_txn_count"].to_numpy(dtype=np.float32)), + "auc_local_event_idx": safe_roc_auc(y, test_examples["eval_local_event_idx"].to_numpy(dtype=np.float32)), + "auc_prefix_txn_count": safe_roc_auc(y, test_examples["prefix_txn_count"].to_numpy(dtype=np.float32)), + "auc_timestamp": safe_roc_auc(y, test_examples["eval_timestamp"].to_numpy(dtype=np.float32)), + "auc_account_age": safe_roc_auc(y, test_examples["account_age"].to_numpy(dtype=np.float32)), + "auc_active_age": safe_roc_auc(y, test_examples["active_age"].to_numpy(dtype=np.float32)), + "fraud_label_event_idx_mean": float(test_pair_rows["fraud_label_event_idx"].mean()) if not test_pair_rows.empty else 0.0, + "fraud_label_event_idx_max": float(test_pair_rows["fraud_label_event_idx"].max()) if not test_pair_rows.empty else 0.0, + "benign_eval_event_idx_mean": float(test_pair_rows["benign_eval_event_idx"].mean()) if not test_pair_rows.empty else 0.0, + "benign_eval_event_idx_max": float(test_pair_rows["benign_eval_event_idx"].max()) if not test_pair_rows.empty else 0.0, + "pair_event_idx_diff_mean": float((test_pair_rows["fraud_label_event_idx"] - test_pair_rows["benign_eval_event_idx"]).abs().mean()) if not test_pair_rows.empty else 0.0, + "pair_event_idx_diff_max": float((test_pair_rows["fraud_label_event_idx"] - test_pair_rows["benign_eval_event_idx"]).abs().max()) if not test_pair_rows.empty else 0.0, + "pair_active_age_diff_mean": float(test_pair_rows["active_age_diff"].mean()) if not test_pair_rows.empty else 0.0, + "pair_active_age_diff_max": float(test_pair_rows["active_age_diff"].max()) if not test_pair_rows.empty else 0.0, + "pair_timestamp_diff_mean": float(test_pair_rows["timestamp_diff"].mean()) if not test_pair_rows.empty else 0.0, + "pair_timestamp_diff_max": float(test_pair_rows["timestamp_diff"].max()) if not test_pair_rows.empty else 0.0, + "benign_motif_hit_rate": float((test_pair_counts["benign_motif_hit_count"] > 0).mean()) if not test_pair_counts.empty else 0.0, + "benign_motif_hit_pairs": int((test_pair_counts["benign_motif_hit_count"] > 0).sum()) if not test_pair_counts.empty else 0, + "matched_control_examples": int(len(test_examples)), + "matched_control_pair_events": int(len(test_pair_rows)), + } + + print("\n--- Matched-Control Shortcut Audit ---") + for key in ( + "pair_total_txn_count_diff_mean", + "pair_total_txn_count_diff_max", + "auc_total_txn_count", + "auc_local_event_idx", + "auc_prefix_txn_count", + "auc_timestamp", + "auc_account_age", + "auc_active_age", + "benign_motif_hit_rate", + "benign_motif_hit_pairs", + ): + print(f" {key:<30}: {audit[key]}") + + if not test_pair_rows.empty: + print("\n label_event_idx distribution (fraud twins):") + print(test_pair_rows["fraud_label_event_idx"].describe().to_string()) + print("\n pseudo-label idx distribution (benign twins):") + print(test_pair_rows["benign_eval_event_idx"].describe().to_string()) + print("\n per-pair fraud-vs-benign evaluation indices:") + cols = [ + "twin_pair_id", + "fraud_label_event_idx", + "benign_eval_event_idx", + "active_age_diff", + "timestamp_diff", + ] + print(test_pair_rows[cols].head(20).to_string(index=False)) + + return audit + + +def bootstrap_auc_summary( + y_true: np.ndarray, + y_score: np.ndarray, + seed: int, + n_bootstrap: int = 200, +) -> dict: + y_true = np.asarray(y_true, dtype=np.float32) + y_score = np.asarray(y_score, dtype=np.float32) + if len(y_true) == 0 or len(np.unique(y_true)) < 2: + return { + "bootstrap_std": float("nan"), + "ci_lo": float("nan"), + "ci_hi": float("nan"), + "n_bootstrap": 0, + } + + rng = np.random.default_rng(seed) + aucs: list[float] = [] + n = len(y_true) + for _ in range(n_bootstrap): + idx = rng.integers(0, n, size=n) + sample_y = y_true[idx] + if len(np.unique(sample_y)) < 2: + continue + aucs.append(safe_roc_auc(sample_y, y_score[idx])) + + if not aucs: + return { + "bootstrap_std": float("nan"), + "ci_lo": float("nan"), + "ci_hi": float("nan"), + "n_bootstrap": 0, + } + + auc_arr = np.asarray(aucs, dtype=np.float32) + return { + "bootstrap_std": float(np.std(auc_arr, ddof=1)) if len(auc_arr) > 1 else 0.0, + "ci_lo": float(np.quantile(auc_arr, 0.025)), + "ci_hi": float(np.quantile(auc_arr, 0.975)), + "n_bootstrap": int(len(auc_arr)), + } + + +def make_auc_result( + y_true: np.ndarray, + y_score: np.ndarray, + seed: int, + extra: dict | None = None, +) -> dict: + y_true = np.asarray(y_true, dtype=np.float32) + y_score = np.asarray(y_score, dtype=np.float32) + result = { + "auc": safe_roc_auc(y_true, y_score), + "y_true": y_true, + "y_score": y_score, + "n_examples": int(len(y_true)), + "n_pos": int(np.sum(y_true == 1)), + "n_neg": int(np.sum(y_true == 0)), + } + result.update(bootstrap_auc_summary(y_true, y_score, seed=seed)) + if extra: + result.update(extra) + return result + + +def attach_auc_result(report: dict, prefix: str, result: dict) -> None: + report[f"{prefix}_roc_auc"] = float(result["auc"]) + report[f"{prefix}_n_examples"] = int(result["n_examples"]) + report[f"{prefix}_n_pos"] = int(result["n_pos"]) + report[f"{prefix}_n_neg"] = int(result["n_neg"]) + report[f"{prefix}_auc_bootstrap_std"] = float(result["bootstrap_std"]) + report[f"{prefix}_auc_ci_lo"] = float(result["ci_lo"]) + report[f"{prefix}_auc_ci_hi"] = float(result["ci_hi"]) + + +def _standardize_train_test( + train_df: pd.DataFrame, + test_df: pd.DataFrame, + feature_cols: Sequence[str], +) -> tuple[np.ndarray, np.ndarray]: + x_train = train_df[list(feature_cols)].to_numpy(dtype=np.float32) + x_test = test_df[list(feature_cols)].to_numpy(dtype=np.float32) + mean = x_train.mean(axis=0, keepdims=True) + std = x_train.std(axis=0, keepdims=True) + 1e-6 + return (x_train - mean) / std, (x_test - mean) / std + + +def compute_matched_static_aggregate_auc( + train_features: pd.DataFrame, + test_features: pd.DataFrame, + seed: int, + verbose: bool = True, +) -> dict: + feature_cols = [ + "txn_count", + "receiver_count", + "retry_count", + "failed_count", + "burst_count", + "quiet_count", + "dt_mean", + "dt_std", + "amount_mean", + "amount_std", + "phase_std", + "recv_entropy", + ] + if train_features.empty or test_features.empty: + return make_auc_result(np.zeros(0, dtype=np.float32), np.zeros(0, dtype=np.float32), seed=seed) + if train_features["label"].nunique() < 2 or test_features["label"].nunique() < 2: + y_test = test_features["label"].to_numpy(dtype=np.float32) + probs = np.full(len(y_test), 0.5, dtype=np.float32) + return make_auc_result(y_test, probs, seed=seed) + + x_train, x_test = _standardize_train_test(train_features, test_features, feature_cols) + clf = LogisticRegression( + max_iter=2000, + class_weight="balanced", + random_state=42, + solver="liblinear", + ) + clf.fit(x_train, train_features["label"].to_numpy(dtype=np.int32)) + probs = clf.predict_proba(x_test)[:, 1] + y_test = test_features["label"].to_numpy(dtype=np.float32) + + if verbose: + coefs = np.abs(clf.coef_[0]) + ranked = np.argsort(coefs)[::-1] + print("\n Top matched static aggregate predictors:") + for rank_i in ranked[:5]: + print(f" {feature_cols[rank_i]:<20}: |coef|={coefs[rank_i]:.4f}") + + return make_auc_result( + y_test, + probs.astype(np.float32), + seed=seed, + ) + + +def compute_matched_xgboost_auc( + train_features: pd.DataFrame, + test_features: pd.DataFrame, + seed: int, +) -> dict: + feature_cols = [ + "txn_count", + "txn_cnt10_last", + "amount_mean", + "amount_std", + "amount_max", + "td_mean", + "td_std", + "fail_rate", + "retry_rate", + "recv_entropy", + "pair_freq_mean", + ] + if train_features.empty or test_features.empty: + return make_auc_result(np.zeros(0, dtype=np.float32), np.zeros(0, dtype=np.float32), seed=seed) + y_train = train_features["label"].to_numpy(dtype=np.int32) + y_test = test_features["label"].to_numpy(dtype=np.int32) + if len(np.unique(y_train)) < 2 or len(np.unique(y_test)) < 2: + probs = np.full(len(y_test), 0.5, dtype=np.float32) + return make_auc_result(y_test.astype(np.float32), probs, seed=seed) + + x_train = train_features[feature_cols].to_numpy(dtype=np.float32) + x_test = test_features[feature_cols].to_numpy(dtype=np.float32) + scale_pos_weight = max(1.0, float((y_train == 0).sum()) / max(float((y_train == 1).sum()), 1.0)) + model = XGBClassifier( + n_estimators=200, + max_depth=6, + learning_rate=0.05, + objective="binary:logistic", + eval_metric="logloss", + scale_pos_weight=scale_pos_weight, + random_state=42, + verbosity=0, + n_jobs=1, + tree_method="exact", + ) + model.fit(x_train, y_train) + probs = model.predict_proba(x_test)[:, 1] + + importances = model.feature_importances_ + ranked = np.argsort(importances)[::-1] + print(" [Matched XGBoost] Top-5 feature importances:") + for idx in ranked[:5]: + print(f" {feature_cols[idx]:<20}: {importances[idx]:.4f}") + + return make_auc_result( + y_test.astype(np.float32), + probs.astype(np.float32), + seed=seed, + ) + + +def _build_example_prefix( + df_full: pd.DataFrame, + sender_id: int, + eval_local_event_idx: int, + eval_timestamp: float, +) -> pd.DataFrame: + prefix = df_full[df_full["timestamp"] <= eval_timestamp].copy() + if "local_event_idx" not in prefix.columns: + prefix = with_local_event_idx(prefix) + sender_mask = prefix["sender_id"] == sender_id + if sender_mask.any(): + prefix = prefix[(~sender_mask) | (prefix["local_event_idx"] <= eval_local_event_idx)].copy() + return prefix.sort_values("timestamp").reset_index(drop=True) + + +def build_static_gnn_example_embeddings( + model: StaticGNNWrapper, + df_full: pd.DataFrame, + examples: pd.DataFrame, +) -> tuple[np.ndarray, dict]: + if examples.empty: + return np.zeros((0, model.hidden_dim), dtype=np.float32), { + "matched_examples": 0, + "unique_prefix_cutoffs": 0, + "graph_builds": 0, + "cache_hit_rate": float("nan"), + "eval_time_sec": 0.0, + } + + clean_full = strip_oracle_cols( + df_full.sort_values("timestamp").reset_index(drop=True) + ) + start = time.perf_counter() + + sender_ids = clean_full["sender_id"].to_numpy(dtype=np.int64) + receiver_ids = clean_full["receiver_id"].to_numpy(dtype=np.int64) + timestamps = clean_full["timestamp"].to_numpy(dtype=np.float64) + edge_feats = build_edge_features(clean_full).astype(np.float32) + ns = model._norm_stats + edge_feats = (edge_feats - ns["ea_mean"]) / ns["ea_std"] + + max_sender = int(sender_ids.max()) if len(sender_ids) else 0 + max_receiver = int(receiver_ids.max()) if len(receiver_ids) else 0 + n_nodes = max(max(max_sender, max_receiver) + 1, model._n_nodes) + feat_sum = np.zeros((n_nodes, edge_feats.shape[1]), dtype=np.float32) + feat_count = np.zeros(n_nodes, dtype=np.float32) + node_feat = np.zeros((n_nodes, edge_feats.shape[1]), dtype=np.float32) + + device = model.device + x_t = torch.zeros((n_nodes, edge_feats.shape[1]), dtype=torch.float32, device=device) + edge_index_full = torch.tensor( + np.vstack([sender_ids, receiver_ids]), + dtype=torch.long, + device=device, + ) + + examples_reset = examples.reset_index(drop=True).copy() + grouped = examples_reset.groupby("eval_timestamp", sort=True).indices + grouped_items = sorted( + [(float(ts), idxs) for ts, idxs in grouped.items()], + key=lambda item: item[0], + ) + ordered_cutoffs = [item[0] for item in grouped_items] + cutoff_ends = np.searchsorted(timestamps, np.asarray(ordered_cutoffs, dtype=np.float64), side="right") + out = np.zeros((len(examples_reset), model.hidden_dim), dtype=np.float32) + + prev_end = 0 + graph_builds = 0 + for (cutoff, row_indices), end_idx in zip(grouped_items, cutoff_ends.tolist()): + if end_idx > prev_end: + batch_senders = sender_ids[prev_end:end_idx] + batch_feats = edge_feats[prev_end:end_idx] + np.add.at(feat_sum, batch_senders, batch_feats) + np.add.at(feat_count, batch_senders, 1.0) + changed_nodes = np.unique(batch_senders) + node_feat[changed_nodes] = feat_sum[changed_nodes] / feat_count[changed_nodes, None] + changed_t = torch.tensor(changed_nodes, dtype=torch.long, device=device) + x_t[changed_t] = torch.tensor(node_feat[changed_nodes], dtype=torch.float32, device=device) + prev_end = end_idx + + edge_index = edge_index_full[:, :end_idx] + model._encoder.eval() + with torch.no_grad(): + prefix_emb = model._encoder(x_t, edge_index) + + graph_builds += 1 + sender_batch = examples_reset.loc[row_indices, "sender_id"].to_numpy(dtype=np.int64) + sender_t = torch.tensor(sender_batch, dtype=torch.long, device=device) + out[row_indices] = prefix_emb[sender_t].detach().cpu().numpy().astype(np.float32) + + matched_examples = int(len(examples_reset)) + unique_cutoffs = int(len(ordered_cutoffs)) + hits = max(0, matched_examples - graph_builds) + diagnostics = { + "matched_examples": matched_examples, + "unique_prefix_cutoffs": unique_cutoffs, + "graph_builds": int(graph_builds), + "cache_hit_rate": float(hits / matched_examples) if matched_examples > 0 else float("nan"), + "eval_time_sec": float(time.perf_counter() - start), + } + return out.astype(np.float32), diagnostics + + +def compute_matched_static_gnn_auc( + df_train: pd.DataFrame, + df_test: pd.DataFrame, + train_examples: pd.DataFrame, + test_examples: pd.DataFrame, + device: str, + num_epochs: int, + seed: int, +) -> dict: + if train_examples.empty or test_examples.empty: + return make_auc_result(np.zeros(0, dtype=np.float32), np.zeros(0, dtype=np.float32), seed=seed) + if train_examples["label"].nunique() < 2 or test_examples["label"].nunique() < 2: + y_test = test_examples["label"].to_numpy(dtype=np.float32) + probs = np.full(len(y_test), 0.5, dtype=np.float32) + return make_auc_result(y_test, probs, seed=seed) + + static_seed = derived_seed(seed, "StaticGNN", "matched_prefix") + set_global_determinism(static_seed) + model = StaticGNNWrapper(hidden_dim=64, n_snapshots=10, device=device) + model.fit(strip_oracle_cols(df_train), num_epochs=num_epochs) + + eval_start = time.perf_counter() + train_emb, train_diag = build_static_gnn_example_embeddings(model, df_train, train_examples) + full_test_df = ( + pd.concat([df_train, df_test], ignore_index=True) + .sort_values("timestamp") + .reset_index(drop=True) + ) + test_emb, test_diag = build_static_gnn_example_embeddings(model, full_test_df, test_examples) + y_train = train_examples["label"].to_numpy(dtype=np.int32) + y_test = test_examples["label"].to_numpy(dtype=np.int32) + + mean = train_emb.mean(axis=0, keepdims=True) + std = train_emb.std(axis=0, keepdims=True) + 1e-6 + train_emb = (train_emb - mean) / std + test_emb = (test_emb - mean) / std + + clf = LogisticRegression( + max_iter=2000, + class_weight="balanced", + random_state=42, + solver="liblinear", + ) + clf.fit(train_emb, y_train) + probs = clf.predict_proba(test_emb)[:, 1] + return make_auc_result( + y_test.astype(np.float32), + probs.astype(np.float32), + seed=seed, + extra={ + "auc_flipped": safe_roc_auc(y_test.astype(np.float32), (1.0 - probs).astype(np.float32)), + "score_mean_pos": float(probs[y_test == 1].mean()) if np.any(y_test == 1) else float("nan"), + "score_mean_neg": float(probs[y_test == 0].mean()) if np.any(y_test == 0) else float("nan"), + "score_std": float(np.std(probs)), + "zero_emb_frac": float(np.mean(np.linalg.norm(test_emb, axis=1) < 1e-8)), + "train_examples": int(len(train_examples)), + "test_examples": int(len(test_examples)), + "matched_examples": int(train_diag["matched_examples"] + test_diag["matched_examples"]), + "unique_prefix_cutoffs": int(train_diag["unique_prefix_cutoffs"] + test_diag["unique_prefix_cutoffs"]), + "graph_builds": int(train_diag["graph_builds"] + test_diag["graph_builds"]), + "cache_hit_rate": float( + ( + max(0, train_diag["matched_examples"] - train_diag["graph_builds"]) + + max(0, test_diag["matched_examples"] - test_diag["graph_builds"]) + ) + / max(1, train_diag["matched_examples"] + test_diag["matched_examples"]) + ), + "eval_time_sec": float(time.perf_counter() - eval_start), + "train_unique_prefix_cutoffs": int(train_diag["unique_prefix_cutoffs"]), + "test_unique_prefix_cutoffs": int(test_diag["unique_prefix_cutoffs"]), + "train_graph_builds": int(train_diag["graph_builds"]), + "test_graph_builds": int(test_diag["graph_builds"]), + "train_eval_time_sec": float(train_diag["eval_time_sec"]), + "test_eval_time_sec": float(test_diag["eval_time_sec"]), + }, + ) + + +def compute_matched_seqgru_metrics( + df_train: pd.DataFrame, + df_test: pd.DataFrame, + train_examples: pd.DataFrame, + test_examples: pd.DataFrame, + device: str, + seed: int, + max_epochs: int, + hidden_dim: int = 96, + receiver_buckets: int = 512, +) -> dict: + if train_examples.empty or test_examples.empty: + empty = make_auc_result(np.zeros(0, dtype=np.float32), np.zeros(0, dtype=np.float32), seed=seed) + return { + "clean": empty, + "shuffled": empty, + "delta": float("nan"), + "clean_fit": {}, + "shuffled_fit": {}, + } + + clean_train_df = strip_oracle_cols(df_train) + clean_test_df = strip_oracle_cols(df_test) + y_test = test_examples["label"].to_numpy(dtype=np.float32) + if train_examples["label"].nunique() < 2 or test_examples["label"].nunique() < 2: + flat_probs = np.full(len(y_test), 0.5, dtype=np.float32) + flat = make_auc_result(y_test, flat_probs, seed=seed) + flat["pr_auc"] = compute_metrics(y_test, flat_probs)["pr_auc"] + return { + "clean": flat, + "shuffled": flat, + "delta": 0.0, + "clean_fit": {}, + "shuffled_fit": {}, + } + + def build_model() -> SequenceGRUWrapper: + return SequenceGRUWrapper( + hidden_dim=hidden_dim, + receiver_buckets=receiver_buckets, + device=device, + ) + + clean_seed = derived_seed(seed, "SeqGRU", "clean") + shuffled_seed = derived_seed(seed, "SeqGRU", "shuffled") + + set_global_determinism(clean_seed) + clean_model = build_model() + clean_model.fit(clean_train_df, num_epochs=1) + clean_fit = clean_model.fit_matched_prefix_examples( + clean_train_df, + train_examples, + seed=clean_seed, + max_epochs=max_epochs, + patience=6, + valid_frac=0.20, + pair_batch_size=64, + learning_rate=2e-3, + weight_decay=1e-4, + shuffle_within_sequence=False, + ) + clean_probs = clean_model.predict_matched_prefix_examples( + clean_test_df, + test_examples, + seed=clean_seed, + shuffle_within_sequence=False, + ) + clean_metrics = compute_metrics(y_test, clean_probs) + clean_result = make_auc_result( + y_test, + clean_probs.astype(np.float32), + seed=seed, + extra={ + "pr_auc": float(clean_metrics["pr_auc"]), + "brier": float(clean_metrics["brier"]), + "ece": float(clean_metrics["ece"]), + }, + ) + + set_global_determinism(shuffled_seed) + shuffled_model = build_model() + shuffled_model.fit(clean_train_df, num_epochs=1) + shuffled_fit = shuffled_model.fit_matched_prefix_examples( + clean_train_df, + train_examples, + seed=shuffled_seed, + max_epochs=max_epochs, + patience=6, + valid_frac=0.20, + pair_batch_size=64, + learning_rate=2e-3, + weight_decay=1e-4, + shuffle_within_sequence=True, + ) + shuffled_probs = shuffled_model.predict_matched_prefix_examples( + clean_test_df, + test_examples, + seed=shuffled_seed, + shuffle_within_sequence=True, + ) + shuffled_metrics = compute_metrics(y_test, shuffled_probs) + shuffled_result = make_auc_result( + y_test, + shuffled_probs.astype(np.float32), + seed=seed, + extra={ + "pr_auc": float(shuffled_metrics["pr_auc"]), + "brier": float(shuffled_metrics["brier"]), + "ece": float(shuffled_metrics["ece"]), + }, + ) + + return { + "clean": clean_result, + "shuffled": shuffled_result, + "delta": float(shuffled_result["auc"] - clean_result["auc"]), + "clean_fit": clean_fit, + "shuffled_fit": shuffled_fit, + } + + +def _combine_matched_examples( + train_examples: pd.DataFrame, + test_examples: pd.DataFrame, +) -> pd.DataFrame: + tagged_train = train_examples.copy() + tagged_train["example_split"] = "train" + tagged_test = test_examples.copy() + tagged_test["example_split"] = "test" + return pd.concat([tagged_train, tagged_test], ignore_index=True) + + +def _fit_embedding_probe( + train_emb: np.ndarray, + test_emb: np.ndarray, + y_train: np.ndarray, + y_test: np.ndarray, + seed: int, +) -> dict: + if len(y_train) == 0 or len(y_test) == 0: + return make_auc_result(np.zeros(0, dtype=np.float32), np.zeros(0, dtype=np.float32), seed=seed) + if len(np.unique(y_train)) < 2 or len(np.unique(y_test)) < 2: + probs = np.full(len(y_test), 0.5, dtype=np.float32) + metrics = compute_metrics(y_test, probs) + return make_auc_result( + y_test.astype(np.float32), + probs, + seed=seed, + extra={ + "pr_auc": float(metrics["pr_auc"]), + "brier": float(metrics["brier"]), + "ece": float(metrics["ece"]), + }, + ) + + mean = train_emb.mean(axis=0, keepdims=True) + std = train_emb.std(axis=0, keepdims=True) + 1e-6 + train_emb = (train_emb - mean) / std + test_emb = (test_emb - mean) / std + + clf = LogisticRegression( + max_iter=2000, + class_weight="balanced", + random_state=seed, + solver="liblinear", + ) + clf.fit(train_emb, y_train.astype(np.int32)) + probs = clf.predict_proba(test_emb)[:, 1].astype(np.float32) + metrics = compute_metrics(y_test.astype(np.float32), probs) + return make_auc_result( + y_test.astype(np.float32), + probs, + seed=seed, + extra={ + "pr_auc": float(metrics["pr_auc"]), + "brier": float(metrics["brier"]), + "ece": float(metrics["ece"]), + }, + ) + + +def compute_matched_temporal_gnn_metrics( + model_name: str, + model_builder, + df_train: pd.DataFrame, + df_test: pd.DataFrame, + train_examples: pd.DataFrame, + test_examples: pd.DataFrame, + seed: int, + num_epochs: int, +) -> dict: + if train_examples.empty or test_examples.empty: + empty = make_auc_result(np.zeros(0, dtype=np.float32), np.zeros(0, dtype=np.float32), seed=seed) + return { + "clean": empty, + "shuffled": empty, + "delta": float("nan"), + } + + clean_train = strip_oracle_cols(df_train) + clean_test = strip_oracle_cols(df_test) + all_examples = _combine_matched_examples(train_examples, test_examples) + train_mask = all_examples["example_split"].to_numpy() == "train" + test_mask = ~train_mask + y_train = all_examples.loc[train_mask, "label"].to_numpy(dtype=np.float32) + y_test = all_examples.loc[test_mask, "label"].to_numpy(dtype=np.float32) + + clean_model_seed = derived_seed(seed, model_name, "clean_model") + shuffled_model_seed = derived_seed(seed, model_name, "shuffled_model") + + set_global_determinism(clean_model_seed) + clean_model = model_builder() + clean_model.fit(clean_train, num_epochs=num_epochs) + clean_full = ( + pd.concat([clean_train, clean_test], ignore_index=True) + .sort_values("timestamp") + .reset_index(drop=True) + ) + clean_emb = clean_model.extract_prefix_embeddings(clean_full, all_examples) + clean_result = _fit_embedding_probe( + clean_emb[train_mask], + clean_emb[test_mask], + y_train, + y_test, + seed=seed, + ) + + shuffled_train = shuffle_chronology(clean_train, seed=seed + 101) + shuffled_test = shuffle_chronology(clean_test, seed=seed + 211) + set_global_determinism(shuffled_model_seed) + shuffled_model = model_builder() + shuffled_model.fit(shuffled_train, num_epochs=num_epochs) + shuffled_full = ( + pd.concat([shuffled_train, shuffled_test], ignore_index=True) + .sort_values("timestamp") + .reset_index(drop=True) + ) + shuffled_emb = shuffled_model.extract_prefix_embeddings(shuffled_full, all_examples) + shuffled_result = _fit_embedding_probe( + shuffled_emb[train_mask], + shuffled_emb[test_mask], + y_train, + y_test, + seed=seed, + ) + + return { + "clean": clean_result, + "shuffled": shuffled_result, + "delta": float(shuffled_result["auc"] - clean_result["auc"]), + "train_examples": int(train_mask.sum()), + "test_examples": int(test_mask.sum()), + "model_name": model_name, + } + + +def ks_distance(x: np.ndarray, y: np.ndarray) -> float: + x = np.sort(np.asarray(x, dtype=np.float64)) + y = np.sort(np.asarray(y, dtype=np.float64)) + if len(x) == 0 or len(y) == 0: + return 0.0 + values = np.sort(np.concatenate([x, y])) + cdf_x = np.searchsorted(x, values, side="right") / len(x) + cdf_y = np.searchsorted(y, values, side="right") / len(y) + return float(np.max(np.abs(cdf_x - cdf_y))) + + +def compute_static_aggregate_auc(node_df: pd.DataFrame, seed: int, verbose: bool = True) -> float: + feature_cols = [ + "txn_count", + "receiver_count", + "retry_count", + "failed_count", + "burst_count", + "quiet_count", + "dt_mean", + "dt_std", + "amount_mean", + "amount_std", + "phase_std", + "recv_entropy", + ] + + audit_df = node_df[node_df["twin_pair_id"] >= 0].copy() + if audit_df.empty or audit_df["label"].nunique() < 2: + return 0.5 + + pair_ids = audit_df["twin_pair_id"].unique() + if len(pair_ids) < 4: + return 0.5 + + rng = np.random.default_rng(seed) + pair_ids = rng.permutation(pair_ids) + split = max(1, int(0.7 * len(pair_ids))) + train_ids = set(pair_ids[:split]) + test_ids = set(pair_ids[split:]) + if not test_ids: + test_ids = set(pair_ids[-1:]) + train_ids = set(pair_ids[:-1]) + + train_df = audit_df[audit_df["twin_pair_id"].isin(train_ids)] + test_df = audit_df[audit_df["twin_pair_id"].isin(test_ids)] + if train_df["label"].nunique() < 2 or test_df["label"].nunique() < 2: + return 0.5 + + x_train = train_df[feature_cols].to_numpy(dtype=np.float32) + x_test = test_df[feature_cols].to_numpy(dtype=np.float32) + mean = x_train.mean(axis=0, keepdims=True) + std = x_train.std(axis=0, keepdims=True) + 1e-6 + x_train = (x_train - mean) / std + x_test = (x_test - mean) / std + + clf = LogisticRegression( + max_iter=2000, + class_weight="balanced", + random_state=seed, + solver="liblinear", + ) + clf.fit(x_train, train_df["label"].to_numpy(dtype=np.int32)) + probs = clf.predict_proba(x_test)[:, 1] + auc = safe_roc_auc(test_df["label"].to_numpy(dtype=np.float32), probs.astype(np.float32)) + + if verbose: + # Top predictors by absolute coefficient + coefs = np.abs(clf.coef_[0]) + ranked = np.argsort(coefs)[::-1] + print("\n Top static aggregate predictors:") + for rank_i in ranked[:5]: + print(f" {feature_cols[rank_i]:<20}: |coef|={coefs[rank_i]:.4f}") + + return auc + + + +def compute_aggregate_ks(node_df: pd.DataFrame) -> tuple[float, float]: + fraud_df = node_df[(node_df["twin_pair_id"] >= 0) & (node_df["label"] == 1)] + benign_df = node_df[(node_df["twin_pair_id"] >= 0) & (node_df["label"] == 0)] + if fraud_df.empty or benign_df.empty: + return 0.0, 0.0 + + feature_cols = [ + "txn_count", + "receiver_count", + "retry_count", + "burst_count", + "dt_mean", + "dt_std", + "recv_entropy", + ] + distances = [ + ks_distance(fraud_df[col].to_numpy(), benign_df[col].to_numpy()) + for col in feature_cols + ] + if not distances: + return 0.0, 0.0 + return float(np.mean(distances)), float(np.max(distances)) + + +def evaluate_matched_pair_separability( + model: TemporalModel, + df_train: pd.DataFrame, + df_test: pd.DataFrame, + delta_time: float, + n_checkpoints: int, +) -> tuple[float, int]: + if "twin_pair_id" not in df_test.columns or "twin_label" not in df_test.columns: + return 0.0, 0 + + checkpoints = make_checkpoints(df_test, delta_time, n_checkpoints=n_checkpoints) + if not checkpoints: + return 0.0, 0 + cutoff_time = checkpoints[-1] + + df_full = ( + pd.concat([df_train, df_test], ignore_index=True) + .sort_values("timestamp") + .reset_index(drop=True) + ) + prefix_df = df_full[df_full["timestamp"] <= cutoff_time].copy() + active_nodes = sorted(df_test[df_test["timestamp"] <= cutoff_time]["sender_id"].unique()) + if not active_nodes: + return 0.0, 0 + + if model.is_temporal: + model.reset_memory() + probs = model.predict(prefix_df, active_nodes) + score_map = {int(node_id): float(prob) for node_id, prob in zip(active_nodes, probs)} + + meta = ( + df_full.groupby("sender_id")[["twin_pair_id", "twin_label"]] + .first() + .reset_index() + ) + meta = meta[(meta["sender_id"].isin(active_nodes)) & (meta["twin_pair_id"] >= 0)] + + pair_scores = [] + for _, pair_df in meta.groupby("twin_pair_id"): + if len(pair_df) != 2 or set(pair_df["twin_label"]) != {0, 1}: + continue + fraud_node = int(pair_df.loc[pair_df["twin_label"] == 1, "sender_id"].iloc[0]) + benign_node = int(pair_df.loc[pair_df["twin_label"] == 0, "sender_id"].iloc[0]) + if fraud_node not in score_map or benign_node not in score_map: + continue + pair_scores.append(float(score_map[fraud_node] > score_map[benign_node])) + + if not pair_scores: + return 0.0, 0 + return float(np.mean(pair_scores)), int(len(pair_scores)) + + +def compute_split_leakage(df_train: pd.DataFrame, df_test: pd.DataFrame) -> dict: + train_users = set(df_train["sender_id"].unique().tolist()) + test_users = set(df_test["sender_id"].unique().tolist()) + leakage = { + "sender_overlap_count": int(len(train_users & test_users)), + "pair_overlap_count": 0, + "template_overlap_count": 0, + "receiver_pair_overlap_count": 0, + } + + if "twin_pair_id" in df_train.columns and "twin_pair_id" in df_test.columns: + train_pairs = set(df_train.loc[df_train["twin_pair_id"] >= 0, "twin_pair_id"].unique().tolist()) + test_pairs = set(df_test.loc[df_test["twin_pair_id"] >= 0, "twin_pair_id"].unique().tolist()) + leakage["pair_overlap_count"] = int(len(train_pairs & test_pairs)) + + if "template_id" in df_train.columns and "template_id" in df_test.columns: + train_templates = set(df_train.loc[df_train["template_id"] >= 0, "template_id"].unique().tolist()) + test_templates = set(df_test.loc[df_test["template_id"] >= 0, "template_id"].unique().tolist()) + leakage["template_overlap_count"] = int(len(train_templates & test_templates)) + + # Receiver-pair overlap: distinct (sender_id, receiver_id) tuples + train_rpairs = set(zip( + df_train["sender_id"].tolist(), df_train["receiver_id"].tolist() + )) + test_rpairs = set(zip( + df_test["sender_id"].tolist(), df_test["receiver_id"].tolist() + )) + leakage["receiver_pair_overlap_count"] = int(len(train_rpairs & test_rpairs)) + + return leakage + + + +# --------------------------------------------------------------------------- +# Prefix-only evaluation guard +# --------------------------------------------------------------------------- + +def assert_prefix_only(df_prefix: pd.DataFrame, cutoff_time: float) -> None: + """Warn if any future event slipped into the prefix. + Uses 1.0s tolerance to absorb float32-vs-float64 precision gaps. + """ + if df_prefix.empty: + return + actual_max = float(df_prefix["timestamp"].max()) + if actual_max > cutoff_time + 1.0: + print( + f"[PREFIX LEAK] df_prefix max timestamp {actual_max:.2f} > cutoff {cutoff_time:.2f}!" + ) + + +# --------------------------------------------------------------------------- +# Label-source audit +# --------------------------------------------------------------------------- + +def build_label_source_audit_table(df: pd.DataFrame) -> pd.DataFrame: + """Return per-positive-event audit table. + + Required audit columns (populated by FraudEngine): + fraud_source, motif_source, motif_hit_count, trigger_event_idx, + label_event_idx, label_delay, is_fallback_label + """ + fraud_rows = df[df["is_fraud"] == 1].copy() + if fraud_rows.empty: + return pd.DataFrame() + + audit_cols = [ + "sender_id", "twin_pair_id", "twin_role", + "fraud_source", "motif_source", "motif_hit_count", + "trigger_event_idx", "label_event_idx", "label_delay", + "is_fallback_label", + ] + available = [c for c in audit_cols if c in fraud_rows.columns] + return fraud_rows[available].reset_index(drop=True) + + +def compute_motif_label_consistency(df: pd.DataFrame, calib_mode: bool = False) -> dict: + """Compute and print motif/label consistency statistics.""" + has_motif = "motif_hit_count" in df.columns + has_fraud = "is_fraud" in df.columns + if not (has_motif and has_fraud): + return {} + + # Restrict to twin users only + if "twin_pair_id" in df.columns: + twin_df = df[df["twin_pair_id"] >= 0].copy() + else: + twin_df = df.copy() + if twin_df.empty: + return {} + + # Node-level aggregation + node_grp = twin_df.groupby("sender_id") + node_label = node_grp["is_fraud"].max() + node_hit = node_grp["motif_hit_count"].max() + node_role = node_grp["twin_role"].first() if "twin_role" in twin_df.columns else None + + has_hit = (node_hit >= 1) + label_pos = (node_label == 1) + + p_label_given_hit = float(label_pos[has_hit].mean()) if has_hit.any() else float("nan") + p_label_given_nohit = float(label_pos[~has_hit].mean()) if (~has_hit).any() else float("nan") + p_hit_given_label = float(has_hit[label_pos].mean()) if label_pos.any() else float("nan") + + if node_role is not None: + benign_mask = (node_role == "benign") + accidental_motif_rate = float(has_hit[benign_mask].mean()) if benign_mask.any() else float("nan") + avg_hits_fraud = float(node_hit[~benign_mask & label_pos].mean()) if (label_pos & ~benign_mask).any() else float("nan") + avg_hits_benign = float(node_hit[benign_mask].mean()) if benign_mask.any() else float("nan") + else: + accidental_motif_rate = float("nan") + avg_hits_fraud = float("nan") + avg_hits_benign = float("nan") + + result = { + "p_label_given_hit": p_label_given_hit, + "p_label_given_nohit": p_label_given_nohit, + "p_hit_given_label": p_hit_given_label, + "accidental_benign_motif": accidental_motif_rate, + "avg_hits_fraud_twin": avg_hits_fraud, + "avg_hits_benign_twin": avg_hits_benign, + } + + print("\n--- Motif-Label Consistency ---") + for k, v in result.items(): + print(f" {k:<30}: {v:.4f}" if not (isinstance(v, float) and v != v) else f" {k:<30}: N/A") + + if calib_mode: + # In calib mode, verify no fallback positives exist + if "is_fallback_label" in df.columns: + fallback_pos = int(df.loc[df["is_fraud"] == 1, "is_fallback_label"].sum()) + print(f" {'fallback_positives':<30}: {fallback_pos}") + if fallback_pos > 0: + print(" [CALIB VIOLATION] Fallback positives found! is_fallback_label.sum() must be 0.") + result["fallback_positives"] = fallback_pos + + return result + + +def compute_label_delay_stats(df: pd.DataFrame) -> dict: + """Print and return min/mean/max label_delay for positive events.""" + if "label_delay" not in df.columns: + return {} + delays = df.loc[(df["is_fraud"] == 1) & (df["label_delay"] >= 0), "label_delay"] + if delays.empty: + print(" label_delay: no valid delay data.") + return {"delay_min": float("nan"), "delay_mean": float("nan"), "delay_max": float("nan")} + result = { + "delay_min": float(delays.min()), + "delay_mean": float(delays.mean()), + "delay_max": float(delays.max()), + } + print(f" label_delay min={result['delay_min']:.1f} mean={result['delay_mean']:.1f} max={result['delay_max']:.1f}") + return result + + +# --------------------------------------------------------------------------- +# Prefix-task helpers +# --------------------------------------------------------------------------- + +def uses_twin_pairs(df: pd.DataFrame) -> bool: + return "twin_pair_id" in df.columns and bool((df["twin_pair_id"] >= 0).any()) + + +def get_eval_nodes(df: pd.DataFrame) -> List[int]: + if uses_twin_pairs(df): + pair_df = df[df["twin_pair_id"] >= 0] + return sorted(pair_df["sender_id"].unique().tolist()) + return sorted(df["sender_id"].unique().tolist()) + + +def remap_node_ids(*dfs: pd.DataFrame) -> list[pd.DataFrame]: + non_empty = [df for df in dfs if df is not None and not df.empty] + if not non_empty: + return [df.copy() for df in dfs] + + all_ids = np.unique( + np.concatenate( + [ + np.concatenate( + [ + df["sender_id"].to_numpy(dtype=np.int64), + df["receiver_id"].to_numpy(dtype=np.int64), + ] + ) + for df in non_empty + ] + ) + ) + id_map = {int(node_id): idx for idx, node_id in enumerate(all_ids.tolist())} + + remapped = [] + for df in dfs: + if df is None: + remapped.append(df) + continue + out = df.copy() + out["sender_id"] = out["sender_id"].map(id_map).astype(np.int64) + out["receiver_id"] = out["receiver_id"].map(id_map).astype(np.int64) + remapped.append(out) + return remapped + + +def augment_with_placeholder_nodes(df_train: pd.DataFrame, df_test: pd.DataFrame) -> pd.DataFrame: + train_nodes = set( + np.concatenate( + [ + df_train["sender_id"].to_numpy(dtype=np.int64), + df_train["receiver_id"].to_numpy(dtype=np.int64), + ] + ).tolist() + ) + test_nodes = set( + np.concatenate( + [ + df_test["sender_id"].to_numpy(dtype=np.int64), + df_test["receiver_id"].to_numpy(dtype=np.int64), + ] + ).tolist() + ) + unseen_nodes = sorted(test_nodes - train_nodes) + if not unseen_nodes: + return df_train + + base_time = float(min(df_train["timestamp"].min(), df_test["timestamp"].min())) - 1.0 + rows = [] + for offset, node_id in enumerate(unseen_nodes): + row = {} + for col in df_train.columns: + if col in {"sender_id", "receiver_id"}: + row[col] = int(node_id) + elif col == "timestamp": + row[col] = base_time - offset + elif col in {"fraud_type", "twin_role"}: + row[col] = "placeholder" + elif col in {"txn_id", "twin_pair_id", "template_id"}: + row[col] = -1 + elif col in {"twin_label", "is_fraud", "is_retry", "failed"}: + row[col] = 0 + else: + row[col] = 0.0 + rows.append(row) + + placeholder_df = pd.DataFrame(rows, columns=df_train.columns) + out = pd.concat([placeholder_df, df_train], ignore_index=True) + return out.sort_values("timestamp").reset_index(drop=True) + + +def split_temporally(df: pd.DataFrame, train_ratio: float = 0.7) -> tuple[pd.DataFrame, pd.DataFrame, float]: + df = df.sort_values("timestamp").reset_index(drop=True) + if uses_twin_pairs(df): + pair_meta = ( + df[df["twin_pair_id"] >= 0] + .groupby("twin_pair_id")["timestamp"] + .min() + .sort_values() + ) + if len(pair_meta) >= 2: + split_idx = max(1, min(len(pair_meta) - 1, int(train_ratio * len(pair_meta)))) + train_ids = set(pair_meta.index[:split_idx].tolist()) + test_ids = set(pair_meta.index[split_idx:].tolist()) + df_train = df[(df["twin_pair_id"] < 0) | (df["twin_pair_id"].isin(train_ids))].copy() + df_test = df[df["twin_pair_id"].isin(test_ids)].copy() + split_time = float(df_test["timestamp"].min()) if not df_test.empty else float(df["timestamp"].quantile(train_ratio)) + return df_train.sort_values("timestamp").reset_index(drop=True), df_test.sort_values("timestamp").reset_index(drop=True), split_time + split_time = float(df["timestamp"].quantile(train_ratio)) + df_train = df[df["timestamp"] <= split_time].copy() + df_test = df[df["timestamp"] > split_time].copy() + return df_train, df_test, split_time + + +def horizon_to_delta(df_test: pd.DataFrame, horizon: float) -> float: + if df_test.empty: + return 1e-6 + t_min = float(df_test["timestamp"].min()) + t_max = float(df_test["timestamp"].max()) + return max(1e-6, horizon * max(t_max - t_min, 1e-6)) + + +def build_window_labels( + df: pd.DataFrame, + cutoff_time: float, + eval_nodes: Sequence[int], + delta_time: float, +) -> np.ndarray: + future = df[(df["timestamp"] > cutoff_time) & (df["timestamp"] <= cutoff_time + delta_time)] + fraud_map = future.groupby("sender_id")["is_fraud"].max() + return np.array([int(fraud_map.get(node_id, 0)) for node_id in eval_nodes], dtype=np.float32) + + +def build_window_state( + df: pd.DataFrame, + cutoff_time: float, + eval_nodes: Sequence[int], + delta_time: float, +) -> np.ndarray: + future = df[(df["timestamp"] > cutoff_time) & (df["timestamp"] <= cutoff_time + delta_time)] + if "dynamic_fraud_state" in future.columns: + state_map = future.groupby("sender_id")["dynamic_fraud_state"].mean() + else: + state_map = future.groupby("sender_id")["is_fraud"].mean() + return np.array([float(state_map.get(node_id, 0.0)) for node_id in eval_nodes], dtype=np.float32) + + +def choose_anchor_time(df_train: pd.DataFrame, delta_time: float) -> float: + t_min = float(df_train["timestamp"].min()) + max_anchor = float(df_train["timestamp"].max()) - delta_time + if max_anchor <= t_min: + return t_min + + candidate_quantiles = [0.80, 0.75, 0.70, 0.65, 0.60, 0.55] + for quantile in candidate_quantiles: + anchor_time = min(float(df_train["timestamp"].quantile(quantile)), max_anchor) + prefix_nodes = get_eval_nodes(df_train[df_train["timestamp"] <= anchor_time]) + if not prefix_nodes: + continue + y_anchor = build_window_labels(df_train, anchor_time, prefix_nodes, delta_time) + if len(np.unique(y_anchor)) >= 2: + return anchor_time + + return min(float(df_train["timestamp"].quantile(0.80)), max_anchor) + + +def make_checkpoints(df_test: pd.DataFrame, delta_time: float, n_checkpoints: int) -> List[float]: + if df_test.empty: + return [] + + t_max = float(df_test["timestamp"].max()) + valid = df_test[df_test["timestamp"] <= t_max - delta_time].sort_values("timestamp") + if valid.empty: + return [] + + timestamps = valid["timestamp"].to_numpy(dtype=np.float64) + idx = np.unique( + np.linspace(0, len(timestamps) - 1, num=min(n_checkpoints, len(timestamps)), dtype=int) + ) + checkpoints = [float(timestamps[i]) for i in idx] + return sorted(set(checkpoints)) + + +def train_node_head( + model: TemporalModel, + df_anchor_prefix: pd.DataFrame, + eval_nodes: List[int], + y_labels: np.ndarray, + num_epochs: int = 150, +) -> None: + if hasattr(model, "train_node_classifier_on_prefix"): + model.train_node_classifier_on_prefix( + df_anchor_prefix, eval_nodes, y_labels, num_epochs=num_epochs + ) + return + + if model.is_temporal: + model.reset_memory() + if len(df_anchor_prefix) > 0 and len(eval_nodes) > 0: + model.predict(df_anchor_prefix, eval_nodes) + + if hasattr(model, "train_node_classifier"): + model.train_node_classifier(eval_nodes, y_labels, num_epochs=num_epochs) + if isinstance(model, TGNWrapper): + assert model._node_head_fitted, "TGN node classifier was not fitted." + return + + raise ValueError(f"Model {model.name} does not expose node-head training.") + + +def fit_model_for_horizon( + model: TemporalModel, + df_train: pd.DataFrame, + delta_time: float, + num_epochs: int, + node_epochs: int, +) -> dict: + # Strip oracle columns from all non-oracle models + train_input = df_train if model.name in _ORACLE_MODEL_NAMES else strip_oracle_cols(df_train) + model.fit(train_input, num_epochs=num_epochs) + + anchor_time = choose_anchor_time(train_input, delta_time) + df_anchor_prefix = train_input[train_input["timestamp"] <= anchor_time].copy() + assert_prefix_only(df_anchor_prefix, anchor_time) + anchor_nodes = get_eval_nodes(df_anchor_prefix) + y_anchor = build_window_labels(train_input, anchor_time, anchor_nodes, delta_time) + + train_node_head( + model, + df_anchor_prefix=df_anchor_prefix, + eval_nodes=anchor_nodes, + y_labels=y_anchor, + num_epochs=node_epochs, + ) + + return { + "anchor_time": anchor_time, + "anchor_nodes": len(anchor_nodes), + "anchor_fraud_rate": float(y_anchor.mean()) if len(y_anchor) else 0.0, + } + + + +def collect_prefix_predictions( + model: TemporalModel, + df_train: pd.DataFrame, + df_test: pd.DataFrame, + delta_time: float, + n_checkpoints: int, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + checkpoints = make_checkpoints(df_test, delta_time, n_checkpoints=n_checkpoints) + if not checkpoints: + return ( + np.zeros(0, dtype=np.float32), + np.zeros(0, dtype=np.float32), + np.zeros(0, dtype=np.float32), + ) + + df_full = ( + pd.concat([df_train, df_test], ignore_index=True) + .sort_values("timestamp") + .reset_index(drop=True) + ) + y_chunks: List[np.ndarray] = [] + p_chunks: List[np.ndarray] = [] + s_chunks: List[np.ndarray] = [] + + is_oracle = model.name in _ORACLE_MODEL_NAMES + + for cutoff_time in checkpoints: + active_nodes = get_eval_nodes(df_test[df_test["timestamp"] <= cutoff_time]) + if not active_nodes: + continue + + prefix_df = df_full[df_full["timestamp"] <= cutoff_time].copy() + assert_prefix_only(prefix_df, cutoff_time) + eval_df = prefix_df if is_oracle else strip_oracle_cols(prefix_df) + if model.is_temporal: + model.reset_memory() + + probs = model.predict(eval_df, active_nodes) + y_true = build_window_labels(df_full, cutoff_time, active_nodes, delta_time) + state = build_window_state(df_full, cutoff_time, active_nodes, delta_time) + + y_chunks.append(y_true) + p_chunks.append(np.asarray(probs, dtype=np.float32)) + s_chunks.append(state) + + if not y_chunks: + return ( + np.zeros(0, dtype=np.float32), + np.zeros(0, dtype=np.float32), + np.zeros(0, dtype=np.float32), + ) + + return ( + np.concatenate(y_chunks).astype(np.float32), + np.concatenate(p_chunks).astype(np.float32), + np.concatenate(s_chunks).astype(np.float32), + ) + + + +def evaluate_model( + model: TemporalModel, + df_train: pd.DataFrame, + df_test: pd.DataFrame, + delta_time: float, + n_checkpoints: int, +) -> tuple[dict, np.ndarray, np.ndarray, np.ndarray]: + y_true, probs, states = collect_prefix_predictions( + model=model, + df_train=df_train, + df_test=df_test, + delta_time=delta_time, + n_checkpoints=n_checkpoints, + ) + metrics = compute_metrics(y_true, probs) if len(y_true) else compute_metrics(np.array([0.0]), np.array([0.5])) + metrics["n_predictions"] = int(len(y_true)) + return metrics, y_true, probs, states + + +def shuffle_chronology(df: pd.DataFrame, seed: int) -> pd.DataFrame: + """Break temporal order while preserving the event table.""" + rng = np.random.default_rng(seed) + shuffled = df.copy() + shuffled["timestamp"] = rng.permutation(shuffled["timestamp"].to_numpy(dtype=np.float64)) + return shuffled.sort_values("timestamp").reset_index(drop=True) + + +# --------------------------------------------------------------------------- +# Experiments (single seed) +# --------------------------------------------------------------------------- + +def run_ood_single( + df_easy: pd.DataFrame, + df_medium: pd.DataFrame, + df_hard: pd.DataFrame, + device: str, + num_epochs: int, + node_epochs: int, + n_checkpoints: int, + horizon: float = 0.10, +) -> pd.DataFrame: + df_train = ( + pd.concat([df_easy, df_medium], ignore_index=True) + .sort_values("timestamp") + .reset_index(drop=True) + ) + df_test = df_hard.sort_values("timestamp").reset_index(drop=True) + df_train, df_test = remap_node_ids(df_train, df_test) + df_train = augment_with_placeholder_nodes(df_train, df_test) + delta_time = horizon_to_delta(df_test, horizon) + + rows = [] + models = build_models(device=device) + for model_name in MODEL_ORDER: + model = models[model_name] + fit_info = fit_model_for_horizon(model, df_train, delta_time, num_epochs, node_epochs) + metrics, _, _, _ = evaluate_model(model, df_train, df_test, delta_time, n_checkpoints) + rows.append({ + "model": model_name, + **metrics, + **fit_info, + }) + + df_out = pd.DataFrame(rows) + xgb_roc = float(df_out.loc[df_out["model"] == "XGBoost", "roc_auc"].iloc[0]) + df_out["gap_vs_xgb"] = df_out["roc_auc"] - xgb_roc + return df_out + + +def run_causal_single( + df_hard: pd.DataFrame, + device: str, + num_epochs: int, + node_epochs: int, + n_checkpoints: int, + seed: int, + horizon: float = 0.10, +) -> pd.DataFrame: + df_clean = df_hard.sort_values("timestamp").reset_index(drop=True) + df_shuffled = shuffle_chronology(df_clean, seed=seed + 17) + + df_train_clean, df_test_clean, _ = split_temporally(df_clean) + df_train_shuf, df_test_shuf, _ = split_temporally(df_shuffled) + df_train_clean, df_test_clean = remap_node_ids(df_train_clean, df_test_clean) + df_train_shuf, df_test_shuf = remap_node_ids(df_train_shuf, df_test_shuf) + df_train_clean = augment_with_placeholder_nodes(df_train_clean, df_test_clean) + df_train_shuf = augment_with_placeholder_nodes(df_train_shuf, df_test_shuf) + delta_time_clean = horizon_to_delta(df_test_clean, horizon) + delta_time_shuf = horizon_to_delta(df_test_shuf, horizon) + + rows = [] + clean_models = build_models(device=device) + shuffled_models = build_models(device=device) + + for model_name in MODEL_ORDER: + clean_model = clean_models[model_name] + shuffled_model = shuffled_models[model_name] + + fit_model_for_horizon(clean_model, df_train_clean, delta_time_clean, num_epochs, node_epochs) + clean_metrics, _, _, _ = evaluate_model( + clean_model, df_train_clean, df_test_clean, delta_time_clean, n_checkpoints + ) + + fit_model_for_horizon(shuffled_model, df_train_shuf, delta_time_shuf, num_epochs, node_epochs) + shuffled_metrics, _, _, _ = evaluate_model( + shuffled_model, df_train_shuf, df_test_shuf, delta_time_shuf, n_checkpoints + ) + + rows.append({ + "model": model_name, + "roc_auc_clean": clean_metrics["roc_auc"], + "pr_auc_clean": clean_metrics["pr_auc"], + "brier_clean": clean_metrics["brier"], + "ece_clean": clean_metrics["ece"], + "roc_auc_shuffled": shuffled_metrics["roc_auc"], + "pr_auc_shuffled": shuffled_metrics["pr_auc"], + "brier_shuffled": shuffled_metrics["brier"], + "ece_shuffled": shuffled_metrics["ece"], + "delta": shuffled_metrics["roc_auc"] - clean_metrics["roc_auc"], + }) + + return pd.DataFrame(rows) + + +def run_horizon_single( + df_medium: pd.DataFrame, + device: str, + num_epochs: int, + node_epochs: int, + n_checkpoints: int, + horizons: Sequence[float], +) -> pd.DataFrame: + df_train, df_test, _ = split_temporally(df_medium) + df_train, df_test = remap_node_ids(df_train, df_test) + df_train = augment_with_placeholder_nodes(df_train, df_test) + rows = [] + + for horizon in horizons: + delta_time = horizon_to_delta(df_test, horizon) + models = build_models(device=device) + for model_name in MODEL_ORDER: + model = models[model_name] + fit_model_for_horizon(model, df_train, delta_time, num_epochs, node_epochs) + metrics, _, _, _ = evaluate_model(model, df_train, df_test, delta_time, n_checkpoints) + rows.append({ + "horizon": float(horizon), + "model": model_name, + **metrics, + }) + + return pd.DataFrame(rows) + + +def run_mechanistic_single( + df_hard: pd.DataFrame, + device: str, + num_epochs: int, + node_epochs: int, + n_checkpoints: int, + horizon: float = 0.10, +) -> pd.DataFrame: + df_train, df_test, _ = split_temporally(df_hard) + df_train, df_test = remap_node_ids(df_train, df_test) + df_train = augment_with_placeholder_nodes(df_train, df_test) + delta_time = horizon_to_delta(df_test, horizon) + rows = [] + models = build_models(device=device) + + for model_name in MODEL_ORDER: + model = models[model_name] + fit_model_for_horizon(model, df_train, delta_time, num_epochs, node_epochs) + _, _, probs, states = evaluate_model(model, df_train, df_test, delta_time, n_checkpoints) + rows.append({ + "model": model_name, + "pearson_r": safe_pearson(states, probs), + }) + + return pd.DataFrame(rows) + + +def run_audit_single( + df_hard: pd.DataFrame, + device: str, + num_epochs: int, + node_epochs: int, + n_checkpoints: int, + seed: int, + horizon: float = 0.10, + benchmark_mode: str = "temporal_twins", +) -> pd.DataFrame: + node_audit = build_node_audit_table(df_hard) + ks_mean, ks_max = compute_aggregate_ks(node_audit) + paired_pairs = int(node_audit.loc[node_audit["twin_pair_id"] >= 0, "twin_pair_id"].nunique()) + paired_nodes = int((node_audit["twin_pair_id"] >= 0).sum()) + + # --- Label-source audit --- + calib_mode = benchmark_mode == "temporal_twins_oracle_calib" + print("\n--- Label-Source Audit ---") + audit_tbl = build_label_source_audit_table(df_hard) + if not audit_tbl.empty: + print(audit_tbl.to_string(index=False, max_rows=20)) + consistency = compute_motif_label_consistency(df_hard, calib_mode=calib_mode) + compute_label_delay_stats(df_hard) + + df_train, df_test, _ = split_temporally(df_hard) + leakage = compute_split_leakage(df_train, df_test) + + # --- Split integrity report --- + print("\n--- Split Integrity ---") + for k, v in leakage.items(): + status = "[OK]" if v == 0 else "[WARN]" + print(f" {status} {k}: {v}") + + df_train, df_test = remap_node_ids(df_train, df_test) + train_examples, train_pair_rows, train_pair_counts = build_matched_control_tables(df_train) + test_examples, test_pair_rows, test_pair_counts = build_matched_control_tables(df_test) + matched_train_features = build_matched_prefix_feature_table(df_train, train_examples) + matched_test_features = build_matched_prefix_feature_table(df_test, test_examples) + matched_audit = report_matched_control_audits( + test_examples=test_examples, + test_pair_rows=test_pair_rows, + test_pair_counts=test_pair_counts, + ) + static_agg_result = compute_matched_static_aggregate_auc( + matched_train_features, + matched_test_features, + seed=seed, + verbose=True, + ) + + xgb_result = compute_matched_xgboost_auc( + matched_train_features, + matched_test_features, + seed=seed, + ) + static_gnn_result = compute_matched_static_gnn_auc( + df_train=df_train, + df_test=df_test, + train_examples=train_examples, + test_examples=test_examples, + device=device, + num_epochs=num_epochs, + seed=seed, + ) + + df_train_eval = augment_with_placeholder_nodes(df_train, df_test) + delta_time = horizon_to_delta(df_test, horizon) + models = build_models(device=device) + rows = [] + + for model_name in MODEL_ORDER: + model = models[model_name] + fit_model_for_horizon(model, df_train_eval, delta_time, num_epochs, node_epochs) + metrics, _, probs, states = evaluate_model(model, df_train_eval, df_test, delta_time, n_checkpoints) + pair_sep, eval_pairs = evaluate_matched_pair_separability( + model, + df_train=df_train_eval, + df_test=df_test, + delta_time=delta_time, + n_checkpoints=n_checkpoints, + ) + matched_control_roc_auc = float("nan") + if model_name == "XGBoost": + matched_control_roc_auc = float(xgb_result["auc"]) + elif model_name == "StaticGNN": + matched_control_roc_auc = float(static_gnn_result["auc"]) + rows.append({ + "model": model_name, + **metrics, + "pearson_r": safe_pearson(states, probs), + "matched_pair_sep": pair_sep, + "matched_pair_eval_pairs": eval_pairs, + "matched_control_roc_auc": matched_control_roc_auc, + "static_agg_auc": float(static_agg_result["auc"]), + "static_agg_auc_bootstrap_std": float(static_agg_result["bootstrap_std"]), + "xgb_auc_bootstrap_std": float(xgb_result["bootstrap_std"]), + "static_gnn_auc_bootstrap_std": float(static_gnn_result["bootstrap_std"]), + "ks_mean": ks_mean, + "ks_max": ks_max, + "paired_pairs": paired_pairs, + "paired_nodes": paired_nodes, + **leakage, + **matched_audit, + }) + + return pd.DataFrame(rows) + + +# --------------------------------------------------------------------------- +# Aggregation / plotting outputs +# --------------------------------------------------------------------------- + +def summarise_mean_std(df: pd.DataFrame, group_cols: Sequence[str], value_cols: Sequence[str]) -> pd.DataFrame: + summary = df.groupby(list(group_cols)).agg({ + value_col: ["mean", "std"] for value_col in value_cols + }) + summary.columns = [ + f"{value_col}_{stat}" + for value_col, stat in summary.columns.to_flat_index() + ] + summary = summary.reset_index() + return summary.fillna(0.0) + + +def save_experiment_outputs( + raw_frames: Dict[str, List[pd.DataFrame]], + results_dir: str, +) -> None: + os.makedirs(results_dir, exist_ok=True) + raw_causal = pd.concat(raw_frames["causal"], ignore_index=True) if raw_frames["causal"] else None + + if raw_frames["ood"]: + raw_ood = pd.concat(raw_frames["ood"], ignore_index=True) + raw_ood.to_csv(os.path.join(results_dir, "ood_raw.csv"), index=False) + ood_summary = summarise_mean_std( + raw_ood, + group_cols=["model"], + value_cols=["roc_auc", "pr_auc", "brier", "ece", "gap_vs_xgb"], + ) + ood_summary.to_csv(os.path.join(results_dir, "ood.csv"), index=False) + + if raw_frames["causal"]: + assert raw_causal is not None + raw_causal.to_csv(os.path.join(results_dir, "causal_raw.csv"), index=False) + causal_summary = summarise_mean_std( + raw_causal, + group_cols=["model"], + value_cols=[ + "roc_auc_clean", + "pr_auc_clean", + "brier_clean", + "ece_clean", + "roc_auc_shuffled", + "pr_auc_shuffled", + "brier_shuffled", + "ece_shuffled", + "delta", + ], + ) + causal_summary.to_csv(os.path.join(results_dir, "causal.csv"), index=False) + + if raw_frames["horizon"]: + raw_horizon = pd.concat(raw_frames["horizon"], ignore_index=True) + raw_horizon.to_csv(os.path.join(results_dir, "horizon_raw.csv"), index=False) + horizon_summary = summarise_mean_std( + raw_horizon, + group_cols=["horizon", "model"], + value_cols=["roc_auc", "pr_auc", "brier", "ece"], + ) + horizon_summary.to_csv(os.path.join(results_dir, "horizon.csv"), index=False) + + if raw_frames["mechanistic"]: + raw_mech = pd.concat(raw_frames["mechanistic"], ignore_index=True) + raw_mech.to_csv(os.path.join(results_dir, "mechanistic_raw.csv"), index=False) + mech_summary = summarise_mean_std( + raw_mech, + group_cols=["model"], + value_cols=["pearson_r"], + ) + mech_summary.to_csv(os.path.join(results_dir, "mechanistic.csv"), index=False) + + if raw_frames.get("audit"): + raw_audit = pd.concat(raw_frames["audit"], ignore_index=True) + raw_audit.to_csv(os.path.join(results_dir, "audit_raw.csv"), index=False) + audit_summary = summarise_mean_std( + raw_audit, + group_cols=["model"], + value_cols=[ + "roc_auc", + "pr_auc", + "brier", + "ece", + "pearson_r", + "matched_pair_sep", + "matched_pair_eval_pairs", + "matched_control_roc_auc", + "static_agg_auc", + "ks_mean", + "ks_max", + "paired_pairs", + "paired_nodes", + "sender_overlap_count", + "pair_overlap_count", + "template_overlap_count", + "pair_total_txn_count_diff_mean", + "pair_total_txn_count_diff_max", + "auc_total_txn_count", + "auc_local_event_idx", + "auc_prefix_txn_count", + "auc_timestamp", + "auc_account_age", + "auc_active_age", + "fraud_label_event_idx_mean", + "fraud_label_event_idx_max", + "benign_eval_event_idx_mean", + "benign_eval_event_idx_max", + "pair_event_idx_diff_mean", + "pair_event_idx_diff_max", + "pair_active_age_diff_mean", + "pair_active_age_diff_max", + "pair_timestamp_diff_mean", + "pair_timestamp_diff_max", + "benign_motif_hit_rate", + "benign_motif_hit_pairs", + "matched_control_examples", + "matched_control_pair_events", + ], + ) + if raw_causal is not None: + causal_delta = summarise_mean_std( + raw_causal, + group_cols=["model"], + value_cols=["delta"], + )[["model", "delta_mean", "delta_std"]] + audit_summary = audit_summary.merge(causal_delta, on="model", how="left") + audit_summary[["delta_mean", "delta_std"]] = audit_summary[ + ["delta_mean", "delta_std"] + ].fillna(0.0) + audit_summary.to_csv(os.path.join(results_dir, "audit.csv"), index=False) + + +# --------------------------------------------------------------------------- +# Node-level oracle evaluation helpers (twin_label, not window label) +# --------------------------------------------------------------------------- + +def _twin_labels_for_nodes(df_full: pd.DataFrame, nodes: List[int]) -> np.ndarray: + """Return twin_label (1=fraud twin, 0=benign) per node. Falls back to + is_fraud if twin_label is absent.""" + col = "twin_label" if "twin_label" in df_full.columns else "is_fraud" + label_series = df_full.groupby("sender_id")[col].max() + return np.array([float(label_series.get(n, 0.0)) for n in nodes], dtype=np.float32) + + +def evaluate_oracle_node_level( + model: TemporalModel, + df_full: pd.DataFrame, + eval_nodes: List[int], +) -> float: + """ROC-AUC of oracle scored against twin_label (user-level, not window-level). + + For oracle-type models we pass the FULL df (with audit columns). + For AuditOracle, predict() directly reads motif_hit_count — no training. + For RawMotifOracle, train_node_classifier_on_prefix must be called first + with twin_labels so it learns the node-level task. + """ + if not eval_nodes: + return float("nan") + y_true = _twin_labels_for_nodes(df_full, eval_nodes) + probs = model.predict(df_full, eval_nodes) + return safe_roc_auc(y_true, probs.astype(np.float32)) + + +def evaluate_oracle_pair_sep_node_level( + model: TemporalModel, + df_full: pd.DataFrame, + eval_nodes: List[int], +) -> float: + """Matched-pair separability: P(score_fraud > score_benign) using twin_label.""" + if not eval_nodes or "twin_pair_id" not in df_full.columns: + return float("nan") + + probs = model.predict(df_full, eval_nodes) + score_map = {n: float(p) for n, p in zip(eval_nodes, probs)} + + meta = ( + df_full[df_full["sender_id"].isin(eval_nodes) & (df_full["twin_pair_id"] >= 0)] + .groupby("sender_id") + .agg(twin_pair_id=("twin_pair_id", "first"), twin_label=("twin_label", "max")) + .reset_index() + ) + + pair_results: List[float] = [] + for _, grp in meta.groupby("twin_pair_id"): + if len(grp) != 2 or set(grp["twin_label"]) != {0, 1}: + continue + fraud_node = int(grp.loc[grp["twin_label"] == 1, "sender_id"].iloc[0]) + benign_node = int(grp.loc[grp["twin_label"] == 0, "sender_id"].iloc[0]) + if fraud_node in score_map and benign_node in score_map: + pair_results.append(float(score_map[fraud_node] > score_map[benign_node])) + + return float(np.mean(pair_results)) if pair_results else float("nan") + + +def build_oracle_debug_table( + df_full: pd.DataFrame, + eval_nodes: List[int], + oracle_scores: dict[str, np.ndarray], + y_twin: np.ndarray, + n_sample: int = 20, + primary_score_name: str = "AuditOracle", + table_title: str = "Oracle Debug Table", +) -> pd.DataFrame: + """Print a per-node debug table for oracle/probe scores vs ground-truth.""" + audit_cols = [ + "twin_pair_id", "twin_role", + "motif_hit_count", "trigger_event_idx", "label_event_idx", + "label_delay", "is_fallback_label", + ] + available = [c for c in audit_cols if c in df_full.columns] + meta = ( + df_full[df_full["sender_id"].isin(eval_nodes)] + .groupby("sender_id")[available] + .first() + .reset_index() # sender_id becomes a column here + ) + meta["twin_label"] = y_twin + meta["_idx"] = meta["sender_id"].map({n: i for i, n in enumerate(eval_nodes)}) + for name, scores in oracle_scores.items(): + meta[f"score_{name}"] = meta["_idx"].map( + {i: float(scores[i]) for i in range(len(scores))} + ) + meta = meta.drop(columns=["_idx"]) + + # Sample: top n_sample/2 by the primary motif score + bottom n_sample/2 + sort_col = f"score_{primary_score_name}" + if sort_col not in meta.columns: + sort_col = meta.columns[-1] + meta = meta.sort_values(sort_col, ascending=False) + sample = pd.concat([meta.head(n_sample // 2), meta.tail(n_sample // 2)]).drop_duplicates() + + print(f"\n--- {table_title} (top & bottom by {primary_score_name} score) ---") + print(sample.to_string(index=False)) + return sample + + +# Gate volume targets / budgets +_FAST_GATE_MIN_MATCHED_PAIRS = 500 +_FULL_GATE_MIN_MATCHED_PAIRS = 2000 +_GATE_MIN_CLASS_EXAMPLES = 500 +_GATE_MIN_UNIQUE_USERS = 50 +_GATE_POS_RATE_RANGE = (0.35, 0.65) +_GATE_BOOTSTRAP_ROUNDS = 200 +_GATE_PACK_NAMESPACE = 10_000_000 +_GATE_MAX_EXTRA_PACKS = 6 + + +def _subsample_for_gate( + df: pd.DataFrame, + rng: np.random.Generator, + max_pairs: int | None = None, +) -> pd.DataFrame: + """Keep at most max_pairs twin pairs for the gate.""" + if "twin_pair_id" not in df.columns: + return df + pair_ids = df.loc[df["twin_pair_id"] >= 0, "twin_pair_id"].unique() + if max_pairs is None or max_pairs <= 0 or len(pair_ids) <= max_pairs: + return df[df["twin_pair_id"] >= 0].copy() + chosen = set(rng.choice(pair_ids, size=max_pairs, replace=False).tolist()) + return df[df["twin_pair_id"].isin(chosen)].copy() + + +def gate_volume_thresholds(fast_mode: bool) -> dict: + return { + "matched_eval_pairs_min": _FAST_GATE_MIN_MATCHED_PAIRS if fast_mode else _FULL_GATE_MIN_MATCHED_PAIRS, + "positives_min": _GATE_MIN_CLASS_EXAMPLES, + "negatives_min": _GATE_MIN_CLASS_EXAMPLES, + "unique_fraud_users_min": _GATE_MIN_UNIQUE_USERS, + "unique_benign_users_min": _GATE_MIN_UNIQUE_USERS, + "positive_rate_lo": _GATE_POS_RATE_RANGE[0], + "positive_rate_hi": _GATE_POS_RATE_RANGE[1], + } + + +def summarize_gate_volume( + test_examples: pd.DataFrame, + test_pair_rows: pd.DataFrame, + eval_nodes: Sequence[int], +) -> dict: + positives = int(test_examples["label"].sum()) if not test_examples.empty else 0 + total_examples = int(len(test_examples)) + negatives = int(total_examples - positives) + fraud_users = int(test_examples.loc[test_examples["label"] == 1, "sender_id"].nunique()) if not test_examples.empty else 0 + benign_users = int(test_examples.loc[test_examples["label"] == 0, "sender_id"].nunique()) if not test_examples.empty else 0 + unique_templates = int(test_examples["template_id"].nunique()) if ("template_id" in test_examples.columns and not test_examples.empty) else 0 + positive_rate = float(positives / max(total_examples, 1)) + return { + "matched_eval_pairs": int(len(test_pair_rows)), + "positives": positives, + "negatives": negatives, + "unique_fraud_users": fraud_users, + "unique_benign_users": benign_users, + "unique_templates": unique_templates, + "positive_rate": positive_rate, + "audit_n_examples": int(len(eval_nodes)), + "raw_n_examples": int(len(eval_nodes)), + "xgb_n_examples": total_examples, + "static_gnn_n_examples": total_examples, + } + + +def gate_volume_violations(volume: dict, fast_mode: bool) -> list[str]: + thresholds = gate_volume_thresholds(fast_mode) + violations: list[str] = [] + if volume.get("matched_eval_pairs", 0) < thresholds["matched_eval_pairs_min"]: + violations.append( + f"matched_eval_pairs {volume.get('matched_eval_pairs', 0)} < {thresholds['matched_eval_pairs_min']}" + ) + if volume.get("positives", 0) < thresholds["positives_min"]: + violations.append(f"positives {volume.get('positives', 0)} < {thresholds['positives_min']}") + if volume.get("negatives", 0) < thresholds["negatives_min"]: + violations.append(f"negatives {volume.get('negatives', 0)} < {thresholds['negatives_min']}") + if volume.get("unique_fraud_users", 0) < thresholds["unique_fraud_users_min"]: + violations.append( + f"unique_fraud_users {volume.get('unique_fraud_users', 0)} < {thresholds['unique_fraud_users_min']}" + ) + if volume.get("unique_benign_users", 0) < thresholds["unique_benign_users_min"]: + violations.append( + f"unique_benign_users {volume.get('unique_benign_users', 0)} < {thresholds['unique_benign_users_min']}" + ) + pos_rate = float(volume.get("positive_rate", 0.0)) + if pos_rate < thresholds["positive_rate_lo"] or pos_rate > thresholds["positive_rate_hi"]: + violations.append( + f"positive_rate {pos_rate:.4f} outside [{thresholds['positive_rate_lo']:.2f}, {thresholds['positive_rate_hi']:.2f}]" + ) + return violations + + +def gate_volume_is_sufficient(volume: dict, fast_mode: bool) -> bool: + return len(gate_volume_violations(volume, fast_mode)) == 0 + + +def offset_gate_namespace(df: pd.DataFrame, pack_idx: int) -> pd.DataFrame: + if pack_idx == 0: + return df.copy() + out = df.copy() + offset = pack_idx * _GATE_PACK_NAMESPACE + out["sender_id"] = out["sender_id"].astype(np.int64) + offset + out["receiver_id"] = out["receiver_id"].astype(np.int64) + offset + for col in ("twin_pair_id", "template_id"): + if col in out.columns: + valid = out[col].astype(np.int64) >= 0 + out.loc[valid, col] = out.loc[valid, col].astype(np.int64) + offset + return out + + +def build_gate_pool_from_frames(frames: Sequence[pd.DataFrame]) -> pd.DataFrame: + non_empty = [frame for frame in frames if frame is not None and not frame.empty] + if not non_empty: + return pd.DataFrame() + return ( + pd.concat(non_empty, ignore_index=True) + .sort_values("timestamp") + .reset_index(drop=True) + ) + + +def gate_pair_budget_candidates(total_pairs: int, fast_mode: bool) -> list[int | None]: + if total_pairs <= 0: + return [0] + target_budget = 900 if fast_mode else 3500 + budgets = [min(total_pairs, target_budget)] + if total_pairs > budgets[0]: + budgets.append(total_pairs) + return [int(budget) for budget in dict.fromkeys(budgets)] + + +def prepare_gate_subset( + df_pool: pd.DataFrame, + seed: int, + fast_mode: bool, +) -> dict: + total_pairs = int(df_pool.loc[df_pool["twin_pair_id"] >= 0, "twin_pair_id"].nunique()) if "twin_pair_id" in df_pool.columns else 0 + if total_pairs == 0: + empty = pd.DataFrame() + return { + "pair_budget": 0, + "df_gate": empty, + "df_train": empty, + "df_test": empty, + "df_train_eval": empty, + "df_full": empty, + "eval_nodes": [], + "train_examples": empty, + "train_pair_rows": empty, + "train_pair_counts": empty, + "test_examples": empty, + "test_pair_rows": empty, + "test_pair_counts": empty, + "volume": summarize_gate_volume(empty, empty, []), + } + best: dict | None = None + + for pair_budget in gate_pair_budget_candidates(total_pairs, fast_mode): + gate_rng = np.random.default_rng(seed + int(pair_budget)) + df_gate = _subsample_for_gate(df_pool, gate_rng, max_pairs=pair_budget) + df_train, df_test, _ = split_temporally(df_gate) + df_train, df_test = remap_node_ids(df_train, df_test) + train_examples, train_pair_rows, train_pair_counts = build_matched_control_tables(df_train) + test_examples, test_pair_rows, test_pair_counts = build_matched_control_tables(df_test) + + df_train_eval = augment_with_placeholder_nodes(df_train, df_test) + df_full = ( + pd.concat([df_train_eval, df_test], ignore_index=True) + .sort_values("timestamp") + .reset_index(drop=True) + ) + eval_nodes = get_eval_nodes(df_full) + volume = summarize_gate_volume(test_examples, test_pair_rows, eval_nodes) + + candidate = { + "pair_budget": int(pair_budget) if pair_budget is not None else total_pairs, + "df_gate": df_gate, + "df_train": df_train, + "df_test": df_test, + "df_train_eval": df_train_eval, + "df_full": df_full, + "eval_nodes": eval_nodes, + "train_examples": train_examples, + "train_pair_rows": train_pair_rows, + "train_pair_counts": train_pair_counts, + "test_examples": test_examples, + "test_pair_rows": test_pair_rows, + "test_pair_counts": test_pair_counts, + "volume": volume, + } + best = candidate + if gate_volume_is_sufficient(volume, fast_mode): + return candidate + + assert best is not None + return best + + +def ensure_gate_volume( + df_pool: pd.DataFrame, + config, + seed: int, + benchmark_mode: str, + fast_mode: bool, +) -> dict: + pool = df_pool.copy() + gate = prepare_gate_subset(pool, seed=seed, fast_mode=fast_mode) + pack_count = 1 + + while (not gate_volume_is_sufficient(gate["volume"], fast_mode)) and pack_count <= _GATE_MAX_EXTRA_PACKS: + extra_seed = seed + pack_count * 10_007 + extra_easy, extra_medium, extra_hard = generate_all( + config, + seed=extra_seed, + benchmark_mode=benchmark_mode, + ) + extra_pack = build_gate_pool_from_frames([ + offset_gate_namespace(extra_easy, pack_count), + offset_gate_namespace(extra_medium, pack_count), + offset_gate_namespace(extra_hard, pack_count), + ]) + pool = build_gate_pool_from_frames([pool, extra_pack]) + gate = prepare_gate_subset(pool, seed=seed, fast_mode=fast_mode) + pack_count += 1 + + gate["source_pool_events"] = int(len(pool)) + gate["source_pool_pairs"] = int(pool.loc[pool["twin_pair_id"] >= 0, "twin_pair_id"].nunique()) if "twin_pair_id" in pool.columns else 0 + gate["source_pool_packs"] = int(pack_count) + return gate + + +def ensure_gate_volume_for_difficulty( + config, + difficulty: str, + seed: int, + benchmark_mode: str, + fast_mode: bool, + initial_pool: pd.DataFrame | None = None, +) -> dict: + """Build a reliable-volume gate pool using repeated packs of a single difficulty.""" + if initial_pool is None: + pool = generate_single_difficulty( + config, + difficulty=difficulty, + seed=seed, + benchmark_mode=benchmark_mode, + ) + else: + pool = initial_pool.copy() + + gate = prepare_gate_subset(pool, seed=seed, fast_mode=fast_mode) + pack_count = 1 + + while (not gate_volume_is_sufficient(gate["volume"], fast_mode)) and pack_count <= _GATE_MAX_EXTRA_PACKS: + extra_seed = seed + pack_count * 10_007 + extra_pack = generate_single_difficulty( + config, + difficulty=difficulty, + seed=extra_seed, + benchmark_mode=benchmark_mode, + ) + extra_pack = offset_gate_namespace(extra_pack, pack_count) + pool = build_gate_pool_from_frames([pool, extra_pack]) + gate = prepare_gate_subset(pool, seed=seed, fast_mode=fast_mode) + pack_count += 1 + + gate["source_pool_events"] = int(len(pool)) + gate["source_pool_pairs"] = int(pool.loc[pool["twin_pair_id"] >= 0, "twin_pair_id"].nunique()) if "twin_pair_id" in pool.columns else 0 + gate["source_pool_packs"] = int(pack_count) + return gate + + +# --------------------------------------------------------------------------- +# Motif Validity Check (req #11) +# --------------------------------------------------------------------------- + +def run_motif_validity_check( + df: pd.DataFrame, + config, + seed: int, + device: str, + num_epochs: int, + node_epochs: int, + n_checkpoints: int, + hard_abort: bool = True, + horizon: float = 0.10, + benchmark_mode: str = "temporal_twins_oracle_calib", + fast_mode: bool = False, + force_temporal_models: bool = False, + prebuilt_gate: dict | None = None, +) -> tuple[bool, dict]: + """Run the staged MOTIF VALIDITY CHECK gate. + + Stage 1 — AuditOracle: reads audit cols directly. >= 0.99 required. + Stage 2 — RawMotifOracle: reconstructs motif. >= 0.95 required. + Stage 3 — Static ceilings: XGB <= 0.65, StaticGNN <= 0.70. + Stage 4 — SeqGRU: >= 0.80 (calib mode only). + + Oracles are evaluated against twin_label (NOT window label) to avoid + the target-alignment bug where late windows have no upcoming fraud events. + """ + calib_mode = _is_oracle_calib_mode(benchmark_mode) + metric_labels = _oracle_metric_labels(benchmark_mode) + + # Dataset-wide stats computed on the FULL df before subsampling + consistency = compute_motif_label_consistency(df, calib_mode=calib_mode) + delay_stats = compute_label_delay_stats(df) + node_audit = build_node_audit_table(df) + ks_mean, ks_max = compute_aggregate_ks(node_audit) + + gate = prebuilt_gate + if gate is None: + gate = ensure_gate_volume( + df_pool=df, + config=config, + seed=seed, + benchmark_mode=benchmark_mode, + fast_mode=fast_mode, + ) + df_gate = gate["df_gate"] + df_train = gate["df_train"] + df_test = gate["df_test"] + df_train_eval = gate["df_train_eval"] + df_full = gate["df_full"] + eval_nodes = gate["eval_nodes"] + train_examples = gate["train_examples"] + train_pair_rows = gate["train_pair_rows"] + train_pair_counts = gate["train_pair_counts"] + test_examples = gate["test_examples"] + test_pair_rows = gate["test_pair_rows"] + test_pair_counts = gate["test_pair_counts"] + gate_volume = gate["volume"] + + print( + f" [gate] Using {df_gate['twin_pair_id'].nunique()} pairs " + f"({len(df_gate):,} events) for model stages from " + f"{gate['source_pool_packs']} pack(s), source pairs={gate['source_pool_pairs']:,}." + ) + + leakage = compute_split_leakage(df_train, df_test) + matched_train_features = build_matched_prefix_feature_table(df_train, train_examples) + matched_test_features = build_matched_prefix_feature_table(df_test, test_examples) + matched_audit = report_matched_control_audits( + test_examples=test_examples, + test_pair_rows=test_pair_rows, + test_pair_counts=test_pair_counts, + ) + static_agg_result = compute_matched_static_aggregate_auc( + matched_train_features, + matched_test_features, + seed=seed, + verbose=False, + ) + delta_time = horizon_to_delta(df_test, horizon) + y_twin = _twin_labels_for_nodes(df_full, eval_nodes) + + report: dict = { + "ks_mean": ks_mean, "ks_max": ks_max, + "static_agg_auc": float(static_agg_result["auc"]), + **delay_stats, + **{k: v for k, v in consistency.items()}, + **matched_audit, + **gate_volume, + "gate_pair_budget": int(gate["pair_budget"]), + "gate_source_pool_events": int(gate["source_pool_events"]), + "gate_source_pool_pairs": int(gate["source_pool_pairs"]), + "gate_source_pool_packs": int(gate["source_pool_packs"]), + } + attach_auc_result(report, "static_agg", static_agg_result) + oracle_scores: dict[str, np.ndarray] = {} + + # Stage 1 — AuditOracle / MotifProbe (no training; reads motif_hit_count directly) + audit_oracle = AuditOracleWrapper() + audit_probs = audit_oracle.predict(df_full, eval_nodes) + oracle_scores[metric_labels["audit"]] = audit_probs + audit_result = make_auc_result(y_twin, audit_probs.astype(np.float32), seed=seed) + attach_auc_result(report, "audit", audit_result) + report["audit_pair_sep"] = evaluate_oracle_pair_sep_node_level( + audit_oracle, df_full, eval_nodes + ) + + # Stage 2 — RawMotifOracle / RawMotifProbe (trained on twin_label, not window label) + raw_oracle = RawMotifOracleWrapper() + raw_oracle.fit(df_train_eval, num_epochs=num_epochs) + train_nodes_raw = get_eval_nodes(df_train_eval) + y_train_twin_raw = _twin_labels_for_nodes(df_train_eval, train_nodes_raw) + train_node_head( + raw_oracle, + df_anchor_prefix=df_train_eval, + eval_nodes=train_nodes_raw, + y_labels=y_train_twin_raw, + num_epochs=node_epochs, + ) + raw_probs = raw_oracle.predict(df_full, eval_nodes) + oracle_scores[metric_labels["raw"]] = raw_probs + raw_result = make_auc_result(y_twin, raw_probs.astype(np.float32), seed=seed) + attach_auc_result(report, "raw", raw_result) + report["raw_pair_sep"] = evaluate_oracle_pair_sep_node_level( + raw_oracle, df_full, eval_nodes + ) + _attach_probe_aliases(report, benchmark_mode) + + # Oracle/probe debug table + build_oracle_debug_table( + df_full, + eval_nodes, + oracle_scores, + y_twin, + primary_score_name=metric_labels["audit"], + table_title=metric_labels["table"], + ) + + # Stage 3 — Static baselines (window-label eval, as in main benchmark) + xgb_result = compute_matched_xgboost_auc( + matched_train_features, + matched_test_features, + seed=seed, + ) + attach_auc_result(report, "xgb", xgb_result) + static_gnn_result = compute_matched_static_gnn_auc( + df_train=df_train, + df_test=df_test, + train_examples=train_examples, + test_examples=test_examples, + device=device, + num_epochs=num_epochs, + seed=seed, + ) + attach_auc_result(report, "static_gnn", static_gnn_result) + report["xgb_roc_auc"] = float(xgb_result["auc"]) + report["static_gnn_roc"] = float(static_gnn_result["auc"]) + report["static_gnn_auc_flipped"] = float(static_gnn_result.get("auc_flipped", float("nan"))) + report["static_gnn_score_mean_pos"] = float(static_gnn_result.get("score_mean_pos", float("nan"))) + report["static_gnn_score_mean_neg"] = float(static_gnn_result.get("score_mean_neg", float("nan"))) + report["static_gnn_score_std"] = float(static_gnn_result.get("score_std", float("nan"))) + report["static_gnn_zero_emb_frac"] = float(static_gnn_result.get("zero_emb_frac", float("nan"))) + report["static_gnn_matched_examples"] = int(static_gnn_result.get("matched_examples", 0)) + report["static_gnn_unique_prefix_cutoffs"] = int(static_gnn_result.get("unique_prefix_cutoffs", 0)) + report["static_gnn_graph_builds"] = int(static_gnn_result.get("graph_builds", 0)) + report["static_gnn_cache_hit_rate"] = float(static_gnn_result.get("cache_hit_rate", float("nan"))) + report["static_gnn_eval_time_sec"] = float(static_gnn_result.get("eval_time_sec", float("nan"))) + report["static_gnn_train_unique_prefix_cutoffs"] = int(static_gnn_result.get("train_unique_prefix_cutoffs", 0)) + report["static_gnn_test_unique_prefix_cutoffs"] = int(static_gnn_result.get("test_unique_prefix_cutoffs", 0)) + report["static_gnn_train_graph_builds"] = int(static_gnn_result.get("train_graph_builds", 0)) + report["static_gnn_test_graph_builds"] = int(static_gnn_result.get("test_graph_builds", 0)) + report["static_gnn_train_eval_time_sec"] = float(static_gnn_result.get("train_eval_time_sec", float("nan"))) + report["static_gnn_test_eval_time_sec"] = float(static_gnn_result.get("test_eval_time_sec", float("nan"))) + + # Stage 4 — SeqGRU (calib mode only) + run_temporal_models = calib_mode or force_temporal_models + if run_temporal_models: + seqgru_result = compute_matched_seqgru_metrics( + df_train=df_train, + df_test=df_test, + train_examples=train_examples, + test_examples=test_examples, + device=device, + seed=seed, + max_epochs=max(24, min(72, node_epochs)), + ) + seqgru_clean = seqgru_result["clean"] + seqgru_shuffled = seqgru_result["shuffled"] + report["seqgru_roc_auc"] = float(seqgru_clean["auc"]) + report["seqgru_pr_auc"] = float(seqgru_clean.get("pr_auc", float("nan"))) + report["seqgru_brier"] = float(seqgru_clean.get("brier", float("nan"))) + report["seqgru_ece"] = float(seqgru_clean.get("ece", float("nan"))) + report["seqgru_n_examples"] = int(seqgru_clean.get("n_examples", 0)) + report["seqgru_shuffle_roc_auc"] = float(seqgru_shuffled["auc"]) + report["seqgru_shuffle_pr_auc"] = float(seqgru_shuffled.get("pr_auc", float("nan"))) + report["seqgru_shuffle_delta"] = float(seqgru_result["delta"]) + report["seqgru_best_epoch"] = int(seqgru_result["clean_fit"].get("best_epoch", 0)) + report["seqgru_best_valid_roc_auc"] = float(seqgru_result["clean_fit"].get("best_valid_roc_auc", float("nan"))) + report["seqgru_best_valid_pr_auc"] = float(seqgru_result["clean_fit"].get("best_valid_pr_auc", float("nan"))) + report["seqgru_shuffle_best_epoch"] = int(seqgru_result["shuffled_fit"].get("best_epoch", 0)) + report["seqgru_shuffle_best_valid_roc_auc"] = float(seqgru_result["shuffled_fit"].get("best_valid_roc_auc", float("nan"))) + + temporal_gnn_specs = [ + ("TGN", "tgn", lambda: TGNWrapper(device=device)), + ("TGAT", "tgat", lambda: TGATWrapper(device=device)), + ("DyRep", "dyrep", lambda: DyRepWrapper(device=device)), + ("JODIE", "jodie", lambda: JODIEWrapper(device=device)), + ] + temporal_num_epochs = max(2, num_epochs) + for model_label, key_prefix, builder in temporal_gnn_specs: + temporal_result = compute_matched_temporal_gnn_metrics( + model_name=model_label, + model_builder=builder, + df_train=df_train, + df_test=df_test, + train_examples=train_examples, + test_examples=test_examples, + seed=seed, + num_epochs=temporal_num_epochs, + ) + clean_result = temporal_result["clean"] + shuffled_result = temporal_result["shuffled"] + report[f"{key_prefix}_roc_auc"] = float(clean_result["auc"]) + report[f"{key_prefix}_pr_auc"] = float(clean_result.get("pr_auc", float("nan"))) + report[f"{key_prefix}_n_examples"] = int(clean_result.get("n_examples", 0)) + report[f"{key_prefix}_shuffle_roc_auc"] = float(shuffled_result["auc"]) + report[f"{key_prefix}_shuffle_pr_auc"] = float(shuffled_result.get("pr_auc", float("nan"))) + report[f"{key_prefix}_shuffle_delta"] = float(temporal_result["delta"]) + else: + report["seqgru_roc_auc"] = float("nan") + report["seqgru_pr_auc"] = float("nan") + report["seqgru_n_examples"] = 0 + report["seqgru_shuffle_roc_auc"] = float("nan") + report["seqgru_shuffle_pr_auc"] = float("nan") + report["seqgru_shuffle_delta"] = float("nan") + report["seqgru_best_epoch"] = 0 + report["seqgru_best_valid_roc_auc"] = float("nan") + report["seqgru_best_valid_pr_auc"] = float("nan") + report["seqgru_shuffle_best_epoch"] = 0 + report["seqgru_shuffle_best_valid_roc_auc"] = float("nan") + for key_prefix in ("tgn", "tgat", "dyrep", "jodie"): + report[f"{key_prefix}_roc_auc"] = float("nan") + report[f"{key_prefix}_pr_auc"] = float("nan") + report[f"{key_prefix}_n_examples"] = 0 + report[f"{key_prefix}_shuffle_roc_auc"] = float("nan") + report[f"{key_prefix}_shuffle_pr_auc"] = float("nan") + report[f"{key_prefix}_shuffle_delta"] = float("nan") + + # Gate table + gate_items = [ + (f"{metric_labels['audit']} ROC-AUC", "audit_roc_auc", "ge", 0.99, "label-alignment bug"), + (f"{metric_labels['audit']} pair-sep", "audit_pair_sep", "ge", 0.99, "pair construction bug"), + (f"{metric_labels['raw']} ROC-AUC", "raw_roc_auc", "ge", 0.95, "motif reconstruction bug"), + (f"{metric_labels['raw']} pair-sep", "raw_pair_sep", "ge", 0.90, "motif reconstruction bug"), + ("static_agg_auc", "static_agg_auc", "le", 0.60, "static leakage"), + ("XGBoost ROC-AUC", "xgb_roc_auc", "le", 0.65, "static leakage"), + ("StaticGNN ROC-AUC", "static_gnn_roc", "le", 0.70, "static leakage"), + ] + if run_temporal_models: + gate_items.extend([ + ("SeqGRU ROC-AUC", "seqgru_roc_auc", "ge", 0.80, "learning/input bug"), + ("SeqGRU shuffle delta", "seqgru_shuffle_delta", "le", -0.10, "order signal missing"), + ]) + temporal_gnn_items = [ + ("TGN ROC-AUC", "tgn_roc_auc", "ge", 0.75, "temporal learnability"), + ("TGN shuffle delta", "tgn_shuffle_delta", "le", -0.10, "order signal missing"), + ("TGAT ROC-AUC", "tgat_roc_auc", "ge", 0.75, "temporal learnability"), + ("TGAT shuffle delta", "tgat_shuffle_delta", "le", -0.10, "order signal missing"), + ("DyRep ROC-AUC", "dyrep_roc_auc", "ge", 0.75, "temporal learnability"), + ("DyRep shuffle delta", "dyrep_shuffle_delta", "le", -0.10, "order signal missing"), + ("JODIE ROC-AUC", "jodie_roc_auc", "ge", 0.75, "temporal learnability"), + ("JODIE shuffle delta", "jodie_shuffle_delta", "le", -0.10, "order signal missing"), + ] + + print("\n" + "=" * 72) + print(" MOTIF VALIDITY CHECK") + print("=" * 72) + print(" Gate Volume") + print(f" matched_eval_pairs : {report['matched_eval_pairs']}") + print(f" positives / negatives : {report['positives']} / {report['negatives']}") + print(f" unique fraud / benign : {report['unique_fraud_users']} / {report['unique_benign_users']}") + print(f" unique templates : {report['unique_templates']}") + print(f" positive rate : {report['positive_rate']:.4f}") + print( + " model examples : " + f"{metric_labels['audit']}={report['audit_n_examples']} " + f"{metric_labels['raw']}={report['raw_n_examples']} " + f"XGB={report['xgb_n_examples']} StaticGNN={report['static_gnn_n_examples']} " + f"SeqGRU={report['seqgru_n_examples']} TGN={report['tgn_n_examples']} " + f"TGAT={report['tgat_n_examples']} DyRep={report['dyrep_n_examples']} " + f"JODIE={report['jodie_n_examples']}" + ) + print(f" gate source packs/pairs : {report['gate_source_pool_packs']} / {report['gate_source_pool_pairs']}") + print(f" gate pair budget : {report['gate_pair_budget']}") + volume_violations = gate_volume_violations(report, fast_mode) + if volume_violations: + print(" INSUFFICIENT_GATE_VOLUME") + for violation in volume_violations: + print(f" - {violation}") + + all_pass = True + if volume_violations: + all_pass = False + for label, key, op, thresh, hint in gate_items: + val = report.get(key, float("nan")) + is_nan = val != val + ok = (not is_nan) and ((val >= thresh) if op == "ge" else (val <= thresh)) + status = "N/A " if is_nan else ("PASS" if ok else "FAIL") + if not ok: + all_pass = False + tstr = f"{'>='+str(thresh) if op=='ge' else '<='+str(thresh)}" + print(f" {label:<28}: {val:>7.4f} [{status} {tstr}] {'<-- '+hint if not ok else ''}") + + for label, key, op, thresh, hint in temporal_gnn_items: + val = report.get(key, float("nan")) + is_nan = val != val + ok = (not is_nan) and ((val >= thresh) if op == "ge" else (val <= thresh)) + status = "N/A " if is_nan else ("PASS" if ok else "FAIL") + tstr = f"{'>='+str(thresh) if op=='ge' else '<='+str(thresh)}" + suffix = "" if ok else f" [advisory: {hint}]" + print(f" {label:<28}: {val:>7.4f} [{status} {tstr}]{suffix}") + + audit_ci_label = f"{metric_labels['audit']} AUC std/CI" + raw_ci_label = f"{metric_labels['raw']} std/CI" + print(f" {audit_ci_label:<28}: {report['audit_auc_bootstrap_std']:.4f} [{report['audit_auc_ci_lo']:.4f}, {report['audit_auc_ci_hi']:.4f}]") + print(f" {raw_ci_label:<28}: {report['raw_auc_bootstrap_std']:.4f} [{report['raw_auc_ci_lo']:.4f}, {report['raw_auc_ci_hi']:.4f}]") + print(f" {'XGBoost AUC std/CI':<28}: {report['xgb_auc_bootstrap_std']:.4f} [{report['xgb_auc_ci_lo']:.4f}, {report['xgb_auc_ci_hi']:.4f}]") + print(f" {'StaticGNN AUC std/CI':<28}: {report['static_gnn_auc_bootstrap_std']:.4f} [{report['static_gnn_auc_ci_lo']:.4f}, {report['static_gnn_auc_ci_hi']:.4f}]") + print(f" {'static_agg_auc std/CI':<28}: {report['static_agg_auc_bootstrap_std']:.4f} [{report['static_agg_auc_ci_lo']:.4f}, {report['static_agg_auc_ci_hi']:.4f}]") + print(f" {'StaticGNN flip check':<28}: auc={report['static_gnn_roc']:.4f} flipped={report['static_gnn_auc_flipped']:.4f} zero_emb={report['static_gnn_zero_emb_frac']:.4f}") + print(f" {'StaticGNN score means':<28}: pos={report['static_gnn_score_mean_pos']:.4f} neg={report['static_gnn_score_mean_neg']:.4f} std={report['static_gnn_score_std']:.4f}") + print( + f" {'StaticGNN runtime':<28}: " + f"examples={report['static_gnn_matched_examples']} " + f"cutoffs={report['static_gnn_unique_prefix_cutoffs']} " + f"builds={report['static_gnn_graph_builds']} " + f"hit_rate={report['static_gnn_cache_hit_rate']:.4f} " + f"time={report['static_gnn_eval_time_sec']:.2f}s" + ) + print( + f" {'StaticGNN train/test rt':<28}: " + f"train_cutoffs={report['static_gnn_train_unique_prefix_cutoffs']} " + f"test_cutoffs={report['static_gnn_test_unique_prefix_cutoffs']} " + f"train_builds={report['static_gnn_train_graph_builds']} " + f"test_builds={report['static_gnn_test_graph_builds']} " + f"train={report['static_gnn_train_eval_time_sec']:.2f}s " + f"test={report['static_gnn_test_eval_time_sec']:.2f}s" + ) + print(f" {'SeqGRU PR-AUC':<28}: {report['seqgru_pr_auc']:.4f} [informational]") + print(f" {'SeqGRU shuffled ROC-AUC':<28}: {report['seqgru_shuffle_roc_auc']:.4f} [informational]") + print(f" {'SeqGRU shuffled PR-AUC':<28}: {report['seqgru_shuffle_pr_auc']:.4f} [informational]") + print(f" {'SeqGRU early stop':<28}: epoch={report['seqgru_best_epoch']} valid_roc={report['seqgru_best_valid_roc_auc']:.4f} valid_pr={report['seqgru_best_valid_pr_auc']:.4f}") + print(f" {'SeqGRU shuffled stop':<28}: epoch={report['seqgru_shuffle_best_epoch']} valid_roc={report['seqgru_shuffle_best_valid_roc_auc']:.4f}") + print(f" {'TGN PR/shuffled ROC':<28}: pr={report['tgn_pr_auc']:.4f} shuffled={report['tgn_shuffle_roc_auc']:.4f}") + print(f" {'TGAT PR/shuffled ROC':<28}: pr={report['tgat_pr_auc']:.4f} shuffled={report['tgat_shuffle_roc_auc']:.4f}") + print(f" {'DyRep PR/shuffled ROC':<28}: pr={report['dyrep_pr_auc']:.4f} shuffled={report['dyrep_shuffle_roc_auc']:.4f}") + print(f" {'JODIE PR/shuffled ROC':<28}: pr={report['jodie_pr_auc']:.4f} shuffled={report['jodie_shuffle_roc_auc']:.4f}") + print(f" {'P(label|hit>=1)':<28}: {report.get('p_label_given_hit', float('nan')):>7.4f} [informational]") + print(f" {'P(label|hit=0)':<28}: {report.get('p_label_given_nohit', float('nan')):>7.4f} [informational]") + print(f" {'accidental_benign_motif':<28}: {report.get('accidental_benign_motif', float('nan')):>7.4f} [informational]") + print(f" {'KS mean/max':<28}: {ks_mean:>7.4f} / {ks_max:.4f}") + print(f" {'delay min/mean/max':<28}: " + f"{report.get('delay_min',float('nan')):.1f} / " + f"{report.get('delay_mean',float('nan')):.1f} / " + f"{report.get('delay_max',float('nan')):.1f}") + print("=" * 72) + + if all_pass: + print(" [GATE] All thresholds met. Proceeding to full run.") + else: + msg = "[GATE] One or more thresholds FAILED." + if hard_abort: + print(f" {msg} Aborting (hard gate).") + sys.exit(1) + else: + print(f" {msg} Continuing as soft diagnostic.") + + return all_pass, report + + +# --------------------------------------------------------------------------- +# Model factory +# --------------------------------------------------------------------------- + +def build_models(device: str = "cpu") -> Dict[str, TemporalModel]: + return { + "OracleMotif": OracleMotifWrapper(), + "SeqGRU": SequenceGRUWrapper(hidden_dim=64, receiver_buckets=256, device=device), + "TGN": TGNWrapper(memory_dim=64, time_dim=16, device=device), + "TGAT": TGATWrapper(memory_dim=64, time_dim=8, num_heads=4, n_neighbors=10, device=device), + "DyRep": DyRepWrapper(memory_dim=64, time_dim=8, device=device), + "JODIE": JODIEWrapper(memory_dim=64, time_emb_dim=16, device=device), + "StaticGNN": StaticGNNWrapper(hidden_dim=64, n_snapshots=10, device=device), + "XGBoost": XGBoostWrapper(n_estimators=200, max_depth=6), + } + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + +def parse_seed_list(seed_string: str) -> List[int]: + return [int(token.strip()) for token in seed_string.split(",") if token.strip()] + + +def parse_args(): + parser = argparse.ArgumentParser(description="Leakage-free UPI-Sim benchmark runner") + parser.add_argument("--fast", action="store_true", help="Fast mode: 1 epoch and fewer checkpoints.") + parser.add_argument("--seed", type=int, default=None, help="Run a single seed.") + parser.add_argument( + "--seeds", + nargs="+", + type=int, + default=None, + help="Space-separated seed list, e.g. --seeds 0 1 2 3 4", + ) + parser.add_argument( + "--config", + type=str, + default="config/default.yaml", + help="Path to config YAML.", + ) + parser.add_argument( + "--device", + type=str, + default="cpu", + help='Torch device ("cpu" or "cuda").', + ) + parser.add_argument( + "--benchmark-mode", + type=str, + default=None, + help='Benchmark mode override, e.g. "standard" or "temporal_twins".', + ) + parser.add_argument( + "--experiments", + nargs="+", + type=str, + default=None, + help="Space-separated list of experiments to run, e.g. --experiments ood causal horizon mechanistic audit", + ) + return parser.parse_args() + + +def main(): + args = parse_args() + # Support both space-separated (nargs=+) and comma-separated experiment lists + if args.experiments is None: + experiments_to_run = {"ood", "causal", "horizon", "mechanistic", "audit"} + elif isinstance(args.experiments, list): + experiments_to_run = set(args.experiments) + else: + experiments_to_run = {exp.strip() for exp in args.experiments.split(",") if exp.strip()} + num_epochs = 1 if args.fast else 3 + node_epochs = 60 if args.fast else 150 + n_checkpoints = 4 if args.fast else 8 + if args.seed is not None: + seeds = [args.seed] + elif args.seeds is not None: + # Already parsed as List[int] via nargs="+" + seeds = args.seeds + else: + seeds = [0, 1, 2, 3, 4] + + config = load_config(args.config) + benchmark_mode = args.benchmark_mode or getattr(config, "benchmark_mode", "standard") + + print("=" * 60) + print(" UPI-Sim Multi-Model Benchmark (Leakage-Free)") + print(f" epochs={num_epochs} node_epochs={node_epochs} checkpoints={n_checkpoints}") + print(f" seeds={seeds} device={args.device} mode={benchmark_mode}") + print("=" * 60) + + raw_frames: Dict[str, List[pd.DataFrame]] = { + "ood": [], + "causal": [], + "horizon": [], + "mechanistic": [], + "audit": [], + } + + import torch + + is_twin_mode = benchmark_mode in ("temporal_twins", "temporal_twins_oracle_calib") + calib_mode = benchmark_mode == "temporal_twins_oracle_calib" + + for seed in seeds: + set_global_determinism(seed) + + print(f"\n[data] Generating datasets for seed={seed}...") + df_easy, df_medium, df_hard = generate_all( + config, + seed=seed, + benchmark_mode=benchmark_mode, + ) + print(f" Easy : {len(df_easy):,} events | fraud={df_easy['is_fraud'].mean():.3f}") + print(f" Medium: {len(df_medium):,} events | fraud={df_medium['is_fraud'].mean():.3f}") + print(f" Hard : {len(df_hard):,} events | fraud={df_hard['is_fraud'].mean():.3f}") + + if is_twin_mode: + # Run validity check: hard-abort in calib mode, soft diagnostic otherwise. + gate_df = build_gate_pool_from_frames([df_easy, df_medium, df_hard]) + run_motif_validity_check( + df=gate_df, + config=config, + seed=seed, + device=args.device, + num_epochs=num_epochs, + node_epochs=node_epochs, + n_checkpoints=n_checkpoints, + hard_abort=calib_mode, + benchmark_mode=benchmark_mode, + fast_mode=args.fast, + ) + + if "ood" in experiments_to_run: + print(f"\n[seed={seed}] OOD generalisation") + df_ood = run_ood_single( + df_easy=df_easy, + df_medium=df_medium, + df_hard=df_hard, + device=args.device, + num_epochs=num_epochs, + node_epochs=node_epochs, + n_checkpoints=n_checkpoints, + ) + df_ood["seed"] = seed + raw_frames["ood"].append(df_ood) + + if "causal" in experiments_to_run: + print(f"\n[seed={seed}] Causal chronology shuffle") + df_causal = run_causal_single( + df_hard=df_hard, + device=args.device, + num_epochs=num_epochs, + node_epochs=node_epochs, + n_checkpoints=n_checkpoints, + seed=seed, + ) + df_causal["seed"] = seed + raw_frames["causal"].append(df_causal) + + if "horizon" in experiments_to_run: + print(f"\n[seed={seed}] Horizon sweep") + df_horizon = run_horizon_single( + df_medium=df_medium, + device=args.device, + num_epochs=num_epochs, + node_epochs=node_epochs, + n_checkpoints=n_checkpoints, + horizons=DEFAULT_HORIZONS, + ) + df_horizon["seed"] = seed + raw_frames["horizon"].append(df_horizon) + + if "mechanistic" in experiments_to_run: + print(f"\n[seed={seed}] Mechanistic correlation") + df_mech = run_mechanistic_single( + df_hard=df_hard, + device=args.device, + num_epochs=num_epochs, + node_epochs=node_epochs, + n_checkpoints=n_checkpoints, + ) + df_mech["seed"] = seed + raw_frames["mechanistic"].append(df_mech) + + if "audit" in experiments_to_run: + print(f"\n[seed={seed}] Temporal twins audit") + df_audit = run_audit_single( + df_hard=df_hard, + device=args.device, + num_epochs=num_epochs, + node_epochs=node_epochs, + n_checkpoints=n_checkpoints, + seed=seed, + benchmark_mode=benchmark_mode, + ) + df_audit["seed"] = seed + raw_frames["audit"].append(df_audit) + + save_experiment_outputs(raw_frames, results_dir="results") + + print("\n" + "=" * 60) + print(" All requested experiments completed.") + print(" Saved raw + summary CSVs in results/") + print(" Run: python -m plots.plot_results") + print("=" * 60) + + +if __name__ == "__main__": + main()