Stephanwu commited on
Commit
40fff37
·
verified ·
1 Parent(s): de73d07

Major update: Add DIN product recommendation and TabBERT anomaly detection with PyTorch

Browse files
Files changed (1) hide show
  1. app.py +810 -214
app.py CHANGED
@@ -1,7 +1,13 @@
1
- """保险APP 用户行为分析 - Gradio Space
2
- 支持: 合成数据训练 + 真实CSV数据上传
3
  """
4
- import os, json, math, warnings, datetime, random, io
 
 
 
 
 
 
 
 
5
  from collections import Counter, defaultdict
6
  from dataclasses import dataclass, field
7
  from typing import List, Dict, Optional
@@ -9,13 +15,13 @@ from typing import List, Dict, Optional
9
  warnings.filterwarnings('ignore')
10
  import numpy as np
11
  import pandas as pd
12
- from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold
13
  from sklearn.preprocessing import StandardScaler, LabelEncoder
14
  from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier
15
  from sklearn.metrics import (
16
  roc_auc_score, f1_score, confusion_matrix,
17
  average_precision_score, precision_recall_curve, classification_report,
18
- roc_curve
19
  )
20
  import matplotlib
21
  matplotlib.use('Agg')
@@ -24,8 +30,20 @@ import seaborn as sns
24
 
25
  import gradio as gr
26
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  # =============================================================================
28
- # 数据模型
29
  # =============================================================================
30
 
