Builder-Neekhil commited on
Commit
e434e59
·
verified ·
1 Parent(s): 703ee9e

Upload phase3_integration.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. phase3_integration.py +695 -0
phase3_integration.py ADDED
@@ -0,0 +1,695 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Phase 3: Integration — Augment Original Model with Phase 1 & Phase 2 Signals
3
+ =============================================================================
4
+ Goal: Add Gottman behavioral risk features + longitudinal survival priors
5
+ to the original speed dating model and measure improvement.
6
+
7
+ We create "proxy" Gottman features from the speed dating data by mapping
8
+ the existing personality/perception features to Gottman dimensions. This
9
+ is a cross-domain feature transfer approach.
10
+ """
11
+
12
+ import os
13
+ import json
14
+ import warnings
15
+ import numpy as np
16
+ import pandas as pd
17
+ import matplotlib
18
+ matplotlib.use('Agg')
19
+ import matplotlib.pyplot as plt
20
+ import seaborn as sns
21
+ from datasets import load_dataset
22
+ from sklearn.model_selection import StratifiedKFold
23
+ from sklearn.metrics import (
24
+ roc_auc_score, accuracy_score, f1_score, classification_report,
25
+ precision_score, recall_score, average_precision_score,
26
+ brier_score_loss, precision_recall_curve, roc_curve
27
+ )
28
+ from sklearn.preprocessing import LabelEncoder
29
+ from xgboost import XGBClassifier
30
+ from lightgbm import LGBMClassifier
31
+ from catboost import CatBoostClassifier
32
+ import joblib
33
+ import shap
34
+
35
+ warnings.filterwarnings('ignore')
36
+ np.random.seed(42)
37
+
38
+ OUTPUT_DIR = "/app/phase3_output"
39
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
40
+ os.makedirs(f"{OUTPUT_DIR}/figures", exist_ok=True)
41
+
42
+ # ============================================================
43
+ # 1. LOAD ORIGINAL MODEL BASELINE
44
+ # ============================================================
45
+ print("=" * 70)
46
+ print("PHASE 3: INTEGRATION — MEASURE IMPROVEMENTS")
47
+ print("=" * 70)
48
+
49
+ # Load original data
50
+ ds = load_dataset("mstz/speeddating", "dating", split="train")
51
+ df = ds.to_pandas()
52
+
53
+ # Load phase outputs
54
+ with open("/app/phase1_output/gottman_recipe.json") as f:
55
+ gottman_recipe = json.load(f)
56
+ with open("/app/phase2_output/survival_recipe.json") as f:
57
+ survival_recipe = json.load(f)
58
+ with open("/app/phase2_output/longevity_priors.json") as f:
59
+ longevity_priors = json.load(f)
60
+
61
+ print(f"Speed dating dataset: {df.shape}")
62
+ print(f"Gottman dimensions: {list(gottman_recipe['dimensions'].keys())}")
63
+ print(f"Survival priors: {list(longevity_priors.keys())}")
64
+
65
+ # ============================================================
66
+ # 2. REPRODUCE ORIGINAL FEATURES (BASELINE)
67
+ # ============================================================
68
+ print("\n" + "=" * 70)
69
+ print("Step 2: Reproducing Original Baseline Features")
70
+ print("=" * 70)
71
+
72
+ # Same feature engineering as original model
73
+ traits = ['attractiveness', 'sincerity', 'intelligence', 'humor', 'ambition']
74
+
75
+ for trait in traits:
76
+ dater_rates_partner = f'reported_{trait}_of_dated_from_dater'
77
+ partner_rates_dater = f'{trait}_score_of_dater_from_dated'
78
+ if dater_rates_partner in df.columns and partner_rates_dater in df.columns:
79
+ df[f'{trait}_perception_gap'] = df[dater_rates_partner] - df[partner_rates_dater]
80
+ df[f'{trait}_mutual_score'] = (df[dater_rates_partner] + df[partner_rates_dater]) / 2
81
+ df[f'{trait}_perception_product'] = df[dater_rates_partner] * df[partner_rates_dater]
82
+
83
+ for trait in traits:
84
+ importance_col = f'{trait}_importance_for_dater'
85
+ score_col = f'{trait}_score_of_dater_from_dated'
86
+ if importance_col in df.columns and score_col in df.columns:
87
+ df[f'{trait}_value_fulfillment_dater'] = df[importance_col] * df[score_col] / 100
88
+
89
+ for trait in traits:
90
+ self_col = f'self_reported_{trait}_of_dater'
91
+ partner_score_col = f'{trait}_score_of_dater_from_dated'
92
+ if self_col in df.columns and partner_score_col in df.columns:
93
+ df[f'{trait}_self_awareness_gap'] = df[self_col] - df[partner_score_col]
94
+
95
+ df['total_perception_gap'] = sum(df[f'{t}_perception_gap'].fillna(0) for t in traits) / len(traits)
96
+ df['total_mutual_score'] = sum(df[f'{t}_mutual_score'].fillna(0) for t in traits) / len(traits)
97
+ df['total_value_fulfillment'] = sum(df[f'{t}_value_fulfillment_dater'].fillna(0) for t in traits)
98
+ df['total_self_awareness_gap'] = sum(df[f'{t}_self_awareness_gap'].fillna(0) for t in traits) / len(traits)
99
+
100
+ df['expectation_meets_reality'] = df['expected_satisfaction_of_dater'] * df['dater_liked_dated']
101
+ df['confidence_calibration'] = (
102
+ df['expected_number_of_likes_of_dater_from_20_people'] / 20 -
103
+ df['probability_dated_wants_to_date'] / 10
104
+ )
105
+
106
+ df['age_gap_abs'] = df['age_difference']
107
+ df['age_gap_squared'] = df['age_difference'] ** 2
108
+ df['dater_is_older'] = (df['dater_age'] > df['dated_age']).astype(int)
109
+ df['combined_age'] = df['dater_age'] + df['dated_age']
110
+
111
+ interest_cols = [c for c in df.columns if c.startswith('dater_interest_in_')]
112
+ if interest_cols:
113
+ df['interest_diversity'] = df[interest_cols].std(axis=1)
114
+ df['interest_intensity'] = df[interest_cols].mean(axis=1)
115
+ df['max_interest'] = df[interest_cols].max(axis=1)
116
+ df['min_interest'] = df[interest_cols].min(axis=1)
117
+ df['interest_range'] = df['max_interest'] - df['min_interest']
118
+
119
+ importance_dater_cols = [
120
+ 'attractiveness_importance_for_dater', 'sincerity_importance_for_dater',
121
+ 'intelligence_importance_for_dater', 'humor_importance_for_dater',
122
+ 'ambition_importance_for_dater', 'shared_interests_importance_for_dater'
123
+ ]
124
+ importance_dated_cols = [
125
+ 'attractiveness_importance_for_dated', 'sincerity_importance_for_dated',
126
+ 'intelligence_importance_for_dated', 'humor_importance_for_dated',
127
+ 'ambition_importance_for_dated', 'shared_interests_importance_for_dated'
128
+ ]
129
+
130
+ df['importance_concentration_dater'] = df[importance_dater_cols].std(axis=1)
131
+ df['max_importance_dater'] = df[importance_dater_cols].max(axis=1)
132
+ df['importance_concentration_dated'] = df[importance_dated_cols].std(axis=1)
133
+
134
+ for i, (d1, d2) in enumerate(zip(importance_dater_cols, importance_dated_cols)):
135
+ df[f'importance_alignment_{i}'] = abs(df[d1] - df[d2])
136
+ df['total_importance_alignment'] = sum(
137
+ abs(df[d1] - df[d2]) for d1, d2 in zip(importance_dater_cols, importance_dated_cols)
138
+ )
139
+
140
+ le_race = LabelEncoder()
141
+ df['dater_race_encoded'] = le_race.fit_transform(df['dater_race'].fillna('Unknown'))
142
+ df['dated_race_encoded'] = le_race.transform(df['dated_race'].fillna('Unknown'))
143
+ df['race_match'] = (df['dater_race'] == df['dated_race']).astype(int)
144
+
145
+ df['is_dater_male_int'] = df['is_dater_male'].astype(int)
146
+ df['are_same_race_int'] = df['are_same_race'].astype(int)
147
+ df['already_met_int'] = df['already_met_before'].astype(int)
148
+
149
+ # Original feature set
150
+ exclude_cols = [
151
+ 'is_match', 'dater_wants_to_date', 'dated_wants_to_date',
152
+ 'dater_race', 'dated_race', 'already_met_before', 'is_dater_male',
153
+ 'are_same_race', 'decision_agreement'
154
+ ]
155
+
156
+ original_feature_cols = [c for c in df.columns if c not in exclude_cols
157
+ and c not in ['decision_agreement']
158
+ and df[c].dtype in ['float64', 'int64', 'int32', 'float32']]
159
+
160
+ # Remove any new features we're about to add
161
+ original_feature_cols = [c for c in original_feature_cols if not c.startswith('gottman_')
162
+ and not c.startswith('survival_') and not c.startswith('prior_')]
163
+
164
+ print(f"Original features: {len(original_feature_cols)}")
165
+
166
+ # ============================================================
167
+ # 3. ADD PHASE 1 FEATURES — GOTTMAN PROXY SCORES
168
+ # ============================================================
169
+ print("\n" + "=" * 70)
170
+ print("Step 3: Adding Gottman Proxy Features (Phase 1)")
171
+ print("=" * 70)
172
+
173
+ # Map speed dating features to Gottman dimensions
174
+ # This is cross-domain feature transfer: we use the SHAP insights from the
175
+ # Gottman model to create proxy scores from available speed dating features
176
+
177
+ # --- CONTEMPT PROXY ---
178
+ # Gottman finding: Contempt (mutual disrespect, low regard) is the #1 divorce predictor
179
+ # Speed dating proxy: Low mutual scores, high perception gaps (I see you as worse than you see me)
180
+ df['gottman_proxy_contempt'] = (
181
+ -df['total_mutual_score'] + # Low mutual regard → contempt-like
182
+ abs(df['total_perception_gap']) + # Asymmetric perception → disrespect
183
+ abs(df['total_self_awareness_gap']) * 0.5 # Low self-awareness → unrealistic expectations
184
+ )
185
+
186
+ # --- CRITICISM PROXY ---
187
+ # Gottman: Attacking character. Speed dating: Harsh gap between what you expect vs what you see
188
+ df['gottman_proxy_criticism'] = (
189
+ df['total_importance_alignment'] * 0.1 + # Misaligned values = source of criticism
190
+ abs(df['total_perception_gap']) # I rate you lower than you rate me = implicit criticism
191
+ )
192
+
193
+ # --- DEFENSIVENESS PROXY ---
194
+ # Gottman: Counter-attacking, refusing to accept influence
195
+ # Proxy: High self-ratings vs low partner ratings (inflated self-view)
196
+ df['gottman_proxy_defensiveness'] = (
197
+ df['total_self_awareness_gap'].clip(lower=0) # I think I'm better than you think I am
198
+ )
199
+
200
+ # --- STONEWALLING PROXY ---
201
+ # Gottman: Withdrawing, shutting down
202
+ # Proxy: Low expected satisfaction, low engagement (low liked score despite meeting)
203
+ df['gottman_proxy_stonewalling'] = (
204
+ (10 - df['dater_liked_dated'].fillna(5)) * 0.3 + # Low liking = withdrawal
205
+ (10 - df['probability_dated_wants_to_date'].fillna(5)) * 0.2 + # Expected rejection
206
+ (1 - df['interests_correlation'].fillna(0.5)) # No shared interests = no engagement
207
+ )
208
+
209
+ # --- LOVE MAPS PROXY ---
210
+ # Gottman: Knowing partner's inner world.
211
+ # Proxy: Interest correlation + shared interests score + mutual perception accuracy
212
+ df['gottman_proxy_love_maps'] = (
213
+ df['interests_correlation'].fillna(0) * 2 +
214
+ df['shared_interests_score_of_dater_from_dated'].fillna(5) * 0.3 +
215
+ df['reported_shared_interests_of_dated_from_dater'].fillna(5) * 0.3 -
216
+ abs(df['total_perception_gap']) * 0.5 # Accurate mutual perception = knowing each other
217
+ )
218
+
219
+ # --- SHARED GOALS PROXY ---
220
+ # Proxy: Value alignment + similar importance weights
221
+ df['gottman_proxy_shared_goals'] = (
222
+ -df['total_importance_alignment'] * 0.1 + # Similar values → shared goals
223
+ df['total_value_fulfillment'] * 0.5 + # Partner meets your values → aligned
224
+ df['interests_correlation'].fillna(0) * 2 # Shared interests → shared life direction
225
+ )
226
+
227
+ # --- COMBINED GOTTMAN SCORES ---
228
+ # Four Horsemen combined (higher = worse)
229
+ df['gottman_proxy_horsemen'] = (
230
+ df['gottman_proxy_contempt'] +
231
+ df['gottman_proxy_criticism'] +
232
+ df['gottman_proxy_defensiveness'] +
233
+ df['gottman_proxy_stonewalling']
234
+ )
235
+
236
+ # Positive combined (higher = better)
237
+ df['gottman_proxy_positive'] = (
238
+ df['gottman_proxy_love_maps'] +
239
+ df['gottman_proxy_shared_goals']
240
+ )
241
+
242
+ # Gottman Ratio (the famous 5:1 positive to negative ratio)
243
+ df['gottman_proxy_ratio'] = (
244
+ (df['gottman_proxy_positive'] + 10) /
245
+ (df['gottman_proxy_horsemen'] + 10)
246
+ )
247
+
248
+ # Horsemen interactions (from Phase 1 SHAP: contempt × stonewalling was top predictor)
249
+ df['gottman_proxy_contempt_x_stonewalling'] = df['gottman_proxy_contempt'] * df['gottman_proxy_stonewalling']
250
+ df['gottman_proxy_criticism_x_defensiveness'] = df['gottman_proxy_criticism'] * df['gottman_proxy_defensiveness']
251
+ df['gottman_proxy_love_x_goals'] = df['gottman_proxy_love_maps'] * df['gottman_proxy_shared_goals']
252
+
253
+ # Horsemen minus Positive (net risk)
254
+ df['gottman_proxy_net_risk'] = df['gottman_proxy_horsemen'] - df['gottman_proxy_positive']
255
+
256
+ gottman_proxy_features = [c for c in df.columns if c.startswith('gottman_proxy_')]
257
+ print(f"Gottman proxy features added: {len(gottman_proxy_features)}")
258
+ for f in gottman_proxy_features:
259
+ print(f" {f}: mean={df[f].mean():.3f}, std={df[f].std():.3f}")
260
+
261
+ # ============================================================
262
+ # 4. ADD PHASE 2 FEATURES — SURVIVAL PRIORS
263
+ # ============================================================
264
+ print("\n" + "=" * 70)
265
+ print("Step 4: Adding Survival Prior Features (Phase 2)")
266
+ print("=" * 70)
267
+
268
+ # Survival priors from the Vedastro longitudinal data
269
+ # Key findings from Phase 2:
270
+ cox_hazard_ratios = survival_recipe.get('cox_summary', {})
271
+
272
+ # Age-at-relationship features (from Cox PH: age_at_marriage HR=0.96, significant)
273
+ # Younger couples face higher divorce risk
274
+ df['survival_age_risk_dater'] = np.where(
275
+ df['dater_age'] < 22, longevity_priors['age_at_marriage_young']['divorce_rate'],
276
+ np.where(df['dater_age'] < 30, longevity_priors['age_at_marriage_prime']['divorce_rate'],
277
+ np.where(df['dater_age'] < 40, longevity_priors['age_at_marriage_mature']['divorce_rate'],
278
+ longevity_priors['age_at_marriage_late']['divorce_rate']))
279
+ )
280
+
281
+ # Average age risk for the couple
282
+ mean_age = (df['dater_age'] + df['dated_age']) / 2
283
+ df['survival_couple_age_risk'] = np.where(
284
+ mean_age < 22, longevity_priors['age_at_marriage_young']['divorce_rate'],
285
+ np.where(mean_age < 30, longevity_priors['age_at_marriage_prime']['divorce_rate'],
286
+ np.where(mean_age < 40, longevity_priors['age_at_marriage_mature']['divorce_rate'],
287
+ longevity_priors['age_at_marriage_late']['divorce_rate']))
288
+ )
289
+
290
+ # First vs subsequent relationship risk (from Cox PH: is_first_marriage HR=0.26, huge effect)
291
+ # We use already_met as a weak proxy for prior relationship history
292
+ df['survival_prior_relationship_risk'] = np.where(
293
+ df['already_met_int'] == 1,
294
+ longevity_priors['marriage_second']['divorce_rate'], # Already know each other → not "first"
295
+ longevity_priors['marriage_first']['divorce_rate'] # First meeting → first relationship proxy
296
+ )
297
+
298
+ # Divorce timing hazard (from Phase 2: 41% of divorces at 3-7 years, 32% at 8-14)
299
+ # Age gap as a risk amplifier (larger gaps → earlier divorce)
300
+ divorce_timing = survival_recipe['divorce_timing']
301
+ df['survival_early_risk'] = (
302
+ divorce_timing['honeymoon_crisis_0_2yr'] +
303
+ divorce_timing['seven_year_itch_3_7yr']
304
+ ) # Base rate: 54.4% of divorces happen in first 7 years
305
+
306
+ # Overall base divorce rate
307
+ df['survival_base_divorce_rate'] = longevity_priors['overall']['divorce_rate']
308
+
309
+ # Age gap interaction with survival (from Cox: age matters)
310
+ df['survival_age_gap_risk'] = (
311
+ df['survival_couple_age_risk'] *
312
+ (1 + df['age_gap_abs'] * 0.02) # Each year of age gap increases risk by 2%
313
+ )
314
+
315
+ # Combined survival risk score
316
+ df['survival_combined_risk'] = (
317
+ df['survival_couple_age_risk'] * 0.4 +
318
+ df['survival_prior_relationship_risk'] * 0.3 +
319
+ df['survival_age_gap_risk'] * 0.3
320
+ )
321
+
322
+ survival_features = [c for c in df.columns if c.startswith('survival_')]
323
+ print(f"Survival prior features added: {len(survival_features)}")
324
+ for f in survival_features:
325
+ print(f" {f}: mean={df[f].mean():.4f}, std={df[f].std():.4f}")
326
+
327
+ # ============================================================
328
+ # 5. TRAIN ENHANCED MODEL & COMPARE
329
+ # ============================================================
330
+ print("\n" + "=" * 70)
331
+ print("Step 5: Training Enhanced Model & Comparing to Baseline")
332
+ print("=" * 70)
333
+
334
+ y = df['is_match'].values
335
+ scale_pos_weight = (y == 0).sum() / (y == 1).sum()
336
+
337
+ # Define feature sets
338
+ enhanced_feature_cols = original_feature_cols + gottman_proxy_features + survival_features
339
+
340
+ # Remove any duplicates
341
+ enhanced_feature_cols = list(dict.fromkeys(enhanced_feature_cols))
342
+
343
+ print(f"\nFeature comparison:")
344
+ print(f" Original: {len(original_feature_cols)} features")
345
+ print(f" + Gottman: +{len(gottman_proxy_features)} features")
346
+ print(f" + Survival:+{len(survival_features)} features")
347
+ print(f" Enhanced: {len(enhanced_feature_cols)} features")
348
+
349
+ X_original = df[original_feature_cols].fillna(df[original_feature_cols].median()).values
350
+ X_enhanced = df[enhanced_feature_cols].fillna(df[enhanced_feature_cols].median()).values
351
+
352
+ # Train both models with same hyperparameters
353
+ n_splits = 5
354
+ skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
355
+
356
+ def train_and_evaluate(X, y, label, feature_names):
357
+ """Train XGB+LGB+CAT ensemble with 5-fold CV."""
358
+ oof_xgb = np.zeros(len(y))
359
+ oof_lgb = np.zeros(len(y))
360
+ oof_cat = np.zeros(len(y))
361
+
362
+ for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
363
+ X_train, X_val = X[train_idx], X[val_idx]
364
+ y_train, y_val = y[train_idx], y[val_idx]
365
+
366
+ # XGBoost
367
+ xgb = XGBClassifier(
368
+ n_estimators=1500, max_depth=7, learning_rate=0.03,
369
+ colsample_bytree=0.8, subsample=0.8, min_child_weight=3,
370
+ gamma=0.1, reg_alpha=0.1, reg_lambda=1.0,
371
+ scale_pos_weight=scale_pos_weight,
372
+ use_label_encoder=False, eval_metric='auc',
373
+ tree_method='hist', random_state=42, n_jobs=-1
374
+ )
375
+ xgb.fit(X_train, y_train, eval_set=[(X_val, y_val)], verbose=False)
376
+ oof_xgb[val_idx] = xgb.predict_proba(X_val)[:, 1]
377
+
378
+ # LightGBM
379
+ lgb = LGBMClassifier(
380
+ n_estimators=1500, max_depth=7, learning_rate=0.03,
381
+ colsample_bytree=0.8, subsample=0.8, min_child_samples=10,
382
+ reg_alpha=0.1, reg_lambda=1.0,
383
+ scale_pos_weight=scale_pos_weight,
384
+ random_state=42, n_jobs=-1, verbose=-1
385
+ )
386
+ lgb.fit(X_train, y_train, eval_set=[(X_val, y_val)])
387
+ oof_lgb[val_idx] = lgb.predict_proba(X_val)[:, 1]
388
+
389
+ # CatBoost
390
+ cat = CatBoostClassifier(
391
+ iterations=1500, depth=7, learning_rate=0.03,
392
+ l2_leaf_reg=3.0, auto_class_weights='Balanced',
393
+ random_seed=42, verbose=0
394
+ )
395
+ cat.fit(X_train, y_train, eval_set=(X_val, y_val))
396
+ oof_cat[val_idx] = cat.predict_proba(X_val)[:, 1]
397
+
398
+ # Ensemble
399
+ oof_ens = 0.4 * oof_xgb + 0.35 * oof_lgb + 0.25 * oof_cat
400
+
401
+ # Compute metrics
402
+ results = {}
403
+ for name, preds in [('XGBoost', oof_xgb), ('LightGBM', oof_lgb),
404
+ ('CatBoost', oof_cat), ('Ensemble', oof_ens)]:
405
+ auc = roc_auc_score(y, preds)
406
+ ap = average_precision_score(y, preds)
407
+ brier = brier_score_loss(y, preds)
408
+
409
+ precision_curve, recall_curve, thresholds = precision_recall_curve(y, preds)
410
+ f1_scores = 2 * (precision_curve * recall_curve) / (precision_curve + recall_curve + 1e-10)
411
+ optimal_threshold = thresholds[np.argmax(f1_scores)]
412
+ y_pred = (preds >= optimal_threshold).astype(int)
413
+
414
+ results[name] = {
415
+ 'AUC-ROC': auc, 'AUC-PR': ap, 'Brier': brier,
416
+ 'Accuracy': accuracy_score(y, y_pred),
417
+ 'F1': f1_score(y, y_pred),
418
+ 'Precision': precision_score(y, y_pred),
419
+ 'Recall': recall_score(y, y_pred),
420
+ 'Threshold': optimal_threshold
421
+ }
422
+
423
+ return results, oof_ens, xgb, lgb, cat
424
+
425
+ print("\nTraining ORIGINAL model (baseline)...")
426
+ baseline_results, baseline_preds, _, _, _ = train_and_evaluate(
427
+ X_original, y, "Original", original_feature_cols)
428
+
429
+ print("\nTraining ENHANCED model (+ Gottman + Survival)...")
430
+ enhanced_results, enhanced_preds, final_xgb, final_lgb, final_cat = train_and_evaluate(
431
+ X_enhanced, y, "Enhanced", enhanced_feature_cols)
432
+
433
+ # ============================================================
434
+ # 6. IMPROVEMENT ANALYSIS
435
+ # ============================================================
436
+ print("\n" + "=" * 70)
437
+ print("Step 6: IMPROVEMENT ANALYSIS")
438
+ print("=" * 70)
439
+
440
+ print("\n" + "=" * 70)
441
+ print(f"{'METRIC':<20} {'BASELINE':>12} {'ENHANCED':>12} {'DELTA':>12} {'% CHANGE':>12}")
442
+ print("=" * 70)
443
+
444
+ improvements = {}
445
+ for metric in ['AUC-ROC', 'AUC-PR', 'Brier', 'Accuracy', 'F1', 'Precision', 'Recall']:
446
+ base_val = baseline_results['Ensemble'][metric]
447
+ enh_val = enhanced_results['Ensemble'][metric]
448
+ delta = enh_val - base_val
449
+ pct = delta / base_val * 100 if base_val != 0 else 0
450
+
451
+ # For Brier, lower is better
452
+ if metric == 'Brier':
453
+ direction = '✅' if delta < 0 else '❌'
454
+ else:
455
+ direction = '✅' if delta > 0 else '❌' if delta < 0 else '➖'
456
+
457
+ print(f"{metric:<20} {base_val:>12.4f} {enh_val:>12.4f} {delta:>+12.4f} {pct:>+11.2f}% {direction}")
458
+ improvements[metric] = {'baseline': base_val, 'enhanced': enh_val, 'delta': delta, 'pct_change': pct}
459
+
460
+ # Per-model breakdown
461
+ print(f"\n\nPer-model AUC-ROC comparison:")
462
+ print(f"{'Model':<12} {'Baseline':>12} {'Enhanced':>12} {'Delta':>12}")
463
+ print("-" * 50)
464
+ for model in ['XGBoost', 'LightGBM', 'CatBoost', 'Ensemble']:
465
+ base = baseline_results[model]['AUC-ROC']
466
+ enh = enhanced_results[model]['AUC-ROC']
467
+ delta = enh - base
468
+ direction = '✅' if delta > 0 else '❌'
469
+ print(f"{model:<12} {base:>12.4f} {enh:>12.4f} {delta:>+12.4f} {direction}")
470
+
471
+ # ============================================================
472
+ # 7. TRAIN FINAL ENHANCED MODELS ON FULL DATA
473
+ # ============================================================
474
+ print("\n" + "=" * 70)
475
+ print("Step 7: Training Final Enhanced Models on Full Data")
476
+ print("=" * 70)
477
+
478
+ X_full = df[enhanced_feature_cols].fillna(df[enhanced_feature_cols].median())
479
+
480
+ final_xgb_full = XGBClassifier(
481
+ n_estimators=2000, max_depth=7, learning_rate=0.03,
482
+ colsample_bytree=0.8, subsample=0.8, min_child_weight=3,
483
+ gamma=0.1, reg_alpha=0.1, reg_lambda=1.0,
484
+ scale_pos_weight=scale_pos_weight,
485
+ use_label_encoder=False, eval_metric='auc',
486
+ tree_method='hist', random_state=42, n_jobs=-1
487
+ )
488
+ final_xgb_full.fit(X_full, y)
489
+
490
+ final_lgb_full = LGBMClassifier(
491
+ n_estimators=2000, max_depth=7, learning_rate=0.03,
492
+ colsample_bytree=0.8, subsample=0.8, min_child_samples=10,
493
+ reg_alpha=0.1, reg_lambda=1.0,
494
+ scale_pos_weight=scale_pos_weight,
495
+ random_state=42, n_jobs=-1, verbose=-1
496
+ )
497
+ final_lgb_full.fit(X_full, y)
498
+
499
+ final_cat_full = CatBoostClassifier(
500
+ iterations=2000, depth=7, learning_rate=0.03,
501
+ l2_leaf_reg=3.0, auto_class_weights='Balanced',
502
+ random_seed=42, verbose=0
503
+ )
504
+ final_cat_full.fit(X_full, y)
505
+
506
+ # Save enhanced models
507
+ joblib.dump(final_xgb_full, f"{OUTPUT_DIR}/enhanced_xgb.joblib")
508
+ joblib.dump(final_lgb_full, f"{OUTPUT_DIR}/enhanced_lgb.joblib")
509
+ final_cat_full.save_model(f"{OUTPUT_DIR}/enhanced_cat.cbm")
510
+ joblib.dump(enhanced_feature_cols, f"{OUTPUT_DIR}/enhanced_feature_columns.joblib")
511
+
512
+ # ============================================================
513
+ # 8. SHAP ANALYSIS ON ENHANCED MODEL
514
+ # ============================================================
515
+ print("\n" + "=" * 70)
516
+ print("Step 8: SHAP Analysis on Enhanced Model")
517
+ print("=" * 70)
518
+
519
+ explainer = shap.TreeExplainer(final_xgb_full)
520
+ shap_values = explainer.shap_values(X_full)
521
+
522
+ mean_shap = np.abs(shap_values).mean(axis=0)
523
+ shap_df = pd.DataFrame({
524
+ 'feature': enhanced_feature_cols,
525
+ 'mean_abs_shap': mean_shap,
526
+ 'source': ['original' if f not in gottman_proxy_features + survival_features
527
+ else 'gottman' if f in gottman_proxy_features
528
+ else 'survival' for f in enhanced_feature_cols]
529
+ }).sort_values('mean_abs_shap', ascending=False)
530
+
531
+ print("\nTop 30 Features in Enhanced Model:")
532
+ for i, row in shap_df.head(30).iterrows():
533
+ marker = {'original': ' ', 'gottman': '🔴', 'survival': '🔵'}[row['source']]
534
+ print(f" {marker} {row['feature']:50s} SHAP={row['mean_abs_shap']:.4f} [{row['source']}]")
535
+
536
+ # New features contribution
537
+ new_features_shap = shap_df[shap_df['source'] != 'original']
538
+ print(f"\nNew features in top 30: {len(shap_df.head(30)[shap_df.head(30)['source'] != 'original'])}")
539
+ print(f"Total SHAP from Gottman features: {shap_df[shap_df['source']=='gottman']['mean_abs_shap'].sum():.4f}")
540
+ print(f"Total SHAP from Survival features: {shap_df[shap_df['source']=='survival']['mean_abs_shap'].sum():.4f}")
541
+ print(f"Total SHAP from Original features: {shap_df[shap_df['source']=='original']['mean_abs_shap'].sum():.4f}")
542
+
543
+ shap_df.to_csv(f"{OUTPUT_DIR}/enhanced_shap_importance.csv", index=False)
544
+
545
+ # SHAP summary plot
546
+ fig, ax = plt.subplots(figsize=(12, 12))
547
+ shap.summary_plot(shap_values, X_full, feature_names=enhanced_feature_cols, max_display=30, show=False)
548
+ plt.tight_layout()
549
+ plt.savefig(f"{OUTPUT_DIR}/figures/enhanced_shap_summary.png", dpi=150, bbox_inches='tight')
550
+ plt.close()
551
+
552
+ # ============================================================
553
+ # 9. COMPARISON VISUALIZATIONS
554
+ # ============================================================
555
+ print("\n" + "=" * 70)
556
+ print("Step 9: Comparison Visualizations")
557
+ print("=" * 70)
558
+
559
+ # ROC curves comparison
560
+ fig, ax = plt.subplots(figsize=(9, 8))
561
+
562
+ fpr_base, tpr_base, _ = roc_curve(y, baseline_preds)
563
+ fpr_enh, tpr_enh, _ = roc_curve(y, enhanced_preds)
564
+
565
+ ax.plot(fpr_base, tpr_base, label=f'Baseline Ensemble (AUC={baseline_results["Ensemble"]["AUC-ROC"]:.4f})',
566
+ linewidth=2, color='#95a5a6', linestyle='--')
567
+ ax.plot(fpr_enh, tpr_enh, label=f'Enhanced Ensemble (AUC={enhanced_results["Ensemble"]["AUC-ROC"]:.4f})',
568
+ linewidth=2.5, color='#e74c3c')
569
+ ax.plot([0, 1], [0, 1], 'k--', alpha=0.3)
570
+ ax.set_xlabel('False Positive Rate', fontsize=12)
571
+ ax.set_ylabel('True Positive Rate', fontsize=12)
572
+ ax.set_title('ROC Curves: Baseline vs Enhanced Model\n(+Gottman Behavioral + Survival Priors)', fontsize=14)
573
+ ax.legend(fontsize=11, loc='lower right')
574
+ ax.grid(True, alpha=0.3)
575
+ plt.tight_layout()
576
+ plt.savefig(f"{OUTPUT_DIR}/figures/roc_comparison.png", dpi=150, bbox_inches='tight')
577
+ plt.close()
578
+
579
+ # Feature source contribution bar chart
580
+ fig, ax = plt.subplots(figsize=(8, 5))
581
+ source_shap = shap_df.groupby('source')['mean_abs_shap'].agg(['sum', 'count', 'mean'])
582
+ colors = {'original': '#3498db', 'gottman': '#e74c3c', 'survival': '#2ecc71'}
583
+ bars = ax.bar(source_shap.index, source_shap['sum'], color=[colors[s] for s in source_shap.index])
584
+ ax.set_ylabel('Total SHAP Importance', fontsize=12)
585
+ ax.set_title('Feature Source Contribution to Enhanced Model', fontsize=14)
586
+ for bar, (idx, row) in zip(bars, source_shap.iterrows()):
587
+ ax.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.01,
588
+ f'n={int(row["count"])}', ha='center', fontsize=10)
589
+ plt.tight_layout()
590
+ plt.savefig(f"{OUTPUT_DIR}/figures/source_contribution.png", dpi=150, bbox_inches='tight')
591
+ plt.close()
592
+
593
+ # Improvement metrics bar chart
594
+ fig, ax = plt.subplots(figsize=(10, 6))
595
+ metrics = ['AUC-ROC', 'AUC-PR', 'Accuracy', 'F1', 'Precision', 'Recall']
596
+ baseline_vals = [baseline_results['Ensemble'][m] for m in metrics]
597
+ enhanced_vals = [enhanced_results['Ensemble'][m] for m in metrics]
598
+
599
+ x = np.arange(len(metrics))
600
+ width = 0.35
601
+ bars1 = ax.bar(x - width/2, baseline_vals, width, label='Baseline', color='#95a5a6', alpha=0.8)
602
+ bars2 = ax.bar(x + width/2, enhanced_vals, width, label='Enhanced', color='#e74c3c', alpha=0.8)
603
+
604
+ ax.set_ylabel('Score', fontsize=12)
605
+ ax.set_title('Baseline vs Enhanced Model Metrics', fontsize=14)
606
+ ax.set_xticks(x)
607
+ ax.set_xticklabels(metrics, fontsize=10)
608
+ ax.legend(fontsize=11)
609
+ ax.set_ylim(0.4, 1.0)
610
+ ax.grid(True, alpha=0.3, axis='y')
611
+
612
+ # Add delta annotations
613
+ for i, (b, e) in enumerate(zip(baseline_vals, enhanced_vals)):
614
+ delta = e - b
615
+ if delta > 0:
616
+ ax.annotate(f'+{delta:.3f}', xy=(x[i] + width/2, e),
617
+ xytext=(0, 5), textcoords='offset points',
618
+ ha='center', fontsize=8, color='green', fontweight='bold')
619
+
620
+ plt.tight_layout()
621
+ plt.savefig(f"{OUTPUT_DIR}/figures/metrics_comparison.png", dpi=150, bbox_inches='tight')
622
+ plt.close()
623
+
624
+ # ============================================================
625
+ # 10. SAVE ENHANCED CONFIG
626
+ # ============================================================
627
+
628
+ best_threshold = enhanced_results['Ensemble']['Threshold']
629
+ enhanced_config = {
630
+ 'model_version': 'v2.0-enhanced',
631
+ 'weights': {'xgboost': 0.4, 'lightgbm': 0.35, 'catboost': 0.25},
632
+ 'optimal_threshold': float(best_threshold),
633
+ 'feature_columns': enhanced_feature_cols,
634
+ 'feature_sources': {
635
+ 'original': [f for f in enhanced_feature_cols if f not in gottman_proxy_features + survival_features],
636
+ 'gottman_proxy': gottman_proxy_features,
637
+ 'survival_prior': survival_features,
638
+ },
639
+ 'metrics': {
640
+ 'auc_roc': float(enhanced_results['Ensemble']['AUC-ROC']),
641
+ 'auc_pr': float(enhanced_results['Ensemble']['AUC-PR']),
642
+ 'f1': float(enhanced_results['Ensemble']['F1']),
643
+ 'accuracy': float(enhanced_results['Ensemble']['Accuracy']),
644
+ 'brier': float(enhanced_results['Ensemble']['Brier']),
645
+ },
646
+ 'improvements_over_baseline': improvements,
647
+ 'data_sources': {
648
+ 'primary': 'mstz/speeddating (1048 encounters)',
649
+ 'gottman_behavioral': 'andrewmvd/divorce-prediction (170 couples, Kaggle)',
650
+ 'survival_longitudinal': 'vedastro-org/15000-Famous-People-Marriage-Divorce-Info (14688 marriages)',
651
+ }
652
+ }
653
+
654
+ with open(f"{OUTPUT_DIR}/enhanced_config.json", "w") as f:
655
+ json.dump(enhanced_config, f, indent=2)
656
+
657
+ # ============================================================
658
+ # FINAL SUMMARY
659
+ # ============================================================
660
+ print("\n" + "=" * 70)
661
+ print("PHASE 3 — INTEGRATION COMPLETE: IMPROVEMENT SUMMARY")
662
+ print("=" * 70)
663
+ print(f"""
664
+ Model Enhancement: v1.0 (baseline) → v2.0 (enhanced)
665
+ =====================================================
666
+
667
+ Data Sources Added:
668
+ Phase 1: Gottman Behavioral Model (54 Q divorce predictors → {len(gottman_proxy_features)} proxy features)
669
+ Phase 2: Marriage Duration Survival (14,688 marriages → {len(survival_features)} prior features)
670
+
671
+ Feature Count: {len(original_feature_cols)} → {len(enhanced_feature_cols)} (+{len(enhanced_feature_cols) - len(original_feature_cols)} new features)
672
+
673
+ PERFORMANCE COMPARISON (5-Fold CV, Ensemble):
674
+ """)
675
+
676
+ print(f"{'Metric':<20} {'v1.0 Baseline':>14} {'v2.0 Enhanced':>14} {'Change':>14}")
677
+ print("-" * 65)
678
+ for metric in ['AUC-ROC', 'AUC-PR', 'Brier', 'Accuracy', 'F1', 'Precision', 'Recall']:
679
+ b = improvements[metric]['baseline']
680
+ e = improvements[metric]['enhanced']
681
+ d = improvements[metric]['delta']
682
+ print(f"{metric:<20} {b:>14.4f} {e:>14.4f} {d:>+14.4f}")
683
+
684
+ print(f"""
685
+ Files Saved:
686
+ {OUTPUT_DIR}/enhanced_xgb.joblib
687
+ {OUTPUT_DIR}/enhanced_lgb.joblib
688
+ {OUTPUT_DIR}/enhanced_cat.cbm
689
+ {OUTPUT_DIR}/enhanced_config.json
690
+ {OUTPUT_DIR}/enhanced_feature_columns.joblib
691
+ {OUTPUT_DIR}/enhanced_shap_importance.csv
692
+ {OUTPUT_DIR}/figures/*.png
693
+
694
+ DONE!
695
+ """)