temporal-twins-anon's picture
Add anonymous Temporal Twins code release
a3682cf verified
"""
experiments/run_all.py
======================
Leakage-free experiment runner for the UPI-Sim temporal fraud benchmark.
Key protocol changes
--------------------
- Strict prefix evaluation: models only see events up to cutoff t.
- Horizon-specific retraining: each horizon uses fresh model instances.
- Causal ablation trains/evaluates on globally shuffled chronology.
- XGBoost uses the real xgboost library with aligned node-level labels.
- All experiments support multi-seed aggregation with mean ± std outputs.
"""
from __future__ import annotations
import argparse
import hashlib
import os
import random
import sys
import time
from typing import Dict, Iterable, List, Sequence
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8")
import numpy as np
import pandas as pd
import torch
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import average_precision_score, brier_score_loss, roc_auc_score
from xgboost import XGBClassifier
_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if _ROOT not in sys.path:
sys.path.insert(0, _ROOT)
from src.core.config_loader import load_config
from src.generators.user_generator import generate_users
from src.generators.transaction_generator import generate_transactions
from src.fraud.fraud_engine import FraudEngine, ORACLE_ONLY_COLS
from src.graph.graph_builder import build_edge_features
from src.risk.risk_engine import apply_risk_engine
from models.base import TemporalModel
from models.dyrep import DyRepWrapper
from models.jodie import JODIEWrapper
from models.audit_oracle import AuditOracleWrapper, RawMotifOracleWrapper
from models.oracle_motif import OracleMotifWrapper
from models.sequence_gru import SequenceGRUWrapper
from models.static_gnn import StaticGNNWrapper
from models.tgat import TGATWrapper
from models.tgn_wrapper import TGNWrapper
from models.xgboost_model import XGBoostWrapper
torch.set_num_threads(1)
if hasattr(torch, "set_num_interop_threads"):
try:
torch.set_num_interop_threads(1)
except RuntimeError:
pass
# Oracle models that are allowed to receive unstripped audit columns
_ORACLE_MODEL_NAMES: frozenset = frozenset({"OracleMotif", "AuditOracle", "RawMotifOracle"})
# ---------------------------------------------------------------------------
# Oracle / audit column stripping
# ---------------------------------------------------------------------------
def strip_oracle_cols(df: pd.DataFrame) -> pd.DataFrame:
"""Remove audit/oracle columns before passing data to learned baselines."""
cols_to_drop = [c for c in df.columns if c in ORACLE_ONLY_COLS]
if cols_to_drop:
return df.drop(columns=cols_to_drop)
return df
DEFAULT_HORIZONS = [0.01, 0.05, 0.10, 0.20]
DEFAULT_SEEDS = [0, 1, 2, 3, 4]
_TWIN_DIFFICULTY_USER_SEED_OFFSETS = {"easy": 11, "medium": 23, "hard": 37}
MODEL_ORDER = [
"OracleMotif",
"SeqGRU",
"TGN",
"TGAT",
"DyRep",
"JODIE",
"StaticGNN",
"XGBoost",
]
def stable_int_hash(*parts: object, modulo: int = 2**32) -> int:
"""Deterministic integer hash for seed derivation across Python processes."""
seed_material = "::".join(map(str, parts))
digest = hashlib.sha256(seed_material.encode("utf-8")).hexdigest()
return int(digest[:16], 16) % modulo
def seed_python_numpy(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
def set_global_determinism(seed: int) -> None:
seed_python_numpy(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
if hasattr(torch.backends, "cudnn"):
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
if hasattr(torch.backends.cudnn, "allow_tf32"):
torch.backends.cudnn.allow_tf32 = False
if hasattr(torch.backends, "cuda") and hasattr(torch.backends.cuda, "matmul"):
torch.backends.cuda.matmul.allow_tf32 = False
try:
torch.use_deterministic_algorithms(True)
except Exception:
pass
def derived_seed(base_seed: int, *parts: object, modulo: int = 2**31 - 1) -> int:
return int((int(base_seed) + stable_int_hash(*parts, modulo=modulo)) % modulo)
def _is_oracle_calib_mode(benchmark_mode: str) -> bool:
return benchmark_mode == "temporal_twins_oracle_calib"
def _oracle_metric_labels(benchmark_mode: str) -> dict[str, str]:
if _is_oracle_calib_mode(benchmark_mode):
return {
"audit": "AuditOracle",
"raw": "RawMotifOracle",
"table": "Oracle Debug Table",
}
return {
"audit": "MotifProbe",
"raw": "RawMotifProbe",
"table": "Probe Debug Table",
}
def _attach_probe_aliases(report: dict, benchmark_mode: str) -> None:
"""Expose standard-mode probe names without breaking old oracle-key consumers."""
labels = _oracle_metric_labels(benchmark_mode)
report["audit_metric_label"] = labels["audit"]
report["raw_metric_label"] = labels["raw"]
if _is_oracle_calib_mode(benchmark_mode):
return
alias_map = {
"motif_probe_roc_auc": "audit_roc_auc",
"motif_probe_pair_sep": "audit_pair_sep",
"motif_probe_n_examples": "audit_n_examples",
"motif_probe_auc_bootstrap_std": "audit_auc_bootstrap_std",
"motif_probe_auc_ci_lo": "audit_auc_ci_lo",
"motif_probe_auc_ci_hi": "audit_auc_ci_hi",
"raw_motif_probe_roc_auc": "raw_roc_auc",
"raw_motif_probe_pair_sep": "raw_pair_sep",
"raw_motif_probe_n_examples": "raw_n_examples",
"raw_motif_probe_auc_bootstrap_std": "raw_auc_bootstrap_std",
"raw_motif_probe_auc_ci_lo": "raw_auc_ci_lo",
"raw_motif_probe_auc_ci_hi": "raw_auc_ci_hi",
}
for alias_key, source_key in alias_map.items():
if source_key in report:
report[alias_key] = report[source_key]
# ---------------------------------------------------------------------------
# Data generation
# ---------------------------------------------------------------------------
def generate_difficulty(
config,
users: pd.DataFrame,
difficulty: str,
seed: int,
time_offset: float = 0.0,
benchmark_mode: str = "standard",
) -> pd.DataFrame:
"""Generate one difficulty slice with a global timestamp offset."""
df = generate_transactions(users, config)
df = apply_risk_engine(df, users, config)
engine_seed = seed + stable_int_hash("FraudEngine", difficulty, benchmark_mode, modulo=10_000)
engine = FraudEngine(
seed=engine_seed,
difficulty=difficulty,
benchmark_mode=benchmark_mode,
)
df = engine.apply(df)
df = df.sort_values("timestamp").reset_index(drop=True)
if benchmark_mode in ("temporal_twins", "temporal_twins_oracle_calib"):
diff_offset = {"easy": 0, "medium": 1_000_000, "hard": 2_000_000}[difficulty]
df["sender_id"] = df["sender_id"].astype(np.int64) + diff_offset
df["receiver_id"] = df["receiver_id"].astype(np.int64) + diff_offset
if "twin_pair_id" in df.columns:
df["twin_pair_id"] = df["twin_pair_id"].astype(np.int64)
valid = df["twin_pair_id"] >= 0
df.loc[valid, "twin_pair_id"] = df.loc[valid, "twin_pair_id"] + diff_offset
if "template_id" in df.columns:
df["template_id"] = df["template_id"].astype(np.int64)
valid = df["template_id"] >= 0
df.loc[valid, "template_id"] = df.loc[valid, "template_id"] + diff_offset
df["timestamp"] = df["timestamp"] + time_offset
return df
def generate_all(config, seed: int = 42, benchmark_mode: str = "standard"):
"""Generate Easy/Medium/Hard datasets."""
seed_python_numpy(seed)
gap = 1_000.0
if benchmark_mode in ("temporal_twins", "temporal_twins_oracle_calib"):
seed_python_numpy(seed + 11)
users_easy = generate_users(config)
seed_python_numpy(seed + 23)
users_medium = generate_users(config)
seed_python_numpy(seed + 37)
users_hard = generate_users(config)
else:
shared_users = generate_users(config)
users_easy = shared_users
users_medium = shared_users
users_hard = shared_users
df_easy = generate_difficulty(
config,
users_easy,
"easy",
seed,
time_offset=0.0,
benchmark_mode=benchmark_mode,
)
t_after_easy = float(df_easy["timestamp"].max()) + gap
df_medium = generate_difficulty(
config,
users_medium,
"medium",
seed,
time_offset=t_after_easy,
benchmark_mode=benchmark_mode,
)
t_after_medium = float(df_medium["timestamp"].max()) + gap
df_hard = generate_difficulty(
config,
users_hard,
"hard",
seed,
time_offset=t_after_medium,
benchmark_mode=benchmark_mode,
)
return df_easy, df_medium, df_hard
def generate_single_difficulty(
config,
difficulty: str,
seed: int = 42,
benchmark_mode: str = "standard",
) -> pd.DataFrame:
"""Generate one difficulty slice using the same user-seed scheme as generate_all()."""
seed_python_numpy(seed)
if benchmark_mode in ("temporal_twins", "temporal_twins_oracle_calib"):
user_seed = seed + _TWIN_DIFFICULTY_USER_SEED_OFFSETS[difficulty]
seed_python_numpy(user_seed)
users = generate_users(config)
else:
users = generate_users(config)
return generate_difficulty(
config,
users,
difficulty,
seed,
time_offset=0.0,
benchmark_mode=benchmark_mode,
)
# ---------------------------------------------------------------------------
# Metrics
# ---------------------------------------------------------------------------
def compute_ece(y_true: np.ndarray, y_prob: np.ndarray, n_bins: int = 10) -> float:
bins = np.linspace(0.0, 1.0, n_bins + 1)
ece = 0.0
for lo, hi in zip(bins[:-1], bins[1:]):
mask = (y_prob >= lo) & (y_prob < hi if hi < 1.0 else y_prob <= hi)
if not mask.any():
continue
frac = float(mask.mean())
avg_conf = float(y_prob[mask].mean())
avg_acc = float(y_true[mask].mean())
ece += frac * abs(avg_conf - avg_acc)
return float(ece)
def safe_roc_auc(y_true: np.ndarray, y_prob: np.ndarray) -> float:
if len(np.unique(y_true)) < 2:
return 0.5
return float(roc_auc_score(y_true, y_prob))
def safe_pr_auc(y_true: np.ndarray, y_prob: np.ndarray) -> float:
positives = float(np.sum(y_true == 1))
negatives = float(np.sum(y_true == 0))
if positives == 0.0:
return 0.0
if negatives == 0.0:
return 1.0
return float(average_precision_score(y_true, y_prob))
def compute_metrics(y_true: np.ndarray, y_prob: np.ndarray) -> dict:
y_true = np.asarray(y_true, dtype=np.float32)
y_prob = np.nan_to_num(np.asarray(y_prob, dtype=np.float32), nan=0.5, posinf=1.0, neginf=0.0)
y_prob = np.clip(y_prob, 0.0, 1.0)
return {
"roc_auc": safe_roc_auc(y_true, y_prob),
"pr_auc": safe_pr_auc(y_true, y_prob),
"brier": float(brier_score_loss(y_true, y_prob)),
"ece": compute_ece(y_true, y_prob),
}
def safe_pearson(x: np.ndarray, y: np.ndarray) -> float:
x = np.asarray(x, dtype=np.float32)
y = np.asarray(y, dtype=np.float32)
if len(x) == 0 or len(y) == 0:
return 0.0
if np.std(x) < 1e-8 or np.std(y) < 1e-8:
return 0.0
return float(np.corrcoef(x, y)[0, 1])
def build_node_audit_table(df: pd.DataFrame) -> pd.DataFrame:
df = df.sort_values("timestamp").reset_index(drop=True).copy()
df["_dt"] = df.groupby("sender_id")["timestamp"].diff().fillna(0.0)
df["_phase"] = df["timestamp"] % 86400.0
df["_burst"] = (df["_dt"] > 0.0) & (df["_dt"] < 600.0)
df["_quiet"] = df["_dt"] > 3600.0
grp = df.groupby("sender_id", sort=False)
node_df = pd.DataFrame({
"txn_count": grp["sender_id"].count(),
"receiver_count": grp["receiver_id"].nunique(),
"retry_count": grp["is_retry"].sum() if "is_retry" in df.columns else 0.0,
"failed_count": grp["failed"].sum() if "failed" in df.columns else 0.0,
"burst_count": grp["_burst"].sum(),
"quiet_count": grp["_quiet"].sum(),
"dt_mean": grp["_dt"].mean(),
"dt_std": grp["_dt"].std().fillna(0.0),
"amount_mean": grp["amount"].mean(),
"amount_std": grp["amount"].std().fillna(0.0),
"phase_std": grp["_phase"].std().fillna(0.0),
})
recv_counts = (
df.groupby(["sender_id", "receiver_id"])
.size()
.reset_index(name="_n")
)
recv_counts["_tot"] = recv_counts.groupby("sender_id")["_n"].transform("sum")
recv_counts["_p"] = recv_counts["_n"] / recv_counts["_tot"]
recv_counts["_h"] = -recv_counts["_p"] * np.log2(recv_counts["_p"] + 1e-9)
node_df["recv_entropy"] = recv_counts.groupby("sender_id")["_h"].sum()
if "twin_pair_id" in df.columns:
node_df["twin_pair_id"] = grp["twin_pair_id"].first().astype(np.int32)
else:
node_df["twin_pair_id"] = -1
if "twin_label" in df.columns:
node_df["label"] = grp["twin_label"].max().astype(np.int32)
else:
node_df["label"] = grp["is_fraud"].max().astype(np.int32)
return node_df.fillna(0.0).reset_index()
def with_local_event_idx(df: pd.DataFrame) -> pd.DataFrame:
out = df.sort_values("timestamp").reset_index(drop=True).copy()
out["local_event_idx"] = (
out.groupby("sender_id").cumcount().astype(np.int32)
)
return out
def build_matched_control_tables(
df: pd.DataFrame,
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""Build matched fraud/benign evaluation examples at the same local index k."""
required = {"twin_pair_id", "twin_role", "twin_label", "label_event_idx"}
if not required.issubset(df.columns):
empty = pd.DataFrame()
return empty, empty, empty
twin_df = with_local_event_idx(df[df["twin_pair_id"] >= 0].copy())
if twin_df.empty:
empty = pd.DataFrame()
return empty, empty, empty
sender_meta = (
twin_df.groupby("sender_id")
.agg(
twin_pair_id=("twin_pair_id", "first"),
twin_role=("twin_role", "first"),
twin_label=("twin_label", "max"),
template_id=("template_id", "first") if "template_id" in twin_df.columns else ("twin_pair_id", "first"),
total_txn_count=("sender_id", "size"),
sender_start_time=("timestamp", "min"),
motif_hit_count=("motif_hit_count", "max") if "motif_hit_count" in twin_df.columns else ("sender_id", "size"),
)
.reset_index()
)
if "motif_hit_count" not in twin_df.columns:
sender_meta["motif_hit_count"] = 0
pair_rows: list[dict] = []
example_rows: list[dict] = []
pair_count_rows: list[dict] = []
pair_event_id = 0
sender_groups = {
int(sender_id): group.reset_index(drop=True).copy()
for sender_id, group in twin_df.groupby("sender_id", sort=False)
}
for pair_id, pair_meta in sender_meta.groupby("twin_pair_id", sort=True):
if len(pair_meta) != 2 or set(pair_meta["twin_role"]) != {"fraud", "benign"}:
continue
fraud_meta = pair_meta[pair_meta["twin_role"] == "fraud"].iloc[0]
benign_meta = pair_meta[pair_meta["twin_role"] == "benign"].iloc[0]
fraud_sender = int(fraud_meta["sender_id"])
benign_sender = int(benign_meta["sender_id"])
template_id = int(fraud_meta["template_id"])
fraud_total = int(fraud_meta["total_txn_count"])
benign_total = int(benign_meta["total_txn_count"])
pair_count_rows.append({
"twin_pair_id": int(pair_id),
"template_id": template_id,
"fraud_sender_id": fraud_sender,
"benign_sender_id": benign_sender,
"fraud_total_txn_count": fraud_total,
"benign_total_txn_count": benign_total,
"pair_total_txn_count_diff": abs(fraud_total - benign_total),
"fraud_motif_hit_count": int(fraud_meta["motif_hit_count"]),
"benign_motif_hit_count": int(benign_meta["motif_hit_count"]),
})
fraud_group = sender_groups[fraud_sender]
benign_group = sender_groups[benign_sender]
benign_by_idx = benign_group.set_index("local_event_idx", drop=False)
fraud_positives = fraud_group[
(fraud_group["is_fraud"] == 1) & (fraud_group["label_event_idx"] >= 0)
].copy()
for row in fraud_positives.itertuples(index=False):
k = int(row.label_event_idx)
if k not in benign_by_idx.index:
continue
benign_row = benign_by_idx.loc[k]
if isinstance(benign_row, pd.DataFrame):
benign_row = benign_row.iloc[0]
fraud_age = float(row.timestamp - fraud_meta["sender_start_time"])
benign_age = float(benign_row["timestamp"] - benign_meta["sender_start_time"])
prefix_txn_count = k + 1
pair_rows.append({
"pair_event_id": pair_event_id,
"twin_pair_id": int(pair_id),
"template_id": template_id,
"fraud_sender_id": fraud_sender,
"benign_sender_id": benign_sender,
"fraud_label_event_idx": k,
"benign_eval_event_idx": int(benign_row["local_event_idx"]),
"fraud_eval_timestamp": float(row.timestamp),
"benign_eval_timestamp": float(benign_row["timestamp"]),
"fraud_active_age": fraud_age,
"benign_active_age": benign_age,
"active_age_diff": abs(fraud_age - benign_age),
"timestamp_diff": abs(float(row.timestamp) - float(benign_row["timestamp"])),
"prefix_txn_count": prefix_txn_count,
"fraud_total_txn_count": fraud_total,
"benign_total_txn_count": benign_total,
"pair_total_txn_count_diff": abs(fraud_total - benign_total),
"fraud_motif_hit_count": int(fraud_meta["motif_hit_count"]),
"benign_motif_hit_count": int(benign_meta["motif_hit_count"]),
"label_delay": int(row.label_delay) if hasattr(row, "label_delay") else -1,
})
common = {
"pair_event_id": pair_event_id,
"twin_pair_id": int(pair_id),
"template_id": template_id,
"eval_local_event_idx": k,
"prefix_txn_count": prefix_txn_count,
}
example_rows.append({
**common,
"sender_id": fraud_sender,
"label": 1,
"twin_role": "fraud",
"matched_sender_id": benign_sender,
"total_txn_count": fraud_total,
"eval_timestamp": float(row.timestamp),
# The simulator has no separate account-creation time, so
# account_age equals active_age for twin-control audits.
"account_age": fraud_age,
"active_age": fraud_age,
})
example_rows.append({
**common,
"sender_id": benign_sender,
"label": 0,
"twin_role": "benign",
"matched_sender_id": fraud_sender,
"total_txn_count": benign_total,
"eval_timestamp": float(benign_row["timestamp"]),
"account_age": benign_age,
"active_age": benign_age,
})
pair_event_id += 1
return (
pd.DataFrame(example_rows),
pd.DataFrame(pair_rows),
pd.DataFrame(pair_count_rows),
)
def _sender_prefix_feature_row(prefix: pd.DataFrame) -> dict:
prefix = prefix.sort_values("timestamp").reset_index(drop=True)
timestamps = prefix["timestamp"].to_numpy(dtype=np.float64)
dts = np.diff(timestamps, prepend=timestamps[0]) if len(prefix) else np.zeros(0, dtype=np.float64)
dts = np.maximum(dts, 0.0)
phase = timestamps % 86400.0 if len(prefix) else np.zeros(0, dtype=np.float64)
burst = ((dts > 0.0) & (dts < 600.0)).astype(np.float32)
quiet = (dts > 3600.0).astype(np.float32)
recv_counts = prefix["receiver_id"].value_counts().to_numpy(dtype=np.float64)
recv_p = recv_counts / max(float(recv_counts.sum()), 1.0)
recv_entropy = float(-np.sum(recv_p * np.log2(recv_p + 1e-9))) if len(recv_counts) else 0.0
return {
"txn_count": float(len(prefix)),
"txn_cnt10_last": float(min(len(prefix), 10)),
"receiver_count": float(prefix["receiver_id"].nunique()) if len(prefix) else 0.0,
"retry_count": float(prefix["is_retry"].sum()) if "is_retry" in prefix.columns else 0.0,
"failed_count": float(prefix["failed"].sum()) if "failed" in prefix.columns else 0.0,
"burst_count": float(burst.sum()),
"quiet_count": float(quiet.sum()),
"amount_mean": float(prefix["amount"].mean()) if len(prefix) else 0.0,
"amount_std": float(prefix["amount"].std(ddof=1)) if len(prefix) > 1 else 0.0,
"amount_max": float(prefix["amount"].max()) if len(prefix) else 0.0,
"td_mean": float(dts.mean()) if len(dts) else 0.0,
"td_std": float(dts.std(ddof=1)) if len(dts) > 1 else 0.0,
"dt_mean": float(dts.mean()) if len(dts) else 0.0,
"dt_std": float(dts.std(ddof=1)) if len(dts) > 1 else 0.0,
"phase_std": float(np.std(phase, ddof=1)) if len(phase) > 1 else 0.0,
"recv_entropy": recv_entropy,
"fail_rate": float(prefix["failed"].mean()) if "failed" in prefix.columns and len(prefix) else 0.0,
"retry_rate": float(prefix["is_retry"].mean()) if "is_retry" in prefix.columns and len(prefix) else 0.0,
"pair_freq_mean": float(prefix["pair_freq"].mean()) if "pair_freq" in prefix.columns and len(prefix) else 0.0,
}
def build_matched_prefix_feature_table(
df: pd.DataFrame,
examples: pd.DataFrame,
) -> pd.DataFrame:
if examples.empty:
return pd.DataFrame()
indexed_df = with_local_event_idx(df)
sender_groups = {
int(sender_id): group.reset_index(drop=True).copy()
for sender_id, group in indexed_df.groupby("sender_id", sort=False)
}
rows: list[dict] = []
for example in examples.itertuples(index=False):
sender_id = int(example.sender_id)
end_idx = int(example.eval_local_event_idx)
sender_prefix = sender_groups[sender_id]
prefix = sender_prefix.iloc[: end_idx + 1].copy()
rows.append({
"pair_event_id": int(example.pair_event_id),
"twin_pair_id": int(example.twin_pair_id),
"template_id": int(example.template_id),
"sender_id": sender_id,
"label": int(example.label),
"eval_local_event_idx": int(example.eval_local_event_idx),
"prefix_txn_count": int(example.prefix_txn_count),
"total_txn_count": int(example.total_txn_count),
"eval_timestamp": float(example.eval_timestamp),
"account_age": float(example.account_age),
"active_age": float(example.active_age),
**_sender_prefix_feature_row(prefix),
})
return pd.DataFrame(rows).fillna(0.0)
def report_matched_control_audits(
test_examples: pd.DataFrame,
test_pair_rows: pd.DataFrame,
test_pair_counts: pd.DataFrame,
) -> dict:
if test_examples.empty:
return {}
y = test_examples["label"].to_numpy(dtype=np.float32)
audit = {
"pair_total_txn_count_diff_mean": float(test_pair_counts["pair_total_txn_count_diff"].mean()) if not test_pair_counts.empty else 0.0,
"pair_total_txn_count_diff_max": float(test_pair_counts["pair_total_txn_count_diff"].max()) if not test_pair_counts.empty else 0.0,
"auc_total_txn_count": safe_roc_auc(y, test_examples["total_txn_count"].to_numpy(dtype=np.float32)),
"auc_local_event_idx": safe_roc_auc(y, test_examples["eval_local_event_idx"].to_numpy(dtype=np.float32)),
"auc_prefix_txn_count": safe_roc_auc(y, test_examples["prefix_txn_count"].to_numpy(dtype=np.float32)),
"auc_timestamp": safe_roc_auc(y, test_examples["eval_timestamp"].to_numpy(dtype=np.float32)),
"auc_account_age": safe_roc_auc(y, test_examples["account_age"].to_numpy(dtype=np.float32)),
"auc_active_age": safe_roc_auc(y, test_examples["active_age"].to_numpy(dtype=np.float32)),
"fraud_label_event_idx_mean": float(test_pair_rows["fraud_label_event_idx"].mean()) if not test_pair_rows.empty else 0.0,
"fraud_label_event_idx_max": float(test_pair_rows["fraud_label_event_idx"].max()) if not test_pair_rows.empty else 0.0,
"benign_eval_event_idx_mean": float(test_pair_rows["benign_eval_event_idx"].mean()) if not test_pair_rows.empty else 0.0,
"benign_eval_event_idx_max": float(test_pair_rows["benign_eval_event_idx"].max()) if not test_pair_rows.empty else 0.0,
"pair_event_idx_diff_mean": float((test_pair_rows["fraud_label_event_idx"] - test_pair_rows["benign_eval_event_idx"]).abs().mean()) if not test_pair_rows.empty else 0.0,
"pair_event_idx_diff_max": float((test_pair_rows["fraud_label_event_idx"] - test_pair_rows["benign_eval_event_idx"]).abs().max()) if not test_pair_rows.empty else 0.0,
"pair_active_age_diff_mean": float(test_pair_rows["active_age_diff"].mean()) if not test_pair_rows.empty else 0.0,
"pair_active_age_diff_max": float(test_pair_rows["active_age_diff"].max()) if not test_pair_rows.empty else 0.0,
"pair_timestamp_diff_mean": float(test_pair_rows["timestamp_diff"].mean()) if not test_pair_rows.empty else 0.0,
"pair_timestamp_diff_max": float(test_pair_rows["timestamp_diff"].max()) if not test_pair_rows.empty else 0.0,
"benign_motif_hit_rate": float((test_pair_counts["benign_motif_hit_count"] > 0).mean()) if not test_pair_counts.empty else 0.0,
"benign_motif_hit_pairs": int((test_pair_counts["benign_motif_hit_count"] > 0).sum()) if not test_pair_counts.empty else 0,
"matched_control_examples": int(len(test_examples)),
"matched_control_pair_events": int(len(test_pair_rows)),
}
print("\n--- Matched-Control Shortcut Audit ---")
for key in (
"pair_total_txn_count_diff_mean",
"pair_total_txn_count_diff_max",
"auc_total_txn_count",
"auc_local_event_idx",
"auc_prefix_txn_count",
"auc_timestamp",
"auc_account_age",
"auc_active_age",
"benign_motif_hit_rate",
"benign_motif_hit_pairs",
):
print(f" {key:<30}: {audit[key]}")
if not test_pair_rows.empty:
print("\n label_event_idx distribution (fraud twins):")
print(test_pair_rows["fraud_label_event_idx"].describe().to_string())
print("\n pseudo-label idx distribution (benign twins):")
print(test_pair_rows["benign_eval_event_idx"].describe().to_string())
print("\n per-pair fraud-vs-benign evaluation indices:")
cols = [
"twin_pair_id",
"fraud_label_event_idx",
"benign_eval_event_idx",
"active_age_diff",
"timestamp_diff",
]
print(test_pair_rows[cols].head(20).to_string(index=False))
return audit
def bootstrap_auc_summary(
y_true: np.ndarray,
y_score: np.ndarray,
seed: int,
n_bootstrap: int = 200,
) -> dict:
y_true = np.asarray(y_true, dtype=np.float32)
y_score = np.asarray(y_score, dtype=np.float32)
if len(y_true) == 0 or len(np.unique(y_true)) < 2:
return {
"bootstrap_std": float("nan"),
"ci_lo": float("nan"),
"ci_hi": float("nan"),
"n_bootstrap": 0,
}
rng = np.random.default_rng(seed)
aucs: list[float] = []
n = len(y_true)
for _ in range(n_bootstrap):
idx = rng.integers(0, n, size=n)
sample_y = y_true[idx]
if len(np.unique(sample_y)) < 2:
continue
aucs.append(safe_roc_auc(sample_y, y_score[idx]))
if not aucs:
return {
"bootstrap_std": float("nan"),
"ci_lo": float("nan"),
"ci_hi": float("nan"),
"n_bootstrap": 0,
}
auc_arr = np.asarray(aucs, dtype=np.float32)
return {
"bootstrap_std": float(np.std(auc_arr, ddof=1)) if len(auc_arr) > 1 else 0.0,
"ci_lo": float(np.quantile(auc_arr, 0.025)),
"ci_hi": float(np.quantile(auc_arr, 0.975)),
"n_bootstrap": int(len(auc_arr)),
}
def make_auc_result(
y_true: np.ndarray,
y_score: np.ndarray,
seed: int,
extra: dict | None = None,
) -> dict:
y_true = np.asarray(y_true, dtype=np.float32)
y_score = np.asarray(y_score, dtype=np.float32)
result = {
"auc": safe_roc_auc(y_true, y_score),
"y_true": y_true,
"y_score": y_score,
"n_examples": int(len(y_true)),
"n_pos": int(np.sum(y_true == 1)),
"n_neg": int(np.sum(y_true == 0)),
}
result.update(bootstrap_auc_summary(y_true, y_score, seed=seed))
if extra:
result.update(extra)
return result
def attach_auc_result(report: dict, prefix: str, result: dict) -> None:
report[f"{prefix}_roc_auc"] = float(result["auc"])
report[f"{prefix}_n_examples"] = int(result["n_examples"])
report[f"{prefix}_n_pos"] = int(result["n_pos"])
report[f"{prefix}_n_neg"] = int(result["n_neg"])
report[f"{prefix}_auc_bootstrap_std"] = float(result["bootstrap_std"])
report[f"{prefix}_auc_ci_lo"] = float(result["ci_lo"])
report[f"{prefix}_auc_ci_hi"] = float(result["ci_hi"])
def _standardize_train_test(
train_df: pd.DataFrame,
test_df: pd.DataFrame,
feature_cols: Sequence[str],
) -> tuple[np.ndarray, np.ndarray]:
x_train = train_df[list(feature_cols)].to_numpy(dtype=np.float32)
x_test = test_df[list(feature_cols)].to_numpy(dtype=np.float32)
mean = x_train.mean(axis=0, keepdims=True)
std = x_train.std(axis=0, keepdims=True) + 1e-6
return (x_train - mean) / std, (x_test - mean) / std
def compute_matched_static_aggregate_auc(
train_features: pd.DataFrame,
test_features: pd.DataFrame,
seed: int,
verbose: bool = True,
) -> dict:
feature_cols = [
"txn_count",
"receiver_count",
"retry_count",
"failed_count",
"burst_count",
"quiet_count",
"dt_mean",
"dt_std",
"amount_mean",
"amount_std",
"phase_std",
"recv_entropy",
]
if train_features.empty or test_features.empty:
return make_auc_result(np.zeros(0, dtype=np.float32), np.zeros(0, dtype=np.float32), seed=seed)
if train_features["label"].nunique() < 2 or test_features["label"].nunique() < 2:
y_test = test_features["label"].to_numpy(dtype=np.float32)
probs = np.full(len(y_test), 0.5, dtype=np.float32)
return make_auc_result(y_test, probs, seed=seed)
x_train, x_test = _standardize_train_test(train_features, test_features, feature_cols)
clf = LogisticRegression(
max_iter=2000,
class_weight="balanced",
random_state=42,
solver="liblinear",
)
clf.fit(x_train, train_features["label"].to_numpy(dtype=np.int32))
probs = clf.predict_proba(x_test)[:, 1]
y_test = test_features["label"].to_numpy(dtype=np.float32)
if verbose:
coefs = np.abs(clf.coef_[0])
ranked = np.argsort(coefs)[::-1]
print("\n Top matched static aggregate predictors:")
for rank_i in ranked[:5]:
print(f" {feature_cols[rank_i]:<20}: |coef|={coefs[rank_i]:.4f}")
return make_auc_result(
y_test,
probs.astype(np.float32),
seed=seed,
)
def compute_matched_xgboost_auc(
train_features: pd.DataFrame,
test_features: pd.DataFrame,
seed: int,
) -> dict:
feature_cols = [
"txn_count",
"txn_cnt10_last",
"amount_mean",
"amount_std",
"amount_max",
"td_mean",
"td_std",
"fail_rate",
"retry_rate",
"recv_entropy",
"pair_freq_mean",
]
if train_features.empty or test_features.empty:
return make_auc_result(np.zeros(0, dtype=np.float32), np.zeros(0, dtype=np.float32), seed=seed)
y_train = train_features["label"].to_numpy(dtype=np.int32)
y_test = test_features["label"].to_numpy(dtype=np.int32)
if len(np.unique(y_train)) < 2 or len(np.unique(y_test)) < 2:
probs = np.full(len(y_test), 0.5, dtype=np.float32)
return make_auc_result(y_test.astype(np.float32), probs, seed=seed)
x_train = train_features[feature_cols].to_numpy(dtype=np.float32)
x_test = test_features[feature_cols].to_numpy(dtype=np.float32)
scale_pos_weight = max(1.0, float((y_train == 0).sum()) / max(float((y_train == 1).sum()), 1.0))
model = XGBClassifier(
n_estimators=200,
max_depth=6,
learning_rate=0.05,
objective="binary:logistic",
eval_metric="logloss",
scale_pos_weight=scale_pos_weight,
random_state=42,
verbosity=0,
n_jobs=1,
tree_method="exact",
)
model.fit(x_train, y_train)
probs = model.predict_proba(x_test)[:, 1]
importances = model.feature_importances_
ranked = np.argsort(importances)[::-1]
print(" [Matched XGBoost] Top-5 feature importances:")
for idx in ranked[:5]:
print(f" {feature_cols[idx]:<20}: {importances[idx]:.4f}")
return make_auc_result(
y_test.astype(np.float32),
probs.astype(np.float32),
seed=seed,
)
def _build_example_prefix(
df_full: pd.DataFrame,
sender_id: int,
eval_local_event_idx: int,
eval_timestamp: float,
) -> pd.DataFrame:
prefix = df_full[df_full["timestamp"] <= eval_timestamp].copy()
if "local_event_idx" not in prefix.columns:
prefix = with_local_event_idx(prefix)
sender_mask = prefix["sender_id"] == sender_id
if sender_mask.any():
prefix = prefix[(~sender_mask) | (prefix["local_event_idx"] <= eval_local_event_idx)].copy()
return prefix.sort_values("timestamp").reset_index(drop=True)
def build_static_gnn_example_embeddings(
model: StaticGNNWrapper,
df_full: pd.DataFrame,
examples: pd.DataFrame,
) -> tuple[np.ndarray, dict]:
if examples.empty:
return np.zeros((0, model.hidden_dim), dtype=np.float32), {
"matched_examples": 0,
"unique_prefix_cutoffs": 0,
"graph_builds": 0,
"cache_hit_rate": float("nan"),
"eval_time_sec": 0.0,
}
clean_full = strip_oracle_cols(
df_full.sort_values("timestamp").reset_index(drop=True)
)
start = time.perf_counter()
sender_ids = clean_full["sender_id"].to_numpy(dtype=np.int64)
receiver_ids = clean_full["receiver_id"].to_numpy(dtype=np.int64)
timestamps = clean_full["timestamp"].to_numpy(dtype=np.float64)
edge_feats = build_edge_features(clean_full).astype(np.float32)
ns = model._norm_stats
edge_feats = (edge_feats - ns["ea_mean"]) / ns["ea_std"]
max_sender = int(sender_ids.max()) if len(sender_ids) else 0
max_receiver = int(receiver_ids.max()) if len(receiver_ids) else 0
n_nodes = max(max(max_sender, max_receiver) + 1, model._n_nodes)
feat_sum = np.zeros((n_nodes, edge_feats.shape[1]), dtype=np.float32)
feat_count = np.zeros(n_nodes, dtype=np.float32)
node_feat = np.zeros((n_nodes, edge_feats.shape[1]), dtype=np.float32)
device = model.device
x_t = torch.zeros((n_nodes, edge_feats.shape[1]), dtype=torch.float32, device=device)
edge_index_full = torch.tensor(
np.vstack([sender_ids, receiver_ids]),
dtype=torch.long,
device=device,
)
examples_reset = examples.reset_index(drop=True).copy()
grouped = examples_reset.groupby("eval_timestamp", sort=True).indices
grouped_items = sorted(
[(float(ts), idxs) for ts, idxs in grouped.items()],
key=lambda item: item[0],
)
ordered_cutoffs = [item[0] for item in grouped_items]
cutoff_ends = np.searchsorted(timestamps, np.asarray(ordered_cutoffs, dtype=np.float64), side="right")
out = np.zeros((len(examples_reset), model.hidden_dim), dtype=np.float32)
prev_end = 0
graph_builds = 0
for (cutoff, row_indices), end_idx in zip(grouped_items, cutoff_ends.tolist()):
if end_idx > prev_end:
batch_senders = sender_ids[prev_end:end_idx]
batch_feats = edge_feats[prev_end:end_idx]
np.add.at(feat_sum, batch_senders, batch_feats)
np.add.at(feat_count, batch_senders, 1.0)
changed_nodes = np.unique(batch_senders)
node_feat[changed_nodes] = feat_sum[changed_nodes] / feat_count[changed_nodes, None]
changed_t = torch.tensor(changed_nodes, dtype=torch.long, device=device)
x_t[changed_t] = torch.tensor(node_feat[changed_nodes], dtype=torch.float32, device=device)
prev_end = end_idx
edge_index = edge_index_full[:, :end_idx]
model._encoder.eval()
with torch.no_grad():
prefix_emb = model._encoder(x_t, edge_index)
graph_builds += 1
sender_batch = examples_reset.loc[row_indices, "sender_id"].to_numpy(dtype=np.int64)
sender_t = torch.tensor(sender_batch, dtype=torch.long, device=device)
out[row_indices] = prefix_emb[sender_t].detach().cpu().numpy().astype(np.float32)
matched_examples = int(len(examples_reset))
unique_cutoffs = int(len(ordered_cutoffs))
hits = max(0, matched_examples - graph_builds)
diagnostics = {
"matched_examples": matched_examples,
"unique_prefix_cutoffs": unique_cutoffs,
"graph_builds": int(graph_builds),
"cache_hit_rate": float(hits / matched_examples) if matched_examples > 0 else float("nan"),
"eval_time_sec": float(time.perf_counter() - start),
}
return out.astype(np.float32), diagnostics
def compute_matched_static_gnn_auc(
df_train: pd.DataFrame,
df_test: pd.DataFrame,
train_examples: pd.DataFrame,
test_examples: pd.DataFrame,
device: str,
num_epochs: int,
seed: int,
) -> dict:
if train_examples.empty or test_examples.empty:
return make_auc_result(np.zeros(0, dtype=np.float32), np.zeros(0, dtype=np.float32), seed=seed)
if train_examples["label"].nunique() < 2 or test_examples["label"].nunique() < 2:
y_test = test_examples["label"].to_numpy(dtype=np.float32)
probs = np.full(len(y_test), 0.5, dtype=np.float32)
return make_auc_result(y_test, probs, seed=seed)
static_seed = derived_seed(seed, "StaticGNN", "matched_prefix")
set_global_determinism(static_seed)
model = StaticGNNWrapper(hidden_dim=64, n_snapshots=10, device=device)
model.fit(strip_oracle_cols(df_train), num_epochs=num_epochs)
eval_start = time.perf_counter()
train_emb, train_diag = build_static_gnn_example_embeddings(model, df_train, train_examples)
full_test_df = (
pd.concat([df_train, df_test], ignore_index=True)
.sort_values("timestamp")
.reset_index(drop=True)
)
test_emb, test_diag = build_static_gnn_example_embeddings(model, full_test_df, test_examples)
y_train = train_examples["label"].to_numpy(dtype=np.int32)
y_test = test_examples["label"].to_numpy(dtype=np.int32)
mean = train_emb.mean(axis=0, keepdims=True)
std = train_emb.std(axis=0, keepdims=True) + 1e-6
train_emb = (train_emb - mean) / std
test_emb = (test_emb - mean) / std
clf = LogisticRegression(
max_iter=2000,
class_weight="balanced",
random_state=42,
solver="liblinear",
)
clf.fit(train_emb, y_train)
probs = clf.predict_proba(test_emb)[:, 1]
return make_auc_result(
y_test.astype(np.float32),
probs.astype(np.float32),
seed=seed,
extra={
"auc_flipped": safe_roc_auc(y_test.astype(np.float32), (1.0 - probs).astype(np.float32)),
"score_mean_pos": float(probs[y_test == 1].mean()) if np.any(y_test == 1) else float("nan"),
"score_mean_neg": float(probs[y_test == 0].mean()) if np.any(y_test == 0) else float("nan"),
"score_std": float(np.std(probs)),
"zero_emb_frac": float(np.mean(np.linalg.norm(test_emb, axis=1) < 1e-8)),
"train_examples": int(len(train_examples)),
"test_examples": int(len(test_examples)),
"matched_examples": int(train_diag["matched_examples"] + test_diag["matched_examples"]),
"unique_prefix_cutoffs": int(train_diag["unique_prefix_cutoffs"] + test_diag["unique_prefix_cutoffs"]),
"graph_builds": int(train_diag["graph_builds"] + test_diag["graph_builds"]),
"cache_hit_rate": float(
(
max(0, train_diag["matched_examples"] - train_diag["graph_builds"])
+ max(0, test_diag["matched_examples"] - test_diag["graph_builds"])
)
/ max(1, train_diag["matched_examples"] + test_diag["matched_examples"])
),
"eval_time_sec": float(time.perf_counter() - eval_start),
"train_unique_prefix_cutoffs": int(train_diag["unique_prefix_cutoffs"]),
"test_unique_prefix_cutoffs": int(test_diag["unique_prefix_cutoffs"]),
"train_graph_builds": int(train_diag["graph_builds"]),
"test_graph_builds": int(test_diag["graph_builds"]),
"train_eval_time_sec": float(train_diag["eval_time_sec"]),
"test_eval_time_sec": float(test_diag["eval_time_sec"]),
},
)
def compute_matched_seqgru_metrics(
df_train: pd.DataFrame,
df_test: pd.DataFrame,
train_examples: pd.DataFrame,
test_examples: pd.DataFrame,
device: str,
seed: int,
max_epochs: int,
hidden_dim: int = 96,
receiver_buckets: int = 512,
) -> dict:
if train_examples.empty or test_examples.empty:
empty = make_auc_result(np.zeros(0, dtype=np.float32), np.zeros(0, dtype=np.float32), seed=seed)
return {
"clean": empty,
"shuffled": empty,
"delta": float("nan"),
"clean_fit": {},
"shuffled_fit": {},
}
clean_train_df = strip_oracle_cols(df_train)
clean_test_df = strip_oracle_cols(df_test)
y_test = test_examples["label"].to_numpy(dtype=np.float32)
if train_examples["label"].nunique() < 2 or test_examples["label"].nunique() < 2:
flat_probs = np.full(len(y_test), 0.5, dtype=np.float32)
flat = make_auc_result(y_test, flat_probs, seed=seed)
flat["pr_auc"] = compute_metrics(y_test, flat_probs)["pr_auc"]
return {
"clean": flat,
"shuffled": flat,
"delta": 0.0,
"clean_fit": {},
"shuffled_fit": {},
}
def build_model() -> SequenceGRUWrapper:
return SequenceGRUWrapper(
hidden_dim=hidden_dim,
receiver_buckets=receiver_buckets,
device=device,
)
clean_seed = derived_seed(seed, "SeqGRU", "clean")
shuffled_seed = derived_seed(seed, "SeqGRU", "shuffled")
set_global_determinism(clean_seed)
clean_model = build_model()
clean_model.fit(clean_train_df, num_epochs=1)
clean_fit = clean_model.fit_matched_prefix_examples(
clean_train_df,
train_examples,
seed=clean_seed,
max_epochs=max_epochs,
patience=6,
valid_frac=0.20,
pair_batch_size=64,
learning_rate=2e-3,
weight_decay=1e-4,
shuffle_within_sequence=False,
)
clean_probs = clean_model.predict_matched_prefix_examples(
clean_test_df,
test_examples,
seed=clean_seed,
shuffle_within_sequence=False,
)
clean_metrics = compute_metrics(y_test, clean_probs)
clean_result = make_auc_result(
y_test,
clean_probs.astype(np.float32),
seed=seed,
extra={
"pr_auc": float(clean_metrics["pr_auc"]),
"brier": float(clean_metrics["brier"]),
"ece": float(clean_metrics["ece"]),
},
)
set_global_determinism(shuffled_seed)
shuffled_model = build_model()
shuffled_model.fit(clean_train_df, num_epochs=1)
shuffled_fit = shuffled_model.fit_matched_prefix_examples(
clean_train_df,
train_examples,
seed=shuffled_seed,
max_epochs=max_epochs,
patience=6,
valid_frac=0.20,
pair_batch_size=64,
learning_rate=2e-3,
weight_decay=1e-4,
shuffle_within_sequence=True,
)
shuffled_probs = shuffled_model.predict_matched_prefix_examples(
clean_test_df,
test_examples,
seed=shuffled_seed,
shuffle_within_sequence=True,
)
shuffled_metrics = compute_metrics(y_test, shuffled_probs)
shuffled_result = make_auc_result(
y_test,
shuffled_probs.astype(np.float32),
seed=seed,
extra={
"pr_auc": float(shuffled_metrics["pr_auc"]),
"brier": float(shuffled_metrics["brier"]),
"ece": float(shuffled_metrics["ece"]),
},
)
return {
"clean": clean_result,
"shuffled": shuffled_result,
"delta": float(shuffled_result["auc"] - clean_result["auc"]),
"clean_fit": clean_fit,
"shuffled_fit": shuffled_fit,
}
def _combine_matched_examples(
train_examples: pd.DataFrame,
test_examples: pd.DataFrame,
) -> pd.DataFrame:
tagged_train = train_examples.copy()
tagged_train["example_split"] = "train"
tagged_test = test_examples.copy()
tagged_test["example_split"] = "test"
return pd.concat([tagged_train, tagged_test], ignore_index=True)
def _fit_embedding_probe(
train_emb: np.ndarray,
test_emb: np.ndarray,
y_train: np.ndarray,
y_test: np.ndarray,
seed: int,
) -> dict:
if len(y_train) == 0 or len(y_test) == 0:
return make_auc_result(np.zeros(0, dtype=np.float32), np.zeros(0, dtype=np.float32), seed=seed)
if len(np.unique(y_train)) < 2 or len(np.unique(y_test)) < 2:
probs = np.full(len(y_test), 0.5, dtype=np.float32)
metrics = compute_metrics(y_test, probs)
return make_auc_result(
y_test.astype(np.float32),
probs,
seed=seed,
extra={
"pr_auc": float(metrics["pr_auc"]),
"brier": float(metrics["brier"]),
"ece": float(metrics["ece"]),
},
)
mean = train_emb.mean(axis=0, keepdims=True)
std = train_emb.std(axis=0, keepdims=True) + 1e-6
train_emb = (train_emb - mean) / std
test_emb = (test_emb - mean) / std
clf = LogisticRegression(
max_iter=2000,
class_weight="balanced",
random_state=seed,
solver="liblinear",
)
clf.fit(train_emb, y_train.astype(np.int32))
probs = clf.predict_proba(test_emb)[:, 1].astype(np.float32)
metrics = compute_metrics(y_test.astype(np.float32), probs)
return make_auc_result(
y_test.astype(np.float32),
probs,
seed=seed,
extra={
"pr_auc": float(metrics["pr_auc"]),
"brier": float(metrics["brier"]),
"ece": float(metrics["ece"]),
},
)
def compute_matched_temporal_gnn_metrics(
model_name: str,
model_builder,
df_train: pd.DataFrame,
df_test: pd.DataFrame,
train_examples: pd.DataFrame,
test_examples: pd.DataFrame,
seed: int,
num_epochs: int,
) -> dict:
if train_examples.empty or test_examples.empty:
empty = make_auc_result(np.zeros(0, dtype=np.float32), np.zeros(0, dtype=np.float32), seed=seed)
return {
"clean": empty,
"shuffled": empty,
"delta": float("nan"),
}
clean_train = strip_oracle_cols(df_train)
clean_test = strip_oracle_cols(df_test)
all_examples = _combine_matched_examples(train_examples, test_examples)
train_mask = all_examples["example_split"].to_numpy() == "train"
test_mask = ~train_mask
y_train = all_examples.loc[train_mask, "label"].to_numpy(dtype=np.float32)
y_test = all_examples.loc[test_mask, "label"].to_numpy(dtype=np.float32)
clean_model_seed = derived_seed(seed, model_name, "clean_model")
shuffled_model_seed = derived_seed(seed, model_name, "shuffled_model")
set_global_determinism(clean_model_seed)
clean_model = model_builder()
clean_model.fit(clean_train, num_epochs=num_epochs)
clean_full = (
pd.concat([clean_train, clean_test], ignore_index=True)
.sort_values("timestamp")
.reset_index(drop=True)
)
clean_emb = clean_model.extract_prefix_embeddings(clean_full, all_examples)
clean_result = _fit_embedding_probe(
clean_emb[train_mask],
clean_emb[test_mask],
y_train,
y_test,
seed=seed,
)
shuffled_train = shuffle_chronology(clean_train, seed=seed + 101)
shuffled_test = shuffle_chronology(clean_test, seed=seed + 211)
set_global_determinism(shuffled_model_seed)
shuffled_model = model_builder()
shuffled_model.fit(shuffled_train, num_epochs=num_epochs)
shuffled_full = (
pd.concat([shuffled_train, shuffled_test], ignore_index=True)
.sort_values("timestamp")
.reset_index(drop=True)
)
shuffled_emb = shuffled_model.extract_prefix_embeddings(shuffled_full, all_examples)
shuffled_result = _fit_embedding_probe(
shuffled_emb[train_mask],
shuffled_emb[test_mask],
y_train,
y_test,
seed=seed,
)
return {
"clean": clean_result,
"shuffled": shuffled_result,
"delta": float(shuffled_result["auc"] - clean_result["auc"]),
"train_examples": int(train_mask.sum()),
"test_examples": int(test_mask.sum()),
"model_name": model_name,
}
def ks_distance(x: np.ndarray, y: np.ndarray) -> float:
x = np.sort(np.asarray(x, dtype=np.float64))
y = np.sort(np.asarray(y, dtype=np.float64))
if len(x) == 0 or len(y) == 0:
return 0.0
values = np.sort(np.concatenate([x, y]))
cdf_x = np.searchsorted(x, values, side="right") / len(x)
cdf_y = np.searchsorted(y, values, side="right") / len(y)
return float(np.max(np.abs(cdf_x - cdf_y)))
def compute_static_aggregate_auc(node_df: pd.DataFrame, seed: int, verbose: bool = True) -> float:
feature_cols = [
"txn_count",
"receiver_count",
"retry_count",
"failed_count",
"burst_count",
"quiet_count",
"dt_mean",
"dt_std",
"amount_mean",
"amount_std",
"phase_std",
"recv_entropy",
]
audit_df = node_df[node_df["twin_pair_id"] >= 0].copy()
if audit_df.empty or audit_df["label"].nunique() < 2:
return 0.5
pair_ids = audit_df["twin_pair_id"].unique()
if len(pair_ids) < 4:
return 0.5
rng = np.random.default_rng(seed)
pair_ids = rng.permutation(pair_ids)
split = max(1, int(0.7 * len(pair_ids)))
train_ids = set(pair_ids[:split])
test_ids = set(pair_ids[split:])
if not test_ids:
test_ids = set(pair_ids[-1:])
train_ids = set(pair_ids[:-1])
train_df = audit_df[audit_df["twin_pair_id"].isin(train_ids)]
test_df = audit_df[audit_df["twin_pair_id"].isin(test_ids)]
if train_df["label"].nunique() < 2 or test_df["label"].nunique() < 2:
return 0.5
x_train = train_df[feature_cols].to_numpy(dtype=np.float32)
x_test = test_df[feature_cols].to_numpy(dtype=np.float32)
mean = x_train.mean(axis=0, keepdims=True)
std = x_train.std(axis=0, keepdims=True) + 1e-6
x_train = (x_train - mean) / std
x_test = (x_test - mean) / std
clf = LogisticRegression(
max_iter=2000,
class_weight="balanced",
random_state=seed,
solver="liblinear",
)
clf.fit(x_train, train_df["label"].to_numpy(dtype=np.int32))
probs = clf.predict_proba(x_test)[:, 1]
auc = safe_roc_auc(test_df["label"].to_numpy(dtype=np.float32), probs.astype(np.float32))
if verbose:
# Top predictors by absolute coefficient
coefs = np.abs(clf.coef_[0])
ranked = np.argsort(coefs)[::-1]
print("\n Top static aggregate predictors:")
for rank_i in ranked[:5]:
print(f" {feature_cols[rank_i]:<20}: |coef|={coefs[rank_i]:.4f}")
return auc
def compute_aggregate_ks(node_df: pd.DataFrame) -> tuple[float, float]:
fraud_df = node_df[(node_df["twin_pair_id"] >= 0) & (node_df["label"] == 1)]
benign_df = node_df[(node_df["twin_pair_id"] >= 0) & (node_df["label"] == 0)]
if fraud_df.empty or benign_df.empty:
return 0.0, 0.0
feature_cols = [
"txn_count",
"receiver_count",
"retry_count",
"burst_count",
"dt_mean",
"dt_std",
"recv_entropy",
]
distances = [
ks_distance(fraud_df[col].to_numpy(), benign_df[col].to_numpy())
for col in feature_cols
]
if not distances:
return 0.0, 0.0
return float(np.mean(distances)), float(np.max(distances))
def evaluate_matched_pair_separability(
model: TemporalModel,
df_train: pd.DataFrame,
df_test: pd.DataFrame,
delta_time: float,
n_checkpoints: int,
) -> tuple[float, int]:
if "twin_pair_id" not in df_test.columns or "twin_label" not in df_test.columns:
return 0.0, 0
checkpoints = make_checkpoints(df_test, delta_time, n_checkpoints=n_checkpoints)
if not checkpoints:
return 0.0, 0
cutoff_time = checkpoints[-1]
df_full = (
pd.concat([df_train, df_test], ignore_index=True)
.sort_values("timestamp")
.reset_index(drop=True)
)
prefix_df = df_full[df_full["timestamp"] <= cutoff_time].copy()
active_nodes = sorted(df_test[df_test["timestamp"] <= cutoff_time]["sender_id"].unique())
if not active_nodes:
return 0.0, 0
if model.is_temporal:
model.reset_memory()
probs = model.predict(prefix_df, active_nodes)
score_map = {int(node_id): float(prob) for node_id, prob in zip(active_nodes, probs)}
meta = (
df_full.groupby("sender_id")[["twin_pair_id", "twin_label"]]
.first()
.reset_index()
)
meta = meta[(meta["sender_id"].isin(active_nodes)) & (meta["twin_pair_id"] >= 0)]
pair_scores = []
for _, pair_df in meta.groupby("twin_pair_id"):
if len(pair_df) != 2 or set(pair_df["twin_label"]) != {0, 1}:
continue
fraud_node = int(pair_df.loc[pair_df["twin_label"] == 1, "sender_id"].iloc[0])
benign_node = int(pair_df.loc[pair_df["twin_label"] == 0, "sender_id"].iloc[0])
if fraud_node not in score_map or benign_node not in score_map:
continue
pair_scores.append(float(score_map[fraud_node] > score_map[benign_node]))
if not pair_scores:
return 0.0, 0
return float(np.mean(pair_scores)), int(len(pair_scores))
def compute_split_leakage(df_train: pd.DataFrame, df_test: pd.DataFrame) -> dict:
train_users = set(df_train["sender_id"].unique().tolist())
test_users = set(df_test["sender_id"].unique().tolist())
leakage = {
"sender_overlap_count": int(len(train_users & test_users)),
"pair_overlap_count": 0,
"template_overlap_count": 0,
"receiver_pair_overlap_count": 0,
}
if "twin_pair_id" in df_train.columns and "twin_pair_id" in df_test.columns:
train_pairs = set(df_train.loc[df_train["twin_pair_id"] >= 0, "twin_pair_id"].unique().tolist())
test_pairs = set(df_test.loc[df_test["twin_pair_id"] >= 0, "twin_pair_id"].unique().tolist())
leakage["pair_overlap_count"] = int(len(train_pairs & test_pairs))
if "template_id" in df_train.columns and "template_id" in df_test.columns:
train_templates = set(df_train.loc[df_train["template_id"] >= 0, "template_id"].unique().tolist())
test_templates = set(df_test.loc[df_test["template_id"] >= 0, "template_id"].unique().tolist())
leakage["template_overlap_count"] = int(len(train_templates & test_templates))
# Receiver-pair overlap: distinct (sender_id, receiver_id) tuples
train_rpairs = set(zip(
df_train["sender_id"].tolist(), df_train["receiver_id"].tolist()
))
test_rpairs = set(zip(
df_test["sender_id"].tolist(), df_test["receiver_id"].tolist()
))
leakage["receiver_pair_overlap_count"] = int(len(train_rpairs & test_rpairs))
return leakage
# ---------------------------------------------------------------------------
# Prefix-only evaluation guard
# ---------------------------------------------------------------------------
def assert_prefix_only(df_prefix: pd.DataFrame, cutoff_time: float) -> None:
"""Warn if any future event slipped into the prefix.
Uses 1.0s tolerance to absorb float32-vs-float64 precision gaps.
"""
if df_prefix.empty:
return
actual_max = float(df_prefix["timestamp"].max())
if actual_max > cutoff_time + 1.0:
print(
f"[PREFIX LEAK] df_prefix max timestamp {actual_max:.2f} > cutoff {cutoff_time:.2f}!"
)
# ---------------------------------------------------------------------------
# Label-source audit
# ---------------------------------------------------------------------------
def build_label_source_audit_table(df: pd.DataFrame) -> pd.DataFrame:
"""Return per-positive-event audit table.
Required audit columns (populated by FraudEngine):
fraud_source, motif_source, motif_hit_count, trigger_event_idx,
label_event_idx, label_delay, is_fallback_label
"""
fraud_rows = df[df["is_fraud"] == 1].copy()
if fraud_rows.empty:
return pd.DataFrame()
audit_cols = [
"sender_id", "twin_pair_id", "twin_role",
"fraud_source", "motif_source", "motif_hit_count",
"trigger_event_idx", "label_event_idx", "label_delay",
"is_fallback_label",
]
available = [c for c in audit_cols if c in fraud_rows.columns]
return fraud_rows[available].reset_index(drop=True)
def compute_motif_label_consistency(df: pd.DataFrame, calib_mode: bool = False) -> dict:
"""Compute and print motif/label consistency statistics."""
has_motif = "motif_hit_count" in df.columns
has_fraud = "is_fraud" in df.columns
if not (has_motif and has_fraud):
return {}
# Restrict to twin users only
if "twin_pair_id" in df.columns:
twin_df = df[df["twin_pair_id"] >= 0].copy()
else:
twin_df = df.copy()
if twin_df.empty:
return {}
# Node-level aggregation
node_grp = twin_df.groupby("sender_id")
node_label = node_grp["is_fraud"].max()
node_hit = node_grp["motif_hit_count"].max()
node_role = node_grp["twin_role"].first() if "twin_role" in twin_df.columns else None
has_hit = (node_hit >= 1)
label_pos = (node_label == 1)
p_label_given_hit = float(label_pos[has_hit].mean()) if has_hit.any() else float("nan")
p_label_given_nohit = float(label_pos[~has_hit].mean()) if (~has_hit).any() else float("nan")
p_hit_given_label = float(has_hit[label_pos].mean()) if label_pos.any() else float("nan")
if node_role is not None:
benign_mask = (node_role == "benign")
accidental_motif_rate = float(has_hit[benign_mask].mean()) if benign_mask.any() else float("nan")
avg_hits_fraud = float(node_hit[~benign_mask & label_pos].mean()) if (label_pos & ~benign_mask).any() else float("nan")
avg_hits_benign = float(node_hit[benign_mask].mean()) if benign_mask.any() else float("nan")
else:
accidental_motif_rate = float("nan")
avg_hits_fraud = float("nan")
avg_hits_benign = float("nan")
result = {
"p_label_given_hit": p_label_given_hit,
"p_label_given_nohit": p_label_given_nohit,
"p_hit_given_label": p_hit_given_label,
"accidental_benign_motif": accidental_motif_rate,
"avg_hits_fraud_twin": avg_hits_fraud,
"avg_hits_benign_twin": avg_hits_benign,
}
print("\n--- Motif-Label Consistency ---")
for k, v in result.items():
print(f" {k:<30}: {v:.4f}" if not (isinstance(v, float) and v != v) else f" {k:<30}: N/A")
if calib_mode:
# In calib mode, verify no fallback positives exist
if "is_fallback_label" in df.columns:
fallback_pos = int(df.loc[df["is_fraud"] == 1, "is_fallback_label"].sum())
print(f" {'fallback_positives':<30}: {fallback_pos}")
if fallback_pos > 0:
print(" [CALIB VIOLATION] Fallback positives found! is_fallback_label.sum() must be 0.")
result["fallback_positives"] = fallback_pos
return result
def compute_label_delay_stats(df: pd.DataFrame) -> dict:
"""Print and return min/mean/max label_delay for positive events."""
if "label_delay" not in df.columns:
return {}
delays = df.loc[(df["is_fraud"] == 1) & (df["label_delay"] >= 0), "label_delay"]
if delays.empty:
print(" label_delay: no valid delay data.")
return {"delay_min": float("nan"), "delay_mean": float("nan"), "delay_max": float("nan")}
result = {
"delay_min": float(delays.min()),
"delay_mean": float(delays.mean()),
"delay_max": float(delays.max()),
}
print(f" label_delay min={result['delay_min']:.1f} mean={result['delay_mean']:.1f} max={result['delay_max']:.1f}")
return result
# ---------------------------------------------------------------------------
# Prefix-task helpers
# ---------------------------------------------------------------------------
def uses_twin_pairs(df: pd.DataFrame) -> bool:
return "twin_pair_id" in df.columns and bool((df["twin_pair_id"] >= 0).any())
def get_eval_nodes(df: pd.DataFrame) -> List[int]:
if uses_twin_pairs(df):
pair_df = df[df["twin_pair_id"] >= 0]
return sorted(pair_df["sender_id"].unique().tolist())
return sorted(df["sender_id"].unique().tolist())
def remap_node_ids(*dfs: pd.DataFrame) -> list[pd.DataFrame]:
non_empty = [df for df in dfs if df is not None and not df.empty]
if not non_empty:
return [df.copy() for df in dfs]
all_ids = np.unique(
np.concatenate(
[
np.concatenate(
[
df["sender_id"].to_numpy(dtype=np.int64),
df["receiver_id"].to_numpy(dtype=np.int64),
]
)
for df in non_empty
]
)
)
id_map = {int(node_id): idx for idx, node_id in enumerate(all_ids.tolist())}
remapped = []
for df in dfs:
if df is None:
remapped.append(df)
continue
out = df.copy()
out["sender_id"] = out["sender_id"].map(id_map).astype(np.int64)
out["receiver_id"] = out["receiver_id"].map(id_map).astype(np.int64)
remapped.append(out)
return remapped
def augment_with_placeholder_nodes(df_train: pd.DataFrame, df_test: pd.DataFrame) -> pd.DataFrame:
train_nodes = set(
np.concatenate(
[
df_train["sender_id"].to_numpy(dtype=np.int64),
df_train["receiver_id"].to_numpy(dtype=np.int64),
]
).tolist()
)
test_nodes = set(
np.concatenate(
[
df_test["sender_id"].to_numpy(dtype=np.int64),
df_test["receiver_id"].to_numpy(dtype=np.int64),
]
).tolist()
)
unseen_nodes = sorted(test_nodes - train_nodes)
if not unseen_nodes:
return df_train
base_time = float(min(df_train["timestamp"].min(), df_test["timestamp"].min())) - 1.0
rows = []
for offset, node_id in enumerate(unseen_nodes):
row = {}
for col in df_train.columns:
if col in {"sender_id", "receiver_id"}:
row[col] = int(node_id)
elif col == "timestamp":
row[col] = base_time - offset
elif col in {"fraud_type", "twin_role"}:
row[col] = "placeholder"
elif col in {"txn_id", "twin_pair_id", "template_id"}:
row[col] = -1
elif col in {"twin_label", "is_fraud", "is_retry", "failed"}:
row[col] = 0
else:
row[col] = 0.0
rows.append(row)
placeholder_df = pd.DataFrame(rows, columns=df_train.columns)
out = pd.concat([placeholder_df, df_train], ignore_index=True)
return out.sort_values("timestamp").reset_index(drop=True)
def split_temporally(df: pd.DataFrame, train_ratio: float = 0.7) -> tuple[pd.DataFrame, pd.DataFrame, float]:
df = df.sort_values("timestamp").reset_index(drop=True)
if uses_twin_pairs(df):
pair_meta = (
df[df["twin_pair_id"] >= 0]
.groupby("twin_pair_id")["timestamp"]
.min()
.sort_values()
)
if len(pair_meta) >= 2:
split_idx = max(1, min(len(pair_meta) - 1, int(train_ratio * len(pair_meta))))
train_ids = set(pair_meta.index[:split_idx].tolist())
test_ids = set(pair_meta.index[split_idx:].tolist())
df_train = df[(df["twin_pair_id"] < 0) | (df["twin_pair_id"].isin(train_ids))].copy()
df_test = df[df["twin_pair_id"].isin(test_ids)].copy()
split_time = float(df_test["timestamp"].min()) if not df_test.empty else float(df["timestamp"].quantile(train_ratio))
return df_train.sort_values("timestamp").reset_index(drop=True), df_test.sort_values("timestamp").reset_index(drop=True), split_time
split_time = float(df["timestamp"].quantile(train_ratio))
df_train = df[df["timestamp"] <= split_time].copy()
df_test = df[df["timestamp"] > split_time].copy()
return df_train, df_test, split_time
def horizon_to_delta(df_test: pd.DataFrame, horizon: float) -> float:
if df_test.empty:
return 1e-6
t_min = float(df_test["timestamp"].min())
t_max = float(df_test["timestamp"].max())
return max(1e-6, horizon * max(t_max - t_min, 1e-6))
def build_window_labels(
df: pd.DataFrame,
cutoff_time: float,
eval_nodes: Sequence[int],
delta_time: float,
) -> np.ndarray:
future = df[(df["timestamp"] > cutoff_time) & (df["timestamp"] <= cutoff_time + delta_time)]
fraud_map = future.groupby("sender_id")["is_fraud"].max()
return np.array([int(fraud_map.get(node_id, 0)) for node_id in eval_nodes], dtype=np.float32)
def build_window_state(
df: pd.DataFrame,
cutoff_time: float,
eval_nodes: Sequence[int],
delta_time: float,
) -> np.ndarray:
future = df[(df["timestamp"] > cutoff_time) & (df["timestamp"] <= cutoff_time + delta_time)]
if "dynamic_fraud_state" in future.columns:
state_map = future.groupby("sender_id")["dynamic_fraud_state"].mean()
else:
state_map = future.groupby("sender_id")["is_fraud"].mean()
return np.array([float(state_map.get(node_id, 0.0)) for node_id in eval_nodes], dtype=np.float32)
def choose_anchor_time(df_train: pd.DataFrame, delta_time: float) -> float:
t_min = float(df_train["timestamp"].min())
max_anchor = float(df_train["timestamp"].max()) - delta_time
if max_anchor <= t_min:
return t_min
candidate_quantiles = [0.80, 0.75, 0.70, 0.65, 0.60, 0.55]
for quantile in candidate_quantiles:
anchor_time = min(float(df_train["timestamp"].quantile(quantile)), max_anchor)
prefix_nodes = get_eval_nodes(df_train[df_train["timestamp"] <= anchor_time])
if not prefix_nodes:
continue
y_anchor = build_window_labels(df_train, anchor_time, prefix_nodes, delta_time)
if len(np.unique(y_anchor)) >= 2:
return anchor_time
return min(float(df_train["timestamp"].quantile(0.80)), max_anchor)
def make_checkpoints(df_test: pd.DataFrame, delta_time: float, n_checkpoints: int) -> List[float]:
if df_test.empty:
return []
t_max = float(df_test["timestamp"].max())
valid = df_test[df_test["timestamp"] <= t_max - delta_time].sort_values("timestamp")
if valid.empty:
return []
timestamps = valid["timestamp"].to_numpy(dtype=np.float64)
idx = np.unique(
np.linspace(0, len(timestamps) - 1, num=min(n_checkpoints, len(timestamps)), dtype=int)
)
checkpoints = [float(timestamps[i]) for i in idx]
return sorted(set(checkpoints))
def train_node_head(
model: TemporalModel,
df_anchor_prefix: pd.DataFrame,
eval_nodes: List[int],
y_labels: np.ndarray,
num_epochs: int = 150,
) -> None:
if hasattr(model, "train_node_classifier_on_prefix"):
model.train_node_classifier_on_prefix(
df_anchor_prefix, eval_nodes, y_labels, num_epochs=num_epochs
)
return
if model.is_temporal:
model.reset_memory()
if len(df_anchor_prefix) > 0 and len(eval_nodes) > 0:
model.predict(df_anchor_prefix, eval_nodes)
if hasattr(model, "train_node_classifier"):
model.train_node_classifier(eval_nodes, y_labels, num_epochs=num_epochs)
if isinstance(model, TGNWrapper):
assert model._node_head_fitted, "TGN node classifier was not fitted."
return
raise ValueError(f"Model {model.name} does not expose node-head training.")
def fit_model_for_horizon(
model: TemporalModel,
df_train: pd.DataFrame,
delta_time: float,
num_epochs: int,
node_epochs: int,
) -> dict:
# Strip oracle columns from all non-oracle models
train_input = df_train if model.name in _ORACLE_MODEL_NAMES else strip_oracle_cols(df_train)
model.fit(train_input, num_epochs=num_epochs)
anchor_time = choose_anchor_time(train_input, delta_time)
df_anchor_prefix = train_input[train_input["timestamp"] <= anchor_time].copy()
assert_prefix_only(df_anchor_prefix, anchor_time)
anchor_nodes = get_eval_nodes(df_anchor_prefix)
y_anchor = build_window_labels(train_input, anchor_time, anchor_nodes, delta_time)
train_node_head(
model,
df_anchor_prefix=df_anchor_prefix,
eval_nodes=anchor_nodes,
y_labels=y_anchor,
num_epochs=node_epochs,
)
return {
"anchor_time": anchor_time,
"anchor_nodes": len(anchor_nodes),
"anchor_fraud_rate": float(y_anchor.mean()) if len(y_anchor) else 0.0,
}
def collect_prefix_predictions(
model: TemporalModel,
df_train: pd.DataFrame,
df_test: pd.DataFrame,
delta_time: float,
n_checkpoints: int,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
checkpoints = make_checkpoints(df_test, delta_time, n_checkpoints=n_checkpoints)
if not checkpoints:
return (
np.zeros(0, dtype=np.float32),
np.zeros(0, dtype=np.float32),
np.zeros(0, dtype=np.float32),
)
df_full = (
pd.concat([df_train, df_test], ignore_index=True)
.sort_values("timestamp")
.reset_index(drop=True)
)
y_chunks: List[np.ndarray] = []
p_chunks: List[np.ndarray] = []
s_chunks: List[np.ndarray] = []
is_oracle = model.name in _ORACLE_MODEL_NAMES
for cutoff_time in checkpoints:
active_nodes = get_eval_nodes(df_test[df_test["timestamp"] <= cutoff_time])
if not active_nodes:
continue
prefix_df = df_full[df_full["timestamp"] <= cutoff_time].copy()
assert_prefix_only(prefix_df, cutoff_time)
eval_df = prefix_df if is_oracle else strip_oracle_cols(prefix_df)
if model.is_temporal:
model.reset_memory()
probs = model.predict(eval_df, active_nodes)
y_true = build_window_labels(df_full, cutoff_time, active_nodes, delta_time)
state = build_window_state(df_full, cutoff_time, active_nodes, delta_time)
y_chunks.append(y_true)
p_chunks.append(np.asarray(probs, dtype=np.float32))
s_chunks.append(state)
if not y_chunks:
return (
np.zeros(0, dtype=np.float32),
np.zeros(0, dtype=np.float32),
np.zeros(0, dtype=np.float32),
)
return (
np.concatenate(y_chunks).astype(np.float32),
np.concatenate(p_chunks).astype(np.float32),
np.concatenate(s_chunks).astype(np.float32),
)
def evaluate_model(
model: TemporalModel,
df_train: pd.DataFrame,
df_test: pd.DataFrame,
delta_time: float,
n_checkpoints: int,
) -> tuple[dict, np.ndarray, np.ndarray, np.ndarray]:
y_true, probs, states = collect_prefix_predictions(
model=model,
df_train=df_train,
df_test=df_test,
delta_time=delta_time,
n_checkpoints=n_checkpoints,
)
metrics = compute_metrics(y_true, probs) if len(y_true) else compute_metrics(np.array([0.0]), np.array([0.5]))
metrics["n_predictions"] = int(len(y_true))
return metrics, y_true, probs, states
def shuffle_chronology(df: pd.DataFrame, seed: int) -> pd.DataFrame:
"""Break temporal order while preserving the event table."""
rng = np.random.default_rng(seed)
shuffled = df.copy()
shuffled["timestamp"] = rng.permutation(shuffled["timestamp"].to_numpy(dtype=np.float64))
return shuffled.sort_values("timestamp").reset_index(drop=True)
# ---------------------------------------------------------------------------
# Experiments (single seed)
# ---------------------------------------------------------------------------
def run_ood_single(
df_easy: pd.DataFrame,
df_medium: pd.DataFrame,
df_hard: pd.DataFrame,
device: str,
num_epochs: int,
node_epochs: int,
n_checkpoints: int,
horizon: float = 0.10,
) -> pd.DataFrame:
df_train = (
pd.concat([df_easy, df_medium], ignore_index=True)
.sort_values("timestamp")
.reset_index(drop=True)
)
df_test = df_hard.sort_values("timestamp").reset_index(drop=True)
df_train, df_test = remap_node_ids(df_train, df_test)
df_train = augment_with_placeholder_nodes(df_train, df_test)
delta_time = horizon_to_delta(df_test, horizon)
rows = []
models = build_models(device=device)
for model_name in MODEL_ORDER:
model = models[model_name]
fit_info = fit_model_for_horizon(model, df_train, delta_time, num_epochs, node_epochs)
metrics, _, _, _ = evaluate_model(model, df_train, df_test, delta_time, n_checkpoints)
rows.append({
"model": model_name,
**metrics,
**fit_info,
})
df_out = pd.DataFrame(rows)
xgb_roc = float(df_out.loc[df_out["model"] == "XGBoost", "roc_auc"].iloc[0])
df_out["gap_vs_xgb"] = df_out["roc_auc"] - xgb_roc
return df_out
def run_causal_single(
df_hard: pd.DataFrame,
device: str,
num_epochs: int,
node_epochs: int,
n_checkpoints: int,
seed: int,
horizon: float = 0.10,
) -> pd.DataFrame:
df_clean = df_hard.sort_values("timestamp").reset_index(drop=True)
df_shuffled = shuffle_chronology(df_clean, seed=seed + 17)
df_train_clean, df_test_clean, _ = split_temporally(df_clean)
df_train_shuf, df_test_shuf, _ = split_temporally(df_shuffled)
df_train_clean, df_test_clean = remap_node_ids(df_train_clean, df_test_clean)
df_train_shuf, df_test_shuf = remap_node_ids(df_train_shuf, df_test_shuf)
df_train_clean = augment_with_placeholder_nodes(df_train_clean, df_test_clean)
df_train_shuf = augment_with_placeholder_nodes(df_train_shuf, df_test_shuf)
delta_time_clean = horizon_to_delta(df_test_clean, horizon)
delta_time_shuf = horizon_to_delta(df_test_shuf, horizon)
rows = []
clean_models = build_models(device=device)
shuffled_models = build_models(device=device)
for model_name in MODEL_ORDER:
clean_model = clean_models[model_name]
shuffled_model = shuffled_models[model_name]
fit_model_for_horizon(clean_model, df_train_clean, delta_time_clean, num_epochs, node_epochs)
clean_metrics, _, _, _ = evaluate_model(
clean_model, df_train_clean, df_test_clean, delta_time_clean, n_checkpoints
)
fit_model_for_horizon(shuffled_model, df_train_shuf, delta_time_shuf, num_epochs, node_epochs)
shuffled_metrics, _, _, _ = evaluate_model(
shuffled_model, df_train_shuf, df_test_shuf, delta_time_shuf, n_checkpoints
)
rows.append({
"model": model_name,
"roc_auc_clean": clean_metrics["roc_auc"],
"pr_auc_clean": clean_metrics["pr_auc"],
"brier_clean": clean_metrics["brier"],
"ece_clean": clean_metrics["ece"],
"roc_auc_shuffled": shuffled_metrics["roc_auc"],
"pr_auc_shuffled": shuffled_metrics["pr_auc"],
"brier_shuffled": shuffled_metrics["brier"],
"ece_shuffled": shuffled_metrics["ece"],
"delta": shuffled_metrics["roc_auc"] - clean_metrics["roc_auc"],
})
return pd.DataFrame(rows)
def run_horizon_single(
df_medium: pd.DataFrame,
device: str,
num_epochs: int,
node_epochs: int,
n_checkpoints: int,
horizons: Sequence[float],
) -> pd.DataFrame:
df_train, df_test, _ = split_temporally(df_medium)
df_train, df_test = remap_node_ids(df_train, df_test)
df_train = augment_with_placeholder_nodes(df_train, df_test)
rows = []
for horizon in horizons:
delta_time = horizon_to_delta(df_test, horizon)
models = build_models(device=device)
for model_name in MODEL_ORDER:
model = models[model_name]
fit_model_for_horizon(model, df_train, delta_time, num_epochs, node_epochs)
metrics, _, _, _ = evaluate_model(model, df_train, df_test, delta_time, n_checkpoints)
rows.append({
"horizon": float(horizon),
"model": model_name,
**metrics,
})
return pd.DataFrame(rows)
def run_mechanistic_single(
df_hard: pd.DataFrame,
device: str,
num_epochs: int,
node_epochs: int,
n_checkpoints: int,
horizon: float = 0.10,
) -> pd.DataFrame:
df_train, df_test, _ = split_temporally(df_hard)
df_train, df_test = remap_node_ids(df_train, df_test)
df_train = augment_with_placeholder_nodes(df_train, df_test)
delta_time = horizon_to_delta(df_test, horizon)
rows = []
models = build_models(device=device)
for model_name in MODEL_ORDER:
model = models[model_name]
fit_model_for_horizon(model, df_train, delta_time, num_epochs, node_epochs)
_, _, probs, states = evaluate_model(model, df_train, df_test, delta_time, n_checkpoints)
rows.append({
"model": model_name,
"pearson_r": safe_pearson(states, probs),
})
return pd.DataFrame(rows)
def run_audit_single(
df_hard: pd.DataFrame,
device: str,
num_epochs: int,
node_epochs: int,
n_checkpoints: int,
seed: int,
horizon: float = 0.10,
benchmark_mode: str = "temporal_twins",
) -> pd.DataFrame:
node_audit = build_node_audit_table(df_hard)
ks_mean, ks_max = compute_aggregate_ks(node_audit)
paired_pairs = int(node_audit.loc[node_audit["twin_pair_id"] >= 0, "twin_pair_id"].nunique())
paired_nodes = int((node_audit["twin_pair_id"] >= 0).sum())
# --- Label-source audit ---
calib_mode = benchmark_mode == "temporal_twins_oracle_calib"
print("\n--- Label-Source Audit ---")
audit_tbl = build_label_source_audit_table(df_hard)
if not audit_tbl.empty:
print(audit_tbl.to_string(index=False, max_rows=20))
consistency = compute_motif_label_consistency(df_hard, calib_mode=calib_mode)
compute_label_delay_stats(df_hard)
df_train, df_test, _ = split_temporally(df_hard)
leakage = compute_split_leakage(df_train, df_test)
# --- Split integrity report ---
print("\n--- Split Integrity ---")
for k, v in leakage.items():
status = "[OK]" if v == 0 else "[WARN]"
print(f" {status} {k}: {v}")
df_train, df_test = remap_node_ids(df_train, df_test)
train_examples, train_pair_rows, train_pair_counts = build_matched_control_tables(df_train)
test_examples, test_pair_rows, test_pair_counts = build_matched_control_tables(df_test)
matched_train_features = build_matched_prefix_feature_table(df_train, train_examples)
matched_test_features = build_matched_prefix_feature_table(df_test, test_examples)
matched_audit = report_matched_control_audits(
test_examples=test_examples,
test_pair_rows=test_pair_rows,
test_pair_counts=test_pair_counts,
)
static_agg_result = compute_matched_static_aggregate_auc(
matched_train_features,
matched_test_features,
seed=seed,
verbose=True,
)
xgb_result = compute_matched_xgboost_auc(
matched_train_features,
matched_test_features,
seed=seed,
)
static_gnn_result = compute_matched_static_gnn_auc(
df_train=df_train,
df_test=df_test,
train_examples=train_examples,
test_examples=test_examples,
device=device,
num_epochs=num_epochs,
seed=seed,
)
df_train_eval = augment_with_placeholder_nodes(df_train, df_test)
delta_time = horizon_to_delta(df_test, horizon)
models = build_models(device=device)
rows = []
for model_name in MODEL_ORDER:
model = models[model_name]
fit_model_for_horizon(model, df_train_eval, delta_time, num_epochs, node_epochs)
metrics, _, probs, states = evaluate_model(model, df_train_eval, df_test, delta_time, n_checkpoints)
pair_sep, eval_pairs = evaluate_matched_pair_separability(
model,
df_train=df_train_eval,
df_test=df_test,
delta_time=delta_time,
n_checkpoints=n_checkpoints,
)
matched_control_roc_auc = float("nan")
if model_name == "XGBoost":
matched_control_roc_auc = float(xgb_result["auc"])
elif model_name == "StaticGNN":
matched_control_roc_auc = float(static_gnn_result["auc"])
rows.append({
"model": model_name,
**metrics,
"pearson_r": safe_pearson(states, probs),
"matched_pair_sep": pair_sep,
"matched_pair_eval_pairs": eval_pairs,
"matched_control_roc_auc": matched_control_roc_auc,
"static_agg_auc": float(static_agg_result["auc"]),
"static_agg_auc_bootstrap_std": float(static_agg_result["bootstrap_std"]),
"xgb_auc_bootstrap_std": float(xgb_result["bootstrap_std"]),
"static_gnn_auc_bootstrap_std": float(static_gnn_result["bootstrap_std"]),
"ks_mean": ks_mean,
"ks_max": ks_max,
"paired_pairs": paired_pairs,
"paired_nodes": paired_nodes,
**leakage,
**matched_audit,
})
return pd.DataFrame(rows)
# ---------------------------------------------------------------------------
# Aggregation / plotting outputs
# ---------------------------------------------------------------------------
def summarise_mean_std(df: pd.DataFrame, group_cols: Sequence[str], value_cols: Sequence[str]) -> pd.DataFrame:
summary = df.groupby(list(group_cols)).agg({
value_col: ["mean", "std"] for value_col in value_cols
})
summary.columns = [
f"{value_col}_{stat}"
for value_col, stat in summary.columns.to_flat_index()
]
summary = summary.reset_index()
return summary.fillna(0.0)
def save_experiment_outputs(
raw_frames: Dict[str, List[pd.DataFrame]],
results_dir: str,
) -> None:
os.makedirs(results_dir, exist_ok=True)
raw_causal = pd.concat(raw_frames["causal"], ignore_index=True) if raw_frames["causal"] else None
if raw_frames["ood"]:
raw_ood = pd.concat(raw_frames["ood"], ignore_index=True)
raw_ood.to_csv(os.path.join(results_dir, "ood_raw.csv"), index=False)
ood_summary = summarise_mean_std(
raw_ood,
group_cols=["model"],
value_cols=["roc_auc", "pr_auc", "brier", "ece", "gap_vs_xgb"],
)
ood_summary.to_csv(os.path.join(results_dir, "ood.csv"), index=False)
if raw_frames["causal"]:
assert raw_causal is not None
raw_causal.to_csv(os.path.join(results_dir, "causal_raw.csv"), index=False)
causal_summary = summarise_mean_std(
raw_causal,
group_cols=["model"],
value_cols=[
"roc_auc_clean",
"pr_auc_clean",
"brier_clean",
"ece_clean",
"roc_auc_shuffled",
"pr_auc_shuffled",
"brier_shuffled",
"ece_shuffled",
"delta",
],
)
causal_summary.to_csv(os.path.join(results_dir, "causal.csv"), index=False)
if raw_frames["horizon"]:
raw_horizon = pd.concat(raw_frames["horizon"], ignore_index=True)
raw_horizon.to_csv(os.path.join(results_dir, "horizon_raw.csv"), index=False)
horizon_summary = summarise_mean_std(
raw_horizon,
group_cols=["horizon", "model"],
value_cols=["roc_auc", "pr_auc", "brier", "ece"],
)
horizon_summary.to_csv(os.path.join(results_dir, "horizon.csv"), index=False)
if raw_frames["mechanistic"]:
raw_mech = pd.concat(raw_frames["mechanistic"], ignore_index=True)
raw_mech.to_csv(os.path.join(results_dir, "mechanistic_raw.csv"), index=False)
mech_summary = summarise_mean_std(
raw_mech,
group_cols=["model"],
value_cols=["pearson_r"],
)
mech_summary.to_csv(os.path.join(results_dir, "mechanistic.csv"), index=False)
if raw_frames.get("audit"):
raw_audit = pd.concat(raw_frames["audit"], ignore_index=True)
raw_audit.to_csv(os.path.join(results_dir, "audit_raw.csv"), index=False)
audit_summary = summarise_mean_std(
raw_audit,
group_cols=["model"],
value_cols=[
"roc_auc",
"pr_auc",
"brier",
"ece",
"pearson_r",
"matched_pair_sep",
"matched_pair_eval_pairs",
"matched_control_roc_auc",
"static_agg_auc",
"ks_mean",
"ks_max",
"paired_pairs",
"paired_nodes",
"sender_overlap_count",
"pair_overlap_count",
"template_overlap_count",
"pair_total_txn_count_diff_mean",
"pair_total_txn_count_diff_max",
"auc_total_txn_count",
"auc_local_event_idx",
"auc_prefix_txn_count",
"auc_timestamp",
"auc_account_age",
"auc_active_age",
"fraud_label_event_idx_mean",
"fraud_label_event_idx_max",
"benign_eval_event_idx_mean",
"benign_eval_event_idx_max",
"pair_event_idx_diff_mean",
"pair_event_idx_diff_max",
"pair_active_age_diff_mean",
"pair_active_age_diff_max",
"pair_timestamp_diff_mean",
"pair_timestamp_diff_max",
"benign_motif_hit_rate",
"benign_motif_hit_pairs",
"matched_control_examples",
"matched_control_pair_events",
],
)
if raw_causal is not None:
causal_delta = summarise_mean_std(
raw_causal,
group_cols=["model"],
value_cols=["delta"],
)[["model", "delta_mean", "delta_std"]]
audit_summary = audit_summary.merge(causal_delta, on="model", how="left")
audit_summary[["delta_mean", "delta_std"]] = audit_summary[
["delta_mean", "delta_std"]
].fillna(0.0)
audit_summary.to_csv(os.path.join(results_dir, "audit.csv"), index=False)
# ---------------------------------------------------------------------------
# Node-level oracle evaluation helpers (twin_label, not window label)
# ---------------------------------------------------------------------------
def _twin_labels_for_nodes(df_full: pd.DataFrame, nodes: List[int]) -> np.ndarray:
"""Return twin_label (1=fraud twin, 0=benign) per node. Falls back to
is_fraud if twin_label is absent."""
col = "twin_label" if "twin_label" in df_full.columns else "is_fraud"
label_series = df_full.groupby("sender_id")[col].max()
return np.array([float(label_series.get(n, 0.0)) for n in nodes], dtype=np.float32)
def evaluate_oracle_node_level(
model: TemporalModel,
df_full: pd.DataFrame,
eval_nodes: List[int],
) -> float:
"""ROC-AUC of oracle scored against twin_label (user-level, not window-level).
For oracle-type models we pass the FULL df (with audit columns).
For AuditOracle, predict() directly reads motif_hit_count — no training.
For RawMotifOracle, train_node_classifier_on_prefix must be called first
with twin_labels so it learns the node-level task.
"""
if not eval_nodes:
return float("nan")
y_true = _twin_labels_for_nodes(df_full, eval_nodes)
probs = model.predict(df_full, eval_nodes)
return safe_roc_auc(y_true, probs.astype(np.float32))
def evaluate_oracle_pair_sep_node_level(
model: TemporalModel,
df_full: pd.DataFrame,
eval_nodes: List[int],
) -> float:
"""Matched-pair separability: P(score_fraud > score_benign) using twin_label."""
if not eval_nodes or "twin_pair_id" not in df_full.columns:
return float("nan")
probs = model.predict(df_full, eval_nodes)
score_map = {n: float(p) for n, p in zip(eval_nodes, probs)}
meta = (
df_full[df_full["sender_id"].isin(eval_nodes) & (df_full["twin_pair_id"] >= 0)]
.groupby("sender_id")
.agg(twin_pair_id=("twin_pair_id", "first"), twin_label=("twin_label", "max"))
.reset_index()
)
pair_results: List[float] = []
for _, grp in meta.groupby("twin_pair_id"):
if len(grp) != 2 or set(grp["twin_label"]) != {0, 1}:
continue
fraud_node = int(grp.loc[grp["twin_label"] == 1, "sender_id"].iloc[0])
benign_node = int(grp.loc[grp["twin_label"] == 0, "sender_id"].iloc[0])
if fraud_node in score_map and benign_node in score_map:
pair_results.append(float(score_map[fraud_node] > score_map[benign_node]))
return float(np.mean(pair_results)) if pair_results else float("nan")
def build_oracle_debug_table(
df_full: pd.DataFrame,
eval_nodes: List[int],
oracle_scores: dict[str, np.ndarray],
y_twin: np.ndarray,
n_sample: int = 20,
primary_score_name: str = "AuditOracle",
table_title: str = "Oracle Debug Table",
) -> pd.DataFrame:
"""Print a per-node debug table for oracle/probe scores vs ground-truth."""
audit_cols = [
"twin_pair_id", "twin_role",
"motif_hit_count", "trigger_event_idx", "label_event_idx",
"label_delay", "is_fallback_label",
]
available = [c for c in audit_cols if c in df_full.columns]
meta = (
df_full[df_full["sender_id"].isin(eval_nodes)]
.groupby("sender_id")[available]
.first()
.reset_index() # sender_id becomes a column here
)
meta["twin_label"] = y_twin
meta["_idx"] = meta["sender_id"].map({n: i for i, n in enumerate(eval_nodes)})
for name, scores in oracle_scores.items():
meta[f"score_{name}"] = meta["_idx"].map(
{i: float(scores[i]) for i in range(len(scores))}
)
meta = meta.drop(columns=["_idx"])
# Sample: top n_sample/2 by the primary motif score + bottom n_sample/2
sort_col = f"score_{primary_score_name}"
if sort_col not in meta.columns:
sort_col = meta.columns[-1]
meta = meta.sort_values(sort_col, ascending=False)
sample = pd.concat([meta.head(n_sample // 2), meta.tail(n_sample // 2)]).drop_duplicates()
print(f"\n--- {table_title} (top & bottom by {primary_score_name} score) ---")
print(sample.to_string(index=False))
return sample
# Gate volume targets / budgets
_FAST_GATE_MIN_MATCHED_PAIRS = 500
_FULL_GATE_MIN_MATCHED_PAIRS = 2000
_GATE_MIN_CLASS_EXAMPLES = 500
_GATE_MIN_UNIQUE_USERS = 50
_GATE_POS_RATE_RANGE = (0.35, 0.65)
_GATE_BOOTSTRAP_ROUNDS = 200
_GATE_PACK_NAMESPACE = 10_000_000
_GATE_MAX_EXTRA_PACKS = 6
def _subsample_for_gate(
df: pd.DataFrame,
rng: np.random.Generator,
max_pairs: int | None = None,
) -> pd.DataFrame:
"""Keep at most max_pairs twin pairs for the gate."""
if "twin_pair_id" not in df.columns:
return df
pair_ids = df.loc[df["twin_pair_id"] >= 0, "twin_pair_id"].unique()
if max_pairs is None or max_pairs <= 0 or len(pair_ids) <= max_pairs:
return df[df["twin_pair_id"] >= 0].copy()
chosen = set(rng.choice(pair_ids, size=max_pairs, replace=False).tolist())
return df[df["twin_pair_id"].isin(chosen)].copy()
def gate_volume_thresholds(fast_mode: bool) -> dict:
return {
"matched_eval_pairs_min": _FAST_GATE_MIN_MATCHED_PAIRS if fast_mode else _FULL_GATE_MIN_MATCHED_PAIRS,
"positives_min": _GATE_MIN_CLASS_EXAMPLES,
"negatives_min": _GATE_MIN_CLASS_EXAMPLES,
"unique_fraud_users_min": _GATE_MIN_UNIQUE_USERS,
"unique_benign_users_min": _GATE_MIN_UNIQUE_USERS,
"positive_rate_lo": _GATE_POS_RATE_RANGE[0],
"positive_rate_hi": _GATE_POS_RATE_RANGE[1],
}
def summarize_gate_volume(
test_examples: pd.DataFrame,
test_pair_rows: pd.DataFrame,
eval_nodes: Sequence[int],
) -> dict:
positives = int(test_examples["label"].sum()) if not test_examples.empty else 0
total_examples = int(len(test_examples))
negatives = int(total_examples - positives)
fraud_users = int(test_examples.loc[test_examples["label"] == 1, "sender_id"].nunique()) if not test_examples.empty else 0
benign_users = int(test_examples.loc[test_examples["label"] == 0, "sender_id"].nunique()) if not test_examples.empty else 0
unique_templates = int(test_examples["template_id"].nunique()) if ("template_id" in test_examples.columns and not test_examples.empty) else 0
positive_rate = float(positives / max(total_examples, 1))
return {
"matched_eval_pairs": int(len(test_pair_rows)),
"positives": positives,
"negatives": negatives,
"unique_fraud_users": fraud_users,
"unique_benign_users": benign_users,
"unique_templates": unique_templates,
"positive_rate": positive_rate,
"audit_n_examples": int(len(eval_nodes)),
"raw_n_examples": int(len(eval_nodes)),
"xgb_n_examples": total_examples,
"static_gnn_n_examples": total_examples,
}
def gate_volume_violations(volume: dict, fast_mode: bool) -> list[str]:
thresholds = gate_volume_thresholds(fast_mode)
violations: list[str] = []
if volume.get("matched_eval_pairs", 0) < thresholds["matched_eval_pairs_min"]:
violations.append(
f"matched_eval_pairs {volume.get('matched_eval_pairs', 0)} < {thresholds['matched_eval_pairs_min']}"
)
if volume.get("positives", 0) < thresholds["positives_min"]:
violations.append(f"positives {volume.get('positives', 0)} < {thresholds['positives_min']}")
if volume.get("negatives", 0) < thresholds["negatives_min"]:
violations.append(f"negatives {volume.get('negatives', 0)} < {thresholds['negatives_min']}")
if volume.get("unique_fraud_users", 0) < thresholds["unique_fraud_users_min"]:
violations.append(
f"unique_fraud_users {volume.get('unique_fraud_users', 0)} < {thresholds['unique_fraud_users_min']}"
)
if volume.get("unique_benign_users", 0) < thresholds["unique_benign_users_min"]:
violations.append(
f"unique_benign_users {volume.get('unique_benign_users', 0)} < {thresholds['unique_benign_users_min']}"
)
pos_rate = float(volume.get("positive_rate", 0.0))
if pos_rate < thresholds["positive_rate_lo"] or pos_rate > thresholds["positive_rate_hi"]:
violations.append(
f"positive_rate {pos_rate:.4f} outside [{thresholds['positive_rate_lo']:.2f}, {thresholds['positive_rate_hi']:.2f}]"
)
return violations
def gate_volume_is_sufficient(volume: dict, fast_mode: bool) -> bool:
return len(gate_volume_violations(volume, fast_mode)) == 0
def offset_gate_namespace(df: pd.DataFrame, pack_idx: int) -> pd.DataFrame:
if pack_idx == 0:
return df.copy()
out = df.copy()
offset = pack_idx * _GATE_PACK_NAMESPACE
out["sender_id"] = out["sender_id"].astype(np.int64) + offset
out["receiver_id"] = out["receiver_id"].astype(np.int64) + offset
for col in ("twin_pair_id", "template_id"):
if col in out.columns:
valid = out[col].astype(np.int64) >= 0
out.loc[valid, col] = out.loc[valid, col].astype(np.int64) + offset
return out
def build_gate_pool_from_frames(frames: Sequence[pd.DataFrame]) -> pd.DataFrame:
non_empty = [frame for frame in frames if frame is not None and not frame.empty]
if not non_empty:
return pd.DataFrame()
return (
pd.concat(non_empty, ignore_index=True)
.sort_values("timestamp")
.reset_index(drop=True)
)
def gate_pair_budget_candidates(total_pairs: int, fast_mode: bool) -> list[int | None]:
if total_pairs <= 0:
return [0]
target_budget = 900 if fast_mode else 3500
budgets = [min(total_pairs, target_budget)]
if total_pairs > budgets[0]:
budgets.append(total_pairs)
return [int(budget) for budget in dict.fromkeys(budgets)]
def prepare_gate_subset(
df_pool: pd.DataFrame,
seed: int,
fast_mode: bool,
) -> dict:
total_pairs = int(df_pool.loc[df_pool["twin_pair_id"] >= 0, "twin_pair_id"].nunique()) if "twin_pair_id" in df_pool.columns else 0
if total_pairs == 0:
empty = pd.DataFrame()
return {
"pair_budget": 0,
"df_gate": empty,
"df_train": empty,
"df_test": empty,
"df_train_eval": empty,
"df_full": empty,
"eval_nodes": [],
"train_examples": empty,
"train_pair_rows": empty,
"train_pair_counts": empty,
"test_examples": empty,
"test_pair_rows": empty,
"test_pair_counts": empty,
"volume": summarize_gate_volume(empty, empty, []),
}
best: dict | None = None
for pair_budget in gate_pair_budget_candidates(total_pairs, fast_mode):
gate_rng = np.random.default_rng(seed + int(pair_budget))
df_gate = _subsample_for_gate(df_pool, gate_rng, max_pairs=pair_budget)
df_train, df_test, _ = split_temporally(df_gate)
df_train, df_test = remap_node_ids(df_train, df_test)
train_examples, train_pair_rows, train_pair_counts = build_matched_control_tables(df_train)
test_examples, test_pair_rows, test_pair_counts = build_matched_control_tables(df_test)
df_train_eval = augment_with_placeholder_nodes(df_train, df_test)
df_full = (
pd.concat([df_train_eval, df_test], ignore_index=True)
.sort_values("timestamp")
.reset_index(drop=True)
)
eval_nodes = get_eval_nodes(df_full)
volume = summarize_gate_volume(test_examples, test_pair_rows, eval_nodes)
candidate = {
"pair_budget": int(pair_budget) if pair_budget is not None else total_pairs,
"df_gate": df_gate,
"df_train": df_train,
"df_test": df_test,
"df_train_eval": df_train_eval,
"df_full": df_full,
"eval_nodes": eval_nodes,
"train_examples": train_examples,
"train_pair_rows": train_pair_rows,
"train_pair_counts": train_pair_counts,
"test_examples": test_examples,
"test_pair_rows": test_pair_rows,
"test_pair_counts": test_pair_counts,
"volume": volume,
}
best = candidate
if gate_volume_is_sufficient(volume, fast_mode):
return candidate
assert best is not None
return best
def ensure_gate_volume(
df_pool: pd.DataFrame,
config,
seed: int,
benchmark_mode: str,
fast_mode: bool,
) -> dict:
pool = df_pool.copy()
gate = prepare_gate_subset(pool, seed=seed, fast_mode=fast_mode)
pack_count = 1
while (not gate_volume_is_sufficient(gate["volume"], fast_mode)) and pack_count <= _GATE_MAX_EXTRA_PACKS:
extra_seed = seed + pack_count * 10_007
extra_easy, extra_medium, extra_hard = generate_all(
config,
seed=extra_seed,
benchmark_mode=benchmark_mode,
)
extra_pack = build_gate_pool_from_frames([
offset_gate_namespace(extra_easy, pack_count),
offset_gate_namespace(extra_medium, pack_count),
offset_gate_namespace(extra_hard, pack_count),
])
pool = build_gate_pool_from_frames([pool, extra_pack])
gate = prepare_gate_subset(pool, seed=seed, fast_mode=fast_mode)
pack_count += 1
gate["source_pool_events"] = int(len(pool))
gate["source_pool_pairs"] = int(pool.loc[pool["twin_pair_id"] >= 0, "twin_pair_id"].nunique()) if "twin_pair_id" in pool.columns else 0
gate["source_pool_packs"] = int(pack_count)
return gate
def ensure_gate_volume_for_difficulty(
config,
difficulty: str,
seed: int,
benchmark_mode: str,
fast_mode: bool,
initial_pool: pd.DataFrame | None = None,
) -> dict:
"""Build a reliable-volume gate pool using repeated packs of a single difficulty."""
if initial_pool is None:
pool = generate_single_difficulty(
config,
difficulty=difficulty,
seed=seed,
benchmark_mode=benchmark_mode,
)
else:
pool = initial_pool.copy()
gate = prepare_gate_subset(pool, seed=seed, fast_mode=fast_mode)
pack_count = 1
while (not gate_volume_is_sufficient(gate["volume"], fast_mode)) and pack_count <= _GATE_MAX_EXTRA_PACKS:
extra_seed = seed + pack_count * 10_007
extra_pack = generate_single_difficulty(
config,
difficulty=difficulty,
seed=extra_seed,
benchmark_mode=benchmark_mode,
)
extra_pack = offset_gate_namespace(extra_pack, pack_count)
pool = build_gate_pool_from_frames([pool, extra_pack])
gate = prepare_gate_subset(pool, seed=seed, fast_mode=fast_mode)
pack_count += 1
gate["source_pool_events"] = int(len(pool))
gate["source_pool_pairs"] = int(pool.loc[pool["twin_pair_id"] >= 0, "twin_pair_id"].nunique()) if "twin_pair_id" in pool.columns else 0
gate["source_pool_packs"] = int(pack_count)
return gate
# ---------------------------------------------------------------------------
# Motif Validity Check (req #11)
# ---------------------------------------------------------------------------
def run_motif_validity_check(
df: pd.DataFrame,
config,
seed: int,
device: str,
num_epochs: int,
node_epochs: int,
n_checkpoints: int,
hard_abort: bool = True,
horizon: float = 0.10,
benchmark_mode: str = "temporal_twins_oracle_calib",
fast_mode: bool = False,
force_temporal_models: bool = False,
prebuilt_gate: dict | None = None,
) -> tuple[bool, dict]:
"""Run the staged MOTIF VALIDITY CHECK gate.
Stage 1 — AuditOracle: reads audit cols directly. >= 0.99 required.
Stage 2 — RawMotifOracle: reconstructs motif. >= 0.95 required.
Stage 3 — Static ceilings: XGB <= 0.65, StaticGNN <= 0.70.
Stage 4 — SeqGRU: >= 0.80 (calib mode only).
Oracles are evaluated against twin_label (NOT window label) to avoid
the target-alignment bug where late windows have no upcoming fraud events.
"""
calib_mode = _is_oracle_calib_mode(benchmark_mode)
metric_labels = _oracle_metric_labels(benchmark_mode)
# Dataset-wide stats computed on the FULL df before subsampling
consistency = compute_motif_label_consistency(df, calib_mode=calib_mode)
delay_stats = compute_label_delay_stats(df)
node_audit = build_node_audit_table(df)
ks_mean, ks_max = compute_aggregate_ks(node_audit)
gate = prebuilt_gate
if gate is None:
gate = ensure_gate_volume(
df_pool=df,
config=config,
seed=seed,
benchmark_mode=benchmark_mode,
fast_mode=fast_mode,
)
df_gate = gate["df_gate"]
df_train = gate["df_train"]
df_test = gate["df_test"]
df_train_eval = gate["df_train_eval"]
df_full = gate["df_full"]
eval_nodes = gate["eval_nodes"]
train_examples = gate["train_examples"]
train_pair_rows = gate["train_pair_rows"]
train_pair_counts = gate["train_pair_counts"]
test_examples = gate["test_examples"]
test_pair_rows = gate["test_pair_rows"]
test_pair_counts = gate["test_pair_counts"]
gate_volume = gate["volume"]
print(
f" [gate] Using {df_gate['twin_pair_id'].nunique()} pairs "
f"({len(df_gate):,} events) for model stages from "
f"{gate['source_pool_packs']} pack(s), source pairs={gate['source_pool_pairs']:,}."
)
leakage = compute_split_leakage(df_train, df_test)
matched_train_features = build_matched_prefix_feature_table(df_train, train_examples)
matched_test_features = build_matched_prefix_feature_table(df_test, test_examples)
matched_audit = report_matched_control_audits(
test_examples=test_examples,
test_pair_rows=test_pair_rows,
test_pair_counts=test_pair_counts,
)
static_agg_result = compute_matched_static_aggregate_auc(
matched_train_features,
matched_test_features,
seed=seed,
verbose=False,
)
delta_time = horizon_to_delta(df_test, horizon)
y_twin = _twin_labels_for_nodes(df_full, eval_nodes)
report: dict = {
"ks_mean": ks_mean, "ks_max": ks_max,
"static_agg_auc": float(static_agg_result["auc"]),
**delay_stats,
**{k: v for k, v in consistency.items()},
**matched_audit,
**gate_volume,
"gate_pair_budget": int(gate["pair_budget"]),
"gate_source_pool_events": int(gate["source_pool_events"]),
"gate_source_pool_pairs": int(gate["source_pool_pairs"]),
"gate_source_pool_packs": int(gate["source_pool_packs"]),
}
attach_auc_result(report, "static_agg", static_agg_result)
oracle_scores: dict[str, np.ndarray] = {}
# Stage 1 — AuditOracle / MotifProbe (no training; reads motif_hit_count directly)
audit_oracle = AuditOracleWrapper()
audit_probs = audit_oracle.predict(df_full, eval_nodes)
oracle_scores[metric_labels["audit"]] = audit_probs
audit_result = make_auc_result(y_twin, audit_probs.astype(np.float32), seed=seed)
attach_auc_result(report, "audit", audit_result)
report["audit_pair_sep"] = evaluate_oracle_pair_sep_node_level(
audit_oracle, df_full, eval_nodes
)
# Stage 2 — RawMotifOracle / RawMotifProbe (trained on twin_label, not window label)
raw_oracle = RawMotifOracleWrapper()
raw_oracle.fit(df_train_eval, num_epochs=num_epochs)
train_nodes_raw = get_eval_nodes(df_train_eval)
y_train_twin_raw = _twin_labels_for_nodes(df_train_eval, train_nodes_raw)
train_node_head(
raw_oracle,
df_anchor_prefix=df_train_eval,
eval_nodes=train_nodes_raw,
y_labels=y_train_twin_raw,
num_epochs=node_epochs,
)
raw_probs = raw_oracle.predict(df_full, eval_nodes)
oracle_scores[metric_labels["raw"]] = raw_probs
raw_result = make_auc_result(y_twin, raw_probs.astype(np.float32), seed=seed)
attach_auc_result(report, "raw", raw_result)
report["raw_pair_sep"] = evaluate_oracle_pair_sep_node_level(
raw_oracle, df_full, eval_nodes
)
_attach_probe_aliases(report, benchmark_mode)
# Oracle/probe debug table
build_oracle_debug_table(
df_full,
eval_nodes,
oracle_scores,
y_twin,
primary_score_name=metric_labels["audit"],
table_title=metric_labels["table"],
)
# Stage 3 — Static baselines (window-label eval, as in main benchmark)
xgb_result = compute_matched_xgboost_auc(
matched_train_features,
matched_test_features,
seed=seed,
)
attach_auc_result(report, "xgb", xgb_result)
static_gnn_result = compute_matched_static_gnn_auc(
df_train=df_train,
df_test=df_test,
train_examples=train_examples,
test_examples=test_examples,
device=device,
num_epochs=num_epochs,
seed=seed,
)
attach_auc_result(report, "static_gnn", static_gnn_result)
report["xgb_roc_auc"] = float(xgb_result["auc"])
report["static_gnn_roc"] = float(static_gnn_result["auc"])
report["static_gnn_auc_flipped"] = float(static_gnn_result.get("auc_flipped", float("nan")))
report["static_gnn_score_mean_pos"] = float(static_gnn_result.get("score_mean_pos", float("nan")))
report["static_gnn_score_mean_neg"] = float(static_gnn_result.get("score_mean_neg", float("nan")))
report["static_gnn_score_std"] = float(static_gnn_result.get("score_std", float("nan")))
report["static_gnn_zero_emb_frac"] = float(static_gnn_result.get("zero_emb_frac", float("nan")))
report["static_gnn_matched_examples"] = int(static_gnn_result.get("matched_examples", 0))
report["static_gnn_unique_prefix_cutoffs"] = int(static_gnn_result.get("unique_prefix_cutoffs", 0))
report["static_gnn_graph_builds"] = int(static_gnn_result.get("graph_builds", 0))
report["static_gnn_cache_hit_rate"] = float(static_gnn_result.get("cache_hit_rate", float("nan")))
report["static_gnn_eval_time_sec"] = float(static_gnn_result.get("eval_time_sec", float("nan")))
report["static_gnn_train_unique_prefix_cutoffs"] = int(static_gnn_result.get("train_unique_prefix_cutoffs", 0))
report["static_gnn_test_unique_prefix_cutoffs"] = int(static_gnn_result.get("test_unique_prefix_cutoffs", 0))
report["static_gnn_train_graph_builds"] = int(static_gnn_result.get("train_graph_builds", 0))
report["static_gnn_test_graph_builds"] = int(static_gnn_result.get("test_graph_builds", 0))
report["static_gnn_train_eval_time_sec"] = float(static_gnn_result.get("train_eval_time_sec", float("nan")))
report["static_gnn_test_eval_time_sec"] = float(static_gnn_result.get("test_eval_time_sec", float("nan")))
# Stage 4 — SeqGRU (calib mode only)
run_temporal_models = calib_mode or force_temporal_models
if run_temporal_models:
seqgru_result = compute_matched_seqgru_metrics(
df_train=df_train,
df_test=df_test,
train_examples=train_examples,
test_examples=test_examples,
device=device,
seed=seed,
max_epochs=max(24, min(72, node_epochs)),
)
seqgru_clean = seqgru_result["clean"]
seqgru_shuffled = seqgru_result["shuffled"]
report["seqgru_roc_auc"] = float(seqgru_clean["auc"])
report["seqgru_pr_auc"] = float(seqgru_clean.get("pr_auc", float("nan")))
report["seqgru_brier"] = float(seqgru_clean.get("brier", float("nan")))
report["seqgru_ece"] = float(seqgru_clean.get("ece", float("nan")))
report["seqgru_n_examples"] = int(seqgru_clean.get("n_examples", 0))
report["seqgru_shuffle_roc_auc"] = float(seqgru_shuffled["auc"])
report["seqgru_shuffle_pr_auc"] = float(seqgru_shuffled.get("pr_auc", float("nan")))
report["seqgru_shuffle_delta"] = float(seqgru_result["delta"])
report["seqgru_best_epoch"] = int(seqgru_result["clean_fit"].get("best_epoch", 0))
report["seqgru_best_valid_roc_auc"] = float(seqgru_result["clean_fit"].get("best_valid_roc_auc", float("nan")))
report["seqgru_best_valid_pr_auc"] = float(seqgru_result["clean_fit"].get("best_valid_pr_auc", float("nan")))
report["seqgru_shuffle_best_epoch"] = int(seqgru_result["shuffled_fit"].get("best_epoch", 0))
report["seqgru_shuffle_best_valid_roc_auc"] = float(seqgru_result["shuffled_fit"].get("best_valid_roc_auc", float("nan")))
temporal_gnn_specs = [
("TGN", "tgn", lambda: TGNWrapper(device=device)),
("TGAT", "tgat", lambda: TGATWrapper(device=device)),
("DyRep", "dyrep", lambda: DyRepWrapper(device=device)),
("JODIE", "jodie", lambda: JODIEWrapper(device=device)),
]
temporal_num_epochs = max(2, num_epochs)
for model_label, key_prefix, builder in temporal_gnn_specs:
temporal_result = compute_matched_temporal_gnn_metrics(
model_name=model_label,
model_builder=builder,
df_train=df_train,
df_test=df_test,
train_examples=train_examples,
test_examples=test_examples,
seed=seed,
num_epochs=temporal_num_epochs,
)
clean_result = temporal_result["clean"]
shuffled_result = temporal_result["shuffled"]
report[f"{key_prefix}_roc_auc"] = float(clean_result["auc"])
report[f"{key_prefix}_pr_auc"] = float(clean_result.get("pr_auc", float("nan")))
report[f"{key_prefix}_n_examples"] = int(clean_result.get("n_examples", 0))
report[f"{key_prefix}_shuffle_roc_auc"] = float(shuffled_result["auc"])
report[f"{key_prefix}_shuffle_pr_auc"] = float(shuffled_result.get("pr_auc", float("nan")))
report[f"{key_prefix}_shuffle_delta"] = float(temporal_result["delta"])
else:
report["seqgru_roc_auc"] = float("nan")
report["seqgru_pr_auc"] = float("nan")
report["seqgru_n_examples"] = 0
report["seqgru_shuffle_roc_auc"] = float("nan")
report["seqgru_shuffle_pr_auc"] = float("nan")
report["seqgru_shuffle_delta"] = float("nan")
report["seqgru_best_epoch"] = 0
report["seqgru_best_valid_roc_auc"] = float("nan")
report["seqgru_best_valid_pr_auc"] = float("nan")
report["seqgru_shuffle_best_epoch"] = 0
report["seqgru_shuffle_best_valid_roc_auc"] = float("nan")
for key_prefix in ("tgn", "tgat", "dyrep", "jodie"):
report[f"{key_prefix}_roc_auc"] = float("nan")
report[f"{key_prefix}_pr_auc"] = float("nan")
report[f"{key_prefix}_n_examples"] = 0
report[f"{key_prefix}_shuffle_roc_auc"] = float("nan")
report[f"{key_prefix}_shuffle_pr_auc"] = float("nan")
report[f"{key_prefix}_shuffle_delta"] = float("nan")
# Gate table
gate_items = [
(f"{metric_labels['audit']} ROC-AUC", "audit_roc_auc", "ge", 0.99, "label-alignment bug"),
(f"{metric_labels['audit']} pair-sep", "audit_pair_sep", "ge", 0.99, "pair construction bug"),
(f"{metric_labels['raw']} ROC-AUC", "raw_roc_auc", "ge", 0.95, "motif reconstruction bug"),
(f"{metric_labels['raw']} pair-sep", "raw_pair_sep", "ge", 0.90, "motif reconstruction bug"),
("static_agg_auc", "static_agg_auc", "le", 0.60, "static leakage"),
("XGBoost ROC-AUC", "xgb_roc_auc", "le", 0.65, "static leakage"),
("StaticGNN ROC-AUC", "static_gnn_roc", "le", 0.70, "static leakage"),
]
if run_temporal_models:
gate_items.extend([
("SeqGRU ROC-AUC", "seqgru_roc_auc", "ge", 0.80, "learning/input bug"),
("SeqGRU shuffle delta", "seqgru_shuffle_delta", "le", -0.10, "order signal missing"),
])
temporal_gnn_items = [
("TGN ROC-AUC", "tgn_roc_auc", "ge", 0.75, "temporal learnability"),
("TGN shuffle delta", "tgn_shuffle_delta", "le", -0.10, "order signal missing"),
("TGAT ROC-AUC", "tgat_roc_auc", "ge", 0.75, "temporal learnability"),
("TGAT shuffle delta", "tgat_shuffle_delta", "le", -0.10, "order signal missing"),
("DyRep ROC-AUC", "dyrep_roc_auc", "ge", 0.75, "temporal learnability"),
("DyRep shuffle delta", "dyrep_shuffle_delta", "le", -0.10, "order signal missing"),
("JODIE ROC-AUC", "jodie_roc_auc", "ge", 0.75, "temporal learnability"),
("JODIE shuffle delta", "jodie_shuffle_delta", "le", -0.10, "order signal missing"),
]
print("\n" + "=" * 72)
print(" MOTIF VALIDITY CHECK")
print("=" * 72)
print(" Gate Volume")
print(f" matched_eval_pairs : {report['matched_eval_pairs']}")
print(f" positives / negatives : {report['positives']} / {report['negatives']}")
print(f" unique fraud / benign : {report['unique_fraud_users']} / {report['unique_benign_users']}")
print(f" unique templates : {report['unique_templates']}")
print(f" positive rate : {report['positive_rate']:.4f}")
print(
" model examples : "
f"{metric_labels['audit']}={report['audit_n_examples']} "
f"{metric_labels['raw']}={report['raw_n_examples']} "
f"XGB={report['xgb_n_examples']} StaticGNN={report['static_gnn_n_examples']} "
f"SeqGRU={report['seqgru_n_examples']} TGN={report['tgn_n_examples']} "
f"TGAT={report['tgat_n_examples']} DyRep={report['dyrep_n_examples']} "
f"JODIE={report['jodie_n_examples']}"
)
print(f" gate source packs/pairs : {report['gate_source_pool_packs']} / {report['gate_source_pool_pairs']}")
print(f" gate pair budget : {report['gate_pair_budget']}")
volume_violations = gate_volume_violations(report, fast_mode)
if volume_violations:
print(" INSUFFICIENT_GATE_VOLUME")
for violation in volume_violations:
print(f" - {violation}")
all_pass = True
if volume_violations:
all_pass = False
for label, key, op, thresh, hint in gate_items:
val = report.get(key, float("nan"))
is_nan = val != val
ok = (not is_nan) and ((val >= thresh) if op == "ge" else (val <= thresh))
status = "N/A " if is_nan else ("PASS" if ok else "FAIL")
if not ok:
all_pass = False
tstr = f"{'>='+str(thresh) if op=='ge' else '<='+str(thresh)}"
print(f" {label:<28}: {val:>7.4f} [{status} {tstr}] {'<-- '+hint if not ok else ''}")
for label, key, op, thresh, hint in temporal_gnn_items:
val = report.get(key, float("nan"))
is_nan = val != val
ok = (not is_nan) and ((val >= thresh) if op == "ge" else (val <= thresh))
status = "N/A " if is_nan else ("PASS" if ok else "FAIL")
tstr = f"{'>='+str(thresh) if op=='ge' else '<='+str(thresh)}"
suffix = "" if ok else f" [advisory: {hint}]"
print(f" {label:<28}: {val:>7.4f} [{status} {tstr}]{suffix}")
audit_ci_label = f"{metric_labels['audit']} AUC std/CI"
raw_ci_label = f"{metric_labels['raw']} std/CI"
print(f" {audit_ci_label:<28}: {report['audit_auc_bootstrap_std']:.4f} [{report['audit_auc_ci_lo']:.4f}, {report['audit_auc_ci_hi']:.4f}]")
print(f" {raw_ci_label:<28}: {report['raw_auc_bootstrap_std']:.4f} [{report['raw_auc_ci_lo']:.4f}, {report['raw_auc_ci_hi']:.4f}]")
print(f" {'XGBoost AUC std/CI':<28}: {report['xgb_auc_bootstrap_std']:.4f} [{report['xgb_auc_ci_lo']:.4f}, {report['xgb_auc_ci_hi']:.4f}]")
print(f" {'StaticGNN AUC std/CI':<28}: {report['static_gnn_auc_bootstrap_std']:.4f} [{report['static_gnn_auc_ci_lo']:.4f}, {report['static_gnn_auc_ci_hi']:.4f}]")
print(f" {'static_agg_auc std/CI':<28}: {report['static_agg_auc_bootstrap_std']:.4f} [{report['static_agg_auc_ci_lo']:.4f}, {report['static_agg_auc_ci_hi']:.4f}]")
print(f" {'StaticGNN flip check':<28}: auc={report['static_gnn_roc']:.4f} flipped={report['static_gnn_auc_flipped']:.4f} zero_emb={report['static_gnn_zero_emb_frac']:.4f}")
print(f" {'StaticGNN score means':<28}: pos={report['static_gnn_score_mean_pos']:.4f} neg={report['static_gnn_score_mean_neg']:.4f} std={report['static_gnn_score_std']:.4f}")
print(
f" {'StaticGNN runtime':<28}: "
f"examples={report['static_gnn_matched_examples']} "
f"cutoffs={report['static_gnn_unique_prefix_cutoffs']} "
f"builds={report['static_gnn_graph_builds']} "
f"hit_rate={report['static_gnn_cache_hit_rate']:.4f} "
f"time={report['static_gnn_eval_time_sec']:.2f}s"
)
print(
f" {'StaticGNN train/test rt':<28}: "
f"train_cutoffs={report['static_gnn_train_unique_prefix_cutoffs']} "
f"test_cutoffs={report['static_gnn_test_unique_prefix_cutoffs']} "
f"train_builds={report['static_gnn_train_graph_builds']} "
f"test_builds={report['static_gnn_test_graph_builds']} "
f"train={report['static_gnn_train_eval_time_sec']:.2f}s "
f"test={report['static_gnn_test_eval_time_sec']:.2f}s"
)
print(f" {'SeqGRU PR-AUC':<28}: {report['seqgru_pr_auc']:.4f} [informational]")
print(f" {'SeqGRU shuffled ROC-AUC':<28}: {report['seqgru_shuffle_roc_auc']:.4f} [informational]")
print(f" {'SeqGRU shuffled PR-AUC':<28}: {report['seqgru_shuffle_pr_auc']:.4f} [informational]")
print(f" {'SeqGRU early stop':<28}: epoch={report['seqgru_best_epoch']} valid_roc={report['seqgru_best_valid_roc_auc']:.4f} valid_pr={report['seqgru_best_valid_pr_auc']:.4f}")
print(f" {'SeqGRU shuffled stop':<28}: epoch={report['seqgru_shuffle_best_epoch']} valid_roc={report['seqgru_shuffle_best_valid_roc_auc']:.4f}")
print(f" {'TGN PR/shuffled ROC':<28}: pr={report['tgn_pr_auc']:.4f} shuffled={report['tgn_shuffle_roc_auc']:.4f}")
print(f" {'TGAT PR/shuffled ROC':<28}: pr={report['tgat_pr_auc']:.4f} shuffled={report['tgat_shuffle_roc_auc']:.4f}")
print(f" {'DyRep PR/shuffled ROC':<28}: pr={report['dyrep_pr_auc']:.4f} shuffled={report['dyrep_shuffle_roc_auc']:.4f}")
print(f" {'JODIE PR/shuffled ROC':<28}: pr={report['jodie_pr_auc']:.4f} shuffled={report['jodie_shuffle_roc_auc']:.4f}")
print(f" {'P(label|hit>=1)':<28}: {report.get('p_label_given_hit', float('nan')):>7.4f} [informational]")
print(f" {'P(label|hit=0)':<28}: {report.get('p_label_given_nohit', float('nan')):>7.4f} [informational]")
print(f" {'accidental_benign_motif':<28}: {report.get('accidental_benign_motif', float('nan')):>7.4f} [informational]")
print(f" {'KS mean/max':<28}: {ks_mean:>7.4f} / {ks_max:.4f}")
print(f" {'delay min/mean/max':<28}: "
f"{report.get('delay_min',float('nan')):.1f} / "
f"{report.get('delay_mean',float('nan')):.1f} / "
f"{report.get('delay_max',float('nan')):.1f}")
print("=" * 72)
if all_pass:
print(" [GATE] All thresholds met. Proceeding to full run.")
else:
msg = "[GATE] One or more thresholds FAILED."
if hard_abort:
print(f" {msg} Aborting (hard gate).")
sys.exit(1)
else:
print(f" {msg} Continuing as soft diagnostic.")
return all_pass, report
# ---------------------------------------------------------------------------
# Model factory
# ---------------------------------------------------------------------------
def build_models(device: str = "cpu") -> Dict[str, TemporalModel]:
return {
"OracleMotif": OracleMotifWrapper(),
"SeqGRU": SequenceGRUWrapper(hidden_dim=64, receiver_buckets=256, device=device),
"TGN": TGNWrapper(memory_dim=64, time_dim=16, device=device),
"TGAT": TGATWrapper(memory_dim=64, time_dim=8, num_heads=4, n_neighbors=10, device=device),
"DyRep": DyRepWrapper(memory_dim=64, time_dim=8, device=device),
"JODIE": JODIEWrapper(memory_dim=64, time_emb_dim=16, device=device),
"StaticGNN": StaticGNNWrapper(hidden_dim=64, n_snapshots=10, device=device),
"XGBoost": XGBoostWrapper(n_estimators=200, max_depth=6),
}
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def parse_seed_list(seed_string: str) -> List[int]:
return [int(token.strip()) for token in seed_string.split(",") if token.strip()]
def parse_args():
parser = argparse.ArgumentParser(description="Leakage-free UPI-Sim benchmark runner")
parser.add_argument("--fast", action="store_true", help="Fast mode: 1 epoch and fewer checkpoints.")
parser.add_argument("--seed", type=int, default=None, help="Run a single seed.")
parser.add_argument(
"--seeds",
nargs="+",
type=int,
default=None,
help="Space-separated seed list, e.g. --seeds 0 1 2 3 4",
)
parser.add_argument(
"--config",
type=str,
default="config/default.yaml",
help="Path to config YAML.",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
help='Torch device ("cpu" or "cuda").',
)
parser.add_argument(
"--benchmark-mode",
type=str,
default=None,
help='Benchmark mode override, e.g. "standard" or "temporal_twins".',
)
parser.add_argument(
"--experiments",
nargs="+",
type=str,
default=None,
help="Space-separated list of experiments to run, e.g. --experiments ood causal horizon mechanistic audit",
)
return parser.parse_args()
def main():
args = parse_args()
# Support both space-separated (nargs=+) and comma-separated experiment lists
if args.experiments is None:
experiments_to_run = {"ood", "causal", "horizon", "mechanistic", "audit"}
elif isinstance(args.experiments, list):
experiments_to_run = set(args.experiments)
else:
experiments_to_run = {exp.strip() for exp in args.experiments.split(",") if exp.strip()}
num_epochs = 1 if args.fast else 3
node_epochs = 60 if args.fast else 150
n_checkpoints = 4 if args.fast else 8
if args.seed is not None:
seeds = [args.seed]
elif args.seeds is not None:
# Already parsed as List[int] via nargs="+"
seeds = args.seeds
else:
seeds = [0, 1, 2, 3, 4]
config = load_config(args.config)
benchmark_mode = args.benchmark_mode or getattr(config, "benchmark_mode", "standard")
print("=" * 60)
print(" UPI-Sim Multi-Model Benchmark (Leakage-Free)")
print(f" epochs={num_epochs} node_epochs={node_epochs} checkpoints={n_checkpoints}")
print(f" seeds={seeds} device={args.device} mode={benchmark_mode}")
print("=" * 60)
raw_frames: Dict[str, List[pd.DataFrame]] = {
"ood": [],
"causal": [],
"horizon": [],
"mechanistic": [],
"audit": [],
}
import torch
is_twin_mode = benchmark_mode in ("temporal_twins", "temporal_twins_oracle_calib")
calib_mode = benchmark_mode == "temporal_twins_oracle_calib"
for seed in seeds:
set_global_determinism(seed)
print(f"\n[data] Generating datasets for seed={seed}...")
df_easy, df_medium, df_hard = generate_all(
config,
seed=seed,
benchmark_mode=benchmark_mode,
)
print(f" Easy : {len(df_easy):,} events | fraud={df_easy['is_fraud'].mean():.3f}")
print(f" Medium: {len(df_medium):,} events | fraud={df_medium['is_fraud'].mean():.3f}")
print(f" Hard : {len(df_hard):,} events | fraud={df_hard['is_fraud'].mean():.3f}")
if is_twin_mode:
# Run validity check: hard-abort in calib mode, soft diagnostic otherwise.
gate_df = build_gate_pool_from_frames([df_easy, df_medium, df_hard])
run_motif_validity_check(
df=gate_df,
config=config,
seed=seed,
device=args.device,
num_epochs=num_epochs,
node_epochs=node_epochs,
n_checkpoints=n_checkpoints,
hard_abort=calib_mode,
benchmark_mode=benchmark_mode,
fast_mode=args.fast,
)
if "ood" in experiments_to_run:
print(f"\n[seed={seed}] OOD generalisation")
df_ood = run_ood_single(
df_easy=df_easy,
df_medium=df_medium,
df_hard=df_hard,
device=args.device,
num_epochs=num_epochs,
node_epochs=node_epochs,
n_checkpoints=n_checkpoints,
)
df_ood["seed"] = seed
raw_frames["ood"].append(df_ood)
if "causal" in experiments_to_run:
print(f"\n[seed={seed}] Causal chronology shuffle")
df_causal = run_causal_single(
df_hard=df_hard,
device=args.device,
num_epochs=num_epochs,
node_epochs=node_epochs,
n_checkpoints=n_checkpoints,
seed=seed,
)
df_causal["seed"] = seed
raw_frames["causal"].append(df_causal)
if "horizon" in experiments_to_run:
print(f"\n[seed={seed}] Horizon sweep")
df_horizon = run_horizon_single(
df_medium=df_medium,
device=args.device,
num_epochs=num_epochs,
node_epochs=node_epochs,
n_checkpoints=n_checkpoints,
horizons=DEFAULT_HORIZONS,
)
df_horizon["seed"] = seed
raw_frames["horizon"].append(df_horizon)
if "mechanistic" in experiments_to_run:
print(f"\n[seed={seed}] Mechanistic correlation")
df_mech = run_mechanistic_single(
df_hard=df_hard,
device=args.device,
num_epochs=num_epochs,
node_epochs=node_epochs,
n_checkpoints=n_checkpoints,
)
df_mech["seed"] = seed
raw_frames["mechanistic"].append(df_mech)
if "audit" in experiments_to_run:
print(f"\n[seed={seed}] Temporal twins audit")
df_audit = run_audit_single(
df_hard=df_hard,
device=args.device,
num_epochs=num_epochs,
node_epochs=node_epochs,
n_checkpoints=n_checkpoints,
seed=seed,
benchmark_mode=benchmark_mode,
)
df_audit["seed"] = seed
raw_frames["audit"].append(df_audit)
save_experiment_outputs(raw_frames, results_dir="results")
print("\n" + "=" * 60)
print(" All requested experiments completed.")
print(" Saved raw + summary CSVs in results/")
print(" Run: python -m plots.plot_results")
print("=" * 60)
if __name__ == "__main__":
main()