Stephanwu commited on
Commit
15cf95f
·
verified ·
1 Parent(s): 153f1c7

Add CSV upload support and comprehensive UI

Browse files
Files changed (1) hide show
  1. app.py +427 -57
app.py CHANGED
@@ -1,18 +1,21 @@
1
- """保险APP 用户行为分析 - Gradio Space"""
2
- import os, json, math, warnings, datetime, random
3
- from collections import Counter
 
 
4
  from dataclasses import dataclass, field
5
  from typing import List, Dict, Optional
6
 
7
  warnings.filterwarnings('ignore')
8
  import numpy as np
9
  import pandas as pd
10
- from sklearn.model_selection import train_test_split
11
- from sklearn.preprocessing import StandardScaler
12
  from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier
13
  from sklearn.metrics import (
14
  roc_auc_score, f1_score, confusion_matrix,
15
- average_precision_score, precision_recall_curve, classification_report
 
16
  )
17
  import matplotlib
18
  matplotlib.use('Agg')
@@ -21,6 +24,10 @@ import seaborn as sns
21
 
22
  import gradio as gr
23
 
 
 
 
 
24
  INSURANCE_EVENT_TYPES = {
25
  "page_view", "product_view", "product_compare", "premium_calculator",
26
  "faq_view", "article_read", "quote_request", "quote_result_view",
@@ -31,6 +38,12 @@ INSURANCE_EVENT_TYPES = {
31
  "policy_cancel", "app_uninstall", "login", "logout",
32
  }
33
 
 
 
 
 
 
 
34
  @dataclass
35
  class InsuranceAppEvent:
36
  event_id: str; user_id: str; session_id: str; timestamp: int
@@ -88,6 +101,54 @@ class InsuranceFeatureEngineer:
88
  }
89
 
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  def generate_synthetic_data(n_users=2000, n_events_per_user=50):
92
  event_types = list(INSURANCE_EVENT_TYPES)
93
  products = ["health_basic","health_premium","critical_illness","term_life",
@@ -119,31 +180,55 @@ def generate_synthetic_data(n_users=2000, n_events_per_user=50):
119
  return data
120
 
121
 
122
- def train_model(n_users, n_events, test_size, random_state):
123
- data = generate_synthetic_data(n_users=n_users, n_events_per_user=n_events)
124
- engineer = InsuranceFeatureEngineer()
125
- features_list, labels = [], []
126
- for profile, label in data:
127
- f = engineer.extract_user_features(profile)
128
- if f: features_list.append(f); labels.append(label)
129
  df = pd.DataFrame(features_list)
130
  df_full = df.copy()
 
 
 
 
 
 
 
 
131
  for c in df.columns:
132
  if df[c].dtype == 'object':
133
  df[c] = pd.to_numeric(df[c], errors='coerce').fillna(0)
134
  df = df.fillna(0).replace([np.inf, -np.inf], 0)
 
135
  X = df.values; y = np.array(labels)
136
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state, stratify=y)
 
 
 
 
 
137
  scaler = StandardScaler()
138
- X_train_s = scaler.fit_transform(X_train); X_test_s = scaler.transform(X_test)
 
139
 
140
- gbdt = GradientBoostingClassifier(n_estimators=200, max_depth=5, learning_rate=0.1, subsample=0.8, random_state=random_state)
 
 
 
 
141
  gbdt.fit(X_train_s, y_train)
142
- y_pred_gbdt = gbdt.predict(X_test_s); y_prob_gbdt = gbdt.predict_proba(X_test_s)[:,1]
 
143
 
144
- rf = RandomForestClassifier(n_estimators=100, max_depth=10, class_weight='balanced', random_state=random_state, n_jobs=-1)
 
 
 
 
145
  rf.fit(X_train_s, y_train)
146
- y_prob_rf = rf.predict_proba(X_test_s)[:,1]; y_pred_rf = rf.predict(X_test_s)
 
147
 
148
  auc_gbdt = float(roc_auc_score(y_test, y_prob_gbdt))
149
  f1_gbdt = float(f1_score(y_test, y_pred_gbdt))
@@ -151,15 +236,26 @@ def train_model(n_users, n_events, test_size, random_state):
151
  auc_rf = float(roc_auc_score(y_test, y_prob_rf))
152
  ap_rf = float(average_precision_score(y_test, y_prob_rf))
153
 
154
- fi = pd.DataFrame({'feature': list(df.columns), 'importance': rf.feature_importances_}).sort_values('importance', ascending=False)
 
 
 
 
 
 
 
 
 
155
 
 
156
  os.makedirs("outputs", exist_ok=True)
157
 
