Stephanwu commited on
Commit
32bf7a6
·
verified ·
1 Parent(s): 8240065

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +178 -92
app.py CHANGED
@@ -338,8 +338,11 @@ AP: {ap_rf:.4f}
338
 
339
 
340
  # =============================================================================
341
- # DIN 产品推荐
342
  # =============================================================================
 
 
 
343
 
344
  def generate_product_recommendation_data(n_users=1000, seed=42):
345
  random.seed(seed); np.random.seed(seed)
@@ -363,11 +366,33 @@ def generate_product_recommendation_data(n_users=1000, seed=42):
363
  'user_features': np.random.randn(20).astype(np.float32),
364
  })
365
  return pd.DataFrame(records)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
 
367
 
368
- def train_din_recommendation(n_users, embedding_dim, epochs, batch_size, lr, seed):
369
  if not TORCH_AVAILABLE:
370
- return "PyTorch 未安装。请在 requirements.txt 中添加 torch 并重启 Space", None, None, None, None, None
371
 
372
  torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)
373
  df = generate_product_recommendation_data(n_users=n_users, seed=seed)
@@ -396,34 +421,67 @@ def train_din_recommendation(n_users, embedding_dim, epochs, batch_size, lr, see
396
 
397
  device = torch.device('cpu')
398
 
399
- class SimpleDIN(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
  def __init__(self, num_events, num_products, d_model=64, max_len=20):
401
  super().__init__()
 
 
402
  self.event_emb = nn.Embedding(num_events+1, d_model//2, padding_idx=0)
403
  self.prod_emb = nn.Embedding(num_products+1, d_model//2, padding_idx=0)
404
  self.cand_emb = nn.Embedding(num_products+1, d_model)
405
- self.attn = nn.Sequential(nn.Linear(d_model*4, 128), nn.ReLU(), nn.Linear(128, 1))
406
- self.mlp = nn.Sequential(nn.Linear(d_model*3, 256), nn.ReLU(), nn.Dropout(0.3),
 
 
407
  nn.Linear(256, 128), nn.ReLU(), nn.Dropout(0.3), nn.Linear(128, 1))
 
408
  def forward(self, be, bp, bm, cp):
409
- B = be.size(0); L = be.size(1)
410
  e_emb = self.event_emb(be)
411
  p_emb = self.prod_emb(bp)
412
  beh_emb = torch.cat([e_emb, p_emb], dim=-1)
413
  cand_emb = self.cand_emb(cp)
414
- cand_exp = cand_emb.unsqueeze(1).expand(B, L, -1)
415
- diff = cand_exp - beh_emb; prod = cand_exp * beh_emb
416
- attn_in = torch.cat([cand_exp, beh_emb, diff, prod], dim=-1)
417
- attn_w = self.attn(attn_in).squeeze(-1)
418
- attn_w = attn_w.masked_fill(~bm.bool(), -1e9)
419
- attn_w = torch.softmax(attn_w, dim=1)
420
- interest = (beh_emb * attn_w.unsqueeze(-1)).sum(dim=1)
421
- x = torch.cat([interest, cand_emb, interest*cand_emb], dim=-1)
 
 
 
 
 
 
 
422
  return self.mlp(x).squeeze(-1)
423
 
424
- model = SimpleDIN(len(all_events), len(all_products), d_model=embedding_dim).to(device)
425
  criterion = nn.BCEWithLogitsLoss()
426
- optimizer = torch.optim.Adam(model.parameters(), lr=lr)
427
 
428
  for epoch in range(epochs):
429
  model.train(); epoch_loss = 0
@@ -437,10 +495,13 @@ def train_din_recommendation(n_users, embedding_dim, epochs, batch_size, lr, see
437
  optimizer.zero_grad()
438
  outputs = model(be, bp, bm, cp)
439
  loss = criterion(outputs, labels)
440
- loss.backward(); optimizer.step()
 
 
441
  epoch_loss += loss.item()
 
442
  if (epoch+1) % max(1, epochs//5) == 0 or epoch == 0:
443
- print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss*batch_size/len(train_df):.4f}")
444
 
445
  model.eval()
446
  with torch.no_grad():
@@ -458,7 +519,6 @@ def train_din_recommendation(n_users, embedding_dim, epochs, batch_size, lr, see
458
 
459
  os.makedirs("outputs", exist_ok=True)
460
 
461
- # 保存 PyTorch 模型
462
  torch.save({
463
  'model_state_dict': model.state_dict(),
464
  'event_vocab': event_vocab,
@@ -468,14 +528,14 @@ def train_din_recommendation(n_users, embedding_dim, epochs, batch_size, lr, see
468
  'num_events': len(all_events),
469
  'num_products': len(all_products),
470
  'metrics': {'auc': auc, 'ap': ap, 'f1': f1, 'acc': acc}
471
- }, 'outputs/din_model.pt')
472
 
 
473
  fig, ax = plt.subplots(figsize=(10,6))
474
  product_perf = {}
475
- for _, row in test_df.iterrows():
476
  prod = row['candidate_product']
477
  if prod not in product_perf: product_perf[prod] = {'preds': [], 'labels': []}
478
- idx = test_df.index.get_loc(_)
479
  product_perf[prod]['preds'].append(preds[idx])
480
  product_perf[prod]['labels'].append(row['label'])
481
  prod_aucs = []
@@ -492,88 +552,114 @@ def train_din_recommendation(n_users, embedding_dim, epochs, batch_size, lr, see
492
  ax2.plot(x, rates, 'ro-', label='Conversion Rate')
493
  ax.set_xticks(x); ax.set_xticklabels(prods, rotation=45, ha='right')
494
  ax.set_ylabel('AUC', color='steelblue'); ax2.set_ylabel('Conversion Rate', color='red')
495
- ax.set_title('Product Recommendation Performance', fontweight='bold')
496
  ax.legend(loc='upper left'); ax2.legend(loc='upper right')
497
  plt.tight_layout()
498
- fig_path1 = "outputs/din_product_performance.png"
499
  plt.savefig(fig_path1, dpi=150); plt.close()
500
 
501
- fig, ax = plt.subplots(figsize=(10,6))
502
  sample_idx = 0
503
  with torch.no_grad():
504
- be_s = be[sample_idx:sample_idx+1]; bp_s = bp[sample_idx:sample_idx+1]
505
- bm_s = bm[sample_idx:sample_idx+1]; cp_s = cp[sample_idx:sample_idx+1]
 
 
506
  B, L = be_s.size()
507
- e_emb = model.event_emb(be_s); p_emb = model.prod_emb(bp_s)
 
508
  beh_emb = torch.cat([e_emb, p_emb], dim=-1)
509
  cand_emb = model.cand_emb(cp_s)
510
- cand_exp = cand_emb.unsqueeze(1).expand(B, L, -1)
511
- diff = cand_exp - beh_emb; prod_feat = cand_exp * beh_emb
512
- attn_in = torch.cat([cand_exp, beh_emb, diff, prod_feat], dim=-1)
513
- attn_w = torch.softmax(model.attn(attn_in).squeeze(-1).masked_fill(~bm_s, -1e9), dim=1)
514
- weights = attn_w[0].cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
515
  valid_len = bm_s[0].sum().item()
516
- valid_weights = weights[-valid_len:] if valid_len > 0 else weights
517
- ax.bar(range(len(valid_weights)), valid_weights, color='coral')
518
- ax.set_title('Attention Weights (Sample User)', fontweight='bold')
519
- ax.set_xlabel('Behavior Position'); ax.set_ylabel('Attention Weight')
 
 
 
 
 
 
 
 
520
  plt.tight_layout()
521
- fig_path2 = "outputs/din_attention.png"
522
  plt.savefig(fig_path2, dpi=150); plt.close()
523
 
524
  fig, ax = plt.subplots(figsize=(8,6))
525
  fpr, tpr, _ = roc_curve(labels, preds)
526
- ax.plot(fpr, tpr, label=f'DIN AUC={auc:.3f}', linewidth=2, color='#2E86AB')
527
  ax.plot([0,1], [0,1], 'k--', alpha=0.5)
528
  ax.set_xlabel('False Positive Rate'); ax.set_ylabel('True Positive Rate')
529
- ax.set_title('ROC Curve - Product Recommendation', fontweight='bold')
530
  ax.legend(); ax.grid(True, alpha=0.3)
531
  plt.tight_layout()
532
- fig_path3 = "outputs/din_roc.png"
533
  plt.savefig(fig_path3, dpi=150); plt.close()
534
 
535
  fig, ax = plt.subplots(figsize=(8,6))
536
  prec, rec, _ = precision_recall_curve(labels, preds)
537
- ax.plot(rec, prec, label=f'DIN AP={ap:.3f}', linewidth=2, color='#A23B72')
538
  ax.set_xlabel('Recall'); ax.set_ylabel('Precision')
539
- ax.set_title('Precision-Recall Curve - Product Recommendation', fontweight='bold')
540
  ax.legend(); ax.grid(True, alpha=0.3)
541
  plt.tight_layout()
542
- fig_path4 = "outputs/din_pr.png"
543
  plt.savefig(fig_path4, dpi=150); plt.close()
544
 
545
- result_text = f"""=== DIN 保险产品推荐模型 ===
546
- 样本数: {n_users} | 产品数: {len(all_products)}
547
  Event vocab: {len(all_events)} | Product vocab: {len(all_products)}
548
- 训练集: {len(train_df)} | 测试集: {len(test_df)}
549
-
550
- --- 模型架构 ---
551
- Embedding dim: {embedding_dim}
552
- Attention: LocalActivationUnit (4路交互: [c, b, c-b, c*b])
553
- MLP: [emb*3] 256 → 128 → 1
554
-
555
- --- 训练配置 ---
 
 
 
 
 
 
556
  Epochs: {epochs} | Batch size: {batch_size} | LR: {lr}
557
- Optimizer: Adam | Loss: BCEWithLogitsLoss
558
 
559
- --- 测试集效果 ---
560
  AUC: {auc:.4f}
561
  AP: {ap:.4f}
562
  F1: {f1:.4f}
563
  Accuracy: {acc:.4f}
564
 
565
- --- 模型洞察 ---
566
- 1. 注意力机制自动学习用户历史行为中对候选产品的相关度
567
- 2. 高权重通常分配给同类产品的历史浏览/购买行为
568
- 3. 新用户(历史短)依赖统计特征, 老用户依赖行为序列
569
 
570
- --- 模型文件 ---
571
- 模型已保存至: outputs/din_model.pt
572
- 可使用"模型管理"Tab上传至Hugging Face Hub"""
573
 
574
  return result_text, fig_path1, fig_path2, fig_path3, fig_path4
575
-
576
-
577
  # =============================================================================
578
  # TabBERT 异常检测
579
  # =============================================================================
@@ -830,12 +916,12 @@ def save_model_to_hub(repo_id, token, model_type, notes):
830
  joblib.dump(artifacts['sklearn'], tmpdir / "sklearn_model.joblib")
831
  model_files.append("sklearn_model.joblib")
832
 
833
- # 检查 DIN 模型
834
- din_path = Path("outputs/din_model.pt")
835
- if din_path.exists():
836
- artifacts['din'] = torch.load(din_path, map_location='cpu')
837
- torch.save(artifacts['din'], tmpdir / "din_model.pt")
838
- model_files.append("din_model.pt")
839
 
840
  # 检查 TabBERT 模型
841
  tab_path = Path("outputs/tabbert_model.pt")
@@ -871,7 +957,7 @@ def save_model_to_hub(repo_id, token, model_type, notes):
871
  | File | Description |
872
  |------|-------------|
873
  | `sklearn_model.joblib` | GBDT + Random Forest + Scaler (sklearn) |
874
- | `din_model.pt` | Deep Interest Network (PyTorch) |
875
  | `tabbert_model.pt` | TabularBERT Anomaly Detection (PyTorch) |
876
  | `model_metadata.json` | Model metadata |
877
 
@@ -887,10 +973,10 @@ model_path = hf_hub_download(repo_id="{repo_id}", filename="sklearn_model.joblib
887
  artifacts = joblib.load(model_path)
888
  # artifacts['gbdt'], artifacts['rf'], artifacts['scaler']
889
 
890
- # Load DIN
891
- din_path = hf_hub_download(repo_id="{repo_id}", filename="din_model.pt")
892
- din_ckpt = torch.load(din_path)
893
- # din_ckpt['model_state_dict'], din_ckpt['event_vocab'], din_ckpt['product_vocab']
894
  ```
895
 
896
  ## Reference
@@ -959,16 +1045,16 @@ def load_model_from_hub(repo_id, token, model_type):
959
  plt.savefig(img_path, dpi=150); plt.close()
960
  images.append(img_path)
961
 
962
- # 加载 DIN
963
- if "din_model.pt" in metadata.get('files', []):
964
- din_path = hf_hub_download(repo_id=repo_id, filename="din_model.pt", token=token, repo_type="model")
965
- din_ckpt = torch.load(din_path, map_location='cpu')
966
- metrics = din_ckpt.get('metrics', {})
967
- results.append(f"📦 DIN 模型已加载")
968
  results.append(f" AUC: {metrics.get('auc', 'N/A')}")
969
- results.append(f" Embedding dim: {din_ckpt.get('embedding_dim', 'N/A')}")
970
- results.append(f" Event vocab: {len(din_ckpt.get('event_vocab', {}))}")
971
- results.append(f" Product vocab: {len(din_ckpt.get('product_vocab', {}))}")
972
 
973
  # 加载 TabBERT
974
  if "tabbert_model.pt" in metadata.get('files', []):
@@ -1435,10 +1521,10 @@ with gr.Blocks(title="🏥 保险APP 用户行为分析模型训练平台 v3.0",
1435
  with gr.Row():
1436
  csv_table = gr.Dataframe(label="特征数据样本")
1437
 
1438
- # ===== Tab 3: DIN 产品推荐 =====
1439
- with gr.Tab("🎯 产品推荐 (DIN)"):
1440
- gr.Markdown("""### Deep Interest Network - 保险产品推荐
1441
- 基于用户历史行为序列, 通过注意力机制动态计算对候选保险产品的兴趣。""")
1442
  with gr.Row():
1443
  with gr.Column(scale=1):
1444
  din_users = gr.Slider(500, 5000, value=2000, step=100, label="用户数量")
@@ -1447,7 +1533,7 @@ with gr.Blocks(title="🏥 保险APP 用户行为分析模型训练平台 v3.0",
1447
  din_batch = gr.Slider(32, 512, value=128, step=32, label="Batch Size")
1448
  din_lr = gr.Slider(0.0001, 0.01, value=0.001, step=0.0001, label="学习率")
1449
  din_seed = gr.Number(value=42, label="随机种子", precision=0)
1450
- din_btn = gr.Button("🚀 训练DIN模型", variant="primary", size="lg")
1451
  if not TORCH_AVAILABLE:
1452
  gr.Markdown("⚠️ **PyTorch 未安装**。请在 requirements.txt 中添加 `torch>=2.0.0` 并重启。")
1453
  with gr.Column(scale=2):
@@ -1610,7 +1696,7 @@ product_compare | video_consult | | | |
1610
  info_btn.click(fn=show_csv_info, inputs=[csv_file], outputs=[csv_info, csv_preview])
1611
  csv_train_btn.click(fn=csv_train, inputs=[csv_file, label_col_input, csv_test_size, csv_random_seed, csv_use_cv],
1612
  outputs=[csv_result, csv_img1, csv_img2, csv_img3, csv_img4, csv_table])
1613
- din_btn.click(fn=train_din_recommendation, inputs=[din_users, din_emb, din_epochs, din_batch, din_lr, din_seed],
1614
  outputs=[din_result, din_img1, din_img2, din_img3, din_img4])
1615
  tab_btn.click(fn=train_tabbert_anomaly, inputs=[tab_normal, tab_anomaly, tab_dmodel, tab_epochs, tab_batch, tab_lr, tab_seed],
1616
  outputs=[tab_result, tab_img1, tab_img2, tab_img3, tab_img4])
 
338
 
339
 
340
  # =============================================================================
341
+ # DIEN 产品推荐 (Deep Interest Evolution Network)
342
  # =============================================================================
343
+ # DIEN = DIN + Interest Extractor (GRU) + Interest Evolving (AUGRU)
344
+ # 论文: Deep Interest Evolution Network for Click-Through Rate Prediction (AAAI 2019)
345
+ # arXiv: https://arxiv.org/abs/1809.03672
346
 
347
  def generate_product_recommendation_data(n_users=1000, seed=42):
348
  random.seed(seed); np.random.seed(seed)
 
366
  'user_features': np.random.randn(20).astype(np.float32),
367
  })
368
  return pd.DataFrame(records)
369
+ def generate_product_recommendation_data(n_users=1000, seed=42):
370
+ random.seed(seed); np.random.seed(seed)
371
+ products = ["health_basic","health_premium","critical_illness","term_life",
372
+ "auto_compulsory","auto_commercial","home","travel_domestic"]
373
+ records = []
374
+ for u in range(n_users):
375
+ n_behaviors = random.randint(5, 30)
376
+ behavior_events = []
377
+ behavior_products = []
378
+ for i in range(n_behaviors):
379
+ et = random.choice(["page_view","product_view","quote_request","article_read"])
380
+ behavior_events.append(et)
381
+ behavior_products.append(random.choice(products))
382
+ candidate = random.choice(products)
383
+ label = 1 if candidate in behavior_products else random.choices([0,1], weights=[0.7,0.3])[0]
384
+ records.append({
385
+ 'user_id': u, 'behavior_events': behavior_events,
386
+ 'behavior_products': behavior_products,
387
+ 'candidate_product': candidate, 'label': label,
388
+ 'user_features': np.random.randn(20).astype(np.float32),
389
+ })
390
+ return pd.DataFrame(records)
391
 
392
 
393
+ def train_dien_recommendation(n_users, embedding_dim, epochs, batch_size, lr, seed):
394
  if not TORCH_AVAILABLE:
395
+ return "PyTorch not installed. Please add torch to requirements.txt and restart Space.", None, None, None, None, None
396
 
397
  torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)
398
  df = generate_product_recommendation_data(n_users=n_users, seed=seed)
 
421
 
422
  device = torch.device('cpu')
423
 
424
+ # ===== DIEN Model Implementation =====
425
+ class AUGRUCell(nn.Module):
426
+ """Attentional Update Gate Recurrent Unit - core DIEN component"""
427
+ def __init__(self, input_dim, hidden_dim):
428
+ super().__init__()
429
+ self.input_dim = input_dim
430
+ self.hidden_dim = hidden_dim
431
+ self.W_ug = nn.Linear(input_dim + hidden_dim, hidden_dim)
432
+ self.W_rg = nn.Linear(input_dim + hidden_dim, hidden_dim)
433
+ self.W_cand = nn.Linear(input_dim + hidden_dim, hidden_dim)
434
+
435
+ def forward(self, x_t, h_prev, attn_t):
436
+ concat = torch.cat([x_t, h_prev], dim=-1)
437
+ r_t = torch.sigmoid(self.W_rg(concat))
438
+ u_t = torch.sigmoid(self.W_ug(concat))
439
+ u_t_att = attn_t * u_t
440
+ r_concat = torch.cat([x_t, r_t * h_prev], dim=-1)
441
+ h_tilde = torch.tanh(self.W_cand(r_concat))
442
+ h_t = (1 - u_t_att) * h_prev + u_t_att * h_tilde
443
+ return h_t
444
+
445
+ class SimpleDIEN(nn.Module):
446
  def __init__(self, num_events, num_products, d_model=64, max_len=20):
447
  super().__init__()
448
+ self.d_model = d_model
449
+ self.max_len = max_len
450
  self.event_emb = nn.Embedding(num_events+1, d_model//2, padding_idx=0)
451
  self.prod_emb = nn.Embedding(num_products+1, d_model//2, padding_idx=0)
452
  self.cand_emb = nn.Embedding(num_products+1, d_model)
453
+ self.gru = nn.GRU(input_size=d_model, hidden_size=d_model, batch_first=True)
454
+ self.augru = AUGRUCell(d_model, d_model)
455
+ self.attn = nn.Sequential(nn.Linear(d_model * 4, 128), nn.ReLU(), nn.Linear(128, 1))
456
+ self.mlp = nn.Sequential(nn.Linear(d_model * 3, 256), nn.ReLU(), nn.Dropout(0.3),
457
  nn.Linear(256, 128), nn.ReLU(), nn.Dropout(0.3), nn.Linear(128, 1))
458
+
459
  def forward(self, be, bp, bm, cp):
460
+ B, L = be.size()
461
  e_emb = self.event_emb(be)
462
  p_emb = self.prod_emb(bp)
463
  beh_emb = torch.cat([e_emb, p_emb], dim=-1)
464
  cand_emb = self.cand_emb(cp)
465
+ gru_out, _ = self.gru(beh_emb)
466
+ h_t = torch.zeros(B, self.d_model, device=beh_emb.device)
467
+ for t in range(L):
468
+ gru_t = gru_out[:, t, :]
469
+ cand_exp = cand_emb
470
+ diff = cand_exp - gru_t
471
+ prod_feat = cand_exp * gru_t
472
+ attn_in = torch.cat([cand_exp, gru_t, diff, prod_feat], dim=-1)
473
+ attn_t = torch.sigmoid(self.attn(attn_in))
474
+ mask_t = bm[:, t:t+1].float()
475
+ h_new = self.augru(gru_t, h_t, attn_t)
476
+ h_t = mask_t * h_new + (1 - mask_t) * h_t
477
+ final_interest = h_t
478
+ interest_prod = final_interest * cand_emb
479
+ x = torch.cat([final_interest, cand_emb, interest_prod], dim=-1)
480
  return self.mlp(x).squeeze(-1)
481
 
482
+ model = SimpleDIEN(len(all_events), len(all_products), d_model=embedding_dim).to(device)
483
  criterion = nn.BCEWithLogitsLoss()
484
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
485
 
486
  for epoch in range(epochs):
487
  model.train(); epoch_loss = 0
 
495
  optimizer.zero_grad()
496
  outputs = model(be, bp, bm, cp)
497
  loss = criterion(outputs, labels)
498
+ loss.backward()
499
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
500
+ optimizer.step()
501
  epoch_loss += loss.item()
502
+ avg_loss = epoch_loss * batch_size / len(train_df)
503
  if (epoch+1) % max(1, epochs//5) == 0 or epoch == 0:
504
+ print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")
505
 
506
  model.eval()
507
  with torch.no_grad():
 
519
 
520
  os.makedirs("outputs", exist_ok=True)
521
 
 
522
  torch.save({
523
  'model_state_dict': model.state_dict(),
524
  'event_vocab': event_vocab,
 
528
  'num_events': len(all_events),
529
  'num_products': len(all_products),
530
  'metrics': {'auc': auc, 'ap': ap, 'f1': f1, 'acc': acc}
531
+ }, 'outputs/dien_model.pt')
532
 
533
+ # Visualizations
534
  fig, ax = plt.subplots(figsize=(10,6))
535
  product_perf = {}
536
+ for idx, row in test_df.iterrows():
537
  prod = row['candidate_product']
538
  if prod not in product_perf: product_perf[prod] = {'preds': [], 'labels': []}
 
539
  product_perf[prod]['preds'].append(preds[idx])
540
  product_perf[prod]['labels'].append(row['label'])
541
  prod_aucs = []
 
552
  ax2.plot(x, rates, 'ro-', label='Conversion Rate')
553
  ax.set_xticks(x); ax.set_xticklabels(prods, rotation=45, ha='right')
554
  ax.set_ylabel('AUC', color='steelblue'); ax2.set_ylabel('Conversion Rate', color='red')
555
+ ax.set_title('DIEN - Product Recommendation Performance', fontweight='bold')
556
  ax.legend(loc='upper left'); ax2.legend(loc='upper right')
557
  plt.tight_layout()
558
+ fig_path1 = "outputs/dien_product_performance.png"
559
  plt.savefig(fig_path1, dpi=150); plt.close()
560
 
561
+ fig, ax = plt.subplots(figsize=(12,6))
562
  sample_idx = 0
563
  with torch.no_grad():
564
+ be_s = be[sample_idx:sample_idx+1]
565
+ bp_s = bp[sample_idx:sample_idx+1]
566
+ bm_s = bm[sample_idx:sample_idx+1]
567
+ cp_s = cp[sample_idx:sample_idx+1]
568
  B, L = be_s.size()
569
+ e_emb = model.event_emb(be_s)
570
+ p_emb = model.prod_emb(bp_s)
571
  beh_emb = torch.cat([e_emb, p_emb], dim=-1)
572
  cand_emb = model.cand_emb(cp_s)
573
+ gru_out, _ = model.gru(beh_emb)
574
+ h_t = torch.zeros(B, model.d_model, device=beh_emb.device)
575
+ attn_weights = []
576
+ interest_norms = []
577
+ for t in range(L):
578
+ gru_t = gru_out[:, t, :]
579
+ cand_exp = cand_emb
580
+ diff = cand_exp - gru_t
581
+ prod_feat = cand_exp * gru_t
582
+ attn_in = torch.cat([cand_exp, gru_t, diff, prod_feat], dim=-1)
583
+ attn_t = torch.sigmoid(model.attn(attn_in))
584
+ h_new = model.augru(gru_t, h_t, attn_t)
585
+ mask_t = bm_s[:, t:t+1].float()
586
+ h_t = mask_t * h_new + (1 - mask_t) * h_t
587
+ attn_weights.append(attn_t.item())
588
+ interest_norms.append(torch.norm(h_t).item())
589
  valid_len = bm_s[0].sum().item()
590
+ valid_attn = attn_weights[-valid_len:] if valid_len > 0 else attn_weights
591
+ valid_norms = interest_norms[-valid_len:] if valid_len > 0 else interest_norms
592
+ ax.plot(range(len(valid_attn)), valid_attn, 'o-', color='coral', linewidth=2, label='Attention Weight', markersize=6)
593
+ ax_twin = ax.twinx()
594
+ ax_twin.plot(range(len(valid_norms)), valid_norms, 's--', color='steelblue', linewidth=2, label='Interest Norm (L2)', markersize=6)
595
+ ax.set_xlabel('Behavior Position')
596
+ ax.set_ylabel('Attention Weight', color='coral')
597
+ ax_twin.set_ylabel('Interest Norm', color='steelblue')
598
+ ax.set_title('DIEN - Interest Evolution (Sample User)', fontweight='bold')
599
+ ax.legend(loc='upper left')
600
+ ax_twin.legend(loc='upper right')
601
+ ax.grid(True, alpha=0.3)
602
  plt.tight_layout()
603
+ fig_path2 = "outputs/dien_interest_evolution.png"
604
  plt.savefig(fig_path2, dpi=150); plt.close()
605
 
606
  fig, ax = plt.subplots(figsize=(8,6))
607
  fpr, tpr, _ = roc_curve(labels, preds)
608
+ ax.plot(fpr, tpr, label=f'DIEN AUC={auc:.3f}', linewidth=2, color='#2E86AB')
609
  ax.plot([0,1], [0,1], 'k--', alpha=0.5)
610
  ax.set_xlabel('False Positive Rate'); ax.set_ylabel('True Positive Rate')
611
+ ax.set_title('ROC Curve - DIEN Product Recommendation', fontweight='bold')
612
  ax.legend(); ax.grid(True, alpha=0.3)
613
  plt.tight_layout()
614
+ fig_path3 = "outputs/dien_roc.png"
615
  plt.savefig(fig_path3, dpi=150); plt.close()
616
 
617
  fig, ax = plt.subplots(figsize=(8,6))
618
  prec, rec, _ = precision_recall_curve(labels, preds)
619
+ ax.plot(rec, prec, label=f'DIEN AP={ap:.3f}', linewidth=2, color='#A23B72')
620
  ax.set_xlabel('Recall'); ax.set_ylabel('Precision')
621
+ ax.set_title('Precision-Recall Curve - DIEN', fontweight='bold')
622
  ax.legend(); ax.grid(True, alpha=0.3)
623
  plt.tight_layout()
624
+ fig_path4 = "outputs/dien_pr.png"
625
  plt.savefig(fig_path4, dpi=150); plt.close()
626
 
627
+ result_text = f"""=== DIEN (Deep Interest Evolution Network) Product Recommendation Model ===
628
+ Samples: {n_users} | Products: {len(all_products)}
629
  Event vocab: {len(all_events)} | Product vocab: {len(all_products)}
630
+ Train: {len(train_df)} | Test: {len(test_df)}
631
+
632
+ --- DIEN Architecture (3 layers) ---
633
+ Layer 1: Embedding
634
+ - event_emb({len(all_events)+1} -> {embedding_dim//2}) + prod_emb({len(all_products)+1} -> {embedding_dim//2})
635
+ Layer 2: Interest Extractor (GRU)
636
+ - Input: concat(event_emb, prod_emb) -> GRU({embedding_dim} -> {embedding_dim})
637
+ Layer 3: Interest Evolving (AUGRU)
638
+ - AUGRUCell: Attention-gated recurrent unit
639
+ - u_t' = a_t * u_t (attention modulated update gate)
640
+ Layer 4: MLP
641
+ - [emb*3] -> 256 -> 128 -> 1
642
+
643
+ --- Training Config ---
644
  Epochs: {epochs} | Batch size: {batch_size} | LR: {lr}
645
+ Optimizer: Adam (weight_decay=1e-5) | Gradient clip: max_norm=5.0
646
 
647
+ --- Test Results ---
648
  AUC: {auc:.4f}
649
  AP: {ap:.4f}
650
  F1: {f1:.4f}
651
  Accuracy: {acc:.4f}
652
 
653
+ --- DIEN vs DIN ---
654
+ 1. [GRU Interest Extractor] Models temporal dependencies in behavior sequences
655
+ 2. [AUGRU Interest Evolving] Attention-modulated update gate, preserves only target-relevant interest evolution
656
+ 3. [Better cold-start] Short sequences benefit from GRU temporal modeling
657
 
658
+ --- Model File ---
659
+ Saved to: outputs/dien_model.pt
660
+ Upload to HF Hub via Model Management tab"""
661
 
662
  return result_text, fig_path1, fig_path2, fig_path3, fig_path4
 
 
663
  # =============================================================================
664
  # TabBERT 异常检测
665
  # =============================================================================
 
916
  joblib.dump(artifacts['sklearn'], tmpdir / "sklearn_model.joblib")
917
  model_files.append("sklearn_model.joblib")
918
 
919
+ # 检查 DIEN 模型
920
+ dien_path = Path("outputs/dien_model.pt")
921
+ if dien_path.exists():
922
+ artifacts['dien'] = torch.load(dien_path, map_location='cpu')
923
+ torch.save(artifacts['dien'], tmpdir / "dien_model.pt")
924
+ model_files.append("dien_model.pt")
925
 
926
  # 检查 TabBERT 模型
927
  tab_path = Path("outputs/tabbert_model.pt")
 
957
  | File | Description |
958
  |------|-------------|
959
  | `sklearn_model.joblib` | GBDT + Random Forest + Scaler (sklearn) |
960
+ | `dien_model.pt` | Deep Interest Evolution Network (PyTorch) |
961
  | `tabbert_model.pt` | TabularBERT Anomaly Detection (PyTorch) |
962
  | `model_metadata.json` | Model metadata |
963
 
 
973
  artifacts = joblib.load(model_path)
974
  # artifacts['gbdt'], artifacts['rf'], artifacts['scaler']
975
 
976
+ # Load DIEN
977
+ dien_path = hf_hub_download(repo_id="{repo_id}", filename="dien_model.pt")
978
+ dien_ckpt = torch.load(dien_path)
979
+ # dien_ckpt['model_state_dict'], dien_ckpt['event_vocab'], dien_ckpt['product_vocab']
980
  ```
981
 
982
  ## Reference
 
1045
  plt.savefig(img_path, dpi=150); plt.close()
1046
  images.append(img_path)
1047
 
1048
+ # 加载 DIEN
1049
+ if "dien_model.pt" in metadata.get('files', []):
1050
+ dien_path = hf_hub_download(repo_id=repo_id, filename="dien_model.pt", token=token, repo_type="model")
1051
+ dien_ckpt = torch.load(dien_path, map_location='cpu')
1052
+ metrics = dien_ckpt.get('metrics', {})
1053
+ results.append(f"📦 DIEN 模型已加载")
1054
  results.append(f" AUC: {metrics.get('auc', 'N/A')}")
1055
+ results.append(f" Embedding dim: {dien_ckpt.get('embedding_dim', 'N/A')}")
1056
+ results.append(f" Event vocab: {len(dien_ckpt.get('event_vocab', {}))}")
1057
+ results.append(f" Product vocab: {len(dien_ckpt.get('product_vocab', {}))}")
1058
 
1059
  # 加载 TabBERT
1060
  if "tabbert_model.pt" in metadata.get('files', []):
 
1521
  with gr.Row():
1522
  csv_table = gr.Dataframe(label="特征数据样本")
1523
 
1524
+ # ===== Tab 3: DIEN 产品推荐 =====
1525
+ with gr.Tab("🎯 产品推荐 (DIEN)"):
1526
+ gr.Markdown("""### Deep Interest Evolution Network - 保险产品推荐
1527
+ 基于DIEN (AAAI 2019), 通过 GRU兴趣提取 + AUGRU兴趣演化 建模用户对候选保险产品的动态兴趣演化过程。""")
1528
  with gr.Row():
1529
  with gr.Column(scale=1):
1530
  din_users = gr.Slider(500, 5000, value=2000, step=100, label="用户数量")
 
1533
  din_batch = gr.Slider(32, 512, value=128, step=32, label="Batch Size")
1534
  din_lr = gr.Slider(0.0001, 0.01, value=0.001, step=0.0001, label="学习率")
1535
  din_seed = gr.Number(value=42, label="随机种子", precision=0)
1536
+ din_btn = gr.Button("🚀 训练DIEN模型", variant="primary", size="lg")
1537
  if not TORCH_AVAILABLE:
1538
  gr.Markdown("⚠️ **PyTorch 未安装**。请在 requirements.txt 中添加 `torch>=2.0.0` 并重启。")
1539
  with gr.Column(scale=2):
 
1696
  info_btn.click(fn=show_csv_info, inputs=[csv_file], outputs=[csv_info, csv_preview])
1697
  csv_train_btn.click(fn=csv_train, inputs=[csv_file, label_col_input, csv_test_size, csv_random_seed, csv_use_cv],
1698
  outputs=[csv_result, csv_img1, csv_img2, csv_img3, csv_img4, csv_table])
1699
+ din_btn.click(fn=train_dien_recommendation, inputs=[din_users, din_emb, din_epochs, din_batch, din_lr, din_seed],
1700
  outputs=[din_result, din_img1, din_img2, din_img3, din_img4])
1701
  tab_btn.click(fn=train_tabbert_anomaly, inputs=[tab_normal, tab_anomaly, tab_dmodel, tab_epochs, tab_batch, tab_lr, tab_seed],
1702
  outputs=[tab_result, tab_img1, tab_img2, tab_img3, tab_img4])