yonghao commited on
Commit
a7e77d8
·
verified ·
1 Parent(s): 6cf7f4b

Add credit bureau model template (TabM+PLE+LightGBM)

Browse files
Files changed (1) hide show
  1. credit_bureau_model.py +723 -0
credit_bureau_model.py ADDED
@@ -0,0 +1,723 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 征信结构化数据 风控模型 — 完整代码模板
3
+ ========================================
4
+ 方法: TabM (ICLR 2025) + PLE 数值编码 + LightGBM 集成
5
+ 论文: arxiv:2410.24210 (TabM), arxiv:2203.05556 (PLE), arxiv:2106.11959 (FT-Transformer)
6
+ 依据: TabM 在 46 个数据集上 DL SOTA,配合 LightGBM 集成效果最佳
7
+
8
+ 使用方式:
9
+ 1. 替换 `load_credit_data()` 为你自己的征信数据加载逻辑
10
+ 2. 配置 `CREDIT_CONFIG` 中的特征列名
11
+ 3. 运行完整 pipeline: 预处理→训练→评估→集成
12
+
13
+ 依赖: pip install torch scikit-learn lightgbm pandas numpy scipy
14
+ 可选: pip install rtdl_num_embeddings rtdl_revisiting_models pytorch-tabular
15
+ """
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
21
+ import numpy as np
22
+ import pandas as pd
23
+ from sklearn.preprocessing import QuantileTransformer, LabelEncoder
24
+ from sklearn.model_selection import train_test_split
25
+ from sklearn.metrics import roc_auc_score, classification_report
26
+ from scipy.stats import ks_2samp
27
+ from typing import List, Dict, Tuple, Optional
28
+ import logging
29
+ import json
30
+
31
+ logging.basicConfig(level=logging.INFO)
32
+ logger = logging.getLogger(__name__)
33
+
34
+ # ============================================================
35
+ # CONFIG
36
+ # ============================================================
37
+ CREDIT_CONFIG = {
38
+ # ---- 特征配置 (请替换为你的实际征信字段) ----
39
+ "numerical_features": [
40
+ "age", # 年龄
41
+ "monthly_income", # 月收入
42
+ "debt_to_income_ratio", # 负债收入比
43
+ "total_credit_limit", # 总授信额度
44
+ "total_balance", # 总余额
45
+ "num_open_accounts", # 开户数
46
+ "num_delinquent_accounts", # 逾期账户数
47
+ "months_since_last_delinq", # 距最近逾期月数
48
+ "credit_utilization", # 信用利用率
49
+ "num_inquiries_6m", # 近6月查询次数
50
+ "longest_credit_history", # 最长信用历史(月)
51
+ "num_credit_cards", # 信用卡数量
52
+ "max_delinquency_amount", # 最大逾期金额
53
+ "avg_monthly_payment", # 月均还款额
54
+ "payment_to_income_ratio", # 还款收入比
55
+ ],
56
+
57
+ "categorical_features": [
58
+ "education_level", # 学历
59
+ "employment_type", # 就业类型
60
+ "marital_status", # 婚姻状况
61
+ "housing_type", # 住房类型
62
+ "province", # 省份
63
+ ],
64
+
65
+ "target_column": "is_default", # 目标变量: 0/1
66
+
67
+ # ---- 模型超参数 ----
68
+ # TabM (ICLR 2025)
69
+ "tabm_hidden_dim": 256,
70
+ "tabm_num_blocks": 4,
71
+ "tabm_ensemble_k": 32,
72
+ "tabm_dropout": 0.1,
73
+
74
+ # PLE 数值编码
75
+ "ple_num_bins": 32,
76
+
77
+ # FT-Transformer (备选)
78
+ "ft_num_layers": 3,
79
+ "ft_num_heads": 8,
80
+ "ft_d_model": 192,
81
+ "ft_dropout": 0.2,
82
+
83
+ # 训练
84
+ "learning_rate": 3e-4,
85
+ "weight_decay": 1e-5,
86
+ "batch_size": 512,
87
+ "max_epochs": 100,
88
+ "patience": 16,
89
+
90
+ # LightGBM
91
+ "lgb_lr": 0.05,
92
+ "lgb_num_leaves": 63,
93
+ "lgb_max_depth": 7,
94
+ "lgb_num_boost_round": 1000,
95
+
96
+ # 集成权重
97
+ "ensemble_weight_tabm": 0.5,
98
+ "ensemble_weight_lgb": 0.5,
99
+ }
100
+
101
+
102
+ # ============================================================
103
+ # 数据预处理 Pipeline
104
+ # ============================================================
105
+ class CreditDataPreprocessor:
106
+ """
107
+ 征信数据预处理器
108
+ 1. 缺失值: 数值→中位数填充 + 添加 is_missing 指示列
109
+ 2. 数值特征: QuantileTransformer → 正态分布
110
+ 3. 类别特征: LabelEncoder
111
+ 4. PLE 编码: 分段线性编码 (arxiv:2203.05556)
112
+ """
113
+
114
+ def __init__(self):
115
+ self.num_features = CREDIT_CONFIG['numerical_features']
116
+ self.cat_features = CREDIT_CONFIG['categorical_features']
117
+ self.target = CREDIT_CONFIG['target_column']
118
+ self.qt = None
119
+ self.label_encoders = {}
120
+ self.medians = {}
121
+ self.cat_cardinalities = []
122
+ self.ple_bins = None
123
+
124
+ def fit_transform(self, df: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
125
+ """返回: (X_num, X_cat, y)"""
126
+ df = df.copy()
127
+
128
+ # 缺失值处理
129
+ missing_indicators = []
130
+ for col in self.num_features:
131
+ is_missing = df[col].isna().astype(np.float32).values
132
+ missing_indicators.append(is_missing)
133
+ median_val = df[col].median()
134
+ self.medians[col] = median_val
135
+ df[col] = df[col].fillna(median_val)
136
+
137
+ for col in self.cat_features:
138
+ df[col] = df[col].fillna("MISSING").astype(str)
139
+
140
+ # 数值特征: QuantileTransformer
141
+ X_num_raw = df[self.num_features].values.astype(np.float32)
142
+ missing_matrix = np.stack(missing_indicators, axis=1)
143
+ X_num_raw = np.concatenate([X_num_raw, missing_matrix], axis=1)
144
+
145
+ self.qt = QuantileTransformer(output_distribution='normal', random_state=42)
146
+ X_num = self.qt.fit_transform(X_num_raw).astype(np.float32)
147
+
148
+ # 类别特征: LabelEncoder
149
+ X_cat_list = []
150
+ for col in self.cat_features:
151
+ le = LabelEncoder()
152
+ encoded = le.fit_transform(df[col])
153
+ X_cat_list.append(encoded)
154
+ self.label_encoders[col] = le
155
+ self.cat_cardinalities.append(len(le.classes_))
156
+
157
+ X_cat = np.stack(X_cat_list, axis=1).astype(np.int64)
158
+ y = df[self.target].values.astype(np.float32)
159
+
160
+ # PLE bins
161
+ self.ple_bins = self._compute_ple_bins(X_num)
162
+
163
+ logger.info(f"Preprocessed: {X_num.shape[0]} samples, "
164
+ f"{X_num.shape[1]} numerical (incl. {len(self.num_features)} missing indicators), "
165
+ f"{X_cat.shape[1]} categorical")
166
+ logger.info(f"Default rate: {y.mean()*100:.2f}%")
167
+
168
+ return X_num, X_cat, y
169
+
170
+ def transform(self, df: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
171
+ """对新数据做同样的变换"""
172
+ df = df.copy()
173
+
174
+ missing_indicators = []
175
+ for col in self.num_features:
176
+ is_missing = df[col].isna().astype(np.float32).values
177
+ missing_indicators.append(is_missing)
178
+ df[col] = df[col].fillna(self.medians[col])
179
+
180
+ for col in self.cat_features:
181
+ df[col] = df[col].fillna("MISSING").astype(str)
182
+
183
+ X_num_raw = df[self.num_features].values.astype(np.float32)
184
+ missing_matrix = np.stack(missing_indicators, axis=1)
185
+ X_num_raw = np.concatenate([X_num_raw, missing_matrix], axis=1)
186
+ X_num = self.qt.transform(X_num_raw).astype(np.float32)
187
+
188
+ X_cat_list = []
189
+ for col in self.cat_features:
190
+ le = self.label_encoders[col]
191
+ encoded = []
192
+ for val in df[col]:
193
+ if val in le.classes_:
194
+ encoded.append(le.transform([val])[0])
195
+ else:
196
+ encoded.append(0)
197
+ X_cat_list.append(np.array(encoded))
198
+
199
+ X_cat = np.stack(X_cat_list, axis=1).astype(np.int64)
200
+ y = df[self.target].values.astype(np.float32)
201
+
202
+ return X_num, X_cat, y
203
+
204
+ def _compute_ple_bins(self, X_num: np.ndarray) -> np.ndarray:
205
+ """计算PLE分段线性编码的bin边界(分位数)"""
206
+ n_bins = CREDIT_CONFIG['ple_num_bins']
207
+ n_features = X_num.shape[1]
208
+ bins = np.zeros((n_features, n_bins + 1))
209
+ for i in range(n_features):
210
+ quantiles = np.linspace(0, 1, n_bins + 1)
211
+ bins[i] = np.quantile(X_num[:, i], quantiles)
212
+ return bins
213
+
214
+
215
+ # ============================================================
216
+ # PLE (Piecewise Linear Encoding) — arxiv:2203.05556
217
+ # ============================================================
218
+ class PiecewiseLinearEncoding(nn.Module):
219
+ """
220
+ 分段线性编码: 把单个数值x编码成T维向量
221
+ 让DL模型像GBDT一样做分段决策
222
+ """
223
+
224
+ def __init__(self, bins: np.ndarray):
225
+ super().__init__()
226
+ self.register_buffer('bins', torch.from_numpy(bins).float())
227
+ self.n_features = bins.shape[0]
228
+ self.n_bins = bins.shape[1] - 1
229
+
230
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
231
+ """x: (batch, n_features) → (batch, n_features, n_bins)"""
232
+ left = self.bins[:, :-1]
233
+ right = self.bins[:, 1:]
234
+
235
+ x_expanded = x.unsqueeze(-1)
236
+ left = left.unsqueeze(0)
237
+ right = right.unsqueeze(0)
238
+
239
+ width = right - left + 1e-8
240
+ ratio = (x_expanded - left) / width
241
+ ple = ratio.clamp(0, 1)
242
+
243
+ return ple
244
+
245
+
246
+ # ============================================================
247
+ # TabM: MLP + BatchEnsemble (ICLR 2025)
248
+ # ============================================================
249
+ class BatchEnsembleLinear(nn.Module):
250
+ """
251
+ BatchEnsemble核心层: 一个Linear共享W,每个ensemble成员用rank-1扰动
252
+ k=32个隐式MLP,只增加O(k*d)参数
253
+ """
254
+
255
+ def __init__(self, in_features: int, out_features: int, k: int = 32):
256
+ super().__init__()
257
+ self.in_features = in_features
258
+ self.out_features = out_features
259
+ self.k = k
260
+
261
+ self.weight = nn.Parameter(torch.randn(in_features, out_features) * 0.02)
262
+ self.bias = nn.Parameter(torch.zeros(out_features))
263
+
264
+ self.r = nn.Parameter(torch.ones(k, in_features))
265
+ self.s = nn.Parameter(torch.ones(k, out_features))
266
+
267
+ nn.init.trunc_normal_(self.r, mean=1.0, std=0.5)
268
+ nn.init.trunc_normal_(self.s, mean=1.0, std=0.5)
269
+
270
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
271
+ """x: (batch, in_features) → (batch, k, out_features)"""
272
+ x_perturbed = x.unsqueeze(1) * self.r.unsqueeze(0)
273
+ out = torch.matmul(x_perturbed, self.weight)
274
+ out = out * self.s.unsqueeze(0) + self.bias.unsqueeze(0).unsqueeze(0)
275
+ return out
276
+
277
+
278
+ class TabM(nn.Module):
279
+ """TabM (ICLR 2025): MLP + BatchEnsemble + PLE"""
280
+
281
+ def __init__(self, n_num_features: int, cat_cardinalities: List[int], ple_bins: np.ndarray):
282
+ super().__init__()
283
+
284
+ self.ple = PiecewiseLinearEncoding(ple_bins)
285
+ n_bins = CREDIT_CONFIG['ple_num_bins']
286
+ ple_input_dim = n_num_features * n_bins
287
+
288
+ self.cat_embeddings = nn.ModuleList([
289
+ nn.Embedding(card + 1, min(50, (card + 1) // 2 + 1))
290
+ for card in cat_cardinalities
291
+ ])
292
+ cat_embed_total = sum(min(50, (c + 1) // 2 + 1) for c in cat_cardinalities)
293
+
294
+ input_dim = ple_input_dim + cat_embed_total
295
+ hidden_dim = CREDIT_CONFIG['tabm_hidden_dim']
296
+ n_blocks = CREDIT_CONFIG['tabm_num_blocks']
297
+ k = CREDIT_CONFIG['tabm_ensemble_k']
298
+ dropout = CREDIT_CONFIG['tabm_dropout']
299
+
300
+ self.input_proj = nn.Linear(input_dim, hidden_dim)
301
+ self.input_norm = nn.LayerNorm(hidden_dim)
302
+
303
+ self.blocks = nn.ModuleList()
304
+ for _ in range(n_blocks):
305
+ self.blocks.append(nn.ModuleDict({
306
+ 'be_linear': BatchEnsembleLinear(hidden_dim, hidden_dim, k=k),
307
+ 'norm': nn.LayerNorm(hidden_dim),
308
+ 'dropout': nn.Dropout(dropout),
309
+ }))
310
+
311
+ self.output_head = BatchEnsembleLinear(hidden_dim, 1, k=k)
312
+
313
+ def forward(self, x_num: torch.Tensor, x_cat: torch.Tensor) -> torch.Tensor:
314
+ """x_num: (batch, n_num_features), x_cat: (batch, n_cat_features) → (batch,)"""
315
+ ple_encoded = self.ple(x_num)
316
+ ple_flat = ple_encoded.view(ple_encoded.shape[0], -1)
317
+
318
+ cat_embeds = []
319
+ for i, embed_layer in enumerate(self.cat_embeddings):
320
+ cat_embeds.append(embed_layer(x_cat[:, i]))
321
+ cat_concat = torch.cat(cat_embeds, dim=-1) if cat_embeds else torch.zeros(x_num.shape[0], 0).to(x_num.device)
322
+
323
+ x = torch.cat([ple_flat, cat_concat], dim=-1)
324
+ x = self.input_proj(x)
325
+ x = self.input_norm(x)
326
+ x = F.relu(x)
327
+
328
+ k = CREDIT_CONFIG['tabm_ensemble_k']
329
+
330
+ for block in self.blocks:
331
+ residual = x
332
+ out = block['be_linear'](x if x.dim() == 2 else x.mean(dim=1))
333
+ out = block['norm'](out)
334
+ out = F.relu(out)
335
+ out = block['dropout'](out)
336
+
337
+ if residual.dim() == 2:
338
+ residual = residual.unsqueeze(1).expand(-1, k, -1)
339
+ x = out + residual
340
+
341
+ x_mean = x.mean(dim=1)
342
+ logits = self.output_head(x_mean)
343
+ logits = logits.squeeze(-1).mean(dim=-1)
344
+
345
+ return logits
346
+
347
+
348
+ # ============================================================
349
+ # FT-Transformer (备选方案)
350
+ # ============================================================
351
+ class FTTransformer(nn.Module):
352
+ """FT-Transformer (NeurIPS 2021): 每个特征独立tokenize → Transformer注意力学特征交互"""
353
+
354
+ def __init__(self, n_num_features: int, cat_cardinalities: List[int]):
355
+ super().__init__()
356
+ d_model = CREDIT_CONFIG['ft_d_model']
357
+
358
+ self.num_tokenizers = nn.ModuleList([nn.Linear(1, d_model) for _ in range(n_num_features)])
359
+ self.cat_tokenizers = nn.ModuleList([nn.Embedding(card + 1, d_model) for card in cat_cardinalities])
360
+ self.cls_token = nn.Parameter(torch.randn(1, 1, d_model) * 0.02)
361
+
362
+ encoder_layer = nn.TransformerEncoderLayer(
363
+ d_model=d_model, nhead=CREDIT_CONFIG['ft_num_heads'],
364
+ dim_feedforward=d_model * 4, dropout=CREDIT_CONFIG['ft_dropout'],
365
+ batch_first=True, norm_first=True,
366
+ )
367
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=CREDIT_CONFIG['ft_num_layers'])
368
+
369
+ self.head = nn.Sequential(
370
+ nn.LayerNorm(d_model), nn.Linear(d_model, d_model // 2),
371
+ nn.ReLU(), nn.Linear(d_model // 2, 1),
372
+ )
373
+
374
+ def forward(self, x_num: torch.Tensor, x_cat: torch.Tensor) -> torch.Tensor:
375
+ batch_size = x_num.shape[0]
376
+ tokens = []
377
+
378
+ for i, tokenizer in enumerate(self.num_tokenizers):
379
+ tokens.append(tokenizer(x_num[:, i:i+1]).unsqueeze(1))
380
+ for i, tokenizer in enumerate(self.cat_tokenizers):
381
+ tokens.append(tokenizer(x_cat[:, i]).unsqueeze(1))
382
+
383
+ cls = self.cls_token.expand(batch_size, -1, -1)
384
+ tokens.insert(0, cls)
385
+
386
+ x = torch.cat(tokens, dim=1)
387
+ x = self.transformer(x)
388
+ logits = self.head(x[:, 0]).squeeze(-1)
389
+ return logits
390
+
391
+
392
+ # ============================================================
393
+ # Dataset
394
+ # ============================================================
395
+ class CreditDataset(Dataset):
396
+ def __init__(self, X_num, X_cat, y):
397
+ self.X_num = torch.from_numpy(X_num).float()
398
+ self.X_cat = torch.from_numpy(X_cat).long()
399
+ self.y = torch.from_numpy(y).float()
400
+
401
+ def __len__(self):
402
+ return len(self.y)
403
+
404
+ def __getitem__(self, idx):
405
+ return self.X_num[idx], self.X_cat[idx], self.y[idx]
406
+
407
+
408
+ # ============================================================
409
+ # 训练 Pipeline
410
+ # ============================================================
411
+ def compute_ks_statistic(y_true: np.ndarray, y_pred: np.ndarray) -> float:
412
+ """计算KS统计量"""
413
+ pos_pred = y_pred[y_true == 1]
414
+ neg_pred = y_pred[y_true == 0]
415
+ if len(pos_pred) == 0 or len(neg_pred) == 0:
416
+ return 0.0
417
+ return ks_2samp(pos_pred, neg_pred).statistic
418
+
419
+
420
+ def train_tabm(X_num_train, X_cat_train, y_train, X_num_val, X_cat_val, y_val, ple_bins: np.ndarray):
421
+ """训练TabM模型"""
422
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
423
+ logger.info(f"Training TabM on {device}")
424
+
425
+ train_dataset = CreditDataset(X_num_train, X_cat_train, y_train)
426
+ val_dataset = CreditDataset(X_num_val, X_cat_val, y_val)
427
+ train_loader = DataLoader(train_dataset, batch_size=CREDIT_CONFIG['batch_size'], shuffle=True)
428
+ val_loader = DataLoader(val_dataset, batch_size=CREDIT_CONFIG['batch_size'])
429
+
430
+ model = TabM(
431
+ n_num_features=X_num_train.shape[1],
432
+ cat_cardinalities=[int(X_cat_train[:, i].max()) + 1 for i in range(X_cat_train.shape[1])],
433
+ ple_bins=ple_bins
434
+ ).to(device)
435
+
436
+ num_pos = y_train.sum()
437
+ num_neg = len(y_train) - num_pos
438
+ pos_weight = torch.tensor([num_neg / max(num_pos, 1)]).to(device)
439
+ criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
440
+
441
+ optimizer = torch.optim.AdamW(model.parameters(), lr=CREDIT_CONFIG['learning_rate'], weight_decay=CREDIT_CONFIG['weight_decay'])
442
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CREDIT_CONFIG['max_epochs'])
443
+
444
+ best_auc = 0
445
+ patience_counter = 0
446
+
447
+ for epoch in range(CREDIT_CONFIG['max_epochs']):
448
+ model.train()
449
+ train_loss = 0
450
+ for x_num, x_cat, y in train_loader:
451
+ x_num, x_cat, y = x_num.to(device), x_cat.to(device), y.to(device)
452
+ logits = model(x_num, x_cat)
453
+ loss = criterion(logits, y)
454
+ optimizer.zero_grad()
455
+ loss.backward()
456
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
457
+ optimizer.step()
458
+ train_loss += loss.item()
459
+
460
+ scheduler.step()
461
+
462
+ model.eval()
463
+ val_preds = []
464
+ val_labels = []
465
+ with torch.no_grad():
466
+ for x_num, x_cat, y in val_loader:
467
+ x_num, x_cat = x_num.to(device), x_cat.to(device)
468
+ logits = model(x_num, x_cat)
469
+ probs = torch.sigmoid(logits).cpu().numpy()
470
+ val_preds.extend(probs)
471
+ val_labels.extend(y.numpy())
472
+
473
+ val_preds = np.array(val_preds)
474
+ val_labels = np.array(val_labels)
475
+ val_auc = roc_auc_score(val_labels, val_preds)
476
+ val_ks = compute_ks_statistic(val_labels, val_preds)
477
+
478
+ if (epoch + 1) % 5 == 0 or val_auc > best_auc:
479
+ logger.info(f"Epoch {epoch+1}: Loss={train_loss/len(train_loader):.4f}, AUC={val_auc:.4f}, KS={val_ks:.4f}")
480
+
481
+ if val_auc > best_auc:
482
+ best_auc = val_auc
483
+ patience_counter = 0
484
+ torch.save(model.state_dict(), 'best_tabm_model.pt')
485
+ else:
486
+ patience_counter += 1
487
+ if patience_counter >= CREDIT_CONFIG['patience']:
488
+ logger.info(f"Early stopping at epoch {epoch+1}")
489
+ break
490
+
491
+ model.load_state_dict(torch.load('best_tabm_model.pt'))
492
+ model.eval()
493
+ val_preds = []
494
+ with torch.no_grad():
495
+ for x_num, x_cat, y in val_loader:
496
+ x_num, x_cat = x_num.to(device), x_cat.to(device)
497
+ probs = torch.sigmoid(model(x_num, x_cat)).cpu().numpy()
498
+ val_preds.extend(probs)
499
+
500
+ val_preds = np.array(val_preds)
501
+ final_auc = roc_auc_score(val_labels, val_preds)
502
+ final_ks = compute_ks_statistic(val_labels, val_preds)
503
+ logger.info(f"TabM Final: AUC={final_auc:.4f}, KS={final_ks:.4f}")
504
+ return model, val_preds, final_auc, final_ks
505
+
506
+
507
+ def train_lightgbm(X_num_train, X_cat_train, y_train, X_num_val, X_cat_val, y_val):
508
+ """训练LightGBM baseline"""
509
+ try:
510
+ import lightgbm as lgb
511
+ except ImportError:
512
+ logger.error("pip install lightgbm")
513
+ return None, None, 0, 0
514
+
515
+ X_train = np.concatenate([X_num_train, X_cat_train.astype(np.float32)], axis=1)
516
+ X_val = np.concatenate([X_num_val, X_cat_val.astype(np.float32)], axis=1)
517
+
518
+ num_pos = y_train.sum()
519
+ num_neg = len(y_train) - num_pos
520
+
521
+ params = {
522
+ 'objective': 'binary', 'metric': 'auc',
523
+ 'learning_rate': CREDIT_CONFIG['lgb_lr'],
524
+ 'num_leaves': CREDIT_CONFIG['lgb_num_leaves'],
525
+ 'max_depth': CREDIT_CONFIG['lgb_max_depth'],
526
+ 'min_child_samples': 20,
527
+ 'scale_pos_weight': num_neg / max(num_pos, 1),
528
+ 'subsample': 0.8, 'colsample_bytree': 0.8,
529
+ 'reg_alpha': 0.1, 'reg_lambda': 1.0,
530
+ 'verbose': -1, 'n_jobs': -1,
531
+ }
532
+
533
+ cat_feature_indices = list(range(X_num_train.shape[1], X_train.shape[1]))
534
+ train_data = lgb.Dataset(X_train, label=y_train, categorical_feature=cat_feature_indices)
535
+ val_data = lgb.Dataset(X_val, label=y_val, reference=train_data)
536
+
537
+ model = lgb.train(
538
+ params, train_data, num_boost_round=CREDIT_CONFIG['lgb_num_boost_round'],
539
+ valid_sets=[val_data],
540
+ callbacks=[lgb.early_stopping(stopping_rounds=50), lgb.log_evaluation(100)]
541
+ )
542
+
543
+ val_preds = model.predict(X_val)
544
+ val_auc = roc_auc_score(y_val, val_preds)
545
+ val_ks = compute_ks_statistic(y_val, val_preds)
546
+ logger.info(f"LightGBM Final: AUC={val_auc:.4f}, KS={val_ks:.4f}")
547
+
548
+ importance = model.feature_importance(importance_type='gain')
549
+ feature_names = CREDIT_CONFIG['numerical_features'] + [f"missing_{f}" for f in CREDIT_CONFIG['numerical_features']] + CREDIT_CONFIG['categorical_features']
550
+ if len(feature_names) == len(importance):
551
+ top_features = sorted(zip(feature_names, importance), key=lambda x: -x[1])[:10]
552
+ logger.info("Top 10 features by gain:")
553
+ for name, imp in top_features:
554
+ logger.info(f" {name}: {imp:.0f}")
555
+
556
+ return model, val_preds, val_auc, val_ks
557
+
558
+
559
+ def ensemble_predictions(tabm_preds: np.ndarray, lgb_preds: np.ndarray, y_true: np.ndarray):
560
+ """集成TabM + LightGBM"""
561
+ w_tabm = CREDIT_CONFIG['ensemble_weight_tabm']
562
+ w_lgb = CREDIT_CONFIG['ensemble_weight_lgb']
563
+
564
+ ensemble_preds = w_tabm * tabm_preds + w_lgb * lgb_preds
565
+ ensemble_auc = roc_auc_score(y_true, ensemble_preds)
566
+ ensemble_ks = compute_ks_statistic(y_true, ensemble_preds)
567
+
568
+ logger.info(f"Ensemble (TabM {w_tabm:.1f} + LGB {w_lgb:.1f}): AUC={ensemble_auc:.4f}, KS={ensemble_ks:.4f}")
569
+
570
+ best_auc = 0
571
+ best_w = 0.5
572
+ for w in np.arange(0.1, 1.0, 0.1):
573
+ pred = w * tabm_preds + (1 - w) * lgb_preds
574
+ auc = roc_auc_score(y_true, pred)
575
+ if auc > best_auc:
576
+ best_auc = auc
577
+ best_w = w
578
+
579
+ logger.info(f"Optimal weight: TabM={best_w:.1f}, LGB={1-best_w:.1f}, AUC={best_auc:.4f}")
580
+ return ensemble_preds, ensemble_auc, ensemble_ks
581
+
582
+
583
+ # ============================================================
584
+ # 阈值校准
585
+ # ============================================================
586
+ def calibrate_threshold(y_true: np.ndarray, y_pred: np.ndarray, method='ks'):
587
+ """阈值校准: 'ks'=最大化KS, 'youden'=Youden's J"""
588
+ thresholds = np.arange(0.01, 1.0, 0.01)
589
+
590
+ if method == 'ks':
591
+ best_ks = 0
592
+ best_threshold = 0.5
593
+ for t in thresholds:
594
+ pred_label = (y_pred >= t).astype(int)
595
+ tp = ((pred_label == 1) & (y_true == 1)).sum()
596
+ fp = ((pred_label == 1) & (y_true == 0)).sum()
597
+ fn = ((pred_label == 0) & (y_true == 1)).sum()
598
+ tn = ((pred_label == 0) & (y_true == 0)).sum()
599
+ tpr = tp / max(tp + fn, 1)
600
+ fpr = fp / max(fp + tn, 1)
601
+ ks = abs(tpr - fpr)
602
+ if ks > best_ks:
603
+ best_ks = ks
604
+ best_threshold = t
605
+ logger.info(f"KS Threshold: {best_threshold:.3f}, KS={best_ks:.4f}")
606
+ return best_threshold
607
+
608
+ elif method == 'youden':
609
+ from sklearn.metrics import roc_curve
610
+ fpr, tpr, roc_thresholds = roc_curve(y_true, y_pred)
611
+ j_scores = tpr - fpr
612
+ best_idx = np.argmax(j_scores)
613
+ best_threshold = roc_thresholds[best_idx]
614
+ logger.info(f"Youden's J Threshold: {best_threshold:.3f}")
615
+ return best_threshold
616
+
617
+
618
+ # ============================================================
619
+ # PSI 稳定性监控
620
+ # ============================================================
621
+ def compute_psi(expected: np.ndarray, actual: np.ndarray, n_bins: int = 10) -> float:
622
+ """PSI < 0.1: 稳定, 0.1-0.25: 需关注, >= 0.25: 显著漂移"""
623
+ breakpoints = np.quantile(expected, np.linspace(0, 1, n_bins + 1))
624
+ breakpoints[0] = -np.inf
625
+ breakpoints[-1] = np.inf
626
+
627
+ expected_percents = np.histogram(expected, bins=breakpoints)[0] / len(expected)
628
+ actual_percents = np.histogram(actual, bins=breakpoints)[0] / len(actual)
629
+
630
+ expected_percents = np.clip(expected_percents, 1e-4, None)
631
+ actual_percents = np.clip(actual_percents, 1e-4, None)
632
+
633
+ psi = np.sum((actual_percents - expected_percents) * np.log(actual_percents / expected_percents))
634
+ return psi
635
+
636
+
637
+ # ============================================================
638
+ # 主流程
639
+ # ============================================================
640
+ def main():
641
+ logger.info("=" * 60)
642
+ logger.info("征信数据风控模型 — 完整训练流程")
643
+ logger.info("=" * 60)
644
+
645
+ # 生成模拟数据 (替换为你的数据加载代码)
646
+ np.random.seed(42)
647
+ n_samples = 50000
648
+
649
+ data = {
650
+ 'age': np.random.randint(18, 65, n_samples).astype(float),
651
+ 'monthly_income': np.random.lognormal(9, 1, n_samples),
652
+ 'debt_to_income_ratio': np.random.beta(2, 5, n_samples),
653
+ 'total_credit_limit': np.random.lognormal(10, 1.5, n_samples),
654
+ 'total_balance': np.random.lognormal(9, 2, n_samples),
655
+ 'num_open_accounts': np.random.poisson(5, n_samples).astype(float),
656
+ 'num_delinquent_accounts': np.random.poisson(0.3, n_samples).astype(float),
657
+ 'months_since_last_delinq': np.random.exponential(24, n_samples),
658
+ 'credit_utilization': np.random.beta(3, 7, n_samples),
659
+ 'num_inquiries_6m': np.random.poisson(2, n_samples).astype(float),
660
+ 'longest_credit_history': np.random.gamma(5, 12, n_samples),
661
+ 'num_credit_cards': np.random.poisson(3, n_samples).astype(float),
662
+ 'max_delinquency_amount': np.random.exponential(1000, n_samples),
663
+ 'avg_monthly_payment': np.random.lognormal(7, 1, n_samples),
664
+ 'payment_to_income_ratio': np.random.beta(3, 7, n_samples),
665
+ 'education_level': np.random.choice(['高中', '大专', '本科', '硕士', '博士'], n_samples),
666
+ 'employment_type': np.random.choice(['企业', '事业单位', '公务员', '自由职业', '学生'], n_samples),
667
+ 'marital_status': np.random.choice(['未婚', '已婚', '离异'], n_samples),
668
+ 'housing_type': np.random.choice(['自有', '租房', '父母同住', '单位宿舍'], n_samples),
669
+ 'province': np.random.choice([f'省份_{i}' for i in range(30)], n_samples),
670
+ }
671
+
672
+ risk_score = (0.3 * data['debt_to_income_ratio'] + 0.2 * data['num_delinquent_accounts'] / 5 +
673
+ 0.2 * data['credit_utilization'] + 0.1 * data['num_inquiries_6m'] / 10 + 0.2 * np.random.random(n_samples))
674
+ data['is_default'] = (risk_score > np.quantile(risk_score, 0.97)).astype(int)
675
+
676
+ for col in ['months_since_last_delinq', 'max_delinquency_amount']:
677
+ mask = np.random.random(n_samples) < 0.3
678
+ data[col] = np.where(mask, np.nan, data[col])
679
+
680
+ df = pd.DataFrame(data)
681
+ logger.info(f"Samples: {n_samples}, Default rate: {df['is_default'].mean()*100:.2f}%")
682
+
683
+ # 时间分割 (实际中按申请时间分)
684
+ train_df, val_df = train_test_split(df, test_size=0.2, stratify=df['is_default'], random_state=42)
685
+
686
+ # 预处理
687
+ preprocessor = CreditDataPreprocessor()
688
+ X_num_train, X_cat_train, y_train = preprocessor.fit_transform(train_df)
689
+ X_num_val, X_cat_val, y_val = preprocessor.transform(val_df)
690
+
691
+ # 训练 LightGBM
692
+ lgb_model, lgb_preds, lgb_auc, lgb_ks = train_lightgbm(X_num_train, X_cat_train, y_train, X_num_val, X_cat_val, y_val)
693
+
694
+ # 训练 TabM
695
+ tabm_model, tabm_preds, tabm_auc, tabm_ks = train_tabm(X_num_train, X_cat_train, y_train, X_num_val, X_cat_val, y_val, ple_bins=preprocessor.ple_bins)
696
+
697
+ # 集成
698
+ if lgb_preds is not None and tabm_preds is not None:
699
+ ensemble_preds, ensemble_auc, ensemble_ks = ensemble_predictions(tabm_preds, lgb_preds, y_val)
700
+
701
+ # 阈值校准
702
+ best_preds = ensemble_preds if lgb_preds is not None else tabm_preds
703
+ threshold = calibrate_threshold(y_val, best_preds, method='ks')
704
+
705
+ # PSI
706
+ if lgb_model is not None:
707
+ X_train_full = np.concatenate([X_num_train, X_cat_train.astype(np.float32)], axis=1)
708
+ train_preds = lgb_model.predict(X_train_full)
709
+ psi = compute_psi(train_preds, lgb_preds)
710
+ logger.info(f"PSI (train vs val): {psi:.4f} {'✓ Stable' if psi < 0.1 else '⚠ Drift!'}")
711
+
712
+ logger.info("=" * 60)
713
+ logger.info("RESULTS SUMMARY")
714
+ logger.info(f" LightGBM: AUC={lgb_auc:.4f}, KS={lgb_ks:.4f}")
715
+ logger.info(f" TabM: AUC={tabm_auc:.4f}, KS={tabm_ks:.4f}")
716
+ if lgb_preds is not None:
717
+ logger.info(f" Ensemble: AUC={ensemble_auc:.4f}, KS={ensemble_ks:.4f}")
718
+ logger.info(f" Threshold: {threshold:.3f}")
719
+ logger.info("=" * 60)
720
+
721
+
722
+ if __name__ == "__main__":
723
+ main()