158
  fig, ax = plt.subplots(figsize=(12,8))
159
  top = fi.head(15)
160
- ax.barh(top['feature'][::-1], top['importance'][::-1], color='steelblue')
161
- ax.set_title('Insurance APP - Top 15 Feature Importance')
162
- ax.set_xlabel('Importance')
 
163
  plt.tight_layout()
164
  fig_path1 = "outputs/feature_importance.png"
165
  plt.savefig(fig_path1, dpi=150, bbox_inches='tight'); plt.close()
@@ -167,29 +263,55 @@ def train_model(n_users, n_events, test_size, random_state):
167
  fig, ax = plt.subplots(figsize=(8,6))
168
  pg, rg, _ = precision_recall_curve(y_test, y_prob_gbdt)
169
  pr, rr, _ = precision_recall_curve(y_test, y_prob_rf)
170
- ax.plot(rg, pg, label=f'GBDT AP={ap_gbdt:.3f}')
171
- ax.plot(rr, pr, label=f'RF AP={ap_rf:.3f}')
172
- ax.set_xlabel('Recall'); ax.set_ylabel('Precision')
173
- ax.set_title('Precision-Recall Curve'); ax.legend()
 
 
 
174
  plt.tight_layout()
175
  fig_path2 = "outputs/pr_curve.png"
176
  plt.savefig(fig_path2, dpi=150, bbox_inches='tight'); plt.close()
177
 
178
  fig, axs = plt.subplots(1,2,figsize=(12,5))
179
- sns.heatmap(confusion_matrix(y_test, y_pred_gbdt), annot=True, fmt='d', cmap='Blues', ax=axs[0])
180
- axs[0].set_title(f'GBDT (AUC={auc_gbdt:.3f})')
181
- sns.heatmap(confusion_matrix(y_test, y_pred_rf), annot=True, fmt='d', cmap='Greens', ax=axs[1])
182
- axs[1].set_title(f'RF (AUC={auc_rf:.3f})')
 
 
183
  plt.tight_layout()
184
  fig_path3 = "outputs/confusion_matrix.png"
185
  plt.savefig(fig_path3, dpi=150, bbox_inches='tight'); plt.close()
186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  fi_str = fi.head(15).to_string(index=False)
188
  report = classification_report(y_test, y_pred_gbdt, digits=4)
189
 
 
 
 
 
190
  result_text = f"""=== 模型训练结果 ===
191
- 样本数: {n_users} | 特征数: {len(df.columns)}
192
  训练集: {len(y_train)} | 测试集: {len(y_test)}
 
193
 
194
  --- GBDT ---
195
  AUC: {auc_gbdt:.4f}
@@ -199,6 +321,7 @@ AP: {ap_gbdt:.4f}
199
  --- Random Forest ---
200
  AUC: {auc_rf:.4f}
201
  AP: {ap_rf:.4f}
 
202
 
203
  --- Top 15 特征重要性 ---
204
  {fi_str}
@@ -206,38 +329,285 @@ AP: {ap_rf:.4f}
206
  --- 分类报告 (GBDT) ---
207
  {report}"""
208
 
209
- return result_text, fig_path1, fig_path2, fig_path3, df_full
210
 
211
 
212
- with gr.Blocks(title="保险APP 用户行为分析模型") as demo:
213
- gr.Markdown("""# 🏥 保险APP 用户行为分析模型训练平台
214
- 基于合成数据演示保险APP用户流失预测模型的完整训练流程。
215
- **核心功能:** 生成合成数据 → 自动特征工程 → 训练 GBDT + RF → 可视化结果
216
 
217
- **参考论文:** Deep Interest Network (KDD 2018) | Transformer Churn Prediction (arXiv 2309.14390) | TabBERT (arXiv 2011.01843)""")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
- with gr.Row():
220
- with gr.Column(scale=1):
221
- n_users_slider = gr.Slider(500, 5000, value=2000, step=100, label="用户数量")
222
- n_events_slider = gr.Slider(10, 100, value=50, step=5, label="每用户最大事件数")
223
- test_size_slider = gr.Slider(0.1, 0.4, value=0.2, step=0.05, label="测试集比例")
224
- random_seed = gr.Number(value=42, label="随机种子", precision=0)
225
- train_btn = gr.Button("🚀 开始训练", variant="primary")
226
- with gr.Column(scale=2):
227
- result_text = gr.Textbox(label="训练结果", lines=20)
 
228
 
229
- with gr.Row():
230
- img1 = gr.Image(label="特征重要性")
231
- img2 = gr.Image(label="PR曲线")
232
- with gr.Row():
233
- img3 = gr.Image(label="混淆矩阵")
234
- data_table = gr.Dataframe(label="数据样本 (前10行)")
235
 
