temporal-twins-code / models /audit_oracle.py
temporal-twins-anon's picture
Add anonymous Temporal Twins code release
a3682cf verified
"""
models/audit_oracle.py
======================
Two oracle baselines for motif validity checking:
AuditOracleWrapper
Reads audit columns (motif_hit_count, label_delay, etc.) directly.
Requires NO learning. In calib mode this should achieve ROC-AUC ~1.0.
If AuditOracle fails → evaluation / label-alignment is broken.
RawMotifOracleWrapper
Alias of OracleMotifWrapper with an explicit name so the gate can
distinguish it. Reconstructs the motif from raw timestamps+receivers.
If AuditOracle passes but RawMotifOracle fails → motif reconstruction broken.
"""
from __future__ import annotations
from typing import List
import numpy as np
import pandas as pd
from models.base import TemporalModel
from models.oracle_motif import OracleMotifWrapper
# ---------------------------------------------------------------------------
# AuditOracle
# ---------------------------------------------------------------------------
class AuditOracleWrapper(TemporalModel):
"""Direct-read oracle: scores users by their stored motif_hit_count.
Allowed to read ALL oracle/audit columns. Requires no training.
In calib mode every fraud twin has motif_hit_count >= 1 and every
benign twin has motif_hit_count == 0, so this oracle should be
near-perfect.
"""
@property
def name(self) -> str:
return "AuditOracle"
@property
def is_temporal(self) -> bool:
return False # no memory; pure lookup
def fit(self, df_train: pd.DataFrame, num_epochs: int = 3) -> None:
pass # no training needed
def train_node_classifier_on_prefix(
self,
df_prefix: pd.DataFrame,
eval_nodes: List[int],
y_labels: np.ndarray,
num_epochs: int = 150,
) -> None:
pass # no training needed
def predict(self, df_eval: pd.DataFrame, eval_nodes: List[int]) -> np.ndarray:
"""Score = normalised motif_hit_count per user.
Falls back to label_delay-based score if motif_hit_count is absent.
"""
scores = np.zeros(len(eval_nodes), dtype=np.float32)
if "motif_hit_count" in df_eval.columns:
grp = df_eval.groupby("sender_id")["motif_hit_count"].max()
raw = np.array([float(grp.get(n, 0.0)) for n in eval_nodes], dtype=np.float32)
max_val = raw.max()
scores = raw / max_val if max_val > 0.0 else raw
elif "label_delay" in df_eval.columns:
# Fallback: any user with a valid delay entry is a fraud twin
pos_nodes = set(
df_eval.loc[
(df_eval["is_fraud"] == 1) & (df_eval["label_delay"] >= 0),
"sender_id",
].unique().tolist()
)
scores = np.array(
[1.0 if n in pos_nodes else 0.0 for n in eval_nodes],
dtype=np.float32,
)
return scores
def reset_memory(self) -> None:
pass
# ---------------------------------------------------------------------------
# RawMotifOracle (= OracleMotifWrapper with a distinct name for the gate)
# ---------------------------------------------------------------------------
class RawMotifOracleWrapper(OracleMotifWrapper):
"""Reconstructs motif from raw timestamps+receivers (no audit columns).
This is identical to OracleMotifWrapper but carries a distinct .name so
the validity-check gate can log and gate it separately.
"""
@property
def name(self) -> str:
return "RawMotifOracle"