hamverbot commited on
Commit
431ef2b
·
verified ·
1 Parent(s): 1f4c5b5

Upload src/price/torchsurv_model.py

Browse files
Files changed (1) hide show
  1. src/price/torchsurv_model.py +266 -0
src/price/torchsurv_model.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TorchSurv-based Clearing Price Distribution Model
3
+ Uses deep survival analysis for censored market price prediction.
4
+
5
+ Right-censored data problem in first-price auctions:
6
+ - When you WIN: you observe the exact clearing price (your bid) → uncensored
7
+ - When you LOSE: you only know clearing price > your bid → right-censored
8
+
9
+ This maps exactly to survival analysis:
10
+ - "Event" = winning (price observed)
11
+ - "Time" = market price
12
+ - "Censoring" = losing (only lower bound)
13
+
14
+ Library: TorchSurv (Novartis, arXiv:2404.10761)
15
+ Install: pip install torchsurv
16
+ """
17
+ import torch
18
+ import torch.nn as nn
19
+ import numpy as np
20
+ from torch.utils.data import DataLoader, TensorDataset
21
+
22
+
23
+ class MarketPriceModel(nn.Module):
24
+ """
25
+ Neural network for predicting market price distribution.
26
+ Outputs log-hazard for Cox PH model or distribution parameters.
27
+
28
+ The survival function S(b|x) = P(market_price > b | features)
29
+ Win probability = 1 - S(b|x)
30
+ """
31
+
32
+ def __init__(self, input_dim, hidden_dims=(256, 128, 64), dropout=0.2):
33
+ super().__init__()
34
+ layers = []
35
+ in_dim = input_dim
36
+
37
+ for h in hidden_dims:
38
+ layers += [
39
+ nn.Linear(in_dim, h),
40
+ nn.BatchNorm1d(h),
41
+ nn.ReLU(),
42
+ nn.Dropout(dropout)
43
+ ]
44
+ in_dim = h
45
+
46
+ layers.append(nn.Linear(in_dim, 1)) # log hazard
47
+ self.net = nn.Sequential(*layers)
48
+
49
+ def forward(self, x):
50
+ return self.net(x).squeeze(-1)
51
+
52
+
53
+ class WinProbabilityModel(nn.Module):
54
+ """
55
+ Simple binary classifier: P(win | bid_price, features).
56
+ Faster alternative to full survival model when only win probability is needed.
57
+ """
58
+
59
+ def __init__(self, input_dim, hidden_dims=(256, 128, 64), dropout=0.2):
60
+ super().__init__()
61
+ layers = []
62
+ in_dim = input_dim + 1 # +1 for bid_price
63
+
64
+ for h in hidden_dims:
65
+ layers += [
66
+ nn.Linear(in_dim, h),
67
+ nn.ReLU(),
68
+ nn.Dropout(dropout)
69
+ ]
70
+ in_dim = h
71
+
72
+ layers.append(nn.Linear(in_dim, 1))
73
+ layers.append(nn.Sigmoid())
74
+ self.net = nn.Sequential(*layers)
75
+
76
+ def forward(self, features, bid_price):
77
+ x = torch.cat([features, bid_price.unsqueeze(-1)], dim=-1)
78
+ return self.net(x).squeeze(-1)
79
+
80
+
81
+ class CensoredPriceDataProcessor:
82
+ """
83
+ Prepare censored data for market price model training.
84
+
85
+ In first-price auction simulation:
86
+ - won=1: event occurred, time = bid_price (what you paid = your bid)
87
+ - won=0: censored, time = bid_price (you only know market_price > your bid)
88
+
89
+ For the Cox PH model:
90
+ - event: 1 if won (uncensored), 0 if lost (censored)
91
+ - time: bid_price in both cases (the "time" variable in survival analysis)
92
+ """
93
+
94
+ def __init__(self):
95
+ pass
96
+
97
+ @staticmethod
98
+ def prepare_from_auction_log(features, bids, won, prices=None):
99
+ """
100
+ Args:
101
+ features: (n, d) impression features
102
+ bids: (n,) bid prices submitted
103
+ won: (n,) boolean, True if won
104
+ prices: (n,) market prices (or None — uses bids as proxy)
105
+ Returns:
106
+ features_tensor, time_tensor, event_tensor
107
+ """
108
+ features = np.asarray(features, dtype=np.float32)
109
+ bids = np.asarray(bids, dtype=np.float32)
110
+ won = np.asarray(won, dtype=np.float32)
111
+
112
+ # In first-price: time = bid (the observed value)
113
+ time = bids.copy()
114
+ # event: 1 if won (we observed the clearing price), 0 if lost
115
+ event = won.copy()
116
+
117
+ return torch.tensor(features), torch.tensor(time), torch.tensor(event)
118
+
119
+ @staticmethod
120
+ def create_dataloader(features, time, event, batch_size=256, shuffle=True):
121
+ ds = TensorDataset(features, time, event)
122
+ return DataLoader(ds, batch_size=batch_size, shuffle=shuffle)
123
+
124
+
125
+ def train_market_price_model(
126
+ model, train_loader, val_loader=None,
127
+ epochs=20, lr=1e-3, device='cuda',
128
+ save_path='/app/models/market_price_model.pt'
129
+ ):
130
+ """
131
+ Train market price model using Cox PH loss (negative partial log-likelihood).
132
+ """
133
+ try:
134
+ from torchsurv.loss import cox
135
+ except ImportError:
136
+ print("torchsurv not installed. Using BCE-based fallback.")
137
+ return train_win_prob_fallback(model, train_loader, val_loader, epochs, lr, device, save_path)
138
+
139
+ model = model.to(device)
140
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
141
+
142
+ best_loss = float('inf')
143
+
144
+ for epoch in range(epochs):
145
+ model.train()
146
+ total_loss = 0.0
147
+
148
+ for batch_features, batch_time, batch_event in train_loader:
149
+ batch_features = batch_features.to(device)
150
+ batch_time = batch_time.to(device)
151
+ batch_event = batch_event.to(device)
152
+
153
+ optimizer.zero_grad()
154
+ log_hazard = model(batch_features)
155
+
156
+ # Cox PH negative partial log-likelihood
157
+ loss = cox.neg_partial_log_likelihood(
158
+ log_hazard,
159
+ event=batch_event,
160
+ time=batch_time
161
+ )
162
+
163
+ loss.backward()
164
+ optimizer.step()
165
+ total_loss += loss.item()
166
+
167
+ avg_loss = total_loss / len(train_loader)
168
+ print(f"Epoch {epoch+1}/{epochs} | Loss: {avg_loss:.4f}")
169
+
170
+ if avg_loss < best_loss:
171
+ best_loss = avg_loss
172
+ torch.save(model.state_dict(), save_path)
173
+
174
+ # Load best
175
+ model.load_state_dict(torch.load(save_path))
176
+ return model
177
+
178
+
179
+ def train_win_prob_fallback(model, train_loader, val_loader, epochs, lr, device, save_path):
180
+ """Fallback: train as binary classifier if TorchSurv not available."""
181
+ criterion = nn.BCEWithLogitsLoss()
182
+ model_win = nn.Sequential(model.net, nn.Sigmoid()).to(device)
183
+ optimizer = torch.optim.Adam(model_win.parameters(), lr=lr)
184
+
185
+ for epoch in range(epochs):
186
+ model_win.train()
187
+ total_loss = 0.0
188
+ for batch_features, batch_time, batch_event in train_loader:
189
+ batch_features = batch_features.to(device)
190
+ optimizer.zero_grad()
191
+ preds = model_win(batch_features).squeeze(-1)
192
+ loss = criterion(preds, batch_event)
193
+ loss.backward()
194
+ optimizer.step()
195
+ total_loss += loss.item()
196
+ print(f"Epoch {epoch+1}/{epochs} | BCE Loss: {total_loss/len(train_loader):.4f}")
197
+
198
+ torch.save(model_win.state_dict(), save_path)
199
+ return model_win
200
+
201
+
202
+ class MarketPricePredictor:
203
+ """
204
+ Predict win probability and expected cost using trained model.
205
+ """
206
+
207
+ def __init__(self, model, device='cpu'):
208
+ self.model = model.to(device)
209
+ self.device = device
210
+ self.model.eval()
211
+
212
+ def predict_win_probability(self, features, bid_prices):
213
+ """
214
+ Predict P(win | bid=b, features=x).
215
+ Uses survival function: P(win|b,x) = 1 - S(b|x)
216
+
217
+ Args:
218
+ features: (n, d) or (d,) feature tensor/array
219
+ bid_prices: (n,) or scalar bid price(s)
220
+ Returns:
221
+ win_prob: (n,) or scalar
222
+ """
223
+ features = torch.as_tensor(features, dtype=torch.float32).to(self.device)
224
+
225
+ with torch.no_grad():
226
+ log_hazard = self.model(features)
227
+ # Cox PH: S(t) = exp(-H(t)) where H is cumulative hazard
228
+ # Approximate P(win|b) = 1 - exp(-exp(log_hazard))
229
+ # This is a rough approximation — full Breslow estimator needed for accuracy
230
+ hazard = torch.exp(log_hazard)
231
+ survival = torch.exp(-hazard)
232
+ win_prob = 1.0 - survival
233
+
234
+ result = win_prob.cpu().numpy()
235
+ return float(result.item()) if result.ndim == 0 else result.squeeze()
236
+
237
+ def find_optimal_bid(self, features, v, lambd, bid_range=None, n_candidates=50):
238
+ """
239
+ Find optimal bid using learned win probability model.
240
+ b_t = argmax_b ( (v - b) * P(win|b,x) - λ * b * P(win|b,x) )
241
+
242
+ Args:
243
+ features: (d,) feature vector for this impression
244
+ v: value of winning (pCTR × value_per_click)
245
+ lambd: dual multiplier
246
+ Returns:
247
+ optimal_bid
248
+ """
249
+ if bid_range is None:
250
+ bid_range = (0.1, v * 2.0)
251
+
252
+ candidates = np.linspace(bid_range[0], bid_range[1], n_candidates)
253
+ features_tiled = np.tile(features, (n_candidates, 1))
254
+
255
+ win_probs = self.predict_win_probability(features_tiled, candidates)
256
+
257
+ scores = (v - candidates) * win_probs - lambd * candidates * win_probs
258
+ best_idx = np.argmax(scores)
259
+
260
+ return candidates[best_idx]
261
+
262
+
263
+ if __name__ == '__main__':
264
+ print("Market Price Model module loaded.")
265
+ print("Use train_market_price_model() with censored auction data.")
266
+ print("Or use EmpiricalCDF for the simpler non-parametric baseline.")