236
- train_btn.click(fn=train_model, inputs=[n_users_slider, n_events_slider, test_size_slider, random_seed],
237
- outputs=[result_text, img1, img2, img3, data_table])
 
 
 
238
 
239
- gr.Markdown("""---
240
- **事件类型:** 浏览(page_view, product_view) | 交互(quote_request, chat_init) | 转化(payment_success, policy_issued) | 理赔(claim_init) | 续保(renewal_click, policy_cancel)""")
 
 
 
241
 
242
  if __name__ == "__main__":
243
  demo.launch()
 
1
+ """保险APP 用户行为分析 - Gradio Space
2
+ 支持: 合成数据训练 + 真实CSV数据上传
3
+ """
4
+ import os, json, math, warnings, datetime, random, io
5
+ from collections import Counter, defaultdict
6
  from dataclasses import dataclass, field
7
  from typing import List, Dict, Optional
8
 
9
  warnings.filterwarnings('ignore')
10
  import numpy as np
11
  import pandas as pd
12
+ from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold
13
+ from sklearn.preprocessing import StandardScaler, LabelEncoder
14
  from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier
15
  from sklearn.metrics import (
16
  roc_auc_score, f1_score, confusion_matrix,
17
+ average_precision_score, precision_recall_curve, classification_report,
18
+ roc_curve
19
  )
20
  import matplotlib
21
  matplotlib.use('Agg')
 
24
 
25
  import gradio as gr
26
 
27
+ # =============================================================================
28
+ # 数据模型
29
+ # =============================================================================
30
+
31
  INSURANCE_EVENT_TYPES = {
32
  "page_view", "product_view", "product_compare", "premium_calculator",
33
  "faq_view", "article_read", "quote_request", "quote_result_view",
 
38
  "policy_cancel", "app_uninstall", "login", "logout",
39
  }
40
 
41
+ BROWSE_EVENTS = {"page_view", "product_view", "premium_calculator", "article_read", "faq_view", "product_compare"}
42
+ INTERACT_EVENTS = {"quote_request", "form_submit", "document_upload", "chat_init", "call_init", "video_consult", "quote_result_view"}
43
+ CONVERT_EVENTS = {"policy_select", "payment_init", "payment_success", "policy_issued"}
44
+ CLAIM_EVENTS = {"claim_init", "claim_doc_upload", "claim_review", "claim_approved", "claim_rejected"}
45
+ RENEW_EVENTS = {"renewal_reminder", "renewal_click", "renewal_complete", "policy_cancel"}
46
+
47
  @dataclass
48
  class InsuranceAppEvent:
49
  event_id: str; user_id: str; session_id: str; timestamp: int
 
101
  }
102
 
103
 
104
+ # =============================================================================
105
+ # 数据解析
106
+ # =============================================================================
107
+
108
+ def parse_csv_to_profiles(df: pd.DataFrame) -> List[UserBehaviorProfile]:
109
+ """将上传的CSV解析为用户行为画像"""
110
+ required_cols = {"user_id", "session_id", "timestamp", "event_type", "page_id"}
111
+ missing = required_cols - set(df.columns)
112
+ if missing:
113
+ raise ValueError(f"CSV缺少必需列: {missing}\n必需列: {required_cols}")
114
+
115
+ # 标准化列名
116
+ df = df.copy()
117
+ df.columns = [c.lower().strip() for c in df.columns]
118
+
119
+ # 转换timestamp为整数
120
+ df["timestamp"] = pd.to_numeric(df["timestamp"], errors="coerce")
121
+ df = df.dropna(subset=["timestamp", "event_type"])
122
+ df["timestamp"] = df["timestamp"].astype(int)
123
+
124
+ # 按user_id和session_id分组
125
+ profiles = {}
126
+ for (user_id, session_id), group in df.groupby(["user_id", "session_id"]):
127
+ if user_id not in profiles:
128
+ profiles[user_id] = UserBehaviorProfile(user_id=str(user_id), sessions=[])
129
+
130
+ events = []
131
+ for _, row in group.sort_values("timestamp").iterrows():
132
+ events.append(InsuranceAppEvent(
133
+ event_id=f"evt_{row.name}",
134
+ user_id=str(row["user_id"]),
135
+ session_id=str(row["session_id"]),
136
+ timestamp=int(row["timestamp"]),
137
+ event_type=str(row["event_type"]).strip(),
138
+ page_id=str(row.get("page_id", "unknown")),
139
+ product_id=str(row.get("product_id")) if pd.notna(row.get("product_id")) else None,
140
+ amount=float(row["amount"]) if pd.notna(row.get("amount")) else None,
141
+ ))
142
+
143
+ profiles[user_id].sessions.append(UserSession(
144
+ session_id=str(session_id),
145
+ user_id=str(user_id),
146
+ events=events
147
+ ))
148
+
149
+ return list(profiles.values())
150
+
151
+
152
  def generate_synthetic_data(n_users=2000, n_events_per_user=50):