31
  INSURANCE_EVENT_TYPES = {
@@ -38,11 +56,11 @@ INSURANCE_EVENT_TYPES = {
38
  "policy_cancel", "app_uninstall", "login", "logout",
39
  }
40
 
41
- BROWSE_EVENTS = {"page_view", "product_view", "premium_calculator", "article_read", "faq_view", "product_compare"}
42
- INTERACT_EVENTS = {"quote_request", "form_submit", "document_upload", "chat_init", "call_init", "video_consult", "quote_result_view"}
43
- CONVERT_EVENTS = {"policy_select", "payment_init", "payment_success", "policy_issued"}
44
- CLAIM_EVENTS = {"claim_init", "claim_doc_upload", "claim_review", "claim_approved", "claim_rejected"}
45
- RENEW_EVENTS = {"renewal_reminder", "renewal_click", "renewal_complete", "policy_cancel"}
46
 
47
  @dataclass
48
  class InsuranceAppEvent:
@@ -77,8 +95,13 @@ class InsuranceFeatureEngineer:
77
  days_active = (last_ts - first_ts) / (24 * 3600 * 1000)
78
  has_purchased = any(e.event_type == "policy_issued" for e in all_events)
79
  has_renewed = any(e.event_type == "renewal_complete" for e in all_events)
80
- has_claimed = any(e.event_type in ("claim_init", "claim_approved") for e in all_events)
81
  support = all_type_counts.get("chat_init", 0) + all_type_counts.get("call_init", 0)
 
 
 
 
 
82
  return {
83
  "total_sessions": len(sessions), "total_events": total,
84
  "days_active": days_active, "avg_events_per_session": total / len(sessions),
@@ -88,6 +111,7 @@ class InsuranceFeatureEngineer:
88
  "payment_success_ratio": all_type_counts.get("payment_success", 0) / total,
89
  "policy_issued_ratio": all_type_counts.get("policy_issued", 0) / total,
90
  "unique_products_viewed": len(product_counter),
 
91
  "has_purchased": int(has_purchased), "has_renewed": int(has_renewed),
92
  "has_claimed": int(has_claimed), "support_dependency": support / total,
93
  "renewal_click_count": all_type_counts.get("renewal_click", 0),
@@ -98,58 +122,47 @@ class InsuranceFeatureEngineer:
98
  "peak_active_hour": Counter(datetime.datetime.fromtimestamp(e.timestamp/1000).hour for e in all_events).most_common(1)[0][0],
99
  "recent_7day_events": sum(1 for e in all_events if (last_ts-e.timestamp)<7*24*3600*1000),
100
  "recent_30day_events": sum(1 for e in all_events if (last_ts-e.timestamp)<30*24*3600*1000),
 
 
 
 
101
  }
102
 
103
 
104
  # =============================================================================
105
- # 数据解析
106
  # =============================================================================
107
 
108
- def parse_csv_to_profiles(df: pd.DataFrame) -> List[UserBehaviorProfile]:
109
- """将上传的CSV解析为用户行为画像"""
110
  required_cols = {"user_id", "session_id", "timestamp", "event_type", "page_id"}
111
- missing = required_cols - set(df.columns)
112
  if missing:
113
- raise ValueError(f"CSV缺少必需列: {missing}\n必需列: {required_cols}")
114
-
115
- # 标准化列名
116
- df = df.copy()
117
  df.columns = [c.lower().strip() for c in df.columns]
118
-
119
- # 转换timestamp为整数
120
  df["timestamp"] = pd.to_numeric(df["timestamp"], errors="coerce")
121
  df = df.dropna(subset=["timestamp", "event_type"])
122
  df["timestamp"] = df["timestamp"].astype(int)
123
 
124
- # 按user_id和session_id分组
125
  profiles = {}
126
- for (user_id, session_id), group in df.groupby(["user_id", "session_id"]):
127
- if user_id not in profiles:
128
- profiles[user_id] = UserBehaviorProfile(user_id=str(user_id), sessions=[])
129
-
130
  events = []
131
  for _, row in group.sort_values("timestamp").iterrows():
132
  events.append(InsuranceAppEvent(
133
- event_id=f"evt_{row.name}",
134
- user_id=str(row["user_id"]),
135
- session_id=str(row["session_id"]),
136
- timestamp=int(row["timestamp"]),
137
  event_type=str(row["event_type"]).strip(),
138
  page_id=str(row.get("page_id", "unknown")),
139
  product_id=str(row.get("product_id")) if pd.notna(row.get("product_id")) else None,
140
  amount=float(row["amount"]) if pd.notna(row.get("amount")) else None,
141
  ))
142
-
143
- profiles[user_id].sessions.append(UserSession(
144
- session_id=str(session_id),
145
- user_id=str(user_id),
146
- events=events
147
- ))
148
-
149
  return list(profiles.values())
150
 
151
 
152
- def generate_synthetic_data(n_users=2000, n_events_per_user=50):
 
153
  event_types = list(INSURANCE_EVENT_TYPES)
154
  products = ["health_basic","health_premium","critical_illness","term_life",
155
  "auto_compulsory","auto_commercial","home","travel_domestic"]
@@ -180,22 +193,93 @@ def generate_synthetic_data(n_users=2000, n_events_per_user=50):
180
  return data
181
 
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  # =============================================================================
184
- # 核心训练函数
185
  # =============================================================================
186
 
187
- def train_model(features_list, labels, test_size=0.2, random_state=42, use_cv=False):
188
- """通用训练函数"""
189
  df = pd.DataFrame(features_list)
190
  df_full = df.copy()
191
 
192
- # 移除非数值列
193
- drop_cols = [c for c in df.columns if df[c].dtype == 'object']
194
  for c in drop_cols:
195
- if c in ["top_product_id", "action_sequence"]:
196
- df.pop(c)
197
-
198
- # 处理object类型
199
  for c in df.columns:
200
  if df[c].dtype == 'object':
201
  df[c] = pd.to_numeric(df[c], errors='coerce').fillna(0)
@@ -212,23 +296,13 @@ def train_model(features_list, labels, test_size=0.2, random_state=42, use_cv=Fa
212
  X_train_s = scaler.fit_transform(X_train)
213
  X_test_s = scaler.transform(X_test)
214
 
215
- # 训练 GBDT
216
- gbdt = GradientBoostingClassifier(
217
- n_estimators=200, max_depth=5, learning_rate=0.1,
218
- subsample=0.8, random_state=random_state
219
- )
220
  gbdt.fit(X_train_s, y_train)
221
- y_pred_gbdt = gbdt.predict(X_test_s)
222
- y_prob_gbdt = gbdt.predict_proba(X_test_s)[:, 1]
223
 
224
- # 训练 RF
225
- rf = RandomForestClassifier(
226
- n_estimators=100, max_depth=10,
227
- class_weight='balanced', random_state=random_state, n_jobs=-1
228
- )
229
  rf.fit(X_train_s, y_train)
230
- y_prob_rf = rf.predict_proba(X_test_s)[:, 1]
231
- y_pred_rf = rf.predict(X_test_s)
232
 
233
  auc_gbdt = float(roc_auc_score(y_test, y_prob_gbdt))
234
  f1_gbdt = float(f1_score(y_test, y_pred_gbdt))
@@ -236,18 +310,13 @@ def train_model(features_list, labels, test_size=0.2, random_state=42, use_cv=Fa
236
  auc_rf = float(roc_auc_score(y_test, y_prob_rf))
237
  ap_rf = float(average_precision_score(y_test, y_prob_rf))
238
 
239
- fi = pd.DataFrame({
240
- 'feature': feature_names,
241
- 'importance': rf.feature_importances_
242
- }).sort_values('importance', ascending=False)
243
 
244
- # 交叉验证
245
  cv_scores = None
246
  if use_cv and len(y) >= 100:
247
  skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=random_state)
248
  cv_scores = cross_val_score(rf, X, y, cv=skf, scoring='roc_auc')
249
 
250
- # 可视化
251
  os.makedirs("outputs", exist_ok=True)
252
 
253
  fig, ax = plt.subplots(figsize=(12,8))
@@ -265,11 +334,9 @@ def train_model(features_list, labels, test_size=0.2, random_state=42, use_cv=Fa
265
  pr, rr, _ = precision_recall_curve(y_test, y_prob_rf)
266
  ax.plot(rg, pg, label=f'GBDT AP={ap_gbdt:.3f}', linewidth=2, color='#2E86AB')
267
  ax.plot(rr, pr, label=f'RF AP={ap_rf:.3f}', linewidth=2, color='#A23B72')
268
- ax.set_xlabel('Recall', fontsize=12)
269
- ax.set_ylabel('Precision', fontsize=12)
270
  ax.set_title('Precision-Recall Curve', fontsize=14, fontweight='bold')
271
- ax.legend(fontsize=11)
272
- ax.grid(True, alpha=0.3)
273
  plt.tight_layout()
274
  fig_path2 = "outputs/pr_curve.png"
275
  plt.savefig(fig_path2, dpi=150, bbox_inches='tight'); plt.close()
@@ -285,7 +352,6 @@ def train_model(features_list, labels, test_size=0.2, random_state=42, use_cv=Fa
285
  fig_path3 = "outputs/confusion_matrix.png"
286
  plt.savefig(fig_path3, dpi=150, bbox_inches='tight'); plt.close()
287
 
288
- # ROC曲线
289
  fig, ax = plt.subplots(figsize=(8,6))
290
  fpr_g, tpr_g, _ = roc_curve(y_test, y_prob_gbdt)
291
  fpr_r, tpr_r, _ = roc_curve(y_test, y_prob_rf)
@@ -295,8 +361,7 @@ def train_model(features_list, labels, test_size=0.2, random_state=42, use_cv=Fa
295
  ax.set_xlabel('False Positive Rate', fontsize=12)
296
  ax.set_ylabel('True Positive Rate', fontsize=12)
297
  ax.set_title('ROC Curve', fontsize=14, fontweight='bold')
298
- ax.legend(fontsize=11)
299
- ax.grid(True, alpha=0.3)
300
  plt.tight_layout()
301
  fig_path4 = "outputs/roc_curve.png"
302
  plt.savefig(fig_path4, dpi=150, bbox_inches='tight'); plt.close()
@@ -311,7 +376,7 @@ def train_model(features_list, labels, test_size=0.2, random_state=42, use_cv=Fa
311
  result_text = f"""=== 模型训练结果 ===
312
  样本数: {len(y)} | 特征数: {len(feature_names)}
313
  训练集: {len(y_train)} | 测试集: {len(y_test)}
314
- 流失率: {y.mean():.1%} | 流失数: {y.sum()}
315
 
316
  --- GBDT ---
317
  AUC: {auc_gbdt:.4f}
@@ -332,42 +397,514 @@ AP: {ap_rf:.4f}
332
  return result_text, fig_path1, fig_path2, fig_path3, fig_path4, df_full
333
 
334
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
  # =============================================================================
336
  # Gradio 回调函数
337
  # =============================================================================
338
 
339
  def demo_train(n_users, n_events, test_size, random_state, use_cv):
340
- """演示模式: 合成数据训练"""
341
- data = generate_synthetic_data(n_users=n_users, n_events_per_user=n_events)
342
  engineer = InsuranceFeatureEngineer()
343
  features_list, labels = [], []
344
  for profile, label in data:
345
  f = engineer.extract_user_features(profile)
346
  if f: features_list.append(f); labels.append(label)
347
-
348
- return train_model(features_list, labels, test_size, random_state, use_cv)
349
 
350
 
351
  def csv_train(csv_file, label_col, test_size, random_state, use_cv):
352
- """CSV模式: 上传数据训练"""
353
  if csv_file is None:
354
  return "请先上传CSV文件", None, None, None, None, None
355
-
356
  try:
357
- # 读取CSV
358
  if isinstance(csv_file, str):
359
  df = pd.read_csv(csv_file)
360
  else:
361
  df = pd.read_csv(csv_file.name if hasattr(csv_file, 'name') else io.BytesIO(csv_file))
362
 
363
- # 检查标签列
364
  label_col = label_col.strip() if label_col else None
365
  if label_col and label_col not in df.columns:
366
  return f"标签列 '{label_col}' 不存在。可用列: {list(df.columns)}", None, None, None, None, None
367
 
368
- # 解析为用户画像
369
  profiles = parse_csv_to_profiles(df)
370
-
371
  engineer = InsuranceFeatureEngineer()
372
  features_list, labels = [], []
373
 
@@ -375,82 +912,68 @@ def csv_train(csv_file, label_col, test_size, random_state, use_cv):
375
  f = engineer.extract_user_features(profile)
376
  if f:
377
  features_list.append(f)
378
- # 如果有标签列,使用真实标签;否则用启发式规则推断
379
  if label_col and label_col in df.columns:
380
- # 找到该用户的标签
381
  user_df = df[df["user_id"] == profile.user_id]
382
- label_val = user_df[label_col].iloc[0] if len(user_df) > 0 else 0
383
- labels.append(int(label_val))
384
  else:
385
- # 启发式: 无购买+无续保 = 高风险流失
386
- is_high_risk = (f["has_purchased"] == 0 and f["has_renewed"] == 0
387
- and f["total_events"] < 20)
388
  labels.append(int(is_high_risk))
389
 
390
  if len(features_list) < 50:
391
- return f"有效样本数 {len(features_list)} 太少,需要至少50个用户", None, None, None, None, None
392
-
393
- result = train_model(features_list, labels, test_size, random_state, use_cv)
394
- return result
395
 
 
396
  except Exception as e:
397
  import traceback
398
  return f"错误: {str(e)}\n\n{traceback.format_exc()}", None, None, None, None, None
399
 
400
 
401
  def show_csv_info(csv_file):
402
- """显示CSV信息"""
403
  if csv_file is None:
404
  return "请先上传CSV文件", None
405
-
406
  try:
407
  if isinstance(csv_file, str):
408
  df = pd.read_csv(csv_file)
409
  else:
410
  df = pd.read_csv(csv_file.name if hasattr(csv_file, 'name') else io.BytesIO(csv_file))
411
-
412
  info = f"""=== CSV文件信息 ===
413
- 行数: {len(df)}
414
- 列数: {len(df.columns)}
415
  列名: {list(df.columns)}
416
 
417
- === 前5行预览 ===
418
  {df.head().to_string()}
419
 
420
  === 事件类型分布 (前10) ===
421
  {df['event_type'].value_counts().head(10).to_string() if 'event_type' in df.columns else '无event_type列'}
422
 
423
- === 用户数 ===
424
- {df['user_id'].nunique() if 'user_id' in df.columns else '无user_id列'}
425
-
426
- === 会话数量 ===
427
- {df['session_id'].nunique() if 'session_id' in df.columns else '无session_id列'}"""
428
-
429
  return info, df.head(20)
430
  except Exception as e:
431
  return f"解析错误: {str(e)}", None
432
 
433
 
434
  # =============================================================================
435
- # Gradio 界面
436
  # =============================================================================
437
 
438
  with gr.Blocks(title="🏥 保险APP 用户行为分析模型训练平台", theme=gr.themes.Soft()) as demo:
439
- gr.Markdown("""
440
- # 🏥 保险APP 用户行为分析模型训练平台
441
-
442
- 基于最新研究论文构建的工业级保险用户行为分析平台。
443
-
444
- **两种模式:**
445
- - 🎲 **演示模式**: 生成合成保险APP数据,体验完整训练流程
446
- - 📁 **CSV上传**: 上传真实用户行为数据,自动特征工程 + 模型训练
447
-
448
- **参考论:** Deep Interest Network (KDD 2018) | Transformer Churn Prediction (arXiv 2309.14390) | TabBERT (arXiv 2011.01843)
449
- """)
 
450
 
451
  with gr.Tabs():
452
  # ===== Tab 1: 演示模式 =====
453
- with gr.Tab("🎲 演示模式 (合成数据)"):
454
  with gr.Row():
455
  with gr.Column(scale=1):
456
  gr.Markdown("### 参数设置")
@@ -460,10 +983,8 @@ with gr.Blocks(title="🏥 保险APP 用户行为分析模型训练平台", them
460
  random_seed = gr.Number(value=42, label="随机种子", precision=0)
461
  use_cv_check = gr.Checkbox(value=False, label="启用5折交叉验证")
462
  train_btn = gr.Button("🚀 开始训练", variant="primary", size="lg")
463
-
464
  with gr.Column(scale=2):
465
  demo_result = gr.Textbox(label="训练结果", lines=25, show_copy_button=True)
466
-
467
  with gr.Row():
468
  demo_img1 = gr.Image(label="特征重要性")
469
  demo_img2 = gr.Image(label="PR曲线")
@@ -471,56 +992,38 @@ with gr.Blocks(title="🏥 保险APP 用户行为分析模型训练平台", them
471
  demo_img3 = gr.Image(label="混淆矩阵")
472
  demo_img4 = gr.Image(label="ROC曲线")
473
  with gr.Row():
474
- demo_table = gr.Dataframe(label="特征数据样本 (前10行)")
475
 
476
  # ===== Tab 2: CSV上传 =====
477
  with gr.Tab("📁 CSV数据上传"):
478
  with gr.Row():
479
  with gr.Column(scale=1):
480
- gr.Markdown("""
481
- ### 📤 上传数据
482
-
483
- **必需列:**
484
- - `user_id`: 用户唯一
485
- - `session_id`: 会话标识
486
- - `timestamp`: Unix 时间戳 (毫秒或秒)
487
- - `event_type`: 事件类型
488
- - `page_id`: 页面标识
489
-
490
- **可选列:**
491
- - `product_id`: 保险产品ID
492
- - `amount`: 金额/保额
493
- - `label` (或其他): 流失标签 (0/1)
494
-
495
- **示例CSV格式:**
496
- ```
497
- user_id,session_id,timestamp,event_type,page_id,product_id,amount
498
- user_001,sess_001,1704067200000,page_view,home,,
499
- user_001,sess_001,1704067230000,product_view,product,health_basic,
500
- user_001,sess_001,1704067260000,quote_request,quote,health_basic,50000
501
- ```
502
- """)
503
-
504
  csv_file = gr.File(label="上传CSV文件", file_types=[".csv"])
505
- label_col_input = gr.Textbox(label="标签列名 (可选, 默认自动推断)", placeholder="如: churn, is_churned, label")
506
-
507
  with gr.Row():
508
  csv_test_size = gr.Slider(0.1, 0.4, value=0.2, step=0.05, label="测试集比例")
509
  csv_random_seed = gr.Number(value=42, label="随机种子", precision=0)
510
-
511
  csv_use_cv = gr.Checkbox(value=False, label="启用5折交叉验证")
512
-
513
  with gr.Row():
514
  info_btn = gr.Button("📊 查看数据信息", variant="secondary")
515
  csv_train_btn = gr.Button("🚀 训练模型", variant="primary", size="lg")
516
-
517
  with gr.Column(scale=2):
518
  csv_info = gr.Textbox(label="CSV信息", lines=15, show_copy_button=True)
519
  csv_preview = gr.Dataframe(label="数据预览")
520
-
521
  with gr.Row():
522
  csv_result = gr.Textbox(label="训练结果", lines=25, show_copy_button=True)
523
-
524
  with gr.Row():
525
  csv_img1 = gr.Image(label="特征重要性")
526
  csv_img2 = gr.Image(label="PR曲线")
@@ -528,67 +1031,152 @@ with gr.Blocks(title="🏥 保险APP 用户行为分析模型训练平台", them
528
  csv_img3 = gr.Image(label="混淆矩阵")
529
  csv_img4 = gr.Image(label="ROC曲线")
530
  with gr.Row():
531
- csv_table = gr.Dataframe(label="特征数据样本 (前10行)")
532
 
533
- # ===== Tab 3: 帮助文档 =====
534
- with gr.Tab(" 帮助文档"):
535
- gr.Markdown("""
536
- ## 事件类型定义
537
-
538
- | 类别 | 事件 | 业务含义 |
539
- |------|------|---------|
540
- | **浏览** | page_view, product_view, premium_calculator, article_read, faq_view, product_compare | 用户浏览保险产品页面 |
541
- | **交互** | quote_request, form_submit, document_upload, chat_init, call_init, video_consult, quote_result_view | 用户深度参与行为 |
542
- | **转化** | policy_select, payment_init, payment_success, policy_issued | 核心KPI转化行为 |
543
- | **理赔** | claim_init, claim_doc_upload, claim_review, claim_approved, claim_rejected | 理赔全流程 |
544
- | **续保** | renewal_reminder, renewal_click, renewal_complete, policy_cancel | 续保/流失信号 |
545
- | **其他** | login, logout, app_uninstall | 登录/登出/卸载 |
546
-
547
- ## 特征工程说明
548
-
549
- 平台自动提取 **30+维行为特征**:
550
-
551
- | 维度 | 特征示例 | 业务含义 |
552
- |------|---------|---------|
553
- | 基础活跃度 | total_sessions, total_events, days_active | 用户使用APP的活跃程度 |
554
- | 浏览深度 | product_view_ratio, article_read_ratio | 内容消费深度 |
555
- | 转化信号 | payment_success_ratio, policy_issued_ratio | 购买/续保意愿 |
556
- | 生命周期 | has_purchased, has_renewed, has_claimed | 客户价值阶段 |
557
- | 时序行为 | recent_7day_events, days_since_last_event | 近期活跃/沉默 |
558
- | 行为模式 | peak_active_hour, weekend_activity_ratio | 使用习惯 |
559
-
560
- ## 模型说明
561
-
562
- | 模型 | 特点 | 适用场景 |
563
- |------|------|---------|
564
- | **GBDT** | 高精度, 可解释 | 流失预测, 欺诈检测 |
565
- | **Random Forest** | 抗过拟合, 特征重要性 | 特征筛选, 基线模型 |
566
 
567
- ## 评估指标
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
568
 
569
- - **AUC-ROC**: 分类器整体区分能力
570
- - **F1-Score**: 精确率和召回率的调和平均
571
- - **AP (Average Precision)**: PR曲线下面积, 适合不平衡数据
572
- - **交叉验证**: 5折StratifiedKFold, 评估模型稳定性
 
 
 
 
 
 
 
 
 
 
 
 
 
573
 
574
- ## 注意事项
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
575
 
576
- 1. 保险场景数据高度不平衡 (流失率 < 5%), 请使用 F1/AP 而非 Accuracy
577
- 2. 建议至少 1000+ 用户样本才能获得稳定结果
578
- 3. timestamp 支持毫秒或秒, 平台自动识别
579
- 4. 无标签列时, 平台使用启发式规则自动推断 (无购买+低活跃 = 高风险)
580
- """)
581
-
582
- gr.Markdown("""
583
- ---
584
- <div align="center">
585
- <b>保险APP 用户行为分析模型训练平台</b> |
586
- 基于 <a href="https://arxiv.org/abs/1706.06978">DIN</a> |
587
- <a href="https://arxiv.org/abs/2309.14390">Churn Transformer</a> |
588
- <a href="https://arxiv.org/abs/2011.01843">TabBERT</a> |
589
- 作者: <a href="https://huggingface.co/Stephanwu">Stephanwu</a>
590
- </div>
591
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
592
 
593
  # ===== 事件绑定 =====
594
  train_btn.click(
@@ -596,18 +1184,26 @@ with gr.Blocks(title="🏥 保险APP 用户行为分析模型训练平台", them
596
  inputs=[n_users_slider, n_events_slider, test_size_slider, random_seed, use_cv_check],
597
  outputs=[demo_result, demo_img1, demo_img2, demo_img3, demo_img4, demo_table]
598
  )
599
-
600
  info_btn.click(
601
  fn=show_csv_info,
602
  inputs=[csv_file],
603
  outputs=[csv_info, csv_preview]
604
  )
605
-
606
  csv_train_btn.click(
607
  fn=csv_train,
608
  inputs=[csv_file, label_col_input, csv_test_size, csv_random_seed, csv_use_cv],
609
  outputs=[csv_result, csv_img1, csv_img2, csv_img3, csv_img4, csv_table]
610
  )
 
 
 
 
 
 
 
 
 
 
611
 
612
  if __name__ == "__main__":
613
  demo.launch()
 
 
 
1
  """
2
+ 保险APP 用户行为分析 - Gradio Space (完整版)
3
+ 支持: 演示模式 | CSV上传 | 产品推荐(DIN) | 异常检测(TabBERT)
4
+
5
+ 参考文献:
6
+ - DIN: Deep Interest Network (KDD 2018, arxiv:1706.06978)
7
+ - TabBERT: Tabular Transformers (arxiv:2011.01843)
8
+ - Focal Loss: RetinaNet (ICCV 2017, arxiv:1708.02002)
9
+ """
10
+ import os, io, math, warnings, datetime, random, json
11
  from collections import Counter, defaultdict
12
  from dataclasses import dataclass, field
13
  from typing import List, Dict, Optional
 
15
  warnings.filterwarnings('ignore')
16
  import numpy as np
17
  import pandas as pd
18
+ from sklearn.model_selection import train_test_split, StratifiedKFold, cross_val_score
19
  from sklearn.preprocessing import StandardScaler, LabelEncoder
20
  from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier
21
  from sklearn.metrics import (
22
  roc_auc_score, f1_score, confusion_matrix,
23
  average_precision_score, precision_recall_curve, classification_report,
24
+ roc_curve, accuracy_score
25
  )
26
  import matplotlib
27
  matplotlib.use('Agg')
 
30
 
31
  import gradio as gr
32
 
33
+ # PyTorch (可选, 用于深度学习模型)
34
+ try:
35
+ import torch
36
+ import torch.nn as nn
37
+ import torch.nn.functional as F
38
+ from torch.utils.data import Dataset, DataLoader
39
+ TORCH_AVAILABLE = True
40
+ except ImportError:
41
+ TORCH_AVAILABLE = False
42
+ print("PyTorch not available. Deep learning models disabled.")
43
+
44
+
45
  # =============================================================================
46
+ # 数据模型 & 特征工程 (保持原有)
47
  # =============================================================================
48
 
49
  INSURANCE_EVENT_TYPES = {
 
56
  "policy_cancel", "app_uninstall", "login", "logout",
57
  }
58
 
59
+ BROWSE = {"page_view","product_view","premium_calculator","article_read","faq_view","product_compare"}
60
+ INTERACT = {"quote_request","form_submit","document_upload","chat_init","call_init","video_consult","quote_result_view"}
61
+ CONVERT = {"policy_select","payment_init","payment_success","policy_issued"}
62
+ CLAIM = {"claim_init","claim_doc_upload","claim_review","claim_approved","claim_rejected"}
63
+ RENEW = {"renewal_reminder","renewal_click","renewal_complete","policy_cancel"}
64
 
65
  @dataclass
66
  class InsuranceAppEvent:
 
95
  days_active = (last_ts - first_ts) / (24 * 3600 * 1000)
96
  has_purchased = any(e.event_type == "policy_issued" for e in all_events)
97
  has_renewed = any(e.event_type == "renewal_complete" for e in all_events)
98
+ has_claimed = any(e.event_type in ("claim_init","claim_approved") for e in all_events)
99
  support = all_type_counts.get("chat_init", 0) + all_type_counts.get("call_init", 0)
100
+
101
+ # 计算行为序列 (用于DIN)
102
+ event_seq = [e.event_type for e in all_events]
103
+ product_seq = [e.product_id or "none" for e in all_events]
104
+
105
  return {
106
  "total_sessions": len(sessions), "total_events": total,
107
  "days_active": days_active, "avg_events_per_session": total / len(sessions),
 
111
  "payment_success_ratio": all_type_counts.get("payment_success", 0) / total,
112
  "policy_issued_ratio": all_type_counts.get("policy_issued", 0) / total,
113
  "unique_products_viewed": len(product_counter),
114
+ "top_product_id": top_product or "none",
115
  "has_purchased": int(has_purchased), "has_renewed": int(has_renewed),
116
  "has_claimed": int(has_claimed), "support_dependency": support / total,
117
  "renewal_click_count": all_type_counts.get("renewal_click", 0),
 
122
  "peak_active_hour": Counter(datetime.datetime.fromtimestamp(e.timestamp/1000).hour for e in all_events).most_common(1)[0][0],
123
  "recent_7day_events": sum(1 for e in all_events if (last_ts-e.timestamp)<7*24*3600*1000),
124
  "recent_30day_events": sum(1 for e in all_events if (last_ts-e.timestamp)<30*24*3600*1000),
125
+ # 序列特征 (用于深度学习模型)
126
+ "_event_sequence": event_seq,
127
+ "_product_sequence": product_seq,
128
+ "_user_id": profile.user_id,
129
  }
130
 
131
 
132
  # =============================================================================
133
+ # 数据解析 & 生成
134
  # =============================================================================
135
 
136
+ def parse_csv_to_profiles(df):
 
137
  required_cols = {"user_id", "session_id", "timestamp", "event_type", "page_id"}
138
+ missing = required_cols - set(c.lower().strip() for c in df.columns)
139
  if missing:
140
+ raise ValueError(f"CSV缺少必需列: {missing}")
 
 
 
141
  df.columns = [c.lower().strip() for c in df.columns]
 
 
142
  df["timestamp"] = pd.to_numeric(df["timestamp"], errors="coerce")
143
  df = df.dropna(subset=["timestamp", "event_type"])
144
  df["timestamp"] = df["timestamp"].astype(int)
145
 
 
146
  profiles = {}
147
+ for (uid, sid), group in df.groupby(["user_id", "session_id"]):
148
+ if uid not in profiles:
149
+ profiles[uid] = UserBehaviorProfile(user_id=str(uid), sessions=[])
 
150
  events = []
151
  for _, row in group.sort_values("timestamp").iterrows():
152
  events.append(InsuranceAppEvent(
153
+ event_id=f"evt_{row.name}", user_id=str(row["user_id"]),
154
+ session_id=str(row["session_id"]), timestamp=int(row["timestamp"]),
 
 
155
  event_type=str(row["event_type"]).strip(),
156
  page_id=str(row.get("page_id", "unknown")),
157
  product_id=str(row.get("product_id")) if pd.notna(row.get("product_id")) else None,
158
  amount=float(row["amount"]) if pd.notna(row.get("amount")) else None,
159
  ))
160
+ profiles[uid].sessions.append(UserSession(session_id=str(sid), user_id=str(uid), events=events))
 
 
 
 
 
 
161
  return list(profiles.values())
162
 
163
 
164
+ def generate_synthetic_data(n_users=2000, n_events_per_user=50, seed=42):
165
+ random.seed(seed); np.random.seed(seed)
166
  event_types = list(INSURANCE_EVENT_TYPES)
167
  products = ["health_basic","health_premium","critical_illness","term_life",
168
  "auto_compulsory","auto_commercial","home","travel_domestic"]
 
193
  return data
194
 
195
 
196
+ def generate_product_recommendation_data(n_users=1000, seed=42):
197
+ """生成产品推荐训练数据"""
198
+ random.seed(seed); np.random.seed(seed)
199
+ products = ["health_basic","health_premium","critical_illness","term_life",
200
+ "auto_compulsory","auto_commercial","home","travel_domestic"]
201
+ event_types = list(INSURANCE_EVENT_TYPES)
202
+
203
+ records = []
204
+ for u in range(n_users):
205
+ user_id = u
206
+ n_behaviors = random.randint(5, 30)
207
+ behavior_events = []
208
+ behavior_products = []
209
+
210
+ # 生成用户历史行为
211
+ for i in range(n_behaviors):
212
+ et = random.choice(["page_view","product_view","quote_request","article_read"])
213
+ behavior_events.append(et)
214
+ behavior_products.append(random.choice(products))
215
+
216
+ # 生成候选产品和标签
217
+ candidate = random.choice(products)
218
+ # 如果候选产品出现过在历史中, 更可能购买
219
+ label = 1 if candidate in behavior_products else random.choices([0,1], weights=[0.7,0.3])[0]
220
+
221
+ records.append({
222
+ 'user_id': user_id,
223
+ 'behavior_events': behavior_events,
224
+ 'behavior_products': behavior_products,
225
+ 'candidate_product': candidate,
226
+ 'label': label,
227
+ 'user_features': np.random.randn(20).astype(np.float32), # 模拟用户统计特征
228
+ })
229
+
230
+ return pd.DataFrame(records)
231
+
232
+
233
+ def generate_anomaly_data(n_normal=800, n_anomaly=200, seed=42):
234
+ """生成异常检测数据 (理赔记录)"""
235
+ random.seed(seed); np.random.seed(seed)
236
+
237
+ normal_records = []
238
+ for i in range(n_normal):
239
+ record = {
240
+ 'user_id': i,
241
+ 'claim_amount': random.uniform(1000, 50000),
242
+ 'claim_type': random.choice(["health","auto","property"]),
243
+ 'days_since_policy': random.randint(30, 365),
244
+ 'num_previous_claims': random.randint(0, 3),
245
+ 'document_count': random.randint(3, 10),
246
+ 'processing_time_days': random.uniform(1, 15),
247
+ 'label': 0, # 正常
248
+ }
249
+ normal_records.append(record)
250
+
251
+ anomaly_records = []
252
+ for i in range(n_anomaly):
253
+ # 异常特征: 高金额、刚投保、多理赔、少材料、快处理
254
+ record = {
255
+ 'user_id': n_normal + i,
256
+ 'claim_amount': random.uniform(50000, 200000),
257
+ 'claim_type': random.choice(["health","auto","property"]),
258
+ 'days_since_policy': random.randint(1, 15), # 刚投保就理赔
259
+ 'num_previous_claims': random.randint(5, 20), # 多次理赔
260
+ 'document_count': random.randint(0, 2), # 材料极少
261
+ 'processing_time_days': random.uniform(0.1, 2), # 异常快
262
+ 'label': 1, # 异常
263
+ }
264
+ anomaly_records.append(record)
265
+
266
+ df = pd.DataFrame(normal_records + anomaly_records)
267
+ df = df.sample(frac=1, random_state=seed).reset_index(drop=True) # 打乱
268
+ return df
269
+
270
+
271
  # =============================================================================
272
+ # 通用训练函数 (sklearn)
273
  # =============================================================================
274
 
275
+ def train_sklearn(features_list, labels, test_size=0.2, random_state=42, use_cv=False):
 
276
  df = pd.DataFrame(features_list)
277
  df_full = df.copy()
278
 
279
+ # 移除非数值列 (内部字段)
280
+ drop_cols = [c for c in df.columns if c.startswith('_')]
281
  for c in drop_cols:
282
+ df.pop(c)
 
 
 
283
  for c in df.columns:
284
  if df[c].dtype == 'object':
285
  df[c] = pd.to_numeric(df[c], errors='coerce').fillna(0)
 
296
  X_train_s = scaler.fit_transform(X_train)
297
  X_test_s = scaler.transform(X_test)
298
 
299
+ gbdt = GradientBoostingClassifier(n_estimators=200, max_depth=5, learning_rate=0.1, subsample=0.8, random_state=random_state)
 
 
 
 
300
  gbdt.fit(X_train_s, y_train)
301
+ y_pred_gbdt = gbdt.predict(X_test_s); y_prob_gbdt = gbdt.predict_proba(X_test_s)[:,1]
 
302
 
303
+ rf = RandomForestClassifier(n_estimators=100, max_depth=10, class_weight='balanced', random_state=random_state, n_jobs=-1)
 
 
 
 
304
  rf.fit(X_train_s, y_train)
305
+ y_prob_rf = rf.predict_proba(X_test_s)[:,1]; y_pred_rf = rf.predict(X_test_s)
 
306
 
307
  auc_gbdt = float(roc_auc_score(y_test, y_prob_gbdt))
308
  f1_gbdt = float(f1_score(y_test, y_pred_gbdt))
 
310
  auc_rf = float(roc_auc_score(y_test, y_prob_rf))
311
  ap_rf = float(average_precision_score(y_test, y_prob_rf))
312
 
313
+ fi = pd.DataFrame({'feature': feature_names, 'importance': rf.feature_importances_}).sort_values('importance', ascending=False)
 
 
 
314
 
 
315
  cv_scores = None
316
  if use_cv and len(y) >= 100:
317
  skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=random_state)
318
  cv_scores = cross_val_score(rf, X, y, cv=skf, scoring='roc_auc')
319
 
 
320
  os.makedirs("outputs", exist_ok=True)
321
 
322
  fig, ax = plt.subplots(figsize=(12,8))
 
334
  pr, rr, _ = precision_recall_curve(y_test, y_prob_rf)
335
  ax.plot(rg, pg, label=f'GBDT AP={ap_gbdt:.3f}', linewidth=2, color='#2E86AB')
336
  ax.plot(rr, pr, label=f'RF AP={ap_rf:.3f}', linewidth=2, color='#A23B72')
337
+ ax.set_xlabel('Recall', fontsize=12); ax.set_ylabel('Precision', fontsize=12)
 
338
  ax.set_title('Precision-Recall Curve', fontsize=14, fontweight='bold')
339
+ ax.legend(fontsize=11); ax.grid(True, alpha=0.3)
 
340
  plt.tight_layout()
341
  fig_path2 = "outputs/pr_curve.png"
342
  plt.savefig(fig_path2, dpi=150, bbox_inches='tight'); plt.close()
 
352
  fig_path3 = "outputs/confusion_matrix.png"
353
  plt.savefig(fig_path3, dpi=150, bbox_inches='tight'); plt.close()
354
 
 
355
  fig, ax = plt.subplots(figsize=(8,6))
356
  fpr_g, tpr_g, _ = roc_curve(y_test, y_prob_gbdt)
357
  fpr_r, tpr_r, _ = roc_curve(y_test, y_prob_rf)
 
361
  ax.set_xlabel('False Positive Rate', fontsize=12)
362
  ax.set_ylabel('True Positive Rate', fontsize=12)
363
  ax.set_title('ROC Curve', fontsize=14, fontweight='bold')
364
+ ax.legend(fontsize=11); ax.grid(True, alpha=0.3)
 
365
  plt.tight_layout()
366
  fig_path4 = "outputs/roc_curve.png"
367
  plt.savefig(fig_path4, dpi=150, bbox_inches='tight'); plt.close()
 
376
  result_text = f"""=== 模型训练结果 ===
377
  样本数: {len(y)} | 特征数: {len(feature_names)}
378
  训练集: {len(y_train)} | 测试集: {len(y_test)}
379
+ 流失率: {y.mean():.1%} | 流失数: {int(y.sum())}
380
 
381
  --- GBDT ---
382
  AUC: {auc_gbdt:.4f}
 
397
  return result_text, fig_path1, fig_path2, fig_path3, fig_path4, df_full
398
 
399
 
400
+ # =============================================================================
401
+ # 产品推荐 (DIN 简化版)
402
+ # =============================================================================
403
+
404
+ def train_din_recommendation(n_users, embedding_dim, epochs, batch_size, lr, seed):
405
+ """训练 DIN 风格的产品推荐模型 (简化版, 使用 PyTorch 模拟)"""
406
+ if not TORCH_AVAILABLE:
407
+ return "❌ PyTorch 未安装。请在 requirements.txt 中添加 torch 并重启 Space。", None, None, None, None, None
408
+
409
+ torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)
410
+
411
+ # 生成数据
412
+ df = generate_product_recommendation_data(n_users=n_users, seed=seed)
413
+
414
+ # 构建 vocab
415
+ all_events = sorted(set(e for seq in df['behavior_events'] for e in seq))
416
+ event_vocab = {e: i+1 for i, e in enumerate(all_events)}
417
+ all_products = sorted(set(p for seq in df['behavior_products'] for p in seq) | set(df['candidate_product']))
418
+ product_vocab = {p: i+1 for i, p in enumerate(all_products)}
419
+
420
+ # 准备序列数据
421
+ max_seq_len = 20
422
+ behavior_events_padded = []
423
+ behavior_products_padded = []
424
+ behavior_masks = []
425
+
426
+ for _, row in df.iterrows():
427
+ e_seq = [event_vocab[e] for e in row['behavior_events'][-max_seq_len:]]
428
+ p_seq = [product_vocab[p] for p in row['behavior_products'][-max_seq_len:]]
429
+ mask = [1] * len(e_seq)
430
+ if len(e_seq) < max_seq_len:
431
+ pad = max_seq_len - len(e_seq)
432
+ e_seq = [0]*pad + e_seq
433
+ p_seq = [0]*pad + p_seq
434
+ mask = [0]*pad + mask
435
+ behavior_events_padded.append(e_seq)
436
+ behavior_products_padded.append(p_seq)
437
+ behavior_masks.append(mask)
438
+
439
+ df['be'] = behavior_events_padded
440
+ df['bp'] = behavior_products_padded
441
+ df['bm'] = behavior_masks
442
+ df['cp'] = df['candidate_product'].map(product_vocab)
443
+
444
+ # 划分
445
+ train_df = df.sample(frac=0.8, random_state=seed)
446
+ test_df = df.drop(train_df.index)
447
+
448
+ # 简单的 PyTorch 训练 (使用 Attention 的 MLP)
449
+ device = torch.device('cpu')
450
+
451
+ class SimpleDIN(nn.Module):
452
+ def __init__(self, num_events, num_products, d_model=64, max_len=20):
453
+ super().__init__()
454
+ self.event_emb = nn.Embedding(num_events+1, d_model//2, padding_idx=0)
455
+ self.prod_emb = nn.Embedding(num_products+1, d_model//2, padding_idx=0)
456
+ self.cand_emb = nn.Embedding(num_products+1, d_model)
457
+ self.attn = nn.Sequential(
458
+ nn.Linear(d_model*4, 128), nn.ReLU(), nn.Linear(128, 1)
459
+ )
460
+ self.mlp = nn.Sequential(
461
+ nn.Linear(d_model*3, 256), nn.ReLU(), nn.Dropout(0.3),
462
+ nn.Linear(256, 128), nn.ReLU(), nn.Dropout(0.3),
463
+ nn.Linear(128, 1)
464
+ )
465
+
466
+ def forward(self, be, bp, bm, cp):
467
+ B = be.size(0); L = be.size(1)
468
+ e_emb = self.event_emb(be) # (B,L,D/2)
469
+ p_emb = self.prod_emb(bp) # (B,L,D/2)
470
+ beh_emb = torch.cat([e_emb, p_emb], dim=-1) # (B,L,D)
471
+ cand_emb = self.cand_emb(cp) # (B,D)
472
+
473
+ # Attention
474
+ cand_exp = cand_emb.unsqueeze(1).expand(B, L, -1)
475
+ diff = cand_exp - beh_emb
476
+ prod = cand_exp * beh_emb
477
+ attn_in = torch.cat([cand_exp, beh_emb, diff, prod], dim=-1)
478
+ attn_w = self.attn(attn_in).squeeze(-1) # (B,L)
479
+ attn_w = attn_w.masked_fill(~bm.bool(), -1e9)
480
+ attn_w = torch.softmax(attn_w, dim=1)
481
+ interest = (beh_emb * attn_w.unsqueeze(-1)).sum(dim=1) # (B,D)
482
+
483
+ # MLP
484
+ x = torch.cat([interest, cand_emb, interest*cand_emb], dim=-1)
485
+ return self.mlp(x).squeeze(-1)
486
+
487
+ model = SimpleDIN(len(all_events), len(all_products), d_model=embedding_dim).to(device)
488
+ criterion = nn.BCEWithLogitsLoss()
489
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
490
+
491
+ # 训练
492
+ for epoch in range(epochs):
493
+ model.train()
494
+ epoch_loss = 0
495
+ for i in range(0, len(train_df), batch_size):
496
+ batch = train_df.iloc[i:i+batch_size]
497
+ be = torch.tensor(np.stack(batch['be'].values), dtype=torch.long).to(device)
498
+ bp = torch.tensor(np.stack(batch['bp'].values), dtype=torch.long).to(device)
499
+ bm = torch.tensor(np.stack(batch['bm'].values), dtype=torch.bool).to(device)
500
+ cp = torch.tensor(batch['cp'].values, dtype=torch.long).to(device)
501
+ labels = torch.tensor(batch['label'].values, dtype=torch.float32).to(device)
502
+
503
+ optimizer.zero_grad()
504
+ outputs = model(be, bp, bm, cp)
505
+ loss = criterion(outputs, labels)
506
+ loss.backward()
507
+ optimizer.step()
508
+ epoch_loss += loss.item()
509
+
510
+ if (epoch+1) % max(1, epochs//5) == 0 or epoch == 0:
511
+ print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss/len(train_df)*batch_size:.4f}")
512
+
513
+ # 评估
514
+ model.eval()
515
+ with torch.no_grad():
516
+ be = torch.tensor(np.stack(test_df['be'].values), dtype=torch.long).to(device)
517
+ bp = torch.tensor(np.stack(test_df['bp'].values), dtype=torch.long).to(device)
518
+ bm = torch.tensor(np.stack(test_df['bm'].values), dtype=torch.bool).to(device)
519
+ cp = torch.tensor(test_df['cp'].values, dtype=torch.long).to(device)
520
+ labels = test_df['label'].values
521
+
522
+ preds = torch.sigmoid(model(be, bp, bm, cp)).cpu().numpy()
523
+
524
+ auc = float(roc_auc_score(labels, preds))
525
+ ap = float(average_precision_score(labels, preds))
526
+ f1 = float(f1_score(labels, preds > 0.5))
527
+ acc = float(accuracy_score(labels, preds > 0.5))
528
+
529
+ # 可视化
530
+ os.makedirs("outputs", exist_ok=True)
531
+
532
+ # 产品推荐效果
533
+ fig, ax = plt.subplots(figsize=(10,6))
534
+ product_perf = {}
535
+ for _, row in test_df.iterrows():
536
+ prod = row['candidate_product']
537
+ if prod not in product_perf:
538
+ product_perf[prod] = {'preds': [], 'labels': []}
539
+ idx = test_df.index.get_loc(_)
540
+ product_perf[prod]['preds'].append(preds[idx])
541
+ product_perf[prod]['labels'].append(row['label'])
542
+
543
+ prod_aucs = []
544
+ for prod, data in product_perf.items():
545
+ if len(set(data['labels'])) > 1 and len(data['labels']) >= 5:
546
+ prod_auc = roc_auc_score(data['labels'], data['preds'])
547
+ prod_aucs.append((prod, prod_auc, np.mean(data['labels'])))
548
+
549
+ if prod_aucs:
550
+ prod_aucs.sort(key=lambda x: x[1], reverse=True)
551
+ prods, aucs, rates = zip(*prod_aucs)
552
+ x = np.arange(len(prods))
553
+ ax.bar(x, aucs, color='steelblue', alpha=0.7, label='AUC')
554
+ ax2 = ax.twinx()
555
+ ax2.plot(x, rates, 'ro-', label='Conversion Rate')
556
+ ax.set_xticks(x); ax.set_xticklabels(prods, rotation=45, ha='right')
557
+ ax.set_ylabel('AUC', color='steelblue')
558
+ ax2.set_ylabel('Conversion Rate', color='red')
559
+ ax.set_title('Product Recommendation Performance by Product', fontweight='bold')
560
+ ax.legend(loc='upper left'); ax2.legend(loc='upper right')
561
+ plt.tight_layout()
562
+ fig_path1 = "outputs/din_product_performance.png"
563
+ plt.savefig(fig_path1, dpi=150); plt.close()
564
+
565
+ # 注意力可视化 (示例)
566
+ fig, ax = plt.subplots(figsize=(10,6))
567
+ sample_idx = 0
568
+ with torch.no_grad():
569
+ be_s = be[sample_idx:sample_idx+1]
570
+ bp_s = bp[sample_idx:sample_idx+1]
571
+ bm_s = bm[sample_idx:sample_idx+1]
572
+ cp_s = cp[sample_idx:sample_idx+1]
573
+
574
+ B, L = be_s.size()
575
+ e_emb = model.event_emb(be_s)
576
+ p_emb = model.prod_emb(bp_s)
577
+ beh_emb = torch.cat([e_emb, p_emb], dim=-1)
578
+ cand_emb = model.cand_emb(cp_s)
579
+ cand_exp = cand_emb.unsqueeze(1).expand(B, L, -1)
580
+ diff = cand_exp - beh_emb
581
+ prod_feat = cand_exp * beh_emb
582
+ attn_in = torch.cat([cand_exp, beh_emb, diff, prod_feat], dim=-1)
583
+ attn_w = torch.softmax(model.attn(attn_in).squeeze(-1).masked_fill(~bm_s, -1e9), dim=1)
584
+ weights = attn_w[0].cpu().numpy()
585
+
586
+ valid_len = bm_s[0].sum().item()
587
+ valid_weights = weights[-valid_len:] if valid_len > 0 else weights
588
+ ax.bar(range(len(valid_weights)), valid_weights, color='coral')
589
+ ax.set_title('Attention Weights (Sample User)', fontweight='bold')
590
+ ax.set_xlabel('Behavior Position')
591
+ ax.set_ylabel('Attention Weight')
592
+ plt.tight_layout()
593
+ fig_path2 = "outputs/din_attention.png"
594
+ plt.savefig(fig_path2, dpi=150); plt.close()
595
+
596
+ # ROC曲线
597
+ fig, ax = plt.subplots(figsize=(8,6))
598
+ fpr, tpr, _ = roc_curve(labels, preds)
599
+ ax.plot(fpr, tpr, label=f'DIN AUC={auc:.3f}', linewidth=2, color='#2E86AB')
600
+ ax.plot([0,1], [0,1], 'k--', alpha=0.5)
601
+ ax.set_xlabel('False Positive Rate'); ax.set_ylabel('True Positive Rate')
602
+ ax.set_title('ROC Curve - Product Recommendation', fontweight='bold')
603
+ ax.legend(); ax.grid(True, alpha=0.3)
604
+ plt.tight_layout()
605
+ fig_path3 = "outputs/din_roc.png"
606
+ plt.savefig(fig_path3, dpi=150); plt.close()
607
+
608
+ # PR曲线
609
+ fig, ax = plt.subplots(figsize=(8,6))
610
+ prec, rec, _ = precision_recall_curve(labels, preds)
611
+ ax.plot(rec, prec, label=f'DIN AP={ap:.3f}', linewidth=2, color='#A23B72')
612
+ ax.set_xlabel('Recall'); ax.set_ylabel('Precision')
613
+ ax.set_title('Precision-Recall Curve - Product Recommendation', fontweight='bold')
614
+ ax.legend(); ax.grid(True, alpha=0.3)
615
+ plt.tight_layout()
616
+ fig_path4 = "outputs/din_pr.png"
617
+ plt.savefig(fig_path4, dpi=150); plt.close()
618
+
619
+ result_text = f"""=== DIN 保险产品推荐模型 ===
620
+ 样本数: {n_users} | 产品数: {len(all_products)}
621
+ 训练集: {len(train_df)} | 测试集: {len(test_df)}
622
+
623
+ --- 模型架构 ---
624
+ Embedding dim: {embedding_dim}
625
+ Event vocab: {len(all_events)} | Product vocab: {len(all_products)}
626
+ Attention: LocalActivationUnit (4路交互特征)
627
+ MLP: [emb*3] → 256 → 128 → 1
628
+
629
+ --- 训练配置 ---
630
+ Epochs: {epochs} | Batch size: {batch_size} | LR: {lr}
631
+ Optimizer: Adam
632
+
633
+ --- 测试集效果 ---
634
+ AUC: {auc:.4f}
635
+ AP: {ap:.4f}
636
+ F1: {f1:.4f}
637
+ Accuracy: {acc:.4f}
638
+
639
+ --- 模型洞察 ---
640
+ 1. 注意力机制自动学习用户历史行为中对候选产品的相关度
641
+ 2. 高权重通常分配给同类产品的历史浏览/购买行为
642
+ 3. 新用户(历史短)依赖统计特征, 老用户依赖行为序列"""
643
+
644
+ return result_text, fig_path1, fig_path2, fig_path3, fig_path4
645
+
646
+
647
+ # =============================================================================
648
+ # 异常检测 (TabBERT 简化版)
649
+ # =============================================================================
650
+
651
+ def train_tabbert_anomaly(n_normal, n_anomaly, d_model, epochs, batch_size, lr, seed):
652
+ """训练 TabularBERT 风格的异常检测模型"""
653
+ if not TORCH_AVAILABLE:
654
+ return "❌ PyTorch 未安装。请在 requirements.txt 中添加 torch 并重启 Space。", None, None, None, None, None
655
+
656
+ torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)
657
+
658
+ # 生成数据
659
+ df = generate_anomaly_data(n_normal=n_normal, n_anomaly=n_anomaly, seed=seed)
660
+
661
+ # 特征编码
662
+ claim_type_map = {"health": 0, "auto": 1, "property": 2}
663
+ df['claim_type_enc'] = df['claim_type'].map(claim_type_map)
664
+
665
+ feature_cols = ['claim_amount', 'claim_type_enc', 'days_since_policy',
666
+ 'num_previous_claims', 'document_count', 'processing_time_days']
667
+
668
+ X = df[feature_cols].values.astype(np.float32)
669
+ y = df['label'].values.astype(np.float32)
670
+
671
+ # 标准化
672
+ scaler = StandardScaler()
673
+ X_s = scaler.fit_transform(X)
674
+
675
+ X_train, X_test, y_train, y_test = train_test_split(
676
+ X_s, y, test_size=0.2, random_state=seed, stratify=y
677
+ )
678
+
679
+ # 简单的 Tabular MLP (模拟 TabBERT)
680
+ device = torch.device('cpu')
681
+
682
+ class SimpleTabBERT(nn.Module):
683
+ def __init__(self, input_dim=6, d_model=128, n_layers=4):
684
+ super().__init__()
685
+ self.input_proj = nn.Linear(input_dim, d_model)
686
+
687
+ # 模拟 Transformer layers
688
+ layers = []
689
+ for _ in range(n_layers):
690
+ layers.extend([
691
+ nn.Linear(d_model, d_model*4),
692
+ nn.ReLU(),
693
+ nn.Dropout(0.2),
694
+ nn.Linear(d_model*4, d_model),
695
+ nn.LayerNorm(d_model),
696
+ nn.ReLU(),
697
+ nn.Dropout(0.2),
698
+ ])
699
+ self.transformer = nn.Sequential(*layers)
700
+
701
+ self.head = nn.Sequential(
702
+ nn.Linear(d_model, 256), nn.ReLU(), nn.Dropout(0.3),
703
+ nn.Linear(256, 64), nn.ReLU(),
704
+ nn.Linear(64, 1)
705
+ )
706
+
707
+ def forward(self, x):
708
+ x = self.input_proj(x)
709
+ x = self.transformer(x)
710
+ return self.head(x).squeeze(-1)
711
+
712
+ model = SimpleTabBERT(input_dim=len(feature_cols), d_model=d_model).to(device)
713
+
714
+ # Focal Loss (不平衡数据)
715
+ class FocalLoss(nn.Module):
716
+ def __init__(self, alpha=0.25, gamma=2.0):
717
+ super().__init__()
718
+ self.alpha = alpha; self.gamma = gamma
719
+
720
+ def forward(self, inputs, targets):
721
+ bce = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
722
+ pt = torch.exp(-bce)
723
+ return (self.alpha * (1-pt)**self.gamma * bce).mean()
724
+
725
+ criterion = FocalLoss(alpha=0.25, gamma=2.0)
726
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
727
+
728
+ # 转换为 tensor
729
+ X_train_t = torch.tensor(X_train, dtype=torch.float32).to(device)
730
+ y_train_t = torch.tensor(y_train, dtype=torch.float32).to(device)
731
+ X_test_t = torch.tensor(X_test, dtype=torch.float32).to(device)
732
+ y_test_t = torch.tensor(y_test, dtype=torch.float32).to(device)
733
+
734
+ # 训练
735
+ for epoch in range(epochs):
736
+ model.train()
737
+ epoch_loss = 0
738
+ n_batches = math.ceil(len(X_train_t) / batch_size)
739
+
740
+ for i in range(n_batches):
741
+ start = i * batch_size
742
+ end = min(start + batch_size, len(X_train_t))
743
+ xb = X_train_t[start:end]
744
+ yb = y_train_t[start:end]
745
+
746
+ optimizer.zero_grad()
747
+ outputs = model(xb)
748
+ loss = criterion(outputs, yb)
749
+ loss.backward()
750
+ optimizer.step()
751
+ epoch_loss += loss.item()
752
+
753
+ if (epoch+1) % max(1, epochs//5) == 0 or epoch == 0:
754
+ print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss/n_batches:.4f}")
755
+
756
+ # 评估
757
+ model.eval()
758
+ with torch.no_grad():
759
+ preds = torch.sigmoid(model(X_test_t)).cpu().numpy()
760
+
761
+ auc = float(roc_auc_score(y_test, preds))
762
+ ap = float(average_precision_score(y_test, preds))
763
+ f1 = float(f1_score(y_test, preds > 0.5))
764
+
765
+ # 可视化
766
+ os.makedirs("outputs", exist_ok=True)
767
+
768
+ # 特征重要性 (通过梯度近似)
769
+ model.eval()
770
+ X_test_grad = torch.tensor(X_test, dtype=torch.float32, requires_grad=True).to(device)
771
+ with torch.no_grad():
772
+ outputs = model(X_test_grad)
773
+
774
+ # 使用 permutation importance 近似
775
+ baseline_auc = auc
776
+ importances = []
777
+ for i in range(len(feature_cols)):
778
+ X_perm = X_test.copy()
779
+ np.random.shuffle(X_perm[:, i])
780
+ X_perm_t = torch.tensor(X_perm, dtype=torch.float32).to(device)
781
+ with torch.no_grad():
782
+ perm_preds = torch.sigmoid(model(X_perm_t)).cpu().numpy()
783
+ perm_auc = roc_auc_score(y_test, perm_preds)
784
+ importances.append(baseline_auc - perm_auc)
785
+
786
+ fig, ax = plt.subplots(figsize=(10,6))
787
+ colors = ['red' if imp > 0 else 'gray' for imp in importances]
788
+ ax.barh(feature_cols, importances, color=colors)
789
+ ax.set_title('TabularBERT - Feature Importance (Permutation)', fontweight='bold')
790
+ ax.set_xlabel('AUC Drop (Importance)')
791
+ plt.tight_layout()
792
+ fig_path1 = "outputs/tabbert_feature_importance.png"
793
+ plt.savefig(fig_path1, dpi=150); plt.close()
794
+
795
+ # 异常分数分布
796
+ fig, ax = plt.subplots(figsize=(10,6))
797
+ normal_scores = preds[y_test == 0]
798
+ anomaly_scores = preds[y_test == 1]
799
+ ax.hist(normal_scores, bins=30, alpha=0.6, label=f'Normal (n={len(normal_scores)})', color='steelblue', edgecolor='white')
800
+ ax.hist(anomaly_scores, bins=30, alpha=0.6, label=f'Anomaly (n={len(anomaly_scores)})', color='red', edgecolor='white')
801
+ ax.axvline(x=0.5, color='black', linestyle='--', label='Threshold=0.5')
802
+ ax.set_xlabel('Anomaly Score'); ax.set_ylabel('Count')
803
+ ax.set_title('Anomaly Score Distribution', fontweight='bold')
804
+ ax.legend(); ax.grid(True, alpha=0.3)
805
+ plt.tight_layout()
806
+ fig_path2 = "outputs/tabbert_distribution.png"
807
+ plt.savefig(fig_path2, dpi=150); plt.close()
808
+
809
+ # ROC曲线
810
+ fig, ax = plt.subplots(figsize=(8,6))
811
+ fpr, tpr, _ = roc_curve(y_test, preds)
812
+ ax.plot(fpr, tpr, label=f'TabBERT AUC={auc:.3f}', linewidth=2, color='#2E86AB')
813
+ ax.plot([0,1], [0,1], 'k--', alpha=0.5)
814
+ ax.set_xlabel('False Positive Rate'); ax.set_ylabel('True Positive Rate')
815
+ ax.set_title('ROC Curve - Anomaly Detection', fontweight='bold')
816
+ ax.legend(); ax.grid(True, alpha=0.3)
817
+ plt.tight_layout()
818
+ fig_path3 = "outputs/tabbert_roc.png"
819
+ plt.savefig(fig_path3, dpi=150); plt.close()
820
+
821
+ # 混淆矩阵 + 阈值分析
822
+ fig, axs = plt.subplots(1, 2, figsize=(14,6))
823
+
824
+ # 混淆矩阵 @ 0.5
825
+ cm = confusion_matrix(y_test, preds > 0.5)
826
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axs[0], cbar=False)
827
+ axs[0].set_title(f'Confusion Matrix @ threshold=0.5\n(F1={f1:.3f})', fontweight='bold')
828
+ axs[0].set_xlabel('Predicted'); axs[0].set_ylabel('Actual')
829
+
830
+ # 阈值分析
831
+ thresholds = np.linspace(0.1, 0.9, 50)
832
+ f1s = [f1_score(y_test, preds > t) for t in thresholds]
833
+ precs = [precision_score(y_test, preds > t, zero_division=0) for t in thresholds]
834
+ recs = [recall_score(y_test, preds > t, zero_division=0) for t in thresholds]
835
+
836
+ axs[1].plot(thresholds, f1s, label='F1', linewidth=2)
837
+ axs[1].plot(thresholds, precs, label='Precision', linewidth=2)
838
+ axs[1].plot(thresholds, recs, label='Recall', linewidth=2)
839
+ best_t = thresholds[np.argmax(f1s)]
840
+ axs[1].axvline(x=best_t, color='red', linestyle='--', label=f'Best F1 @ {best_t:.2f}')
841
+ axs[1].set_xlabel('Threshold'); axs[1].set_ylabel('Score')
842
+ axs[1].set_title('Threshold Analysis', fontweight='bold')
843
+ axs[1].legend(); axs[1].grid(True, alpha=0.3)
844
+ plt.tight_layout()
845
+ fig_path4 = "outputs/tabbert_threshold.png"
846
+ plt.savefig(fig_path4, dpi=150); plt.close()
847
+
848
+ result_text = f"""=== TabularBERT 异常行为检测模型 ===
849
+ 样本数: {len(df)} (正常: {n_normal}, 异常: {n_anomaly})
850
+ 特征数: {len(feature_cols)}
851
+ 训练集: {len(y_train)} | 测试集: {len(y_test)}
852
+
853
+ --- 模型架构 ---
854
+ Input dim: {len(feature_cols)} → d_model: {d_model}
855
+ Transformer layers: {4} (模拟层次化BERT)
856
+ Head: {d_model} → 256 → 64 → 1
857
+ Loss: Focal Loss (α=0.25, γ=2.0)
858
+
859
+ --- 训练配置 ---
860
+ Epochs: {epochs} | Batch size: {batch_size} | LR: {lr}
861
+ Optimizer: Adam
862
+
863
+ --- 测试集效果 ---
864
+ AUC: {auc:.4f}
865
+ AP: {ap:.4f}
866
+ F1: {f1:.4f} @ threshold=0.5
867
+ Best F1: {max(f1s):.4f} @ threshold={best_t:.2f}
868
+
869
+ --- 模型洞察 ---
870
+ 1. Focal Loss 自动聚焦难分异常样本, 解决类别不平衡
871
+ 2. 关键异常特征: claim_amount(高), days_since_policy(短), document_count(少)
872
+ 3. 建议阈值: {best_t:.2f} (平衡精确率与召回率)
873
+ 4. 高AUC说明模型能很好区分正常与异常理赔"""
874
+
875
+ return result_text, fig_path1, fig_path2, fig_path3, fig_path4
876
+
877
+
878
  # =============================================================================
879
  # Gradio 回调函数
880
  # =============================================================================
881
 
882
  def demo_train(n_users, n_events, test_size, random_state, use_cv):
883
+ """演示模式"""
884
+ data = generate_synthetic_data(n_users=n_users, n_events_per_user=n_events, seed=random_state)
885
  engineer = InsuranceFeatureEngineer()
886
  features_list, labels = [], []
887
  for profile, label in data:
888
  f = engineer.extract_user_features(profile)
889
  if f: features_list.append(f); labels.append(label)
890
+ return train_sklearn(features_list, labels, test_size, random_state, use_cv)
 
891
 
892
 
893
  def csv_train(csv_file, label_col, test_size, random_state, use_cv):
894
+ """CSV模式"""
895
  if csv_file is None:
896
  return "请先上传CSV文件", None, None, None, None, None
 
897
  try:
 
898
  if isinstance(csv_file, str):
899
  df = pd.read_csv(csv_file)
900
  else:
901
  df = pd.read_csv(csv_file.name if hasattr(csv_file, 'name') else io.BytesIO(csv_file))
902
 
 
903
  label_col = label_col.strip() if label_col else None
904
  if label_col and label_col not in df.columns:
905
  return f"标签列 '{label_col}' 不存在。可用列: {list(df.columns)}", None, None, None, None, None
906
 
 
907
  profiles = parse_csv_to_profiles(df)
 
908
  engineer = InsuranceFeatureEngineer()
909
  features_list, labels = [], []
910
 
 
912
  f = engineer.extract_user_features(profile)
913
  if f:
914
  features_list.append(f)
 
915
  if label_col and label_col in df.columns:
 
916
  user_df = df[df["user_id"] == profile.user_id]
917
+ labels.append(int(user_df[label_col].iloc[0]))
 
918
  else:
919
+ is_high_risk = (f["has_purchased"] == 0 and f["has_renewed"] == 0 and f["total_events"] < 20)
 
 
920
  labels.append(int(is_high_risk))
921
 
922
  if len(features_list) < 50:
923
+ return f"有效样本数 {len(features_list)} 太少,需要至少50个", None, None, None, None, None
 
 
 
924
 
925
+ return train_sklearn(features_list, labels, test_size, random_state, use_cv)
926
  except Exception as e:
927
  import traceback
928
  return f"错误: {str(e)}\n\n{traceback.format_exc()}", None, None, None, None, None
929
 
930
 
931
  def show_csv_info(csv_file):
 
932
  if csv_file is None:
933
  return "请先上传CSV文件", None
 
934
  try:
935
  if isinstance(csv_file, str):
936
  df = pd.read_csv(csv_file)
937
  else:
938
  df = pd.read_csv(csv_file.name if hasattr(csv_file, 'name') else io.BytesIO(csv_file))
 
939
  info = f"""=== CSV文件信息 ===
940
+ 行数: {len(df)} | 列数: {len(df.columns)}
 
941
  列名: {list(df.columns)}
942
 
943
+ === 前5行 ===
944
  {df.head().to_string()}
945
 
946
  === 事件类型分布 (前10) ===
947
  {df['event_type'].value_counts().head(10).to_string() if 'event_type' in df.columns else '无event_type列'}
948
 
949
+ === 用户数: {df['user_id'].nunique() if 'user_id' in df.columns else 'N/A'} ===
950
+ === 会话数: {df['session_id'].nunique() if 'session_id' in df.columns else 'N/A'} ==="""
 
 
 
 
951
  return info, df.head(20)
952
  except Exception as e:
953
  return f"解析错误: {str(e)}", None
954
 
955
 
956
  # =============================================================================
957
+ # Gradio 界面 (5 Tabs)
958
  # =============================================================================
959
 
960
  with gr.Blocks(title="🏥 保险APP 用户行为分析模型训练平台", theme=gr.themes.Soft()) as demo:
961
+ gr.Markdown("""# 🏥 保险APP 用户行为分析模型训练平台
962
+
963
+ 基于最新研究论文构建的工业级保险用户行为分析平台。
964
+
965
+ **五大功能模块:**
966
+ - 🎲 **演示模式**: 合成数据体验完整训练流程
967
+ - 📁 **CSV上传**: 上传真实用户行为数据
968
+ - 🎯 **产品推荐 (DIN)**: Deep Interest Network 保险产品推荐
969
+ - 🔍 **异常检测 (TabBERT)**: 层次化Transformer理赔欺诈检测
970
+ - ❓ **帮助**: 完整使用指南
971
+
972
+ **参考论文:** Deep Interest Network (KDD 2018) | Transformer Churn Prediction (arXiv 2309.14390) | TabBERT (arXiv 2011.01843) | Focal Loss (ICCV 2017)""")
973
 
974
  with gr.Tabs():
975
  # ===== Tab 1: 演示模式 =====
976
+ with gr.Tab("🎲 演示模式"):
977
  with gr.Row():
978
  with gr.Column(scale=1):
979
  gr.Markdown("### 参数设置")
 
983
  random_seed = gr.Number(value=42, label="随机种子", precision=0)
984
  use_cv_check = gr.Checkbox(value=False, label="启用5折交叉验证")
985
  train_btn = gr.Button("🚀 开始训练", variant="primary", size="lg")
 
986
  with gr.Column(scale=2):
987
  demo_result = gr.Textbox(label="训练结果", lines=25, show_copy_button=True)
 
988
  with gr.Row():
989
  demo_img1 = gr.Image(label="特征重要性")
990
  demo_img2 = gr.Image(label="PR曲线")
 
992
  demo_img3 = gr.Image(label="混淆矩阵")
993
  demo_img4 = gr.Image(label="ROC曲线")
994
  with gr.Row():
995
+ demo_table = gr.Dataframe(label="特征数据样本")
996
 
997
  # ===== Tab 2: CSV上传 =====
998
  with gr.Tab("📁 CSV数据上传"):
999
  with gr.Row():
1000
  with gr.Column(scale=1):
1001
+ gr.Markdown("""### 📤 上传数据
1002
+
1003
+ **必需列:** `user_id`, `session_id`, `timestamp`, `event_type`, `page_id`
1004
+
1005
+ **可选列:** `product_id`, `amount`, `label`(流失签)
1006
+
1007
+ **示例:**
1008
+ ```
1009
+ user_id,session_id,timestamp,event_type,page_id,product_id,amount
1010
+ user_001,sess_001,1704067200000,page_view,home,,
1011
+ user_001,sess_001,1704067230000,product_view,product,health_basic,
1012
+ ```""")
 
 
 
 
 
 
 
 
 
 
 
 
1013
  csv_file = gr.File(label="上传CSV文件", file_types=[".csv"])
1014
+ label_col_input = gr.Textbox(label="标签列名 (可选)", placeholder="如: churn, is_churned")
 
1015
  with gr.Row():
1016
  csv_test_size = gr.Slider(0.1, 0.4, value=0.2, step=0.05, label="测试集比例")
1017
  csv_random_seed = gr.Number(value=42, label="随机种子", precision=0)
 
1018
  csv_use_cv = gr.Checkbox(value=False, label="启用5折交叉验证")
 
1019
  with gr.Row():
1020
  info_btn = gr.Button("📊 查看数据信息", variant="secondary")
1021
  csv_train_btn = gr.Button("🚀 训练模型", variant="primary", size="lg")
 
1022
  with gr.Column(scale=2):
1023
  csv_info = gr.Textbox(label="CSV信息", lines=15, show_copy_button=True)
1024
  csv_preview = gr.Dataframe(label="数据预览")
 
1025
  with gr.Row():
1026
  csv_result = gr.Textbox(label="训练结果", lines=25, show_copy_button=True)
 
1027
  with gr.Row():
1028
  csv_img1 = gr.Image(label="特征重要性")
1029
  csv_img2 = gr.Image(label="PR曲线")
 
1031
  csv_img3 = gr.Image(label="混淆矩阵")
1032
  csv_img4 = gr.Image(label="ROC曲线")
1033
  with gr.Row():
1034
+ csv_table = gr.Dataframe(label="特征数据样本")
1035
 
1036
+ # ===== Tab 3: 产品推荐 (DIN) =====
1037
+ with gr.Tab("🎯 产品推荐 (DIN)"):
1038
+ gr.Markdown("""### Deep Interest Network - 保险产品推荐
1039
+
1040
+ 基于用户历史行为序列, 通过注意力机制动态计算对候选保险产品的兴趣度, 预测购买概率。
1041
+
1042
+ **核心架构:**
1043
+ - 用户历史行为 Embedding LocalActivationUnit 动态兴趣向量
1044
+ - 候选产品Embedding → 拼接交互特征 MLP 购买概率""")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1045
 
1046
+ with gr.Row():
1047
+ with gr.Column(scale=1):
1048
+ gr.Markdown("### DIN 参数")
1049
+ din_users = gr.Slider(500, 5000, value=2000, step=100, label="用户数量")
1050
+ din_emb = gr.Slider(32, 256, value=64, step=32, label="Embedding维度")
1051
+ din_epochs = gr.Slider(5, 50, value=20, step=5, label="训练轮数")
1052
+ din_batch = gr.Slider(32, 512, value=128, step=32, label="Batch Size")
1053
+ din_lr = gr.Slider(0.0001, 0.01, value=0.001, step=0.0001, label="学习率")
1054
+ din_seed = gr.Number(value=42, label="随机种子", precision=0)
1055
+ din_btn = gr.Button("🚀 训练DIN模型", variant="primary", size="lg")
1056
+
1057
+ if not TORCH_AVAILABLE:
1058
+ gr.Markdown("⚠️ **PyTorch 未安装**。请在 requirements.txt 中添加 `torch>=2.0.0` 并重启 Space。")
1059
+
1060
+ with gr.Column(scale=2):
1061
+ din_result = gr.Textbox(label="训练结果", lines=25, show_copy_button=True)
1062
 
1063
+ with gr.Row():
1064
+ din_img1 = gr.Image(label="产品推荐效果")
1065
+ din_img2 = gr.Image(label="注意力权重示例")
1066
+ with gr.Row():
1067
+ din_img3 = gr.Image(label="ROC曲线")
1068
+ din_img4 = gr.Image(label="PR曲线")
1069
+
1070
+ # ===== Tab 4: 异常检测 (TabBERT) =====
1071
+ with gr.Tab("🔍 异常检测 (TabBERT)"):
1072
+ gr.Markdown("""### TabularBERT - 理赔欺诈/异常检测
1073
+
1074
+ 层次化Transformer架构, 学习理赔记录的多字段关联和时序模式, 自动识别异常理赔行为。
1075
+
1076
+ **核心架构:**
1077
+ - Field-level Transformer: 单条理赔记录内字段关联
1078
+ - Sequence-level Transformer: 跨理赔记录时序模式
1079
+ - Focal Loss: 解决异常样本极少的不平衡问题""")
1080
 
1081
+ with gr.Row():
1082
+ with gr.Column(scale=1):
1083
+ gr.Markdown("### TabBERT 参数")
1084
+ tab_normal = gr.Slider(500, 2000, value=800, step=100, label="正常样本数")
1085
+ tab_anomaly = gr.Slider(100, 1000, value=200, step=50, label="异常样本数")
1086
+ tab_dmodel = gr.Slider(64, 256, value=128, step=64, label="模型维度 d_model")
1087
+ tab_epochs = gr.Slider(10, 100, value=30, step=10, label="训练轮数")
1088
+ tab_batch = gr.Slider(16, 256, value=64, step=16, label="Batch Size")
1089
+ tab_lr = gr.Slider(0.0001, 0.01, value=0.001, step=0.0001, label="学习率")
1090
+ tab_seed = gr.Number(value=42, label="随机种子", precision=0)
1091
+ tab_btn = gr.Button("🚀 训练TabBERT模型", variant="primary", size="lg")
1092
+
1093
+ if not TORCH_AVAILABLE:
1094
+ gr.Markdown("⚠️ **PyTorch 未安装**。请在 requirements.txt 中添加 `torch>=2.0.0` 并重启 Space。")
1095
+
1096
+ with gr.Column(scale=2):
1097
+ tab_result = gr.Textbox(label="训练结果", lines=25, show_copy_button=True)
1098
 
1099
+ with gr.Row():
1100
+ tab_img1 = gr.Image(label="特征重要性")
1101
+ tab_img2 = gr.Image(label="异常分数分布")
1102
+ with gr.Row():
1103
+ tab_img3 = gr.Image(label="ROC曲线")
1104
+ tab_img4 = gr.Image(label="混淆矩阵与阈值分析")
1105
+
1106
+ # ===== Tab 5: 帮助文档 =====
1107
+ with gr.Tab("❓ 帮助文档"):
1108
+ gr.Markdown("""## 📚 完整使用指南
1109
+
1110
+ ### 1. 演示模式
1111
+ - 调整用户数量和事件数, 系统自动生成合成保险APP行为数据
1112
+ - 高流失风险用户模拟: 低频浏览、无转化、短会话
1113
+ - 低流失风险用户模拟: 完整行为漏斗、有保单、有续保
1114
+
1115
+ ### 2. CSV数据上传
1116
+ **必需列:**
1117
+ | 列名 | 类型 | 说明 |
1118
+ |------|------|------|
1119
+ | user_id | string/int | 用户唯一标识 |
1120
+ | session_id | string/int | 会话标识 |
1121
+ | timestamp | int | Unix时间戳(毫秒或秒) |
1122
+ | event_type | string | 见下方事件类型表 |
1123
+ | page_id | string | 页面标识 |
1124
+
1125
+ **可选列:**
1126
+ | 列名 | 类型 | 说明 |
1127
+ |------|------|------|
1128
+ | product_id | string | 保险产品ID |
1129
+ | amount | float | 金额/保额 |
1130
+ | label | int(0/1) | 流失标签 |
1131
+
1132
+ ### 3. 事件类型定义
1133
+
1134
+ | 类别 | 事件 | 业务含义 |
1135
+ |------|------|---------|
1136
+ | **浏览** | page_view, product_view, premium_calculator, article_read, faq_view, product_compare | 用户浏览保险产品页面 |
1137
+ | **交互** | quote_request, form_submit, document_upload, chat_init, call_init, video_consult, quote_result_view | 用户深度参与行为 |
1138
+ | **转化** | policy_select, payment_init, payment_success, policy_issued | 核心KPI转化行为 |
1139
+ | **理赔** | claim_init, claim_doc_upload, claim_review, claim_approved, claim_rejected | 理赔全流程 |
1140
+ | **续保** | renewal_reminder, renewal_click, renewal_complete, policy_cancel | 续保/流失信号 |
1141
+
1142
+ ### 4. 模型对比
1143
+
1144
+ | 模型 | 适用场景 | 核心特点 |
1145
+ |------|---------|---------|
1146
+ | **GBDT** | 流失预测基线 | 高精度, 可解释, 训练快 |
1147
+ | **Random Forest** | 特征筛选 | 抗过拟合, 特征重要性直观 |
1148
+ | **DIN** | 产品推荐 | 注意力动态兴趣, 候选产品自适应 |
1149
+ | **TabBERT** | 异常检测 | 层次化Transformer, Focal Loss |
1150
+
1151
+ ### 5. 评估指标
1152
+
1153
+ | 指标 | 说明 | 适用场景 |
1154
+ |------|------|---------|
1155
+ | **AUC-ROC** | 分类器整体区分能力 | 所有二分类任务 |
1156
+ | **F1-Score** | 精确率与召回率调和平均 | 不平衡数据 |
1157
+ | **AP** | PR曲线下面积 | 正样本极少时 |
1158
+ | **交叉验证** | 5折StratifiedKFold | 评估模型稳定性 |
1159
+
1160
+ ### 6. 参考文献
1161
+
1162
+ | 论文 | 应用 | arXiv |
1163
+ |------|------|-------|
1164
+ | Deep Interest Network | 产品推荐 | [1706.06978](https://arxiv.org/abs/1706.06978) |
1165
+ | SDIM | 长期行为建模 | [2205.10249](https://arxiv.org/abs/2205.10249) |
1166
+ | TabBERT/TabFormer | 表格时序异常检测 | [2011.01843](https://arxiv.org/abs/2011.01843) |
1167
+ | Transformer Churn | 非合约流失预测 | [2309.14390](https://arxiv.org/abs/2309.14390) |
1168
+ | Focal Loss | 不平衡分类 | [1708.02002](https://arxiv.org/abs/1708.02002) |
1169
+ """)
1170
+
1171
+ gr.Markdown("""---
1172
+ <div align="center">
1173
+ <b>保险APP 用户行为分析模型训练平台</b> |
1174
+ <a href="https://arxiv.org/abs/1706.06978">DIN</a> |
1175
+ <a href="https://arxiv.org/abs/2309.14390">Churn Transformer</a> |
1176
+ <a href="https://arxiv.org/abs/2011.01843">TabBERT</a> |
1177
+ <a href="https://arxiv.org/abs/1708.02002">Focal Loss</a> |
1178
+ 作者: <a href="https://huggingface.co/Stephanwu">Stephanwu</a>
1179
+ </div>""")
1180
 
1181
  # ===== 事件绑定 =====
1182
  train_btn.click(
 
1184
  inputs=[n_users_slider, n_events_slider, test_size_slider, random_seed, use_cv_check],
1185
  outputs=[demo_result, demo_img1, demo_img2, demo_img3, demo_img4, demo_table]
1186
  )
 
1187
  info_btn.click(
1188
  fn=show_csv_info,
1189
  inputs=[csv_file],
1190
  outputs=[csv_info, csv_preview]
1191
  )
 
1192
  csv_train_btn.click(
1193
  fn=csv_train,
1194
  inputs=[csv_file, label_col_input, csv_test_size, csv_random_seed, csv_use_cv],
1195
  outputs=[csv_result, csv_img1, csv_img2, csv_img3, csv_img4, csv_table]
1196
  )
1197
+ din_btn.click(
1198
+ fn=train_din_recommendation,
1199
+ inputs=[din_users, din_emb, din_epochs, din_batch, din_lr, din_seed],
1200
+ outputs=[din_result, din_img1, din_img2, din_img3, din_img4]
1201
+ )
1202
+ tab_btn.click(
1203
+ fn=train_tabbert_anomaly,
1204
+ inputs=[tab_normal, tab_anomaly, tab_dmodel, tab_epochs, tab_batch, tab_lr, tab_seed],
1205
+ outputs=[tab_result, tab_img1, tab_img2, tab_img3, tab_img4]
1206
+ )
1207
 
1208
  if __name__ == "__main__":
1209
  demo.launch()