temporal-twins-code / models /oracle_motif.py
temporal-twins-anon's picture
Add anonymous Temporal Twins code release
a3682cf verified
from __future__ import annotations
from typing import List
import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression
from models.base import TemporalModel
from src.fraud.fraud_engine import temporal_twin_motif_trace
def _motif_features_for_user(user_df: pd.DataFrame) -> dict:
user_df = user_df.sort_values("timestamp").reset_index(drop=True)
n = len(user_df)
if n == 0:
return {
"chain_last": 0.0,
"chain_max": 0.0,
"motif_last": 0.0,
"motif_mean_last8": 0.0,
"source_count": 0.0,
"source_recent8": 0.0,
"source_recent16": 0.0,
"source_recent24": 0.0,
"last_source_age": 999.0,
"quiet_sum": 0.0,
"accel_sum": 0.0,
"revisit_sum": 0.0,
"burst_release_burst": 0.0,
"revisit_recent8": 0.0,
"brb_recent8": 0.0,
"txn_count": 0.0,
}
timestamps = user_df["timestamp"].to_numpy(dtype=np.float64)
receivers = user_df["receiver_id"].to_numpy(dtype=np.int64)
trace = temporal_twin_motif_trace(timestamps, receivers)
chain_vals = trace["chain"].tolist()
motif_vals = trace["motif_strength"].tolist()
source_positions = np.flatnonzero(trace["source"]).tolist()
last8 = motif_vals[-8:] if motif_vals else [0.0]
recent8_cutoff = max(0, n - 8)
recent16_cutoff = max(0, n - 16)
recent24_cutoff = max(0, n - 24)
last_source_age = float(n - 1 - source_positions[-1]) if source_positions else float(n + 1)
return {
"chain_last": float(chain_vals[-1]) if chain_vals else 0.0,
"chain_max": float(max(chain_vals)) if chain_vals else 0.0,
"motif_last": float(motif_vals[-1]) if motif_vals else 0.0,
"motif_mean_last8": float(np.mean(last8)),
"source_count": float(len(source_positions)),
"source_recent8": float(sum(pos >= recent8_cutoff for pos in source_positions)),
"source_recent16": float(sum(pos >= recent16_cutoff for pos in source_positions)),
"source_recent24": float(sum(pos >= recent24_cutoff for pos in source_positions)),
"last_source_age": last_source_age,
"quiet_sum": float(np.sum(trace["quiet"])),
"accel_sum": float(np.sum(trace["accel"])),
"revisit_sum": float(np.sum(trace["revisit"])),
"burst_release_burst": float(np.sum(trace["burst_release_burst"])),
"revisit_recent8": float(np.sum(trace["revisit"][recent8_cutoff:])),
"brb_recent8": float(np.sum(trace["burst_release_burst"][recent8_cutoff:])),
"txn_count": float(n),
}
class OracleMotifWrapper(TemporalModel):
def __init__(self):
self._model: LogisticRegression | None = None
self._constant_prob: float | None = None
self._feature_cols: list[str] = []
self._mean: np.ndarray | None = None
self._std: np.ndarray | None = None
@property
def name(self) -> str:
return "OracleMotif"
@property
def is_temporal(self) -> bool:
return True
def fit(self, df_train: pd.DataFrame, num_epochs: int = 3) -> None:
self._model = None
self._constant_prob = None
self._feature_cols = []
self._mean = None
self._std = None
@staticmethod
def _extract_features(df: pd.DataFrame) -> pd.DataFrame:
rows = []
for sender_id, group in df.groupby("sender_id", sort=False):
feats = _motif_features_for_user(group)
feats["sender_id"] = int(sender_id)
rows.append(feats)
if not rows:
return pd.DataFrame(columns=["sender_id"])
return pd.DataFrame(rows).set_index("sender_id").sort_index()
def train_node_classifier_on_prefix(
self,
df_prefix: pd.DataFrame,
eval_nodes: List[int],
y_labels: np.ndarray,
num_epochs: int = 150,
) -> None:
X = self._extract_features(df_prefix).reindex(eval_nodes).fillna(0.0)
y = np.asarray(y_labels, dtype=np.int64)
self._feature_cols = list(X.columns)
if len(y) == 0 or len(np.unique(y)) < 2:
self._model = None
self._constant_prob = float(y.mean()) if len(y) else 0.0
return
x_train = X.to_numpy(dtype=np.float32)
self._mean = x_train.mean(axis=0, keepdims=True)
self._std = x_train.std(axis=0, keepdims=True) + 1e-6
x_train = (x_train - self._mean) / self._std
self._model = LogisticRegression(
max_iter=2000,
class_weight="balanced",
solver="liblinear",
random_state=42,
)
self._model.fit(x_train, y)
self._constant_prob = None
def predict(self, df_eval: pd.DataFrame, eval_nodes: List[int]) -> np.ndarray:
X = self._extract_features(df_eval).reindex(eval_nodes).fillna(0.0)
if self._constant_prob is not None:
return np.full(len(eval_nodes), self._constant_prob, dtype=np.float32)
assert self._model is not None and self._mean is not None and self._std is not None
x_eval = (X.to_numpy(dtype=np.float32) - self._mean) / self._std
probs = self._model.predict_proba(x_eval)[:, 1]
return probs.astype(np.float32)
def reset_memory(self) -> None:
pass