153
  event_types = list(INSURANCE_EVENT_TYPES)
154
  products = ["health_basic","health_premium","critical_illness","term_life",
 
180
  return data
181
 
182
 
183
+ # =============================================================================
184
+ # 核心训练函数
185
+ # =============================================================================
186
+
187
+ def train_model(features_list, labels, test_size=0.2, random_state=42, use_cv=False):
188
+ """通用训练函数"""
 
189
  df = pd.DataFrame(features_list)
190
  df_full = df.copy()
191
+
192
+ # 移除非数值列
193
+ drop_cols = [c for c in df.columns if df[c].dtype == 'object']
194
+ for c in drop_cols:
195
+ if c in ["top_product_id", "action_sequence"]:
196
+ df.pop(c)
197
+
198
+ # 处理object类型
199
  for c in df.columns:
200
  if df[c].dtype == 'object':
201
  df[c] = pd.to_numeric(df[c], errors='coerce').fillna(0)
202
  df = df.fillna(0).replace([np.inf, -np.inf], 0)
203
+
204
  X = df.values; y = np.array(labels)
205
+ feature_names = list(df.columns)
206
+
207
+ X_train, X_test, y_train, y_test = train_test_split(
208
+ X, y, test_size=test_size, random_state=random_state, stratify=y
209
+ )
210
+
211
  scaler = StandardScaler()
212
+ X_train_s = scaler.fit_transform(X_train)
213
+ X_test_s = scaler.transform(X_test)
214
 
215
+ # 训练 GBDT
216
+ gbdt = GradientBoostingClassifier(
217
+ n_estimators=200, max_depth=5, learning_rate=0.1,
218
+ subsample=0.8, random_state=random_state
219
+ )
220
  gbdt.fit(X_train_s, y_train)
221
+ y_pred_gbdt = gbdt.predict(X_test_s)
222
+ y_prob_gbdt = gbdt.predict_proba(X_test_s)[:, 1]
223
 
224
+ # 训练 RF
225
+ rf = RandomForestClassifier(
226
+ n_estimators=100, max_depth=10,
227
+ class_weight='balanced', random_state=random_state, n_jobs=-1
228
+ )
229
  rf.fit(X_train_s, y_train)
230
+ y_prob_rf = rf.predict_proba(X_test_s)[:, 1]
231
+ y_pred_rf = rf.predict(X_test_s)
232
 
233
  auc_gbdt = float(roc_auc_score(y_test, y_prob_gbdt))
234
  f1_gbdt = float(f1_score(y_test, y_pred_gbdt))
 
236
  auc_rf = float(roc_auc_score(y_test, y_prob_rf))
237
  ap_rf = float(average_precision_score(y_test, y_prob_rf))
238
 
239
+ fi = pd.DataFrame({
240
+ 'feature': feature_names,
241
+ 'importance': rf.feature_importances_
242
+ }).sort_values('importance', ascending=False)
243
+
244
+ # 交叉验证
245
+ cv_scores = None
246
+ if use_cv and len(y) >= 100:
247
+ skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=random_state)
248
+ cv_scores = cross_val_score(rf, X, y, cv=skf, scoring='roc_auc')
249
 
250
+ # 可视化
251
  os.makedirs("outputs", exist_ok=True)
252
 
253
  fig, ax = plt.subplots(figsize=(12,8))
254
  top = fi.head(15)
255
+ colors = plt.cm.RdYlGn(np.linspace(0.2, 0.8, len(top)))[::-1]
256
+ ax.barh(top['feature'][::-1], top['importance'][::-1], color=colors)
257
+ ax.set_title('Insurance APP - Top 15 Feature Importance', fontsize=14, fontweight='bold')
258
+ ax.set_xlabel('Importance Score')
259
  plt.tight_layout()
260
  fig_path1 = "outputs/feature_importance.png"
261
  plt.savefig(fig_path1, dpi=150, bbox_inches='tight'); plt.close()
 
263
  fig, ax = plt.subplots(figsize=(8,6))
264
  pg, rg, _ = precision_recall_curve(y_test, y_prob_gbdt)
265
  pr, rr, _ = precision_recall_curve(y_test, y_prob_rf)
