Add deep learning models: DIN, TabularBERT, Transformer, FocalLoss
Browse files
models.py
ADDED
|
@@ -0,0 +1,643 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
保险APP 深度学习模型定义
|
| 3 |
+
- InsuranceProductDIN: 保险产品推荐 (Deep Interest Network)
|
| 4 |
+
- TabularBERT: 异常行为检测 (层次化Transformer)
|
| 5 |
+
- FocalLoss: 不平衡数据专用损失函数
|
| 6 |
+
|
| 7 |
+
参考文献:
|
| 8 |
+
- DIN: Deep Interest Network (KDD 2018, arxiv:1706.06978)
|
| 9 |
+
- TabBERT: Tabular Transformers (arxiv:2011.01843)
|
| 10 |
+
- Focal Loss: RetinaNet (ICCV 2017, arxiv:1708.02002)
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import math
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
from torch.utils.data import Dataset, DataLoader
|
| 19 |
+
|
| 20 |
+
# =============================================================================
|
| 21 |
+
# 1. 保险产品推荐 — DIN (Deep Interest Network)
|
| 22 |
+
# =============================================================================
|
| 23 |
+
|
| 24 |
+
class LocalActivationUnit(nn.Module):
|
| 25 |
+
"""
|
| 26 |
+
DIN 核心: 局部激活单元
|
| 27 |
+
对用户历史行为序列做加权求和, 权重由候选产品动态决定
|
| 28 |
+
|
| 29 |
+
输入: [candidate_emb, behavior_emb, candidate-behavior, candidate*behavior]
|
| 30 |
+
输出: 加权聚合的用户兴趣向量
|
| 31 |
+
"""
|
| 32 |
+
def __init__(self, embedding_dim: int, hidden_dims: list = [128, 64]):
|
| 33 |
+
super().__init__()
|
| 34 |
+
layers = []
|
| 35 |
+
input_dim = embedding_dim * 4
|
| 36 |
+
for dim in hidden_dims:
|
| 37 |
+
layers.extend([
|
| 38 |
+
nn.Linear(input_dim, dim),
|
| 39 |
+
nn.ReLU(),
|
| 40 |
+
nn.Dropout(0.2),
|
| 41 |
+
])
|
| 42 |
+
input_dim = dim
|
| 43 |
+
layers.append(nn.Linear(input_dim, 1))
|
| 44 |
+
self.attention = nn.Sequential(*layers)
|
| 45 |
+
|
| 46 |
+
def forward(self, candidate_emb, behavior_embs, mask=None):
|
| 47 |
+
"""
|
| 48 |
+
Args:
|
| 49 |
+
candidate_emb: (B, D) 候选产品嵌入
|
| 50 |
+
behavior_embs: (B, L, D) 用户历史行为嵌入
|
| 51 |
+
mask: (B, L) 有效行为mask (True=有效)
|
| 52 |
+
Returns:
|
| 53 |
+
interest_vector: (B, D) 加权聚合的兴趣向量
|
| 54 |
+
"""
|
| 55 |
+
B, L, D = behavior_embs.shape
|
| 56 |
+
|
| 57 |
+
# 扩展候选产品到历史长度
|
| 58 |
+
candidate_expanded = candidate_emb.unsqueeze(1).expand(B, L, D)
|
| 59 |
+
|
| 60 |
+
# 4路交互特征: [c, b, c-b, c*b]
|
| 61 |
+
diff = candidate_expanded - behavior_embs
|
| 62 |
+
prod = candidate_expanded * behavior_embs
|
| 63 |
+
attention_input = torch.cat([candidate_expanded, behavior_embs, diff, prod], dim=-1)
|
| 64 |
+
|
| 65 |
+
# 计算注意力权重
|
| 66 |
+
attention_weights = self.attention(attention_input).squeeze(-1) # (B, L)
|
| 67 |
+
|
| 68 |
+
# 应用mask
|
| 69 |
+
if mask is not None:
|
| 70 |
+
attention_weights = attention_weights.masked_fill(~mask, -1e9)
|
| 71 |
+
|
| 72 |
+
attention_weights = F.softmax(attention_weights, dim=1) # (B, L)
|
| 73 |
+
|
| 74 |
+
# 加权求和
|
| 75 |
+
interest_vector = (behavior_embs * attention_weights.unsqueeze(-1)).sum(dim=1) # (B, D)
|
| 76 |
+
|
| 77 |
+
return interest_vector
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class InsuranceProductDIN(nn.Module):
|
| 81 |
+
"""
|
| 82 |
+
保险产品推荐 DIN 模型
|
| 83 |
+
|
| 84 |
+
架构: Embedding + 局部激活注意力 + MLP
|
| 85 |
+
适用: 基于用户行为序列推荐保险产品, 预测购买概率
|
| 86 |
+
"""
|
| 87 |
+
def __init__(
|
| 88 |
+
self,
|
| 89 |
+
num_users: int = 10000,
|
| 90 |
+
num_products: int = 100,
|
| 91 |
+
num_event_types: int = 40,
|
| 92 |
+
num_user_features: int = 20,
|
| 93 |
+
embedding_dim: int = 64,
|
| 94 |
+
mlp_dims: list = [512, 256, 128],
|
| 95 |
+
max_seq_len: int = 50,
|
| 96 |
+
dropout: float = 0.3,
|
| 97 |
+
):
|
| 98 |
+
super().__init__()
|
| 99 |
+
|
| 100 |
+
self.embedding_dim = embedding_dim
|
| 101 |
+
self.max_seq_len = max_seq_len
|
| 102 |
+
|
| 103 |
+
# 嵌入层
|
| 104 |
+
self.user_embedding = nn.Embedding(num_users, embedding_dim)
|
| 105 |
+
self.product_embedding = nn.Embedding(num_products, embedding_dim)
|
| 106 |
+
self.event_embedding = nn.Embedding(num_event_types, embedding_dim // 2)
|
| 107 |
+
|
| 108 |
+
# 用户统计特征投影
|
| 109 |
+
self.user_feature_proj = nn.Linear(num_user_features, embedding_dim)
|
| 110 |
+
|
| 111 |
+
# 局部激活单元 (核心)
|
| 112 |
+
self.attention = LocalActivationUnit(embedding_dim)
|
| 113 |
+
|
| 114 |
+
# MLP 预测头
|
| 115 |
+
input_dim = embedding_dim * 4 + num_user_features
|
| 116 |
+
layers = []
|
| 117 |
+
for dim in mlp_dims:
|
| 118 |
+
layers.extend([
|
| 119 |
+
nn.Linear(input_dim, dim),
|
| 120 |
+
nn.ReLU(),
|
| 121 |
+
nn.Dropout(dropout),
|
| 122 |
+
nn.BatchNorm1d(dim),
|
| 123 |
+
])
|
| 124 |
+
input_dim = dim
|
| 125 |
+
layers.append(nn.Linear(input_dim, 1))
|
| 126 |
+
self.mlp = nn.Sequential(*layers)
|
| 127 |
+
|
| 128 |
+
def forward(self, user_ids, user_features, behavior_events, behavior_products, behavior_mask, candidate_product):
|
| 129 |
+
"""
|
| 130 |
+
Args:
|
| 131 |
+
user_ids: (B,) 用户ID
|
| 132 |
+
user_features: (B, num_user_features) 用户统计特征
|
| 133 |
+
behavior_events: (B, L) 历史事件类型ID
|
| 134 |
+
behavior_products: (B, L) 历史产品ID
|
| 135 |
+
behavior_mask: (B, L) 有效历史mask
|
| 136 |
+
candidate_product: (B,) 候选产品ID
|
| 137 |
+
Returns:
|
| 138 |
+
logits: (B,) 购买概率
|
| 139 |
+
"""
|
| 140 |
+
# 用户嵌入
|
| 141 |
+
user_emb = self.user_embedding(user_ids) # (B, D)
|
| 142 |
+
user_feat = self.user_feature_proj(user_features) # (B, D)
|
| 143 |
+
user_repr = user_emb + user_feat # (B, D)
|
| 144 |
+
|
| 145 |
+
# 历史行为嵌入: event_emb + product_emb
|
| 146 |
+
beh_event_emb = self.event_embedding(behavior_events) # (B, L, D/2)
|
| 147 |
+
beh_prod_emb = self.product_embedding(behavior_products) # (B, L, D)
|
| 148 |
+
# 补齐维度
|
| 149 |
+
beh_event_pad = F.pad(beh_event_emb, (0, self.embedding_dim - beh_event_emb.size(-1)))
|
| 150 |
+
behavior_emb = beh_event_pad + beh_prod_emb # (B, L, D)
|
| 151 |
+
|
| 152 |
+
# 候选产品嵌入
|
| 153 |
+
candidate_emb = self.product_embedding(candidate_product) # (B, D)
|
| 154 |
+
|
| 155 |
+
# 注意力兴趣向量
|
| 156 |
+
interest = self.attention(candidate_emb, behavior_emb, behavior_mask) # (B, D)
|
| 157 |
+
|
| 158 |
+
# 交互特征
|
| 159 |
+
user_item_prod = user_repr * candidate_emb # (B, D)
|
| 160 |
+
|
| 161 |
+
# 拼接所有特征
|
| 162 |
+
combined = torch.cat([
|
| 163 |
+
user_repr, # 用户画像
|
| 164 |
+
interest, # 动态兴趣
|
| 165 |
+
candidate_emb, # 候选产品
|
| 166 |
+
user_item_prod, # 交互
|
| 167 |
+
user_features, # 原始统计特征
|
| 168 |
+
], dim=-1)
|
| 169 |
+
|
| 170 |
+
# MLP预测
|
| 171 |
+
logits = self.mlp(combined).squeeze(-1) # (B,)
|
| 172 |
+
return logits
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
# =============================================================================
|
| 176 |
+
# 2. 异常行为检测 — TabularBERT
|
| 177 |
+
# =============================================================================
|
| 178 |
+
|
| 179 |
+
class PositionalEncoding(nn.Module):
|
| 180 |
+
"""Transformer 位置编码"""
|
| 181 |
+
def __init__(self, d_model: int, max_len: int = 5000):
|
| 182 |
+
super().__init__()
|
| 183 |
+
pe = torch.zeros(max_len, d_model)
|
| 184 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| 185 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
| 186 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 187 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 188 |
+
self.register_buffer('pe', pe.unsqueeze(0))
|
| 189 |
+
|
| 190 |
+
def forward(self, x):
|
| 191 |
+
return x + self.pe[:, :x.size(1), :]
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
class TabularBERT(nn.Module):
|
| 195 |
+
"""
|
| 196 |
+
保险理赔/交易异常检测的层次化 BERT
|
| 197 |
+
|
| 198 |
+
层级1: Field Transformer (单条记录内字段关联)
|
| 199 |
+
层级2: Sequence Transformer (跨记录时序关联)
|
| 200 |
+
|
| 201 |
+
适用: 理赔欺诈检测、异常交易识别
|
| 202 |
+
"""
|
| 203 |
+
def __init__(
|
| 204 |
+
self,
|
| 205 |
+
num_fields: int = 15,
|
| 206 |
+
field_vocab_sizes: list = None,
|
| 207 |
+
d_model: int = 128,
|
| 208 |
+
nhead: int = 8,
|
| 209 |
+
num_field_layers: int = 2,
|
| 210 |
+
num_seq_layers: int = 4,
|
| 211 |
+
dim_feedforward: int = 512,
|
| 212 |
+
dropout: float = 0.2,
|
| 213 |
+
max_seq_len: int = 100,
|
| 214 |
+
):
|
| 215 |
+
super().__init__()
|
| 216 |
+
|
| 217 |
+
self.num_fields = num_fields
|
| 218 |
+
self.d_model = d_model
|
| 219 |
+
|
| 220 |
+
# 字段嵌入
|
| 221 |
+
if field_vocab_sizes is None:
|
| 222 |
+
field_vocab_sizes = [1000] * num_fields
|
| 223 |
+
|
| 224 |
+
self.field_embeddings = nn.ModuleList([
|
| 225 |
+
nn.Embedding(vocab_size, d_model) for vocab_size in field_vocab_sizes
|
| 226 |
+
])
|
| 227 |
+
|
| 228 |
+
# 字段类型嵌入
|
| 229 |
+
self.field_type_embedding = nn.Embedding(num_fields, d_model)
|
| 230 |
+
|
| 231 |
+
# 层级1: Field Transformer (intra-record)
|
| 232 |
+
field_encoder_layer = nn.TransformerEncoderLayer(
|
| 233 |
+
d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward,
|
| 234 |
+
dropout=dropout, batch_first=True
|
| 235 |
+
)
|
| 236 |
+
self.field_transformer = nn.TransformerEncoder(field_encoder_layer, num_field_layers)
|
| 237 |
+
|
| 238 |
+
# 层级2: Sequence Transformer (inter-record)
|
| 239 |
+
seq_encoder_layer = nn.TransformerEncoderLayer(
|
| 240 |
+
d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward,
|
| 241 |
+
dropout=dropout, batch_first=True
|
| 242 |
+
)
|
| 243 |
+
self.seq_transformer = nn.TransformerEncoder(seq_encoder_layer, num_seq_layers)
|
| 244 |
+
|
| 245 |
+
# 位置编码
|
| 246 |
+
self.pos_encoding = PositionalEncoding(d_model, max_seq_len)
|
| 247 |
+
|
| 248 |
+
# 异常检测头
|
| 249 |
+
self.anomaly_head = nn.Sequential(
|
| 250 |
+
nn.Linear(d_model, 256),
|
| 251 |
+
nn.ReLU(),
|
| 252 |
+
nn.Dropout(dropout),
|
| 253 |
+
nn.Linear(256, 64),
|
| 254 |
+
nn.ReLU(),
|
| 255 |
+
nn.Linear(64, 1),
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
# MLM 预训练头
|
| 259 |
+
self.mlm_heads = nn.ModuleList([
|
| 260 |
+
nn.Linear(d_model, vocab_size) for vocab_size in field_vocab_sizes
|
| 261 |
+
])
|
| 262 |
+
|
| 263 |
+
def forward(self, field_ids, mask=None, return_mlm=False):
|
| 264 |
+
"""
|
| 265 |
+
Args:
|
| 266 |
+
field_ids: (B, seq_len, num_fields) 每个字段的token ID
|
| 267 |
+
mask: (B, seq_len) 序列mask
|
| 268 |
+
return_mlm: 是否返回MLM预测
|
| 269 |
+
Returns:
|
| 270 |
+
anomaly_score: (B,) 异常分数 (sigmoid前)
|
| 271 |
+
mlm_logits: 可选, 用于预训练
|
| 272 |
+
"""
|
| 273 |
+
B, L, F = field_ids.shape
|
| 274 |
+
assert F == self.num_fields
|
| 275 |
+
|
| 276 |
+
# 字段嵌入: 每个字段独立嵌入 + 字段类型嵌入
|
| 277 |
+
field_embs = []
|
| 278 |
+
for i in range(F):
|
| 279 |
+
emb = self.field_embeddings[i](field_ids[:, :, i]) # (B, L, D)
|
| 280 |
+
type_emb = self.field_type_embedding(torch.tensor(i, device=field_ids.device))
|
| 281 |
+
emb = emb + type_emb.unsqueeze(0).unsqueeze(0)
|
| 282 |
+
field_embs.append(emb)
|
| 283 |
+
|
| 284 |
+
# 合并: (B, L, F, D) → (B*L, F, D)
|
| 285 |
+
x = torch.stack(field_embs, dim=2) # (B, L, F, D)
|
| 286 |
+
x = x.view(B * L, F, self.d_model)
|
| 287 |
+
|
| 288 |
+
# Field-level attention
|
| 289 |
+
x = self.field_transformer(x) # (B*L, F, D)
|
| 290 |
+
|
| 291 |
+
# 池化到记录级表示
|
| 292 |
+
record_emb = x.mean(dim=1) # (B*L, D)
|
| 293 |
+
record_emb = record_emb.view(B, L, self.d_model)
|
| 294 |
+
|
| 295 |
+
# 位置编码 + Sequence-level attention
|
| 296 |
+
record_emb = self.pos_encoding(record_emb)
|
| 297 |
+
if mask is not None:
|
| 298 |
+
x = self.seq_transformer(record_emb, src_key_padding_mask=~mask)
|
| 299 |
+
else:
|
| 300 |
+
x = self.seq_transformer(record_emb)
|
| 301 |
+
|
| 302 |
+
# 全局池化
|
| 303 |
+
if mask is not None:
|
| 304 |
+
mask_float = mask.float().unsqueeze(-1) # (B, L, 1)
|
| 305 |
+
seq_emb = (x * mask_float).sum(dim=1) / mask_float.sum(dim=1).clamp(min=1)
|
| 306 |
+
else:
|
| 307 |
+
seq_emb = x.mean(dim=1)
|
| 308 |
+
|
| 309 |
+
# 异常分数
|
| 310 |
+
anomaly_score = self.anomaly_head(seq_emb).squeeze(-1) # (B,)
|
| 311 |
+
|
| 312 |
+
if return_mlm:
|
| 313 |
+
mlm_logits = []
|
| 314 |
+
record_emb_flat = record_emb.view(B * L, self.d_model)
|
| 315 |
+
for i, head in enumerate(self.mlm_heads):
|
| 316 |
+
mlm_logits.append(head(record_emb_flat))
|
| 317 |
+
return anomaly_score, mlm_logits
|
| 318 |
+
|
| 319 |
+
return anomaly_score
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
# =============================================================================
|
| 323 |
+
# 3. 用户流失预测 — Transformer
|
| 324 |
+
# =============================================================================
|
| 325 |
+
|
| 326 |
+
class ChurnPredictionTransformer(nn.Module):
|
| 327 |
+
"""
|
| 328 |
+
基于 Transformer 的用户流失/续保预测
|
| 329 |
+
|
| 330 |
+
参考: Early Churn Prediction from Large Scale User-Product Interaction Time Series
|
| 331 |
+
(arXiv 2309.14390)
|
| 332 |
+
|
| 333 |
+
输入: 用户最近 N 个行为的嵌入序列
|
| 334 |
+
输出: 流失概率
|
| 335 |
+
"""
|
| 336 |
+
def __init__(
|
| 337 |
+
self,
|
| 338 |
+
num_event_types: int = 40,
|
| 339 |
+
num_products: int = 100,
|
| 340 |
+
d_model: int = 128,
|
| 341 |
+
nhead: int = 8,
|
| 342 |
+
num_layers: int = 6,
|
| 343 |
+
dim_feedforward: int = 512,
|
| 344 |
+
dropout: float = 0.3,
|
| 345 |
+
max_seq_len: int = 100,
|
| 346 |
+
num_continuous_features: int = 20,
|
| 347 |
+
):
|
| 348 |
+
super().__init__()
|
| 349 |
+
|
| 350 |
+
# 嵌入层
|
| 351 |
+
self.event_embedding = nn.Embedding(num_event_types, d_model // 2)
|
| 352 |
+
self.product_embedding = nn.Embedding(num_products, d_model // 2)
|
| 353 |
+
|
| 354 |
+
# 连续特征投影
|
| 355 |
+
self.continuous_proj = nn.Linear(num_continuous_features, d_model)
|
| 356 |
+
|
| 357 |
+
# 时间间隔编码 (对数变换)
|
| 358 |
+
self.time_proj = nn.Linear(1, d_model // 4)
|
| 359 |
+
|
| 360 |
+
# 特征融合
|
| 361 |
+
self.fusion = nn.Linear(d_model + d_model // 2 + d_model // 4, d_model)
|
| 362 |
+
|
| 363 |
+
# Transformer
|
| 364 |
+
self.pos_encoding = PositionalEncoding(d_model, max_seq_len)
|
| 365 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
| 366 |
+
d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward,
|
| 367 |
+
dropout=dropout, batch_first=True
|
| 368 |
+
)
|
| 369 |
+
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
|
| 370 |
+
|
| 371 |
+
# 分类头
|
| 372 |
+
self.classifier = nn.Sequential(
|
| 373 |
+
nn.Linear(d_model, 256),
|
| 374 |
+
nn.ReLU(),
|
| 375 |
+
nn.Dropout(dropout),
|
| 376 |
+
nn.Linear(256, 64),
|
| 377 |
+
nn.ReLU(),
|
| 378 |
+
nn.Dropout(dropout),
|
| 379 |
+
nn.Linear(64, 1),
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
def forward(self, event_ids, product_ids, continuous_features, time_intervals, mask=None):
|
| 383 |
+
"""
|
| 384 |
+
Args:
|
| 385 |
+
event_ids: (B, L)
|
| 386 |
+
product_ids: (B, L)
|
| 387 |
+
continuous_features: (B, L, num_continuous)
|
| 388 |
+
time_intervals: (B, L) 事件间隔(秒)
|
| 389 |
+
mask: (B, L) padding mask
|
| 390 |
+
"""
|
| 391 |
+
B, L = event_ids.shape
|
| 392 |
+
|
| 393 |
+
# 嵌入
|
| 394 |
+
e_emb = self.event_embedding(event_ids) # (B, L, D/2)
|
| 395 |
+
p_emb = self.product_embedding(product_ids) # (B, L, D/2)
|
| 396 |
+
item_emb = torch.cat([e_emb, p_emb], dim=-1) # (B, L, D)
|
| 397 |
+
|
| 398 |
+
# 连续特征
|
| 399 |
+
c_emb = self.continuous_proj(continuous_features) # (B, L, D)
|
| 400 |
+
|
| 401 |
+
# 时间间隔
|
| 402 |
+
time_log = torch.log1p(time_intervals.unsqueeze(-1).clamp(min=0))
|
| 403 |
+
t_emb = self.time_proj(time_log) # (B, L, D/4)
|
| 404 |
+
|
| 405 |
+
# 融合
|
| 406 |
+
fused = torch.cat([item_emb, c_emb, t_emb], dim=-1)
|
| 407 |
+
x = self.fusion(fused) # (B, L, D)
|
| 408 |
+
|
| 409 |
+
# 位置编码 + Transformer
|
| 410 |
+
x = self.pos_encoding(x)
|
| 411 |
+
if mask is not None:
|
| 412 |
+
x = self.transformer(x, src_key_padding_mask=~mask)
|
| 413 |
+
else:
|
| 414 |
+
x = self.transformer(x)
|
| 415 |
+
|
| 416 |
+
# 全局平均池化
|
| 417 |
+
if mask is not None:
|
| 418 |
+
mask_float = mask.float().unsqueeze(-1)
|
| 419 |
+
x = (x * mask_float).sum(dim=1) / mask_float.sum(dim=1).clamp(min=1)
|
| 420 |
+
else:
|
| 421 |
+
x = x.mean(dim=1)
|
| 422 |
+
|
| 423 |
+
logits = self.classifier(x).squeeze(-1)
|
| 424 |
+
return logits
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
# =============================================================================
|
| 428 |
+
# 4. 损失函数 — Focal Loss (不平衡数据)
|
| 429 |
+
# =============================================================================
|
| 430 |
+
|
| 431 |
+
class FocalLoss(nn.Module):
|
| 432 |
+
"""
|
| 433 |
+
Focal Loss for imbalanced classification
|
| 434 |
+
|
| 435 |
+
降低易分样本的权重, 聚焦难分样本
|
| 436 |
+
适用于: 保险欺诈检测 (fraud < 1%), 流失预测 (churn < 5%)
|
| 437 |
+
"""
|
| 438 |
+
def __init__(self, alpha: float = 0.25, gamma: float = 2.0, reduction: str = 'mean'):
|
| 439 |
+
super().__init__()
|
| 440 |
+
self.alpha = alpha
|
| 441 |
+
self.gamma = gamma
|
| 442 |
+
self.reduction = reduction
|
| 443 |
+
|
| 444 |
+
def forward(self, inputs, targets):
|
| 445 |
+
"""
|
| 446 |
+
Args:
|
| 447 |
+
inputs: (B,) 原始logits
|
| 448 |
+
targets: (B,) 0/1标签
|
| 449 |
+
"""
|
| 450 |
+
bce = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
|
| 451 |
+
pt = torch.exp(-bce) # 预测概率
|
| 452 |
+
focal_weight = self.alpha * (1 - pt) ** self.gamma
|
| 453 |
+
loss = focal_weight * bce
|
| 454 |
+
|
| 455 |
+
if self.reduction == 'mean':
|
| 456 |
+
return loss.mean()
|
| 457 |
+
elif self.reduction == 'sum':
|
| 458 |
+
return loss.sum()
|
| 459 |
+
else:
|
| 460 |
+
return loss
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
# =============================================================================
|
| 464 |
+
# 5. 数据集定义
|
| 465 |
+
# =============================================================================
|
| 466 |
+
|
| 467 |
+
class BehaviorSequenceDataset(Dataset):
|
| 468 |
+
"""行为序列数据集 (用于 Transformer 模型)"""
|
| 469 |
+
def __init__(self, features, event_sequences, product_sequences, labels, max_len=50):
|
| 470 |
+
self.features = np.array(features, dtype=np.float32)
|
| 471 |
+
self.event_seqs = event_sequences
|
| 472 |
+
self.product_seqs = product_sequences
|
| 473 |
+
self.labels = np.array(labels, dtype=np.float32)
|
| 474 |
+
self.max_len = max_len
|
| 475 |
+
|
| 476 |
+
# 构建vocab
|
| 477 |
+
all_events = set()
|
| 478 |
+
for seq in event_sequences:
|
| 479 |
+
all_events.update(seq)
|
| 480 |
+
self.event_vocab = {e: i+1 for i, e in enumerate(sorted(all_events))} # 0=PAD
|
| 481 |
+
|
| 482 |
+
all_products = set()
|
| 483 |
+
for seq in product_sequences:
|
| 484 |
+
all_products.update(p for p in seq if p)
|
| 485 |
+
self.product_vocab = {p: i+1 for i, p in enumerate(sorted(all_products))}
|
| 486 |
+
|
| 487 |
+
def __len__(self):
|
| 488 |
+
return len(self.labels)
|
| 489 |
+
|
| 490 |
+
def pad_sequence(self, seq, vocab, max_len):
|
| 491 |
+
"""填充/截断序列"""
|
| 492 |
+
ids = [vocab.get(x, 0) for x in seq[-max_len:]]
|
| 493 |
+
if len(ids) < max_len:
|
| 494 |
+
ids = [0] * (max_len - len(ids)) + ids
|
| 495 |
+
return ids, len(seq[-max_len:]) if seq else 0
|
| 496 |
+
|
| 497 |
+
def __getitem__(self, idx):
|
| 498 |
+
e_ids, e_len = self.pad_sequence(self.event_seqs[idx], self.event_vocab, self.max_len)
|
| 499 |
+
p_ids, p_len = self.pad_sequence(self.product_seqs[idx], self.product_vocab, self.max_len)
|
| 500 |
+
|
| 501 |
+
mask = [1 if i >= self.max_len - e_len else 0 for i in range(self.max_len)]
|
| 502 |
+
|
| 503 |
+
return {
|
| 504 |
+
'features': torch.tensor(self.features[idx]),
|
| 505 |
+
'event_ids': torch.tensor(e_ids, dtype=torch.long),
|
| 506 |
+
'product_ids': torch.tensor(p_ids, dtype=torch.long),
|
| 507 |
+
'mask': torch.tensor(mask, dtype=torch.float),
|
| 508 |
+
'label': torch.tensor(self.labels[idx]),
|
| 509 |
+
'time_intervals': torch.zeros(self.max_len), # 简化版
|
| 510 |
+
}
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
class ProductInteractionDataset(Dataset):
|
| 514 |
+
"""产品交互数据集 (用于 DIN 模型)"""
|
| 515 |
+
def __init__(self, user_ids, user_features, behavior_events, behavior_products,
|
| 516 |
+
behavior_masks, candidate_products, labels, max_len=50):
|
| 517 |
+
self.user_ids = np.array(user_ids, dtype=np.longlong)
|
| 518 |
+
self.user_features = np.array(user_features, dtype=np.float32)
|
| 519 |
+
self.behavior_events = behavior_events
|
| 520 |
+
self.behavior_products = behavior_products
|
| 521 |
+
self.behavior_masks = behavior_masks
|
| 522 |
+
self.candidate_products = np.array(candidate_products, dtype=np.longlong)
|
| 523 |
+
self.labels = np.array(labels, dtype=np.float32)
|
| 524 |
+
self.max_len = max_len
|
| 525 |
+
|
| 526 |
+
def __len__(self):
|
| 527 |
+
return len(self.labels)
|
| 528 |
+
|
| 529 |
+
def pad_seq(self, seq, max_len):
|
| 530 |
+
if len(seq) >= max_len:
|
| 531 |
+
return seq[-max_len:], [1]*max_len
|
| 532 |
+
else:
|
| 533 |
+
pad_len = max_len - len(seq)
|
| 534 |
+
return [0]*pad_len + seq, [0]*pad_len + [1]*len(seq)
|
| 535 |
+
|
| 536 |
+
def __getitem__(self, idx):
|
| 537 |
+
e_seq, e_mask = self.pad_seq(self.behavior_events[idx], self.max_len)
|
| 538 |
+
p_seq, p_mask = self.pad_seq(self.behavior_products[idx], self.max_len)
|
| 539 |
+
|
| 540 |
+
return {
|
| 541 |
+
'user_id': torch.tensor(self.user_ids[idx]),
|
| 542 |
+
'user_features': torch.tensor(self.user_features[idx]),
|
| 543 |
+
'behavior_events': torch.tensor(e_seq, dtype=torch.long),
|
| 544 |
+
'behavior_products': torch.tensor(p_seq, dtype=torch.long),
|
| 545 |
+
'behavior_mask': torch.tensor(e_mask, dtype=torch.bool),
|
| 546 |
+
'candidate_product': torch.tensor(self.candidate_products[idx]),
|
| 547 |
+
'label': torch.tensor(self.labels[idx]),
|
| 548 |
+
}
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
def build_vocab(values, offset=1):
|
| 552 |
+
"""构建vocabulary"""
|
| 553 |
+
unique = sorted(set(v for sublist in values for v in sublist if v))
|
| 554 |
+
return {v: i+offset for i, v in enumerate(unique)}
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
# =============================================================================
|
| 558 |
+
# 6. 训练工具
|
| 559 |
+
# =============================================================================
|
| 560 |
+
|
| 561 |
+
def train_epoch(model, dataloader, optimizer, criterion, device):
|
| 562 |
+
"""单epoch训练"""
|
| 563 |
+
model.train()
|
| 564 |
+
total_loss = 0
|
| 565 |
+
for batch in dataloader:
|
| 566 |
+
optimizer.zero_grad()
|
| 567 |
+
|
| 568 |
+
# 根据模型类型处理输入
|
| 569 |
+
if hasattr(model, 'attention'): # DIN
|
| 570 |
+
outputs = model(
|
| 571 |
+
batch['user_id'].to(device),
|
| 572 |
+
batch['user_features'].to(device),
|
| 573 |
+
batch['behavior_events'].to(device),
|
| 574 |
+
batch['behavior_products'].to(device),
|
| 575 |
+
batch['behavior_mask'].to(device),
|
| 576 |
+
batch['candidate_product'].to(device),
|
| 577 |
+
)
|
| 578 |
+
elif hasattr(model, 'transformer'): # Churn Transformer
|
| 579 |
+
outputs = model(
|
| 580 |
+
batch['event_ids'].to(device),
|
| 581 |
+
batch['product_ids'].to(device),
|
| 582 |
+
batch['features'].unsqueeze(1).expand(-1, batch['event_ids'].size(1), -1).to(device),
|
| 583 |
+
batch['time_intervals'].to(device),
|
| 584 |
+
batch['mask'].to(device),
|
| 585 |
+
)
|
| 586 |
+
else: # TabularBERT
|
| 587 |
+
# 简化: 使用随机field_ids演示
|
| 588 |
+
B = batch['features'].size(0)
|
| 589 |
+
field_ids = torch.randint(0, 100, (B, 10, 5)).to(device)
|
| 590 |
+
outputs = model(field_ids)
|
| 591 |
+
|
| 592 |
+
labels = batch['label'].to(device)
|
| 593 |
+
loss = criterion(outputs, labels)
|
| 594 |
+
loss.backward()
|
| 595 |
+
optimizer.step()
|
| 596 |
+
|
| 597 |
+
total_loss += loss.item()
|
| 598 |
+
|
| 599 |
+
return total_loss / len(dataloader)
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
def evaluate_model(model, dataloader, device):
|
| 603 |
+
"""评估模型"""
|
| 604 |
+
model.eval()
|
| 605 |
+
all_preds = []
|
| 606 |
+
all_labels = []
|
| 607 |
+
|
| 608 |
+
with torch.no_grad():
|
| 609 |
+
for batch in dataloader:
|
| 610 |
+
if hasattr(model, 'attention'):
|
| 611 |
+
outputs = model(
|
| 612 |
+
batch['user_id'].to(device),
|
| 613 |
+
batch['user_features'].to(device),
|
| 614 |
+
batch['behavior_events'].to(device),
|
| 615 |
+
batch['behavior_products'].to(device),
|
| 616 |
+
batch['behavior_mask'].to(device),
|
| 617 |
+
batch['candidate_product'].to(device),
|
| 618 |
+
)
|
| 619 |
+
elif hasattr(model, 'transformer'):
|
| 620 |
+
outputs = model(
|
| 621 |
+
batch['event_ids'].to(device),
|
| 622 |
+
batch['product_ids'].to(device),
|
| 623 |
+
batch['features'].unsqueeze(1).expand(-1, batch['event_ids'].size(1), -1).to(device),
|
| 624 |
+
batch['time_intervals'].to(device),
|
| 625 |
+
batch['mask'].to(device),
|
| 626 |
+
)
|
| 627 |
+
else:
|
| 628 |
+
B = batch['features'].size(0)
|
| 629 |
+
field_ids = torch.randint(0, 100, (B, 10, 5)).to(device)
|
| 630 |
+
outputs = model(field_ids)
|
| 631 |
+
|
| 632 |
+
all_preds.extend(torch.sigmoid(outputs).cpu().numpy())
|
| 633 |
+
all_labels.extend(batch['label'].numpy())
|
| 634 |
+
|
| 635 |
+
from sklearn.metrics import roc_auc_score, f1_score, average_precision_score
|
| 636 |
+
preds = np.array(all_preds)
|
| 637 |
+
labels = np.array(all_labels)
|
| 638 |
+
|
| 639 |
+
auc = roc_auc_score(labels, preds)
|
| 640 |
+
ap = average_precision_score(labels, preds)
|
| 641 |
+
f1 = f1_score(labels, preds > 0.5)
|
| 642 |
+
|
| 643 |
+
return {'auc': auc, 'ap': ap, 'f1': f1, 'preds': preds, 'labels': labels}
|