""" models/audit_oracle.py ====================== Two oracle baselines for motif validity checking: AuditOracleWrapper Reads audit columns (motif_hit_count, label_delay, etc.) directly. Requires NO learning. In calib mode this should achieve ROC-AUC ~1.0. If AuditOracle fails → evaluation / label-alignment is broken. RawMotifOracleWrapper Alias of OracleMotifWrapper with an explicit name so the gate can distinguish it. Reconstructs the motif from raw timestamps+receivers. If AuditOracle passes but RawMotifOracle fails → motif reconstruction broken. """ from __future__ import annotations from typing import List import numpy as np import pandas as pd from models.base import TemporalModel from models.oracle_motif import OracleMotifWrapper # --------------------------------------------------------------------------- # AuditOracle # --------------------------------------------------------------------------- class AuditOracleWrapper(TemporalModel): """Direct-read oracle: scores users by their stored motif_hit_count. Allowed to read ALL oracle/audit columns. Requires no training. In calib mode every fraud twin has motif_hit_count >= 1 and every benign twin has motif_hit_count == 0, so this oracle should be near-perfect. """ @property def name(self) -> str: return "AuditOracle" @property def is_temporal(self) -> bool: return False # no memory; pure lookup def fit(self, df_train: pd.DataFrame, num_epochs: int = 3) -> None: pass # no training needed def train_node_classifier_on_prefix( self, df_prefix: pd.DataFrame, eval_nodes: List[int], y_labels: np.ndarray, num_epochs: int = 150, ) -> None: pass # no training needed def predict(self, df_eval: pd.DataFrame, eval_nodes: List[int]) -> np.ndarray: """Score = normalised motif_hit_count per user. Falls back to label_delay-based score if motif_hit_count is absent. """ scores = np.zeros(len(eval_nodes), dtype=np.float32) if "motif_hit_count" in df_eval.columns: grp = df_eval.groupby("sender_id")["motif_hit_count"].max() raw = np.array([float(grp.get(n, 0.0)) for n in eval_nodes], dtype=np.float32) max_val = raw.max() scores = raw / max_val if max_val > 0.0 else raw elif "label_delay" in df_eval.columns: # Fallback: any user with a valid delay entry is a fraud twin pos_nodes = set( df_eval.loc[ (df_eval["is_fraud"] == 1) & (df_eval["label_delay"] >= 0), "sender_id", ].unique().tolist() ) scores = np.array( [1.0 if n in pos_nodes else 0.0 for n in eval_nodes], dtype=np.float32, ) return scores def reset_memory(self) -> None: pass # --------------------------------------------------------------------------- # RawMotifOracle (= OracleMotifWrapper with a distinct name for the gate) # --------------------------------------------------------------------------- class RawMotifOracleWrapper(OracleMotifWrapper): """Reconstructs motif from raw timestamps+receivers (no audit columns). This is identical to OracleMotifWrapper but carries a distinct .name so the validity-check gate can log and gate it separately. """ @property def name(self) -> str: return "RawMotifOracle"