266
+ ax.plot(rg, pg, label=f'GBDT AP={ap_gbdt:.3f}', linewidth=2, color='#2E86AB')
267
+ ax.plot(rr, pr, label=f'RF AP={ap_rf:.3f}', linewidth=2, color='#A23B72')
268
+ ax.set_xlabel('Recall', fontsize=12)
269
+ ax.set_ylabel('Precision', fontsize=12)
270
+ ax.set_title('Precision-Recall Curve', fontsize=14, fontweight='bold')
271
+ ax.legend(fontsize=11)
272
+ ax.grid(True, alpha=0.3)
273
  plt.tight_layout()
274
  fig_path2 = "outputs/pr_curve.png"
275
  plt.savefig(fig_path2, dpi=150, bbox_inches='tight'); plt.close()
276
 
277
  fig, axs = plt.subplots(1,2,figsize=(12,5))
278
+ sns.heatmap(confusion_matrix(y_test, y_pred_gbdt), annot=True, fmt='d', cmap='Blues', ax=axs[0], cbar=False)
279
+ axs[0].set_title(f'GBDT (AUC={auc_gbdt:.3f})', fontsize=12, fontweight='bold')
280
+ axs[0].set_xlabel('Predicted'); axs[0].set_ylabel('Actual')
281
+ sns.heatmap(confusion_matrix(y_test, y_pred_rf), annot=True, fmt='d', cmap='Greens', ax=axs[1], cbar=False)
282
+ axs[1].set_title(f'RF (AUC={auc_rf:.3f})', fontsize=12, fontweight='bold')
283
+ axs[1].set_xlabel('Predicted'); axs[1].set_ylabel('Actual')
284
  plt.tight_layout()
285
  fig_path3 = "outputs/confusion_matrix.png"
286
  plt.savefig(fig_path3, dpi=150, bbox_inches='tight'); plt.close()
287
 
288
+ # ROC曲线
289
+ fig, ax = plt.subplots(figsize=(8,6))
290
+ fpr_g, tpr_g, _ = roc_curve(y_test, y_prob_gbdt)
291
+ fpr_r, tpr_r, _ = roc_curve(y_test, y_prob_rf)
292
+ ax.plot(fpr_g, tpr_g, label=f'GBDT AUC={auc_gbdt:.3f}', linewidth=2, color='#2E86AB')
293
+ ax.plot(fpr_r, tpr_r, label=f'RF AUC={auc_rf:.3f}', linewidth=2, color='#A23B72')
294
+ ax.plot([0,1], [0,1], 'k--', alpha=0.5)
295
+ ax.set_xlabel('False Positive Rate', fontsize=12)
296
+ ax.set_ylabel('True Positive Rate', fontsize=12)
297
+ ax.set_title('ROC Curve', fontsize=14, fontweight='bold')
298
+ ax.legend(fontsize=11)
299
+ ax.grid(True, alpha=0.3)
300
+ plt.tight_layout()
301
+ fig_path4 = "outputs/roc_curve.png"
302
+ plt.savefig(fig_path4, dpi=150, bbox_inches='tight'); plt.close()
303
+
304
  fi_str = fi.head(15).to_string(index=False)
305
  report = classification_report(y_test, y_pred_gbdt, digits=4)
306
 
307
+ cv_str = ""
308
+ if cv_scores is not None:
309
+ cv_str = f"\n--- 5折交叉验证 (RF AUC) ---\nMean: {cv_scores.mean():.4f} (+/- {cv_scores.std()*2:.4f})\nScores: {cv_scores.round(4).tolist()}"
310
+
311
  result_text = f"""=== 模型训练结果 ===
312
+ 样本数: {len(y)} | 特征数: {len(feature_names)}
313
  训练集: {len(y_train)} | 测试集: {len(y_test)}
314
+ 流失率: {y.mean():.1%} | 流失数: {y.sum()}
315
 
316
  --- GBDT ---
317
  AUC: {auc_gbdt:.4f}
 
321
  --- Random Forest ---
322
  AUC: {auc_rf:.4f}
323
  AP: {ap_rf:.4f}
324
+ {cv_str}
325
 
326
  --- Top 15 特征重要性 ---
327
  {fi_str}
 
329
  --- 分类报告 (GBDT) ---
330
  {report}"""
331
 
332
+ return result_text, fig_path1, fig_path2, fig_path3, fig_path4, df_full
333
 
334
 
335
+ # =============================================================================
336
+ # Gradio 回调函数
337
+ # =============================================================================
 
338
 
