| """ |
| models/audit_oracle.py |
| ====================== |
| Two oracle baselines for motif validity checking: |
| |
| AuditOracleWrapper |
| Reads audit columns (motif_hit_count, label_delay, etc.) directly. |
| Requires NO learning. In calib mode this should achieve ROC-AUC ~1.0. |
| If AuditOracle fails → evaluation / label-alignment is broken. |
| |
| RawMotifOracleWrapper |
| Alias of OracleMotifWrapper with an explicit name so the gate can |
| distinguish it. Reconstructs the motif from raw timestamps+receivers. |
| If AuditOracle passes but RawMotifOracle fails → motif reconstruction broken. |
| """ |
| from __future__ import annotations |
|
|
| from typing import List |
|
|
| import numpy as np |
| import pandas as pd |
|
|
| from models.base import TemporalModel |
| from models.oracle_motif import OracleMotifWrapper |
|
|
|
|
| |
| |
| |
|
|
| class AuditOracleWrapper(TemporalModel): |
| """Direct-read oracle: scores users by their stored motif_hit_count. |
| |
| Allowed to read ALL oracle/audit columns. Requires no training. |
| In calib mode every fraud twin has motif_hit_count >= 1 and every |
| benign twin has motif_hit_count == 0, so this oracle should be |
| near-perfect. |
| """ |
|
|
| @property |
| def name(self) -> str: |
| return "AuditOracle" |
|
|
| @property |
| def is_temporal(self) -> bool: |
| return False |
|
|
| def fit(self, df_train: pd.DataFrame, num_epochs: int = 3) -> None: |
| pass |
|
|
| def train_node_classifier_on_prefix( |
| self, |
| df_prefix: pd.DataFrame, |
| eval_nodes: List[int], |
| y_labels: np.ndarray, |
| num_epochs: int = 150, |
| ) -> None: |
| pass |
|
|
| def predict(self, df_eval: pd.DataFrame, eval_nodes: List[int]) -> np.ndarray: |
| """Score = normalised motif_hit_count per user. |
| Falls back to label_delay-based score if motif_hit_count is absent. |
| """ |
| scores = np.zeros(len(eval_nodes), dtype=np.float32) |
|
|
| if "motif_hit_count" in df_eval.columns: |
| grp = df_eval.groupby("sender_id")["motif_hit_count"].max() |
| raw = np.array([float(grp.get(n, 0.0)) for n in eval_nodes], dtype=np.float32) |
| max_val = raw.max() |
| scores = raw / max_val if max_val > 0.0 else raw |
| elif "label_delay" in df_eval.columns: |
| |
| pos_nodes = set( |
| df_eval.loc[ |
| (df_eval["is_fraud"] == 1) & (df_eval["label_delay"] >= 0), |
| "sender_id", |
| ].unique().tolist() |
| ) |
| scores = np.array( |
| [1.0 if n in pos_nodes else 0.0 for n in eval_nodes], |
| dtype=np.float32, |
| ) |
|
|
| return scores |
|
|
| def reset_memory(self) -> None: |
| pass |
|
|
|
|
| |
| |
| |
|
|
| class RawMotifOracleWrapper(OracleMotifWrapper): |
| """Reconstructs motif from raw timestamps+receivers (no audit columns). |
| |
| This is identical to OracleMotifWrapper but carries a distinct .name so |
| the validity-check gate can log and gate it separately. |
| """ |
|
|
| @property |
| def name(self) -> str: |
| return "RawMotifOracle" |
|
|