fraud-detection-system / ae_model.py
rajvivan's picture
Complete fraud detection system: code, figures, models, paper
408a9b2 verified
raw
history blame
1.4 kB
"""Shared autoencoder wrapper class for pickle compatibility."""
import numpy as np
import torch
import torch.nn as nn
import pandas as pd
class Autoencoder(nn.Module):
def __init__(self, input_dim):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(input_dim, 64), nn.ReLU(), nn.Dropout(0.2),
nn.Linear(64, 32), nn.ReLU(),
nn.Linear(32, 16), nn.ReLU()
)
self.decoder = nn.Sequential(
nn.Linear(16, 32), nn.ReLU(), nn.Dropout(0.2),
nn.Linear(32, 64), nn.ReLU(),
nn.Linear(64, input_dim)
)
def forward(self, x):
return self.decoder(self.encoder(x))
class AutoencoderWrapper:
"""Wrapper to make autoencoder compatible with sklearn interface."""
def __init__(self, model):
self.model = model
self.classes_ = np.array([0, 1])
def predict_proba(self, X):
self.model.eval()
Xn = X.values if isinstance(X, pd.DataFrame) else X
with torch.no_grad():
Xt = torch.FloatTensor(Xn)
out = self.model(Xt)
re = torch.mean((out - Xt)**2, dim=1).numpy()
scores = 1 / (1 + np.exp(-10 * (re - np.median(re))))
return np.column_stack([1-scores, scores])
def predict(self, X, threshold=0.5):
return (self.predict_proba(X)[:, 1] >= threshold).astype(int)