339
+ def demo_train(n_users, n_events, test_size, random_state, use_cv):
340
+ """演示模式: 合成数据训练"""
341
+ data = generate_synthetic_data(n_users=n_users, n_events_per_user=n_events)
342
+ engineer = InsuranceFeatureEngineer()
343
+ features_list, labels = [], []
344
+ for profile, label in data:
345
+ f = engineer.extract_user_features(profile)
346
+ if f: features_list.append(f); labels.append(label)
347
+
348
+ return train_model(features_list, labels, test_size, random_state, use_cv)
349
+
350
+
351
+ def csv_train(csv_file, label_col, test_size, random_state, use_cv):
352
+ """CSV模式: 上传数据训练"""
353
+ if csv_file is None:
354
+ return "请先上传CSV文件", None, None, None, None, None
355
+
356
+ try:
357
+ # 读取CSV
358
+ if isinstance(csv_file, str):
359
+ df = pd.read_csv(csv_file)
360
+ else:
361
+ df = pd.read_csv(csv_file.name if hasattr(csv_file, 'name') else io.BytesIO(csv_file))
362
+
363
+ # 检查标签列
364
+ label_col = label_col.strip() if label_col else None
365
+ if label_col and label_col not in df.columns:
366
+ return f"标签列 '{label_col}' 不存在。可用列: {list(df.columns)}", None, None, None, None, None
367
+
368
+ # 解析为用户画像
369
+ profiles = parse_csv_to_profiles(df)
370
+
371
+ engineer = InsuranceFeatureEngineer()
372
+ features_list, labels = [], []
373
+
374
+ for profile in profiles:
375
+ f = engineer.extract_user_features(profile)
376
+ if f:
377
+ features_list.append(f)
378
+ # 如果有标签列,使用真实标签;否则用启发式规则推断
379
+ if label_col and label_col in df.columns:
380
+ # 找到该用户的标签
381
+ user_df = df[df["user_id"] == profile.user_id]
382
+ label_val = user_df[label_col].iloc[0] if len(user_df) > 0 else 0
383
+ labels.append(int(label_val))
384
+ else:
385
+ # 启发式: 无购买+无续保 = 高风险流失
386
+ is_high_risk = (f["has_purchased"] == 0 and f["has_renewed"] == 0
387
+ and f["total_events"] < 20)
388
+ labels.append(int(is_high_risk))
389
+
390
+ if len(features_list) < 50:
391
+ return f"有效样本数 {len(features_list)} 太少,需要至少50个用户", None, None, None, None, None
392
+
393
+ result = train_model(features_list, labels, test_size, random_state, use_cv)
394
+ return result
395
+
396
+ except Exception as e:
397
+ import traceback
398
+ return f"错误: {str(e)}\n\n{traceback.format_exc()}", None, None, None, None, None
399
+
400
+
401
+ def show_csv_info(csv_file):
402
+ """显示CSV信息"""
403
+ if csv_file is None:
404
+ return "请先上传CSV文件", None
405
+
406
+ try:
407
+ if isinstance(csv_file, str):
408
+ df = pd.read_csv(csv_file)
409
+ else:
410
+ df = pd.read_csv(csv_file.name if hasattr(csv_file, 'name') else io.BytesIO(csv_file))
411
+
412
+ info = f"""=== CSV文件信息 ===
413
+ 行数: {len(df)}
414
+ 列数: {len(df.columns)}
415
+ 列名: {list(df.columns)}
416
+
417
+ === 前5行预览 ===
418
+ {df.head().to_string()}
419
+
420
+ === 事件类型分布 (前10) ===
421
+ {df['event_type'].value_counts().head(10).to_string() if 'event_type' in df.columns else '无event_type列'}
422
+
423
+ === 用户数量 ===
424
+ {df['user_id'].nunique() if 'user_id' in df.columns else '无user_id列'}
425
+
426
+ === 会话数量 ===
427
+ {df['session_id'].nunique() if 'session_id' in df.columns else '无session_id列'}"""
428
+
429
+ return info, df.head(20)
430
+ except Exception as e:
431
+ return f"解析错误: {str(e)}", None
432
+
433
+
434
+ # =============================================================================
435
+ # Gradio 界面
436
+ # =============================================================================
437
+
438
+ with gr.Blocks(title="🏥 保险APP 用户行为分析模型训练平台", theme=gr.themes.Soft()) as demo:
439
+ gr.Markdown("""
440
+ # 🏥 保险APP 用户行为分析模型训练平台
441
+
442
+ 基于最新研究论文构建的工业级保险用户行为分析平台。
443
+
444
+ **两种模式:**
445
+ - 🎲 **演示模式**: 生成合成保险APP数据,体验完整训练流程
446
+ - 📁 **CSV上传**: 上传真实用户行为数据,自动特征工程 + 模型训练
447
+
448
+ **参考论文:** Deep Interest Network (KDD 2018) | Transformer Churn Prediction (arXiv 2309.14390) | TabBERT (arXiv 2011.01843)
449
+ """)
450
+
451
+ with gr.Tabs():
452
+ # ===== Tab 1: 演示模式 =====
453
+ with gr.Tab("🎲 演示模式 (合成数据)"):
454
+ with gr.Row():
455
+ with gr.Column(scale=1):
456
+ gr.Markdown("### 参数设置")
457
+ n_users_slider = gr.Slider(500, 5000, value=2000, step=100, label="用户数量")
458
+ n_events_slider = gr.Slider(10, 100, value=50, step=5, label="每用户最大事件数")
459
+ test_size_slider = gr.Slider(0.1, 0.4, value=0.2, step=0.05, label="测试集比例")
460
+ random_seed = gr.Number(value=42, label="随机种子", precision=0)
461
+ use_cv_check = gr.Checkbox(value=False, label="启用5折交叉验证")
462
+ train_btn = gr.Button("🚀 开始训练", variant="primary", size="lg")
463
+
464
+ with gr.Column(scale=2):
465
+ demo_result = gr.Textbox(label="训练结果", lines=25, show_copy_button=True)
466
+
467
+ with gr.Row():
468
+ demo_img1 = gr.Image(label="特征重要性")
469
+ demo_img2 = gr.Image(label="PR曲线")
470
+ with gr.Row():
471
+ demo_img3 = gr.Image(label="混淆矩阵")
472
+ demo_img4 = gr.Image(label="ROC曲线")
473
+ with gr.Row():
474
+ demo_table = gr.Dataframe(label="特征数据样本 (前10行)")
475
+
476
+ # ===== Tab 2: CSV上传 =====
477
+ with gr.Tab("📁 CSV数据上传"):
478
+ with gr.Row():
479
+ with gr.Column(scale=1):
480
+ gr.Markdown("""
481
+ ### 📤 上传数据
482
+
483
+ **必需列:**
484
+ - `user_id`: 用户唯一标识
485
+ - `session_id`: 会话标识
486
+ - `timestamp`: Unix 时间戳 (毫秒或秒)
487
+ - `event_type`: 事件类型
488
+ - `page_id`: 页面标识
489
+
490
+ **可选列:**
491
+ - `product_id`: 保险产品ID
492
+ - `amount`: 金额/保额
493
+ - `label` (或其他): 流失标签 (0/1)
494
+
495
+ **示例CSV格式:**
496
+ ```
497
+ user_id,session_id,timestamp,event_type,page_id,product_id,amount
498
+ user_001,sess_001,1704067200000,page_view,home,,
499
+ user_001,sess_001,1704067230000,product_view,product,health_basic,
500
+ user_001,sess_001,1704067260000,quote_request,quote,health_basic,50000
501
+ ```
502
+ """)
503
+
504
+ csv_file = gr.File(label="上传CSV文件", file_types=[".csv"])
505
+ label_col_input = gr.Textbox(label="标签列名 (可选, 默认自动推断)", placeholder="如: churn, is_churned, label")
506
+
507
+ with gr.Row():
508
+ csv_test_size = gr.Slider(0.1, 0.4, value=0.2, step=0.05, label="测试集比例")
509
+ csv_random_seed = gr.Number(value=42, label="随机种子", precision=0)
510
+
511
+ csv_use_cv = gr.Checkbox(value=False, label="启用5折交叉验证")
512
+
513
+ with gr.Row():
514
+ info_btn = gr.Button("📊 查看数据信息", variant="secondary")
515
+ csv_train_btn = gr.Button("🚀 训练模型", variant="primary", size="lg")
516
+
517
+ with gr.Column(scale=2):
518
+ csv_info = gr.Textbox(label="CSV信息", lines=15, show_copy_button=True)
519
+ csv_preview = gr.Dataframe(label="数据预览")
520
+
521
+ with gr.Row():
522
+ csv_result = gr.Textbox(label="训练结果", lines=25, show_copy_button=True)
523
+
524
+ with gr.Row():
525
+ csv_img1 = gr.Image(label="特征重要性")
526
+ csv_img2 = gr.Image(label="PR曲线")
527
+ with gr.Row():
528
+ csv_img3 = gr.Image(label="混淆矩阵")
529
+ csv_img4 = gr.Image(label="ROC曲线")
530
+ with gr.Row():
531
+ csv_table = gr.Dataframe(label="特征数据样本 (前10行)")
532
+
533
+ # ===== Tab 3: 帮助文档 =====
534
+ with gr.Tab("❓ 帮助文档"):
535
+ gr.Markdown("""
536
+ ## 事件类型定义
537
+
538
+ | 类别 | 事件 | 业务含义 |
539
+ |------|------|---------|
540
+ | **浏览** | page_view, product_view, premium_calculator, article_read, faq_view, product_compare | 用户浏览保险产品页面 |
541
+ | **交互** | quote_request, form_submit, document_upload, chat_init, call_init, video_consult, quote_result_view | 用户深度参与行为 |
542
+ | **转化** | policy_select, payment_init, payment_success, policy_issued | 核心KPI转化行为 |
543
+ | **理赔** | claim_init, claim_doc_upload, claim_review, claim_approved, claim_rejected | 理赔全流程 |
544
+ | **续保** | renewal_reminder, renewal_click, renewal_complete, policy_cancel | 续保/流失信号 |
545
+ | **其他** | login, logout, app_uninstall | 登录/登出/卸载 |
546
+
547
+ ## 特征工程说明
548
+
549
+ 平台自动提取 **30+维行为特征**:
550
+
551
+ | 维度 | 特征示例 | 业务含义 |
552
+ |------|---------|---------|
553
+ | 基础活跃度 | total_sessions, total_events, days_active | 用户使用APP的活跃程度 |
554
+ | 浏览深度 | product_view_ratio, article_read_ratio | 内容消费深度 |
555
+ | 转化信号 | payment_success_ratio, policy_issued_ratio | 购买/续保意愿 |
556
+ | 生命周期 | has_purchased, has_renewed, has_claimed | 客户价值阶段 |
557
+ | 时序行为 | recent_7day_events, days_since_last_event | 近期活跃/沉默 |
558
+ | 行为模式 | peak_active_hour, weekend_activity_ratio | 使用习惯 |
559
+
560
+ ## 模型说明
561
+
562
+ | 模型 | 特点 | 适用场景 |
563
+ |------|------|---------|
564
+ | **GBDT** | 高精度, 可解释 | 流失预测, 欺诈检测 |
565
+ | **Random Forest** | 抗过拟合, 特征重要性 | 特征筛选, 基线模型 |
566
+
567
+ ## 评估指标
568
+
569
+ - **AUC-ROC**: 分类器整体区分能力
570
+ - **F1-Score**: 精确率和召回率的调和平均
571
+ - **AP (Average Precision)**: PR曲线下面积, 适合不平衡数据
572
+ - **交叉验证**: 5折StratifiedKFold, 评估模型稳定性
573
+
574
+ ## 注意事项
575
+
576
+ 1. 保险场景数据高度不平衡 (流失率 < 5%), 请使用 F1/AP 而非 Accuracy
577
+ 2. 建议至少 1000+ 用户样本才能获得稳定结果
578
+ 3. timestamp 支持毫秒或秒, 平台自动识别
579
+ 4. 无标签列时, 平台使用启发式规则自动推断 (无购买+低活跃 = 高风险)
580
+ """)
581
 
