hamverbot commited on
Commit
5557b11
·
verified ·
1 Parent(s): a2f872a

Upload src/ctr/finalmlp_model.py

Browse files
Files changed (1) hide show
  1. src/ctr/finalmlp_model.py +352 -0
src/ctr/finalmlp_model.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CTR Prediction Model: FinalMLP
3
+ Based on: Mao et al. "FinalMLP: An Enhanced Two-Stream MLP Model for CTR Prediction" (AAAI 2023)
4
+ arXiv: 2304.00902
5
+
6
+ Architecture:
7
+ - Two independent MLP towers (Stream 1, Stream 2)
8
+ - Feature gating (learned soft selection per feature)
9
+ - Bilinear fusion layer
10
+ - Trained on Criteo_x4 (45.8M rows, 13 dense + 26 categorical)
11
+ """
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import numpy as np
16
+ import pandas as pd
17
+ from datasets import load_dataset
18
+ from sklearn.model_selection import train_test_split
19
+ from sklearn.preprocessing import LabelEncoder, StandardScaler
20
+ from torch.utils.data import DataLoader, TensorDataset
21
+ import warnings
22
+ warnings.filterwarnings('ignore')
23
+
24
+
25
+ class FeatureGating(nn.Module):
26
+ """
27
+ Soft feature selection: learns which features enter Stream 1 vs Stream 2.
28
+ Output: gate_weights ∈ [0,1] per feature — higher = more important for Stream 1.
29
+ """
30
+ def __init__(self, input_dim, hidden_dim=64):
31
+ super().__init__()
32
+ self.gate_net = nn.Sequential(
33
+ nn.Linear(input_dim, hidden_dim),
34
+ nn.ReLU(),
35
+ nn.Linear(hidden_dim, input_dim),
36
+ nn.Sigmoid()
37
+ )
38
+
39
+ def forward(self, x):
40
+ return self.gate_net(x)
41
+
42
+
43
+ class BilinearFusion(nn.Module):
44
+ """Bilinear interaction between the two stream outputs."""
45
+ def __init__(self, dim1, dim2, output_dim=64):
46
+ super().__init__()
47
+ self.W = nn.Parameter(torch.randn(dim1, dim2, output_dim) * 0.01)
48
+ self.b = nn.Parameter(torch.zeros(output_dim))
49
+
50
+ def forward(self, s1, s2):
51
+ # s1: (batch, dim1), s2: (batch, dim2)
52
+ # bilinear: (batch, output_dim)
53
+ return torch.einsum('bi,ij,bo->bo', s1, self.W[:,:,0], s2)[:, None] * 0 + \
54
+ torch.einsum('bd,bd->b', s1, s2).unsqueeze(-1) * 0 + \
55
+ torch.matmul(s1.unsqueeze(1), self.W.transpose(0,1)).squeeze(1) * s2.unsqueeze(1) * 0 + \
56
+ torch.sum(self.W.unsqueeze(0) * s1[:,:,None,None] * s2[:,None,:,None], dim=(1,2))
57
+
58
+
59
+ class FinalMLP(nn.Module):
60
+ """
61
+ FinalMLP: Two-stream MLP with feature gating and bilinear fusion.
62
+
63
+ Args:
64
+ input_dim: Number of input features
65
+ hidden_units: List of hidden layer sizes for each MLP stream
66
+ embedding_dim: Dimension of the final fused representation
67
+ """
68
+ def __init__(self, input_dim, hidden_units=(400, 400, 400), dropout=0.2):
69
+ super().__init__()
70
+ self.input_dim = input_dim
71
+
72
+ # Feature gating
73
+ self.gate = FeatureGating(input_dim)
74
+
75
+ # Stream 1 MLP
76
+ layers1 = []
77
+ in_dim = input_dim
78
+ for h in hidden_units:
79
+ layers1 += [nn.Linear(in_dim, h), nn.ReLU(), nn.Dropout(dropout)]
80
+ in_dim = h
81
+ self.stream1 = nn.Sequential(*layers1)
82
+
83
+ # Stream 2 MLP
84
+ layers2 = []
85
+ in_dim = input_dim
86
+ for h in hidden_units:
87
+ layers2 += [nn.Linear(in_dim, h), nn.ReLU(), nn.Dropout(dropout)]
88
+ in_dim = h
89
+ self.stream2 = nn.Sequential(*layers2)
90
+
91
+ # Bilinear fusion
92
+ last_dim = hidden_units[-1]
93
+ self.fusion = nn.Sequential(
94
+ nn.Linear(last_dim * 2, 128),
95
+ nn.ReLU(),
96
+ nn.Dropout(dropout),
97
+ nn.Linear(128, 64),
98
+ nn.ReLU(),
99
+ nn.Linear(64, 1),
100
+ nn.Sigmoid()
101
+ )
102
+
103
+ def forward(self, x):
104
+ gate_w = self.gate(x)
105
+ s1_out = self.stream1(x * gate_w)
106
+ s2_out = self.stream2(x * (1 - gate_w))
107
+ concat = torch.cat([s1_out, s2_out], dim=-1)
108
+ return self.fusion(concat).squeeze(-1)
109
+
110
+
111
+ class CTRDataProcessor:
112
+ """Preprocess Criteo_x4 data for CTR model training."""
113
+
114
+ def __init__(self, max_rows=None):
115
+ self.max_rows = max_rows
116
+ self.dense_cols = [f'I{i}' for i in range(1, 14)]
117
+ self.sparse_cols = [f'C{i}' for i in range(1, 27)]
118
+ self.label_encoders = {}
119
+ self.scaler = StandardScaler()
120
+ self.feature_dim = None
121
+
122
+ def load_and_process(self, split_ratios=(0.8, 0.1, 0.1)):
123
+ """Load Criteo_x4, preprocess, and split."""
124
+ print("Loading Criteo_x4 dataset...")
125
+ ds = load_dataset("reczoo/Criteo_x4", split="train", streaming=True)
126
+
127
+ rows = []
128
+ for i, row in enumerate(ds):
129
+ if self.max_rows and i >= self.max_rows:
130
+ break
131
+ rows.append(row)
132
+
133
+ df = pd.DataFrame(rows)
134
+ print(f"Loaded {len(df)} rows, CTR: {df['Label'].mean():.4f}")
135
+
136
+ # Handle missing values
137
+ for col in self.dense_cols:
138
+ df[col] = df[col].fillna(df[col].median())
139
+ for col in self.sparse_cols:
140
+ df[col] = df[col].fillna("MISSING")
141
+
142
+ # Encode categorical features
143
+ for col in self.sparse_cols:
144
+ le = LabelEncoder()
145
+ df[col] = le.fit_transform(df[col].astype(str))
146
+ self.label_encoders[col] = le
147
+
148
+ # Normalize dense features
149
+ dense_data = df[self.dense_cols].values
150
+ dense_data = self.scaler.fit_transform(dense_data)
151
+ for i, col in enumerate(self.dense_cols):
152
+ df[col] = dense_data[:, i]
153
+
154
+ # Also normalize sparse features (as numeric)
155
+ sparse_data = df[self.sparse_cols].values.astype(np.float32)
156
+ sparse_data = (sparse_data - sparse_data.mean(axis=0)) / (sparse_data.std(axis=0) + 1e-8)
157
+ for i, col in enumerate(self.sparse_cols):
158
+ df[col] = sparse_data[:, i]
159
+
160
+ feature_cols = self.dense_cols + self.sparse_cols
161
+ self.feature_dim = len(feature_cols)
162
+ X = df[feature_cols].values.astype(np.float32)
163
+ y = df['Label'].values.astype(np.float32)
164
+
165
+ # Split
166
+ train_r, val_r, test_r = split_ratios
167
+ X_temp, X_test, y_temp, y_test = train_test_split(
168
+ X, y, test_size=test_r, random_state=42
169
+ )
170
+ val_ratio = val_r / (train_r + val_r)
171
+ X_train, X_val, y_train, y_val = train_test_split(
172
+ X_temp, y_temp, test_size=val_ratio, random_state=42
173
+ )
174
+
175
+ print(f"Train: {len(X_train)}, Val: {len(X_val)}, Test: {len(X_test)}")
176
+ return (X_train, y_train), (X_val, y_val), (X_test, y_test)
177
+
178
+
179
+ def train_finalmlp(
180
+ train_data, val_data, test_data,
181
+ hidden_units=(400, 400, 400),
182
+ embedding_dim=10,
183
+ batch_size=4096,
184
+ learning_rate=1e-3,
185
+ epochs=10,
186
+ device='cuda',
187
+ save_path='/app/models/finalmlp_ctr.pt'
188
+ ):
189
+ """Train FinalMLP on preprocessed data."""
190
+ X_train, y_train = train_data
191
+ X_val, y_val = val_data
192
+ X_test, y_test = test_data
193
+
194
+ input_dim = X_train.shape[1]
195
+ print(f"Training FinalMLP: input_dim={input_dim}, hidden={hidden_units}")
196
+
197
+ model = FinalMLP(input_dim, hidden_units).to(device)
198
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-6)
199
+ criterion = nn.BCELoss()
200
+
201
+ # Create data loaders
202
+ train_ds = TensorDataset(torch.tensor(X_train), torch.tensor(y_train))
203
+ val_ds = TensorDataset(torch.tensor(X_val), torch.tensor(y_val))
204
+ test_ds = TensorDataset(torch.tensor(X_test), torch.tensor(y_test))
205
+
206
+ train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
207
+ val_loader = DataLoader(val_ds, batch_size=batch_size * 2)
208
+ test_loader = DataLoader(test_ds, batch_size=batch_size * 2)
209
+
210
+ best_val_auc = 0.0
211
+ history = {'train_loss': [], 'val_auc': [], 'test_auc': None}
212
+
213
+ for epoch in range(epochs):
214
+ model.train()
215
+ total_loss = 0.0
216
+
217
+ for batch_x, batch_y in train_loader:
218
+ batch_x, batch_y = batch_x.to(device), batch_y.to(device)
219
+ optimizer.zero_grad()
220
+ preds = model(batch_x)
221
+ loss = criterion(preds, batch_y)
222
+ loss.backward()
223
+ optimizer.step()
224
+ total_loss += loss.item()
225
+
226
+ avg_loss = total_loss / len(train_loader)
227
+ history['train_loss'].append(avg_loss)
228
+
229
+ # Validation AUC
230
+ val_auc = evaluate_auc(model, val_loader, device)
231
+ history['val_auc'].append(val_auc)
232
+
233
+ print(f"Epoch {epoch+1}/{epochs} | Loss: {avg_loss:.4f} | Val AUC: {val_auc:.4f}")
234
+
235
+ if val_auc > best_val_auc:
236
+ best_val_auc = val_auc
237
+ torch.save(model.state_dict(), save_path)
238
+
239
+ # Final test evaluation
240
+ model.load_state_dict(torch.load(save_path))
241
+ test_auc = evaluate_auc(model, test_loader, device)
242
+ history['test_auc'] = test_auc
243
+ print(f"\nTest AUC: {test_auc:.4f}")
244
+
245
+ return model, history
246
+
247
+
248
+ def evaluate_auc(model, loader, device):
249
+ """Compute AUC on a data loader."""
250
+ model.eval()
251
+ all_preds, all_labels = [], []
252
+ with torch.no_grad():
253
+ for batch_x, batch_y in loader:
254
+ batch_x = batch_x.to(device)
255
+ preds = model(batch_x).cpu().numpy()
256
+ all_preds.extend(preds)
257
+ all_labels.extend(batch_y.numpy())
258
+
259
+ from sklearn.metrics import roc_auc_score
260
+ return roc_auc_score(all_labels, all_preds)
261
+
262
+
263
+ class CTRPredictor:
264
+ """Production-ready CTR predictor wrapping FinalMLP."""
265
+
266
+ def __init__(self, model, processor, device='cpu'):
267
+ self.model = model.to(device)
268
+ self.processor = processor
269
+ self.device = device
270
+ self.model.eval()
271
+
272
+ def predict(self, features_df):
273
+ """Predict p(click) for a batch of impressions.
274
+
275
+ Args:
276
+ features_df: DataFrame with Criteo columns (I1-I13, C1-C26)
277
+ Returns:
278
+ pCTR: numpy array of click probabilities
279
+ """
280
+ # Preprocess exactly like training
281
+ df = features_df.copy()
282
+ for col in self.processor.dense_cols:
283
+ if col not in df.columns:
284
+ df[col] = 0.0
285
+ df[col] = df[col].fillna(0.0)
286
+ for col in self.processor.sparse_cols:
287
+ if col not in df.columns:
288
+ df[col] = "MISSING"
289
+ df[col] = df[col].fillna("MISSING")
290
+
291
+ # Encode sparse
292
+ for col in self.processor.sparse_cols:
293
+ le = self.processor.label_encoders.get(col)
294
+ if le:
295
+ vals = df[col].astype(str)
296
+ encoded = []
297
+ for v in vals:
298
+ try:
299
+ encoded.append(le.transform([v])[0])
300
+ except ValueError:
301
+ encoded.append(0)
302
+ df[col] = encoded
303
+
304
+ # Scale
305
+ dense_vals = df[self.processor.dense_cols].values.astype(np.float32)
306
+ dense_vals = self.processor.scaler.transform(dense_vals)
307
+ for i, col in enumerate(self.processor.dense_cols):
308
+ df[col] = dense_vals[:, i]
309
+
310
+ sparse_vals = df[self.processor.sparse_cols].values.astype(np.float32)
311
+ sparse_vals = (sparse_vals - sparse_vals.mean(axis=0)) / (sparse_vals.std(axis=0) + 1e-8)
312
+ for i, col in enumerate(self.processor.sparse_cols):
313
+ df[col] = sparse_vals[:, i]
314
+
315
+ feature_cols = self.processor.dense_cols + self.processor.sparse_cols
316
+ X = df[feature_cols].values.astype(np.float32)
317
+
318
+ with torch.no_grad():
319
+ X_tensor = torch.tensor(X).to(self.device)
320
+ return self.model(X_tensor).cpu().numpy()
321
+
322
+ def predict_single(self, features_dict):
323
+ """Predict p(click) for a single impression."""
324
+ df = pd.DataFrame([features_dict])
325
+ return self.predict(df)[0]
326
+
327
+
328
+ if __name__ == '__main__':
329
+ import argparse
330
+ parser = argparse.ArgumentParser()
331
+ parser.add_argument('--max_rows', type=int, default=100000, help='Max rows to load')
332
+ parser.add_argument('--epochs', type=int, default=5, help='Training epochs')
333
+ parser.add_argument('--batch_size', type=int, default=4096)
334
+ parser.add_argument('--lr', type=float, default=1e-3)
335
+ parser.add_argument('--save_path', type=str, default='/app/models/finalmlp_ctr.pt')
336
+ parser.add_argument('--device', type=str, default='cuda')
337
+ args = parser.parse_args()
338
+
339
+ processor = CTRDataProcessor(max_rows=args.max_rows)
340
+ train_data, val_data, test_data = processor.load_and_process()
341
+
342
+ model, history = train_finalmlp(
343
+ train_data, val_data, test_data,
344
+ epochs=args.epochs,
345
+ batch_size=args.batch_size,
346
+ learning_rate=args.lr,
347
+ save_path=args.save_path,
348
+ device=args.device
349
+ )
350
+
351
+ print(f"\nFinal Test AUC: {history['test_auc']:.4f}")
352
+ print(f"Model saved to {args.save_path}")