Stephanwu commited on
Commit
de73d07
·
verified ·
1 Parent(s): 08ad82e

Add deep learning models: DIN, TabularBERT, Transformer, FocalLoss

Browse files
Files changed (1) hide show
  1. models.py +643 -0
models.py ADDED
@@ -0,0 +1,643 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 保险APP 深度学习模型定义
3
+ - InsuranceProductDIN: 保险产品推荐 (Deep Interest Network)
4
+ - TabularBERT: 异常行为检测 (层次化Transformer)
5
+ - FocalLoss: 不平衡数据专用损失函数
6
+
7
+ 参考文献:
8
+ - DIN: Deep Interest Network (KDD 2018, arxiv:1706.06978)
9
+ - TabBERT: Tabular Transformers (arxiv:2011.01843)
10
+ - Focal Loss: RetinaNet (ICCV 2017, arxiv:1708.02002)
11
+ """
12
+
13
+ import math
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from torch.utils.data import Dataset, DataLoader
19
+
20
+ # =============================================================================
21
+ # 1. 保险产品推荐 — DIN (Deep Interest Network)
22
+ # =============================================================================
23
+
24
+ class LocalActivationUnit(nn.Module):
25
+ """
26
+ DIN 核心: 局部激活单元
27
+ 对用户历史行为序列做加权求和, 权重由候选产品动态决定
28
+
29
+ 输入: [candidate_emb, behavior_emb, candidate-behavior, candidate*behavior]
30
+ 输出: 加权聚合的用户兴趣向量
31
+ """
32
+ def __init__(self, embedding_dim: int, hidden_dims: list = [128, 64]):
33
+ super().__init__()
34
+ layers = []
35
+ input_dim = embedding_dim * 4
36
+ for dim in hidden_dims:
37
+ layers.extend([
38
+ nn.Linear(input_dim, dim),
39
+ nn.ReLU(),
40
+ nn.Dropout(0.2),
41
+ ])
42
+ input_dim = dim
43
+ layers.append(nn.Linear(input_dim, 1))
44
+ self.attention = nn.Sequential(*layers)
45
+
46
+ def forward(self, candidate_emb, behavior_embs, mask=None):
47
+ """
48
+ Args:
49
+ candidate_emb: (B, D) 候选产品嵌入
50
+ behavior_embs: (B, L, D) 用户历史行为嵌入
51
+ mask: (B, L) 有效行为mask (True=有效)
52
+ Returns:
53
+ interest_vector: (B, D) 加权聚合的兴趣向量
54
+ """
55
+ B, L, D = behavior_embs.shape
56
+
57
+ # 扩展候选产品到历史长度
58
+ candidate_expanded = candidate_emb.unsqueeze(1).expand(B, L, D)
59
+
60
+ # 4路交互特征: [c, b, c-b, c*b]
61
+ diff = candidate_expanded - behavior_embs
62
+ prod = candidate_expanded * behavior_embs
63
+ attention_input = torch.cat([candidate_expanded, behavior_embs, diff, prod], dim=-1)
64
+
65
+ # 计算注意力权重
66
+ attention_weights = self.attention(attention_input).squeeze(-1) # (B, L)
67
+
68
+ # 应用mask
69
+ if mask is not None:
70
+ attention_weights = attention_weights.masked_fill(~mask, -1e9)
71
+
72
+ attention_weights = F.softmax(attention_weights, dim=1) # (B, L)
73
+
74
+ # 加权求和
75
+ interest_vector = (behavior_embs * attention_weights.unsqueeze(-1)).sum(dim=1) # (B, D)
76
+
77
+ return interest_vector
78
+
79
+
80
+ class InsuranceProductDIN(nn.Module):
81
+ """
82
+ 保险产品推荐 DIN 模型
83
+
84
+ 架构: Embedding + 局部激活注意力 + MLP
85
+ 适用: 基于用户行为序列推荐保险产品, 预测购买概率
86
+ """
87
+ def __init__(
88
+ self,
89
+ num_users: int = 10000,
90
+ num_products: int = 100,
91
+ num_event_types: int = 40,
92
+ num_user_features: int = 20,
93
+ embedding_dim: int = 64,
94
+ mlp_dims: list = [512, 256, 128],
95
+ max_seq_len: int = 50,
96
+ dropout: float = 0.3,
97
+ ):
98
+ super().__init__()
99
+
100
+ self.embedding_dim = embedding_dim
101
+ self.max_seq_len = max_seq_len
102
+
103
+ # 嵌入层
104
+ self.user_embedding = nn.Embedding(num_users, embedding_dim)
105
+ self.product_embedding = nn.Embedding(num_products, embedding_dim)
106
+ self.event_embedding = nn.Embedding(num_event_types, embedding_dim // 2)
107
+
108
+ # 用户统计特征投影
109
+ self.user_feature_proj = nn.Linear(num_user_features, embedding_dim)
110
+
111
+ # 局部激活单元 (核心)
112
+ self.attention = LocalActivationUnit(embedding_dim)
113
+
114
+ # MLP 预测头
115
+ input_dim = embedding_dim * 4 + num_user_features
116
+ layers = []
117
+ for dim in mlp_dims:
118
+ layers.extend([
119
+ nn.Linear(input_dim, dim),
120
+ nn.ReLU(),
121
+ nn.Dropout(dropout),
122
+ nn.BatchNorm1d(dim),
123
+ ])
124
+ input_dim = dim
125
+ layers.append(nn.Linear(input_dim, 1))
126
+ self.mlp = nn.Sequential(*layers)
127
+
128
+ def forward(self, user_ids, user_features, behavior_events, behavior_products, behavior_mask, candidate_product):
129
+ """
130
+ Args:
131
+ user_ids: (B,) 用户ID
132
+ user_features: (B, num_user_features) 用户统计特征
133
+ behavior_events: (B, L) 历史事件类型ID
134
+ behavior_products: (B, L) 历史产品ID
135
+ behavior_mask: (B, L) 有效历史mask
136
+ candidate_product: (B,) 候选产品ID
137
+ Returns:
138
+ logits: (B,) 购买概率
139
+ """
140
+ # 用户嵌入
141
+ user_emb = self.user_embedding(user_ids) # (B, D)
142
+ user_feat = self.user_feature_proj(user_features) # (B, D)
143
+ user_repr = user_emb + user_feat # (B, D)
144
+
145
+ # 历史行为嵌入: event_emb + product_emb
146
+ beh_event_emb = self.event_embedding(behavior_events) # (B, L, D/2)
147
+ beh_prod_emb = self.product_embedding(behavior_products) # (B, L, D)
148
+ # 补齐维度
149
+ beh_event_pad = F.pad(beh_event_emb, (0, self.embedding_dim - beh_event_emb.size(-1)))
150
+ behavior_emb = beh_event_pad + beh_prod_emb # (B, L, D)
151
+
152
+ # 候选产品嵌入
153
+ candidate_emb = self.product_embedding(candidate_product) # (B, D)
154
+
155
+ # 注意力兴趣向量
156
+ interest = self.attention(candidate_emb, behavior_emb, behavior_mask) # (B, D)
157
+
158
+ # 交互特征
159
+ user_item_prod = user_repr * candidate_emb # (B, D)
160
+
161
+ # 拼接所有特征
162
+ combined = torch.cat([
163
+ user_repr, # 用户画像
164
+ interest, # 动态兴趣
165
+ candidate_emb, # 候选产品
166
+ user_item_prod, # 交互
167
+ user_features, # 原始统计特征
168
+ ], dim=-1)
169
+
170
+ # MLP预测
171
+ logits = self.mlp(combined).squeeze(-1) # (B,)
172
+ return logits
173
+
174
+
175
+ # =============================================================================
176
+ # 2. 异常行为检测 — TabularBERT
177
+ # =============================================================================
178
+
179
+ class PositionalEncoding(nn.Module):
180
+ """Transformer 位置编码"""
181
+ def __init__(self, d_model: int, max_len: int = 5000):
182
+ super().__init__()
183
+ pe = torch.zeros(max_len, d_model)
184
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
185
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
186
+ pe[:, 0::2] = torch.sin(position * div_term)
187
+ pe[:, 1::2] = torch.cos(position * div_term)
188
+ self.register_buffer('pe', pe.unsqueeze(0))
189
+
190
+ def forward(self, x):
191
+ return x + self.pe[:, :x.size(1), :]
192
+
193
+
194
+ class TabularBERT(nn.Module):
195
+ """
196
+ 保险理赔/交易异常检测的层次化 BERT
197
+
198
+ 层级1: Field Transformer (单条记录内字段关联)
199
+ 层级2: Sequence Transformer (跨记录时序关联)
200
+
201
+ 适用: 理赔欺诈检测、异常交易识别
202
+ """
203
+ def __init__(
204
+ self,
205
+ num_fields: int = 15,
206
+ field_vocab_sizes: list = None,
207
+ d_model: int = 128,
208
+ nhead: int = 8,
209
+ num_field_layers: int = 2,
210
+ num_seq_layers: int = 4,
211
+ dim_feedforward: int = 512,
212
+ dropout: float = 0.2,
213
+ max_seq_len: int = 100,
214
+ ):
215
+ super().__init__()
216
+
217
+ self.num_fields = num_fields
218
+ self.d_model = d_model
219
+
220
+ # 字段嵌入
221
+ if field_vocab_sizes is None:
222
+ field_vocab_sizes = [1000] * num_fields
223
+
224
+ self.field_embeddings = nn.ModuleList([
225
+ nn.Embedding(vocab_size, d_model) for vocab_size in field_vocab_sizes
226
+ ])
227
+
228
+ # 字段类型嵌入
229
+ self.field_type_embedding = nn.Embedding(num_fields, d_model)
230
+
231
+ # 层级1: Field Transformer (intra-record)
232
+ field_encoder_layer = nn.TransformerEncoderLayer(
233
+ d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward,
234
+ dropout=dropout, batch_first=True
235
+ )
236
+ self.field_transformer = nn.TransformerEncoder(field_encoder_layer, num_field_layers)
237
+
238
+ # 层级2: Sequence Transformer (inter-record)
239
+ seq_encoder_layer = nn.TransformerEncoderLayer(
240
+ d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward,
241
+ dropout=dropout, batch_first=True
242
+ )
243
+ self.seq_transformer = nn.TransformerEncoder(seq_encoder_layer, num_seq_layers)
244
+
245
+ # 位置编码
246
+ self.pos_encoding = PositionalEncoding(d_model, max_seq_len)
247
+
248
+ # 异常检测头
249
+ self.anomaly_head = nn.Sequential(
250
+ nn.Linear(d_model, 256),
251
+ nn.ReLU(),
252
+ nn.Dropout(dropout),
253
+ nn.Linear(256, 64),
254
+ nn.ReLU(),
255
+ nn.Linear(64, 1),
256
+ )
257
+
258
+ # MLM 预训练头
259
+ self.mlm_heads = nn.ModuleList([
260
+ nn.Linear(d_model, vocab_size) for vocab_size in field_vocab_sizes
261
+ ])
262
+
263
+ def forward(self, field_ids, mask=None, return_mlm=False):
264
+ """
265
+ Args:
266
+ field_ids: (B, seq_len, num_fields) 每个字段的token ID
267
+ mask: (B, seq_len) 序列mask
268
+ return_mlm: 是否返回MLM预测
269
+ Returns:
270
+ anomaly_score: (B,) 异常分数 (sigmoid前)
271
+ mlm_logits: 可选, 用于预训练
272
+ """
273
+ B, L, F = field_ids.shape
274
+ assert F == self.num_fields
275
+
276
+ # 字段嵌入: 每个字段独立嵌入 + 字段类型嵌入
277
+ field_embs = []
278
+ for i in range(F):
279
+ emb = self.field_embeddings[i](field_ids[:, :, i]) # (B, L, D)
280
+ type_emb = self.field_type_embedding(torch.tensor(i, device=field_ids.device))
281
+ emb = emb + type_emb.unsqueeze(0).unsqueeze(0)
282
+ field_embs.append(emb)
283
+
284
+ # 合并: (B, L, F, D) → (B*L, F, D)
285
+ x = torch.stack(field_embs, dim=2) # (B, L, F, D)
286
+ x = x.view(B * L, F, self.d_model)
287
+
288
+ # Field-level attention
289
+ x = self.field_transformer(x) # (B*L, F, D)
290
+
291
+ # 池化到记录级表示
292
+ record_emb = x.mean(dim=1) # (B*L, D)
293
+ record_emb = record_emb.view(B, L, self.d_model)
294
+
295
+ # 位置编码 + Sequence-level attention
296
+ record_emb = self.pos_encoding(record_emb)
297
+ if mask is not None:
298
+ x = self.seq_transformer(record_emb, src_key_padding_mask=~mask)
299
+ else:
300
+ x = self.seq_transformer(record_emb)
301
+
302
+ # 全局池化
303
+ if mask is not None:
304
+ mask_float = mask.float().unsqueeze(-1) # (B, L, 1)
305
+ seq_emb = (x * mask_float).sum(dim=1) / mask_float.sum(dim=1).clamp(min=1)
306
+ else:
307
+ seq_emb = x.mean(dim=1)
308
+
309
+ # 异常分数
310
+ anomaly_score = self.anomaly_head(seq_emb).squeeze(-1) # (B,)
311
+
312
+ if return_mlm:
313
+ mlm_logits = []
314
+ record_emb_flat = record_emb.view(B * L, self.d_model)
315
+ for i, head in enumerate(self.mlm_heads):
316
+ mlm_logits.append(head(record_emb_flat))
317
+ return anomaly_score, mlm_logits
318
+
319
+ return anomaly_score
320
+
321
+
322
+ # =============================================================================
323
+ # 3. 用户流失预测 — Transformer
324
+ # =============================================================================
325
+
326
+ class ChurnPredictionTransformer(nn.Module):
327
+ """
328
+ 基于 Transformer 的用户流失/续保预测
329
+
330
+ 参考: Early Churn Prediction from Large Scale User-Product Interaction Time Series
331
+ (arXiv 2309.14390)
332
+
333
+ 输入: 用户最近 N 个行为的嵌入序列
334
+ 输出: 流失概率
335
+ """
336
+ def __init__(
337
+ self,
338
+ num_event_types: int = 40,
339
+ num_products: int = 100,
340
+ d_model: int = 128,
341
+ nhead: int = 8,
342
+ num_layers: int = 6,
343
+ dim_feedforward: int = 512,
344
+ dropout: float = 0.3,
345
+ max_seq_len: int = 100,
346
+ num_continuous_features: int = 20,
347
+ ):
348
+ super().__init__()
349
+
350
+ # 嵌入层
351
+ self.event_embedding = nn.Embedding(num_event_types, d_model // 2)
352
+ self.product_embedding = nn.Embedding(num_products, d_model // 2)
353
+
354
+ # 连续特征投影
355
+ self.continuous_proj = nn.Linear(num_continuous_features, d_model)
356
+
357
+ # 时间间隔编码 (对数变换)
358
+ self.time_proj = nn.Linear(1, d_model // 4)
359
+
360
+ # 特征融合
361
+ self.fusion = nn.Linear(d_model + d_model // 2 + d_model // 4, d_model)
362
+
363
+ # Transformer
364
+ self.pos_encoding = PositionalEncoding(d_model, max_seq_len)
365
+ encoder_layer = nn.TransformerEncoderLayer(
366
+ d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward,
367
+ dropout=dropout, batch_first=True
368
+ )
369
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
370
+
371
+ # 分类头
372
+ self.classifier = nn.Sequential(
373
+ nn.Linear(d_model, 256),
374
+ nn.ReLU(),
375
+ nn.Dropout(dropout),
376
+ nn.Linear(256, 64),
377
+ nn.ReLU(),
378
+ nn.Dropout(dropout),
379
+ nn.Linear(64, 1),
380
+ )
381
+
382
+ def forward(self, event_ids, product_ids, continuous_features, time_intervals, mask=None):
383
+ """
384
+ Args:
385
+ event_ids: (B, L)
386
+ product_ids: (B, L)
387
+ continuous_features: (B, L, num_continuous)
388
+ time_intervals: (B, L) 事件间隔(秒)
389
+ mask: (B, L) padding mask
390
+ """
391
+ B, L = event_ids.shape
392
+
393
+ # 嵌入
394
+ e_emb = self.event_embedding(event_ids) # (B, L, D/2)
395
+ p_emb = self.product_embedding(product_ids) # (B, L, D/2)
396
+ item_emb = torch.cat([e_emb, p_emb], dim=-1) # (B, L, D)
397
+
398
+ # 连续特征
399
+ c_emb = self.continuous_proj(continuous_features) # (B, L, D)
400
+
401
+ # 时间间隔
402
+ time_log = torch.log1p(time_intervals.unsqueeze(-1).clamp(min=0))
403
+ t_emb = self.time_proj(time_log) # (B, L, D/4)
404
+
405
+ # 融合
406
+ fused = torch.cat([item_emb, c_emb, t_emb], dim=-1)
407
+ x = self.fusion(fused) # (B, L, D)
408
+
409
+ # 位置编码 + Transformer
410
+ x = self.pos_encoding(x)
411
+ if mask is not None:
412
+ x = self.transformer(x, src_key_padding_mask=~mask)
413
+ else:
414
+ x = self.transformer(x)
415
+
416
+ # 全局平均池化
417
+ if mask is not None:
418
+ mask_float = mask.float().unsqueeze(-1)
419
+ x = (x * mask_float).sum(dim=1) / mask_float.sum(dim=1).clamp(min=1)
420
+ else:
421
+ x = x.mean(dim=1)
422
+
423
+ logits = self.classifier(x).squeeze(-1)
424
+ return logits
425
+
426
+
427
+ # =============================================================================
428
+ # 4. 损失函数 — Focal Loss (不平衡数据)
429
+ # =============================================================================
430
+
431
+ class FocalLoss(nn.Module):
432
+ """
433
+ Focal Loss for imbalanced classification
434
+
435
+ 降低易分样本的权重, 聚焦难分样本
436
+ 适用于: 保险欺诈检测 (fraud < 1%), 流失预测 (churn < 5%)
437
+ """
438
+ def __init__(self, alpha: float = 0.25, gamma: float = 2.0, reduction: str = 'mean'):
439
+ super().__init__()
440
+ self.alpha = alpha
441
+ self.gamma = gamma
442
+ self.reduction = reduction
443
+
444
+ def forward(self, inputs, targets):
445
+ """
446
+ Args:
447
+ inputs: (B,) 原始logits
448
+ targets: (B,) 0/1标签
449
+ """
450
+ bce = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
451
+ pt = torch.exp(-bce) # 预测概率
452
+ focal_weight = self.alpha * (1 - pt) ** self.gamma
453
+ loss = focal_weight * bce
454
+
455
+ if self.reduction == 'mean':
456
+ return loss.mean()
457
+ elif self.reduction == 'sum':
458
+ return loss.sum()
459
+ else:
460
+ return loss
461
+
462
+
463
+ # =============================================================================
464
+ # 5. 数据集定义
465
+ # =============================================================================
466
+
467
+ class BehaviorSequenceDataset(Dataset):
468
+ """行为序列数据集 (用于 Transformer 模型)"""
469
+ def __init__(self, features, event_sequences, product_sequences, labels, max_len=50):
470
+ self.features = np.array(features, dtype=np.float32)
471
+ self.event_seqs = event_sequences
472
+ self.product_seqs = product_sequences
473
+ self.labels = np.array(labels, dtype=np.float32)
474
+ self.max_len = max_len
475
+
476
+ # 构建vocab
477
+ all_events = set()
478
+ for seq in event_sequences:
479
+ all_events.update(seq)
480
+ self.event_vocab = {e: i+1 for i, e in enumerate(sorted(all_events))} # 0=PAD
481
+
482
+ all_products = set()
483
+ for seq in product_sequences:
484
+ all_products.update(p for p in seq if p)
485
+ self.product_vocab = {p: i+1 for i, p in enumerate(sorted(all_products))}
486
+
487
+ def __len__(self):
488
+ return len(self.labels)
489
+
490
+ def pad_sequence(self, seq, vocab, max_len):
491
+ """填充/截断序列"""
492
+ ids = [vocab.get(x, 0) for x in seq[-max_len:]]
493
+ if len(ids) < max_len:
494
+ ids = [0] * (max_len - len(ids)) + ids
495
+ return ids, len(seq[-max_len:]) if seq else 0
496
+
497
+ def __getitem__(self, idx):
498
+ e_ids, e_len = self.pad_sequence(self.event_seqs[idx], self.event_vocab, self.max_len)
499
+ p_ids, p_len = self.pad_sequence(self.product_seqs[idx], self.product_vocab, self.max_len)
500
+
501
+ mask = [1 if i >= self.max_len - e_len else 0 for i in range(self.max_len)]
502
+
503
+ return {
504
+ 'features': torch.tensor(self.features[idx]),
505
+ 'event_ids': torch.tensor(e_ids, dtype=torch.long),
506
+ 'product_ids': torch.tensor(p_ids, dtype=torch.long),
507
+ 'mask': torch.tensor(mask, dtype=torch.float),
508
+ 'label': torch.tensor(self.labels[idx]),
509
+ 'time_intervals': torch.zeros(self.max_len), # 简化版
510
+ }
511
+
512
+
513
+ class ProductInteractionDataset(Dataset):
514
+ """产品交互数据集 (用于 DIN 模型)"""
515
+ def __init__(self, user_ids, user_features, behavior_events, behavior_products,
516
+ behavior_masks, candidate_products, labels, max_len=50):
517
+ self.user_ids = np.array(user_ids, dtype=np.longlong)
518
+ self.user_features = np.array(user_features, dtype=np.float32)
519
+ self.behavior_events = behavior_events
520
+ self.behavior_products = behavior_products
521
+ self.behavior_masks = behavior_masks
522
+ self.candidate_products = np.array(candidate_products, dtype=np.longlong)
523
+ self.labels = np.array(labels, dtype=np.float32)
524
+ self.max_len = max_len
525
+
526
+ def __len__(self):
527
+ return len(self.labels)
528
+
529
+ def pad_seq(self, seq, max_len):
530
+ if len(seq) >= max_len:
531
+ return seq[-max_len:], [1]*max_len
532
+ else:
533
+ pad_len = max_len - len(seq)
534
+ return [0]*pad_len + seq, [0]*pad_len + [1]*len(seq)
535
+
536
+ def __getitem__(self, idx):
537
+ e_seq, e_mask = self.pad_seq(self.behavior_events[idx], self.max_len)
538
+ p_seq, p_mask = self.pad_seq(self.behavior_products[idx], self.max_len)
539
+
540
+ return {
541
+ 'user_id': torch.tensor(self.user_ids[idx]),
542
+ 'user_features': torch.tensor(self.user_features[idx]),
543
+ 'behavior_events': torch.tensor(e_seq, dtype=torch.long),
544
+ 'behavior_products': torch.tensor(p_seq, dtype=torch.long),
545
+ 'behavior_mask': torch.tensor(e_mask, dtype=torch.bool),
546
+ 'candidate_product': torch.tensor(self.candidate_products[idx]),
547
+ 'label': torch.tensor(self.labels[idx]),
548
+ }
549
+
550
+
551
+ def build_vocab(values, offset=1):
552
+ """构建vocabulary"""
553
+ unique = sorted(set(v for sublist in values for v in sublist if v))
554
+ return {v: i+offset for i, v in enumerate(unique)}
555
+
556
+
557
+ # =============================================================================
558
+ # 6. 训练工具
559
+ # =============================================================================
560
+
561
+ def train_epoch(model, dataloader, optimizer, criterion, device):
562
+ """单epoch训练"""
563
+ model.train()
564
+ total_loss = 0
565
+ for batch in dataloader:
566
+ optimizer.zero_grad()
567
+
568
+ # 根据模型类型处理输入
569
+ if hasattr(model, 'attention'): # DIN
570
+ outputs = model(
571
+ batch['user_id'].to(device),
572
+ batch['user_features'].to(device),
573
+ batch['behavior_events'].to(device),
574
+ batch['behavior_products'].to(device),
575
+ batch['behavior_mask'].to(device),
576
+ batch['candidate_product'].to(device),
577
+ )
578
+ elif hasattr(model, 'transformer'): # Churn Transformer
579
+ outputs = model(
580
+ batch['event_ids'].to(device),
581
+ batch['product_ids'].to(device),
582
+ batch['features'].unsqueeze(1).expand(-1, batch['event_ids'].size(1), -1).to(device),
583
+ batch['time_intervals'].to(device),
584
+ batch['mask'].to(device),
585
+ )
586
+ else: # TabularBERT
587
+ # 简化: 使用随机field_ids演示
588
+ B = batch['features'].size(0)
589
+ field_ids = torch.randint(0, 100, (B, 10, 5)).to(device)
590
+ outputs = model(field_ids)
591
+
592
+ labels = batch['label'].to(device)
593
+ loss = criterion(outputs, labels)
594
+ loss.backward()
595
+ optimizer.step()
596
+
597
+ total_loss += loss.item()
598
+
599
+ return total_loss / len(dataloader)
600
+
601
+
602
+ def evaluate_model(model, dataloader, device):
603
+ """评估模型"""
604
+ model.eval()
605
+ all_preds = []
606
+ all_labels = []
607
+
608
+ with torch.no_grad():
609
+ for batch in dataloader:
610
+ if hasattr(model, 'attention'):
611
+ outputs = model(
612
+ batch['user_id'].to(device),
613
+ batch['user_features'].to(device),
614
+ batch['behavior_events'].to(device),
615
+ batch['behavior_products'].to(device),
616
+ batch['behavior_mask'].to(device),
617
+ batch['candidate_product'].to(device),
618
+ )
619
+ elif hasattr(model, 'transformer'):
620
+ outputs = model(
621
+ batch['event_ids'].to(device),
622
+ batch['product_ids'].to(device),
623
+ batch['features'].unsqueeze(1).expand(-1, batch['event_ids'].size(1), -1).to(device),
624
+ batch['time_intervals'].to(device),
625
+ batch['mask'].to(device),
626
+ )
627
+ else:
628
+ B = batch['features'].size(0)
629
+ field_ids = torch.randint(0, 100, (B, 10, 5)).to(device)
630
+ outputs = model(field_ids)
631
+
632
+ all_preds.extend(torch.sigmoid(outputs).cpu().numpy())
633
+ all_labels.extend(batch['label'].numpy())
634
+
635
+ from sklearn.metrics import roc_auc_score, f1_score, average_precision_score
636
+ preds = np.array(all_preds)
637
+ labels = np.array(all_labels)
638
+
639
+ auc = roc_auc_score(labels, preds)
640
+ ap = average_precision_score(labels, preds)
641
+ f1 = f1_score(labels, preds > 0.5)
642
+
643
+ return {'auc': auc, 'ap': ap, 'f1': f1, 'preds': preds, 'labels': labels}