582
+ gr.Markdown("""
583
+ ---
584
+ <div align="center">
585
+ <b>保险APP 用户行为分析模型训练平台</b> |
586
+ 基于 <a href="https://arxiv.org/abs/1706.06978">DIN</a> |
587
+ <a href="https://arxiv.org/abs/2309.14390">Churn Transformer</a> |
588
+ <a href="https://arxiv.org/abs/2011.01843">TabBERT</a> |
589
+ 作者: <a href="https://huggingface.co/Stephanwu">Stephanwu</a>
590
+ </div>
591
+ """)
592
 
593
+ # ===== 事件绑定 =====
594
+ train_btn.click(
595
+ fn=demo_train,
596
+ inputs=[n_users_slider, n_events_slider, test_size_slider, random_seed, use_cv_check],
597
+ outputs=[demo_result, demo_img1, demo_img2, demo_img3, demo_img4, demo_table]
598
+ )
599
 
600
+ info_btn.click(
601
+ fn=show_csv_info,
602
+ inputs=[csv_file],
603
+ outputs=[csv_info, csv_preview]
604
+ )
605
 
606
+ csv_train_btn.click(
607
+ fn=csv_train,
608
+ inputs=[csv_file, label_col_input, csv_test_size, csv_random_seed, csv_use_cv],
609
+ outputs=[csv_result, csv_img1, csv_img2, csv_img3, csv_img4, csv_table]
610
+ )
611
 
612
  if __name__ == "__main__":
613
  demo.launch()