from __future__ import annotations import copy from typing import List import numpy as np import pandas as pd import torch import torch.nn as nn from sklearn.metrics import average_precision_score, roc_auc_score from models.base import TemporalModel _BLOCKED_COLS = frozenset({ "motif_hit_count", "motif_source", "trigger_event_idx", "label_event_idx", "label_delay", "is_fallback_label", "fraud_source", "twin_role", "twin_label", "twin_pair_id", "template_id", "dynamic_fraud_state", "motif_chain_state", "motif_strength", }) def _safe_roc_auc(y_true: np.ndarray, y_prob: np.ndarray) -> float: y_true = np.asarray(y_true, dtype=np.float32) y_prob = np.asarray(y_prob, dtype=np.float32) if len(y_true) == 0 or len(np.unique(y_true)) < 2: return 0.5 return float(roc_auc_score(y_true, y_prob)) def _safe_pr_auc(y_true: np.ndarray, y_prob: np.ndarray) -> float: y_true = np.asarray(y_true, dtype=np.float32) y_prob = np.asarray(y_prob, dtype=np.float32) positives = float(np.sum(y_true == 1)) negatives = float(np.sum(y_true == 0)) if positives == 0.0: return 0.0 if negatives == 0.0: return 1.0 return float(average_precision_score(y_true, y_prob)) class _SeqGRU(nn.Module): def __init__( self, num_buckets: int, numeric_dim: int, emb_dim: int = 32, pos_dim: int = 16, time_dim: int = 24, hidden_dim: int = 64, max_positions: int = 256, ): super().__init__() self.receiver_emb = nn.Embedding(num_buckets + 1, emb_dim) self.position_emb = nn.Embedding(max_positions + 1, pos_dim) self.numeric_proj = nn.Sequential( nn.Linear(numeric_dim, time_dim), nn.ReLU(), nn.LayerNorm(time_dim), ) self.input_proj = nn.Sequential( nn.Linear(emb_dim + pos_dim + time_dim, hidden_dim), nn.ReLU(), ) self.gru = nn.GRU( input_size=hidden_dim, hidden_size=hidden_dim, batch_first=True, bidirectional=False, ) self.attn = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.Tanh(), nn.Linear(hidden_dim, 1), ) self.head = nn.Sequential( nn.LayerNorm(hidden_dim * 3), nn.Linear(hidden_dim * 3, hidden_dim), nn.ReLU(), nn.Dropout(0.10), nn.Linear(hidden_dim, 1), ) def forward( self, receiver_ids: torch.Tensor, numeric_feats: torch.Tensor, positions: torch.Tensor, lengths: torch.Tensor, ) -> torch.Tensor: emb = self.receiver_emb(receiver_ids) pos_emb = self.position_emb(positions) time_repr = self.numeric_proj(numeric_feats) x = torch.cat([emb, pos_emb, time_repr], dim=-1) x = self.input_proj(x) h_seq, _ = self.gru(x) batch_size, seq_len, hidden_dim = h_seq.shape mask = ( torch.arange(seq_len, device=lengths.device).unsqueeze(0) < lengths.unsqueeze(1) ) masked_h = h_seq.masked_fill(~mask.unsqueeze(-1), -1e9) attn_scores = self.attn(h_seq).squeeze(-1).masked_fill(~mask, -1e9) attn_weights = torch.softmax(attn_scores, dim=1) attn_pool = (h_seq * attn_weights.unsqueeze(-1)).sum(dim=1) max_hidden = masked_h.max(dim=1).values sum_hidden = (h_seq * mask.unsqueeze(-1)).sum(dim=1) mean_hidden = sum_hidden / lengths.clamp(min=1).unsqueeze(1) pooled = torch.cat([attn_pool, max_hidden, mean_hidden], dim=-1) logits = self.head(pooled).squeeze(-1) return logits class SequenceGRUWrapper(TemporalModel): def __init__( self, hidden_dim: int = 64, receiver_buckets: int = 256, max_positions: int = 256, device: str = "cpu", ): self.hidden_dim = hidden_dim self.receiver_buckets = receiver_buckets self.max_positions = max_positions self.device = torch.device(device) self._model: _SeqGRU | None = None self._constant_prob: float | None = None @property def name(self) -> str: return "SeqGRU" @property def is_temporal(self) -> bool: return True def fit(self, df_train: pd.DataFrame, num_epochs: int = 3) -> None: self._model = _SeqGRU( num_buckets=self.receiver_buckets, numeric_dim=6, emb_dim=32, hidden_dim=self.hidden_dim, max_positions=self.max_positions, ).to(self.device) self._constant_prob = None def _receiver_token(self, receiver_ids: np.ndarray) -> np.ndarray: receiver_ids = np.asarray(receiver_ids, dtype=np.int64) local_map: dict[int, int] = {} next_token = 1 tokens = np.zeros(len(receiver_ids), dtype=np.int64) for idx, receiver_id in enumerate(receiver_ids.tolist()): if receiver_id not in local_map: local_map[receiver_id] = min(next_token, self.receiver_buckets) next_token += 1 tokens[idx] = local_map[receiver_id] return tokens def _build_event_numeric(self, group: pd.DataFrame) -> np.ndarray: group = group.sort_values("timestamp").reset_index(drop=True) timestamps = group["timestamp"].to_numpy(dtype=np.float64) dts = np.diff(timestamps, prepend=timestamps[0]) dts = np.maximum(dts, 0.0) phase = (timestamps % 86400.0) / 86400.0 amount = group["amount"].to_numpy(dtype=np.float32) if "amount" in group.columns else np.zeros(len(group), dtype=np.float32) retry = group["is_retry"].to_numpy(dtype=np.float32) if "is_retry" in group.columns else np.zeros(len(group), dtype=np.float32) failed = group["failed"].to_numpy(dtype=np.float32) if "failed" in group.columns else np.zeros(len(group), dtype=np.float32) return np.stack( [ np.log1p(dts).astype(np.float32), np.log1p(np.maximum(amount, 0.0)).astype(np.float32), retry.astype(np.float32), failed.astype(np.float32), np.sin(2.0 * np.pi * phase).astype(np.float32), np.cos(2.0 * np.pi * phase).astype(np.float32), ], axis=1, ) def _finalize_sequence( self, receiver_ids: np.ndarray, numeric: np.ndarray, perm: np.ndarray | None = None, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: receiver_ids = np.asarray(receiver_ids, dtype=np.int64) numeric = np.asarray(numeric, dtype=np.float32) if perm is not None and len(receiver_ids): receiver_ids = receiver_ids[perm] numeric = numeric[perm] receiver_tokens = self._receiver_token(receiver_ids) positions = np.minimum( np.arange(len(receiver_tokens), dtype=np.int64), self.max_positions, ) return receiver_tokens, numeric.astype(np.float32), positions def _pad_example_batch( self, receiver_seqs: list[np.ndarray], numeric_seqs: list[np.ndarray], position_seqs: list[np.ndarray], ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: lengths = np.array([len(seq) for seq in receiver_seqs], dtype=np.int64) max_len = int(max(lengths.max() if len(lengths) else 1, 1)) recv_batch = np.zeros((len(receiver_seqs), max_len), dtype=np.int64) feat_batch = np.zeros((len(receiver_seqs), max_len, 6), dtype=np.float32) pos_batch = np.zeros((len(receiver_seqs), max_len), dtype=np.int64) for idx, (receiver_ids, numeric, positions) in enumerate(zip(receiver_seqs, numeric_seqs, position_seqs)): seq_len = len(receiver_ids) recv_batch[idx, :seq_len] = receiver_ids feat_batch[idx, :seq_len, :] = numeric pos_batch[idx, :seq_len] = positions return ( torch.tensor(recv_batch, dtype=torch.long, device=self.device), torch.tensor(feat_batch, dtype=torch.float32, device=self.device), torch.tensor(pos_batch, dtype=torch.long, device=self.device), torch.tensor(lengths, dtype=torch.long, device=self.device), ) def _build_sequences(self, df: pd.DataFrame, eval_nodes: List[int]): leaked = _BLOCKED_COLS & set(df.columns) assert not leaked, f"Oracle columns leaked into SeqGRU: {leaked}" df = df.sort_values("timestamp").reset_index(drop=True).copy() groups = {int(sender_id): group for sender_id, group in df.groupby("sender_id", sort=False)} sequences = [] lengths = [] for node_id in eval_nodes: group = groups.get(int(node_id)) if group is None or group.empty: receiver_ids = np.zeros((1,), dtype=np.int64) numeric = np.zeros((1, 6), dtype=np.float32) else: receiver_ids, numeric, _ = self._finalize_sequence( group["receiver_id"].to_numpy(dtype=np.int64), self._build_event_numeric(group), ) sequences.append((receiver_ids, numeric)) lengths.append(len(receiver_ids)) max_len = max(lengths) if lengths else 1 recv_batch = np.zeros((len(eval_nodes), max_len), dtype=np.int64) feat_batch = np.zeros((len(eval_nodes), max_len, 6), dtype=np.float32) pos_batch = np.zeros((len(eval_nodes), max_len), dtype=np.int64) for idx, (receiver_ids, numeric) in enumerate(sequences): seq_len = len(receiver_ids) recv_batch[idx, :seq_len] = receiver_ids feat_batch[idx, :seq_len, :] = numeric pos_batch[idx, :seq_len] = np.minimum( np.arange(seq_len, dtype=np.int64), self.max_positions, ) return ( torch.tensor(recv_batch, dtype=torch.long, device=self.device), torch.tensor(feat_batch, dtype=torch.float32, device=self.device), torch.tensor(pos_batch, dtype=torch.long, device=self.device), torch.tensor(lengths, dtype=torch.long, device=self.device), ) def _build_matched_example_dataset( self, df: pd.DataFrame, examples: pd.DataFrame, shuffle_within_sequence: bool = False, seed: int = 0, ) -> dict: if examples.empty: return { "receiver_seqs": [], "numeric_seqs": [], "position_seqs": [], "labels": np.zeros(0, dtype=np.float32), "pair_event_ids": np.zeros(0, dtype=np.int64), } df = df.sort_values("timestamp").reset_index(drop=True).copy() if "local_event_idx" not in df.columns: df["local_event_idx"] = df.groupby("sender_id").cumcount().astype(np.int32) groups = { int(sender_id): group.reset_index(drop=True).copy() for sender_id, group in df.groupby("sender_id", sort=False) } receiver_seqs: list[np.ndarray] = [] numeric_seqs: list[np.ndarray] = [] position_seqs: list[np.ndarray] = [] labels: list[float] = [] pair_event_ids: list[int] = [] for row in examples.itertuples(index=False): sender_id = int(row.sender_id) group = groups.get(sender_id) if group is None or group.empty: receiver_tokens = np.zeros((1,), dtype=np.int64) numeric = np.zeros((1, 6), dtype=np.float32) positions = np.zeros((1,), dtype=np.int64) else: end_idx = int(row.eval_local_event_idx) prefix = group.iloc[: end_idx + 1].copy() receiver_ids = prefix["receiver_id"].to_numpy(dtype=np.int64) numeric = self._build_event_numeric(prefix) perm = None if shuffle_within_sequence and len(receiver_ids) > 1: rng = np.random.default_rng(seed + int(row.pair_event_id) * 97 + int(row.label) * 13) perm = rng.permutation(len(receiver_ids)) receiver_tokens, numeric, positions = self._finalize_sequence( receiver_ids, numeric, perm=perm, ) receiver_seqs.append(receiver_tokens) numeric_seqs.append(numeric) position_seqs.append(positions) labels.append(float(row.label)) pair_event_ids.append(int(row.pair_event_id)) return { "receiver_seqs": receiver_seqs, "numeric_seqs": numeric_seqs, "position_seqs": position_seqs, "labels": np.asarray(labels, dtype=np.float32), "pair_event_ids": np.asarray(pair_event_ids, dtype=np.int64), } def _dataset_subset(self, dataset: dict, idx: np.ndarray) -> dict: idx_list = idx.tolist() return { "receiver_seqs": [dataset["receiver_seqs"][i] for i in idx_list], "numeric_seqs": [dataset["numeric_seqs"][i] for i in idx_list], "position_seqs": [dataset["position_seqs"][i] for i in idx_list], "labels": dataset["labels"][idx], "pair_event_ids": dataset["pair_event_ids"][idx], } def _predict_dataset(self, dataset: dict, batch_size: int = 256) -> np.ndarray: if self._constant_prob is not None: return np.full(len(dataset["labels"]), self._constant_prob, dtype=np.float32) assert self._model is not None, "Call fit() first." if len(dataset["labels"]) == 0: return np.zeros(0, dtype=np.float32) self._model.eval() preds: list[np.ndarray] = [] with torch.no_grad(): for start in range(0, len(dataset["labels"]), batch_size): end = min(len(dataset["labels"]), start + batch_size) receiver_ids, numeric_feats, positions, lengths = self._pad_example_batch( dataset["receiver_seqs"][start:end], dataset["numeric_seqs"][start:end], dataset["position_seqs"][start:end], ) logits = self._model(receiver_ids, numeric_feats, positions, lengths) preds.append(torch.sigmoid(logits).cpu().numpy().astype(np.float32)) return np.concatenate(preds, axis=0) def fit_matched_prefix_examples( self, df_train: pd.DataFrame, train_examples: pd.DataFrame, seed: int = 0, max_epochs: int = 32, patience: int = 6, valid_frac: float = 0.20, pair_batch_size: int = 64, learning_rate: float = 2e-3, weight_decay: float = 1e-4, shuffle_within_sequence: bool = False, ) -> dict: assert self._model is not None, "Call fit() first." dataset = self._build_matched_example_dataset( df_train, train_examples, shuffle_within_sequence=shuffle_within_sequence, seed=seed, ) y = dataset["labels"] if len(y) == 0 or len(np.unique(y)) < 2: self._constant_prob = float(y.mean()) if len(y) else 0.0 return { "best_epoch": 0, "best_valid_roc_auc": float("nan"), "best_valid_pr_auc": float("nan"), "train_examples": int(len(y)), "valid_examples": 0, } pair_ids = np.unique(dataset["pair_event_ids"]) rng = np.random.default_rng(seed) shuffled_pair_ids = rng.permutation(pair_ids) valid_pairs = int(max(1, round(len(shuffled_pair_ids) * valid_frac))) if len(shuffled_pair_ids) >= 5 else 0 if valid_pairs >= len(shuffled_pair_ids): valid_pairs = max(1, len(shuffled_pair_ids) - 1) valid_pair_ids = set(shuffled_pair_ids[:valid_pairs].tolist()) if valid_pairs > 0 else set() valid_mask = np.isin(dataset["pair_event_ids"], list(valid_pair_ids)) if valid_pair_ids else np.zeros(len(y), dtype=bool) train_mask = ~valid_mask train_idx = np.flatnonzero(train_mask) valid_idx = np.flatnonzero(valid_mask) if len(train_idx) == 0: train_idx = np.arange(len(y)) valid_idx = np.zeros(0, dtype=np.int64) train_dataset = self._dataset_subset(dataset, train_idx) valid_dataset = self._dataset_subset(dataset, valid_idx) if len(valid_idx) else None train_pair_order = np.unique(train_dataset["pair_event_ids"]) pair_to_indices: dict[int, list[int]] = {} for idx, pair_event_id in enumerate(train_dataset["pair_event_ids"].tolist()): pair_to_indices.setdefault(int(pair_event_id), []).append(idx) optimizer = torch.optim.AdamW( self._model.parameters(), lr=learning_rate, weight_decay=weight_decay, ) loss_fn = nn.BCEWithLogitsLoss() best_state = copy.deepcopy(self._model.state_dict()) best_epoch = 0 best_valid_roc = -np.inf best_valid_pr = float("nan") stale_epochs = 0 n_epochs = max(12, max_epochs) for epoch in range(n_epochs): self._model.train() epoch_pair_ids = rng.permutation(train_pair_order) for start in range(0, len(epoch_pair_ids), pair_batch_size): batch_pair_ids = epoch_pair_ids[start : start + pair_batch_size] batch_indices: list[int] = [] for pair_event_id in batch_pair_ids.tolist(): batch_indices.extend(pair_to_indices[int(pair_event_id)]) receiver_ids, numeric_feats, positions, lengths = self._pad_example_batch( [train_dataset["receiver_seqs"][i] for i in batch_indices], [train_dataset["numeric_seqs"][i] for i in batch_indices], [train_dataset["position_seqs"][i] for i in batch_indices], ) labels = torch.tensor( train_dataset["labels"][batch_indices], dtype=torch.float32, device=self.device, ) logits = self._model(receiver_ids, numeric_feats, positions, lengths) loss = loss_fn(logits, labels) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self._model.parameters(), 1.0) optimizer.step() if valid_dataset is None or len(valid_dataset["labels"]) == 0: best_state = copy.deepcopy(self._model.state_dict()) best_epoch = epoch + 1 continue valid_probs = self._predict_dataset(valid_dataset) valid_roc = _safe_roc_auc(valid_dataset["labels"], valid_probs) valid_pr = _safe_pr_auc(valid_dataset["labels"], valid_probs) if valid_roc > best_valid_roc + 1e-4: best_valid_roc = valid_roc best_valid_pr = valid_pr best_state = copy.deepcopy(self._model.state_dict()) best_epoch = epoch + 1 stale_epochs = 0 else: stale_epochs += 1 if stale_epochs >= patience: break self._model.load_state_dict(best_state) self._model.eval() self._constant_prob = None return { "best_epoch": int(best_epoch), "best_valid_roc_auc": float(best_valid_roc) if best_valid_roc > -np.inf else float("nan"), "best_valid_pr_auc": float(best_valid_pr), "train_examples": int(len(train_dataset["labels"])), "valid_examples": int(len(valid_dataset["labels"])) if valid_dataset is not None else 0, } def predict_matched_prefix_examples( self, df_eval: pd.DataFrame, examples: pd.DataFrame, seed: int = 0, shuffle_within_sequence: bool = False, batch_size: int = 256, ) -> np.ndarray: dataset = self._build_matched_example_dataset( df_eval, examples, shuffle_within_sequence=shuffle_within_sequence, seed=seed, ) return self._predict_dataset(dataset, batch_size=batch_size) def train_node_classifier_on_prefix( self, df_prefix: pd.DataFrame, eval_nodes: List[int], y_labels: np.ndarray, num_epochs: int = 150, ) -> None: assert self._model is not None, "Call fit() first." y = np.asarray(y_labels, dtype=np.float32) if len(y) == 0 or len(np.unique(y)) < 2: self._constant_prob = float(y.mean()) if len(y) else 0.0 return receiver_ids, numeric_feats, positions, lengths = self._build_sequences(df_prefix, eval_nodes) y_t = torch.tensor(y, dtype=torch.float32, device=self.device) pos_weight = torch.clamp((y_t == 0).sum() / ((y_t == 1).sum() + 1e-6), max=10.0) loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight) optimizer = torch.optim.Adam(self._model.parameters(), lr=1e-3) n_epochs = max(24, min(64, max(1, num_epochs // 2))) self._model.train() for _ in range(n_epochs): logits = self._model(receiver_ids, numeric_feats, positions, lengths) loss = loss_fn(logits, y_t) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self._model.parameters(), 1.0) optimizer.step() self._constant_prob = None self._model.eval() def predict(self, df_eval: pd.DataFrame, eval_nodes: List[int]) -> np.ndarray: if self._constant_prob is not None: return np.full(len(eval_nodes), self._constant_prob, dtype=np.float32) assert self._model is not None, "Call fit() first." receiver_ids, numeric_feats, positions, lengths = self._build_sequences(df_eval, eval_nodes) self._model.eval() with torch.no_grad(): logits = self._model(receiver_ids, numeric_feats, positions, lengths) probs = torch.sigmoid(logits).cpu().numpy() return probs.astype(np.float32) def reset_memory(self) -> None: pass