relationship-longevity-predictor / phase3_integration.py
Builder-Neekhil's picture
Upload phase3_integration.py with huggingface_hub
e434e59 verified
"""
Phase 3: Integration — Augment Original Model with Phase 1 & Phase 2 Signals
=============================================================================
Goal: Add Gottman behavioral risk features + longitudinal survival priors
to the original speed dating model and measure improvement.
We create "proxy" Gottman features from the speed dating data by mapping
the existing personality/perception features to Gottman dimensions. This
is a cross-domain feature transfer approach.
"""
import os
import json
import warnings
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
from datasets import load_dataset
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import (
roc_auc_score, accuracy_score, f1_score, classification_report,
precision_score, recall_score, average_precision_score,
brier_score_loss, precision_recall_curve, roc_curve
)
from sklearn.preprocessing import LabelEncoder
from xgboost import XGBClassifier
from lightgbm import LGBMClassifier
from catboost import CatBoostClassifier
import joblib
import shap
warnings.filterwarnings('ignore')
np.random.seed(42)
OUTPUT_DIR = "/app/phase3_output"
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(f"{OUTPUT_DIR}/figures", exist_ok=True)
# ============================================================
# 1. LOAD ORIGINAL MODEL BASELINE
# ============================================================
print("=" * 70)
print("PHASE 3: INTEGRATION — MEASURE IMPROVEMENTS")
print("=" * 70)
# Load original data
ds = load_dataset("mstz/speeddating", "dating", split="train")
df = ds.to_pandas()
# Load phase outputs
with open("/app/phase1_output/gottman_recipe.json") as f:
gottman_recipe = json.load(f)
with open("/app/phase2_output/survival_recipe.json") as f:
survival_recipe = json.load(f)
with open("/app/phase2_output/longevity_priors.json") as f:
longevity_priors = json.load(f)
print(f"Speed dating dataset: {df.shape}")
print(f"Gottman dimensions: {list(gottman_recipe['dimensions'].keys())}")
print(f"Survival priors: {list(longevity_priors.keys())}")
# ============================================================
# 2. REPRODUCE ORIGINAL FEATURES (BASELINE)
# ============================================================
print("\n" + "=" * 70)
print("Step 2: Reproducing Original Baseline Features")
print("=" * 70)
# Same feature engineering as original model
traits = ['attractiveness', 'sincerity', 'intelligence', 'humor', 'ambition']
for trait in traits:
dater_rates_partner = f'reported_{trait}_of_dated_from_dater'
partner_rates_dater = f'{trait}_score_of_dater_from_dated'
if dater_rates_partner in df.columns and partner_rates_dater in df.columns:
df[f'{trait}_perception_gap'] = df[dater_rates_partner] - df[partner_rates_dater]
df[f'{trait}_mutual_score'] = (df[dater_rates_partner] + df[partner_rates_dater]) / 2
df[f'{trait}_perception_product'] = df[dater_rates_partner] * df[partner_rates_dater]
for trait in traits:
importance_col = f'{trait}_importance_for_dater'
score_col = f'{trait}_score_of_dater_from_dated'
if importance_col in df.columns and score_col in df.columns:
df[f'{trait}_value_fulfillment_dater'] = df[importance_col] * df[score_col] / 100
for trait in traits:
self_col = f'self_reported_{trait}_of_dater'
partner_score_col = f'{trait}_score_of_dater_from_dated'
if self_col in df.columns and partner_score_col in df.columns:
df[f'{trait}_self_awareness_gap'] = df[self_col] - df[partner_score_col]
df['total_perception_gap'] = sum(df[f'{t}_perception_gap'].fillna(0) for t in traits) / len(traits)
df['total_mutual_score'] = sum(df[f'{t}_mutual_score'].fillna(0) for t in traits) / len(traits)
df['total_value_fulfillment'] = sum(df[f'{t}_value_fulfillment_dater'].fillna(0) for t in traits)
df['total_self_awareness_gap'] = sum(df[f'{t}_self_awareness_gap'].fillna(0) for t in traits) / len(traits)
df['expectation_meets_reality'] = df['expected_satisfaction_of_dater'] * df['dater_liked_dated']
df['confidence_calibration'] = (
df['expected_number_of_likes_of_dater_from_20_people'] / 20 -
df['probability_dated_wants_to_date'] / 10
)
df['age_gap_abs'] = df['age_difference']
df['age_gap_squared'] = df['age_difference'] ** 2
df['dater_is_older'] = (df['dater_age'] > df['dated_age']).astype(int)
df['combined_age'] = df['dater_age'] + df['dated_age']
interest_cols = [c for c in df.columns if c.startswith('dater_interest_in_')]
if interest_cols:
df['interest_diversity'] = df[interest_cols].std(axis=1)
df['interest_intensity'] = df[interest_cols].mean(axis=1)
df['max_interest'] = df[interest_cols].max(axis=1)
df['min_interest'] = df[interest_cols].min(axis=1)
df['interest_range'] = df['max_interest'] - df['min_interest']
importance_dater_cols = [
'attractiveness_importance_for_dater', 'sincerity_importance_for_dater',
'intelligence_importance_for_dater', 'humor_importance_for_dater',
'ambition_importance_for_dater', 'shared_interests_importance_for_dater'
]
importance_dated_cols = [
'attractiveness_importance_for_dated', 'sincerity_importance_for_dated',
'intelligence_importance_for_dated', 'humor_importance_for_dated',
'ambition_importance_for_dated', 'shared_interests_importance_for_dated'
]
df['importance_concentration_dater'] = df[importance_dater_cols].std(axis=1)
df['max_importance_dater'] = df[importance_dater_cols].max(axis=1)
df['importance_concentration_dated'] = df[importance_dated_cols].std(axis=1)
for i, (d1, d2) in enumerate(zip(importance_dater_cols, importance_dated_cols)):
df[f'importance_alignment_{i}'] = abs(df[d1] - df[d2])
df['total_importance_alignment'] = sum(
abs(df[d1] - df[d2]) for d1, d2 in zip(importance_dater_cols, importance_dated_cols)
)
le_race = LabelEncoder()
df['dater_race_encoded'] = le_race.fit_transform(df['dater_race'].fillna('Unknown'))
df['dated_race_encoded'] = le_race.transform(df['dated_race'].fillna('Unknown'))
df['race_match'] = (df['dater_race'] == df['dated_race']).astype(int)
df['is_dater_male_int'] = df['is_dater_male'].astype(int)
df['are_same_race_int'] = df['are_same_race'].astype(int)
df['already_met_int'] = df['already_met_before'].astype(int)
# Original feature set
exclude_cols = [
'is_match', 'dater_wants_to_date', 'dated_wants_to_date',
'dater_race', 'dated_race', 'already_met_before', 'is_dater_male',
'are_same_race', 'decision_agreement'
]
original_feature_cols = [c for c in df.columns if c not in exclude_cols
and c not in ['decision_agreement']
and df[c].dtype in ['float64', 'int64', 'int32', 'float32']]
# Remove any new features we're about to add
original_feature_cols = [c for c in original_feature_cols if not c.startswith('gottman_')
and not c.startswith('survival_') and not c.startswith('prior_')]
print(f"Original features: {len(original_feature_cols)}")
# ============================================================
# 3. ADD PHASE 1 FEATURES — GOTTMAN PROXY SCORES
# ============================================================
print("\n" + "=" * 70)
print("Step 3: Adding Gottman Proxy Features (Phase 1)")
print("=" * 70)
# Map speed dating features to Gottman dimensions
# This is cross-domain feature transfer: we use the SHAP insights from the
# Gottman model to create proxy scores from available speed dating features
# --- CONTEMPT PROXY ---
# Gottman finding: Contempt (mutual disrespect, low regard) is the #1 divorce predictor
# Speed dating proxy: Low mutual scores, high perception gaps (I see you as worse than you see me)
df['gottman_proxy_contempt'] = (
-df['total_mutual_score'] + # Low mutual regard → contempt-like
abs(df['total_perception_gap']) + # Asymmetric perception → disrespect
abs(df['total_self_awareness_gap']) * 0.5 # Low self-awareness → unrealistic expectations
)
# --- CRITICISM PROXY ---
# Gottman: Attacking character. Speed dating: Harsh gap between what you expect vs what you see
df['gottman_proxy_criticism'] = (
df['total_importance_alignment'] * 0.1 + # Misaligned values = source of criticism
abs(df['total_perception_gap']) # I rate you lower than you rate me = implicit criticism
)
# --- DEFENSIVENESS PROXY ---
# Gottman: Counter-attacking, refusing to accept influence
# Proxy: High self-ratings vs low partner ratings (inflated self-view)
df['gottman_proxy_defensiveness'] = (
df['total_self_awareness_gap'].clip(lower=0) # I think I'm better than you think I am
)
# --- STONEWALLING PROXY ---
# Gottman: Withdrawing, shutting down
# Proxy: Low expected satisfaction, low engagement (low liked score despite meeting)
df['gottman_proxy_stonewalling'] = (
(10 - df['dater_liked_dated'].fillna(5)) * 0.3 + # Low liking = withdrawal
(10 - df['probability_dated_wants_to_date'].fillna(5)) * 0.2 + # Expected rejection
(1 - df['interests_correlation'].fillna(0.5)) # No shared interests = no engagement
)
# --- LOVE MAPS PROXY ---
# Gottman: Knowing partner's inner world.
# Proxy: Interest correlation + shared interests score + mutual perception accuracy
df['gottman_proxy_love_maps'] = (
df['interests_correlation'].fillna(0) * 2 +
df['shared_interests_score_of_dater_from_dated'].fillna(5) * 0.3 +
df['reported_shared_interests_of_dated_from_dater'].fillna(5) * 0.3 -
abs(df['total_perception_gap']) * 0.5 # Accurate mutual perception = knowing each other
)
# --- SHARED GOALS PROXY ---
# Proxy: Value alignment + similar importance weights
df['gottman_proxy_shared_goals'] = (
-df['total_importance_alignment'] * 0.1 + # Similar values → shared goals
df['total_value_fulfillment'] * 0.5 + # Partner meets your values → aligned
df['interests_correlation'].fillna(0) * 2 # Shared interests → shared life direction
)
# --- COMBINED GOTTMAN SCORES ---
# Four Horsemen combined (higher = worse)
df['gottman_proxy_horsemen'] = (
df['gottman_proxy_contempt'] +
df['gottman_proxy_criticism'] +
df['gottman_proxy_defensiveness'] +
df['gottman_proxy_stonewalling']
)
# Positive combined (higher = better)
df['gottman_proxy_positive'] = (
df['gottman_proxy_love_maps'] +
df['gottman_proxy_shared_goals']
)
# Gottman Ratio (the famous 5:1 positive to negative ratio)
df['gottman_proxy_ratio'] = (
(df['gottman_proxy_positive'] + 10) /
(df['gottman_proxy_horsemen'] + 10)
)
# Horsemen interactions (from Phase 1 SHAP: contempt × stonewalling was top predictor)
df['gottman_proxy_contempt_x_stonewalling'] = df['gottman_proxy_contempt'] * df['gottman_proxy_stonewalling']
df['gottman_proxy_criticism_x_defensiveness'] = df['gottman_proxy_criticism'] * df['gottman_proxy_defensiveness']
df['gottman_proxy_love_x_goals'] = df['gottman_proxy_love_maps'] * df['gottman_proxy_shared_goals']
# Horsemen minus Positive (net risk)
df['gottman_proxy_net_risk'] = df['gottman_proxy_horsemen'] - df['gottman_proxy_positive']
gottman_proxy_features = [c for c in df.columns if c.startswith('gottman_proxy_')]
print(f"Gottman proxy features added: {len(gottman_proxy_features)}")
for f in gottman_proxy_features:
print(f" {f}: mean={df[f].mean():.3f}, std={df[f].std():.3f}")
# ============================================================
# 4. ADD PHASE 2 FEATURES — SURVIVAL PRIORS
# ============================================================
print("\n" + "=" * 70)
print("Step 4: Adding Survival Prior Features (Phase 2)")
print("=" * 70)
# Survival priors from the Vedastro longitudinal data
# Key findings from Phase 2:
cox_hazard_ratios = survival_recipe.get('cox_summary', {})
# Age-at-relationship features (from Cox PH: age_at_marriage HR=0.96, significant)
# Younger couples face higher divorce risk
df['survival_age_risk_dater'] = np.where(
df['dater_age'] < 22, longevity_priors['age_at_marriage_young']['divorce_rate'],
np.where(df['dater_age'] < 30, longevity_priors['age_at_marriage_prime']['divorce_rate'],
np.where(df['dater_age'] < 40, longevity_priors['age_at_marriage_mature']['divorce_rate'],
longevity_priors['age_at_marriage_late']['divorce_rate']))
)
# Average age risk for the couple
mean_age = (df['dater_age'] + df['dated_age']) / 2
df['survival_couple_age_risk'] = np.where(
mean_age < 22, longevity_priors['age_at_marriage_young']['divorce_rate'],
np.where(mean_age < 30, longevity_priors['age_at_marriage_prime']['divorce_rate'],
np.where(mean_age < 40, longevity_priors['age_at_marriage_mature']['divorce_rate'],
longevity_priors['age_at_marriage_late']['divorce_rate']))
)
# First vs subsequent relationship risk (from Cox PH: is_first_marriage HR=0.26, huge effect)
# We use already_met as a weak proxy for prior relationship history
df['survival_prior_relationship_risk'] = np.where(
df['already_met_int'] == 1,
longevity_priors['marriage_second']['divorce_rate'], # Already know each other → not "first"
longevity_priors['marriage_first']['divorce_rate'] # First meeting → first relationship proxy
)
# Divorce timing hazard (from Phase 2: 41% of divorces at 3-7 years, 32% at 8-14)
# Age gap as a risk amplifier (larger gaps → earlier divorce)
divorce_timing = survival_recipe['divorce_timing']
df['survival_early_risk'] = (
divorce_timing['honeymoon_crisis_0_2yr'] +
divorce_timing['seven_year_itch_3_7yr']
) # Base rate: 54.4% of divorces happen in first 7 years
# Overall base divorce rate
df['survival_base_divorce_rate'] = longevity_priors['overall']['divorce_rate']
# Age gap interaction with survival (from Cox: age matters)
df['survival_age_gap_risk'] = (
df['survival_couple_age_risk'] *
(1 + df['age_gap_abs'] * 0.02) # Each year of age gap increases risk by 2%
)
# Combined survival risk score
df['survival_combined_risk'] = (
df['survival_couple_age_risk'] * 0.4 +
df['survival_prior_relationship_risk'] * 0.3 +
df['survival_age_gap_risk'] * 0.3
)
survival_features = [c for c in df.columns if c.startswith('survival_')]
print(f"Survival prior features added: {len(survival_features)}")
for f in survival_features:
print(f" {f}: mean={df[f].mean():.4f}, std={df[f].std():.4f}")
# ============================================================
# 5. TRAIN ENHANCED MODEL & COMPARE
# ============================================================
print("\n" + "=" * 70)
print("Step 5: Training Enhanced Model & Comparing to Baseline")
print("=" * 70)
y = df['is_match'].values
scale_pos_weight = (y == 0).sum() / (y == 1).sum()
# Define feature sets
enhanced_feature_cols = original_feature_cols + gottman_proxy_features + survival_features
# Remove any duplicates
enhanced_feature_cols = list(dict.fromkeys(enhanced_feature_cols))
print(f"\nFeature comparison:")
print(f" Original: {len(original_feature_cols)} features")
print(f" + Gottman: +{len(gottman_proxy_features)} features")
print(f" + Survival:+{len(survival_features)} features")
print(f" Enhanced: {len(enhanced_feature_cols)} features")
X_original = df[original_feature_cols].fillna(df[original_feature_cols].median()).values
X_enhanced = df[enhanced_feature_cols].fillna(df[enhanced_feature_cols].median()).values
# Train both models with same hyperparameters
n_splits = 5
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
def train_and_evaluate(X, y, label, feature_names):
"""Train XGB+LGB+CAT ensemble with 5-fold CV."""
oof_xgb = np.zeros(len(y))
oof_lgb = np.zeros(len(y))
oof_cat = np.zeros(len(y))
for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
X_train, X_val = X[train_idx], X[val_idx]
y_train, y_val = y[train_idx], y[val_idx]
# XGBoost
xgb = XGBClassifier(
n_estimators=1500, max_depth=7, learning_rate=0.03,
colsample_bytree=0.8, subsample=0.8, min_child_weight=3,
gamma=0.1, reg_alpha=0.1, reg_lambda=1.0,
scale_pos_weight=scale_pos_weight,
use_label_encoder=False, eval_metric='auc',
tree_method='hist', random_state=42, n_jobs=-1
)
xgb.fit(X_train, y_train, eval_set=[(X_val, y_val)], verbose=False)
oof_xgb[val_idx] = xgb.predict_proba(X_val)[:, 1]
# LightGBM
lgb = LGBMClassifier(
n_estimators=1500, max_depth=7, learning_rate=0.03,
colsample_bytree=0.8, subsample=0.8, min_child_samples=10,
reg_alpha=0.1, reg_lambda=1.0,
scale_pos_weight=scale_pos_weight,
random_state=42, n_jobs=-1, verbose=-1
)
lgb.fit(X_train, y_train, eval_set=[(X_val, y_val)])
oof_lgb[val_idx] = lgb.predict_proba(X_val)[:, 1]
# CatBoost
cat = CatBoostClassifier(
iterations=1500, depth=7, learning_rate=0.03,
l2_leaf_reg=3.0, auto_class_weights='Balanced',
random_seed=42, verbose=0
)
cat.fit(X_train, y_train, eval_set=(X_val, y_val))
oof_cat[val_idx] = cat.predict_proba(X_val)[:, 1]
# Ensemble
oof_ens = 0.4 * oof_xgb + 0.35 * oof_lgb + 0.25 * oof_cat
# Compute metrics
results = {}
for name, preds in [('XGBoost', oof_xgb), ('LightGBM', oof_lgb),
('CatBoost', oof_cat), ('Ensemble', oof_ens)]:
auc = roc_auc_score(y, preds)
ap = average_precision_score(y, preds)
brier = brier_score_loss(y, preds)
precision_curve, recall_curve, thresholds = precision_recall_curve(y, preds)
f1_scores = 2 * (precision_curve * recall_curve) / (precision_curve + recall_curve + 1e-10)
optimal_threshold = thresholds[np.argmax(f1_scores)]
y_pred = (preds >= optimal_threshold).astype(int)
results[name] = {
'AUC-ROC': auc, 'AUC-PR': ap, 'Brier': brier,
'Accuracy': accuracy_score(y, y_pred),
'F1': f1_score(y, y_pred),
'Precision': precision_score(y, y_pred),
'Recall': recall_score(y, y_pred),
'Threshold': optimal_threshold
}
return results, oof_ens, xgb, lgb, cat
print("\nTraining ORIGINAL model (baseline)...")
baseline_results, baseline_preds, _, _, _ = train_and_evaluate(
X_original, y, "Original", original_feature_cols)
print("\nTraining ENHANCED model (+ Gottman + Survival)...")
enhanced_results, enhanced_preds, final_xgb, final_lgb, final_cat = train_and_evaluate(
X_enhanced, y, "Enhanced", enhanced_feature_cols)
# ============================================================
# 6. IMPROVEMENT ANALYSIS
# ============================================================
print("\n" + "=" * 70)
print("Step 6: IMPROVEMENT ANALYSIS")
print("=" * 70)
print("\n" + "=" * 70)
print(f"{'METRIC':<20} {'BASELINE':>12} {'ENHANCED':>12} {'DELTA':>12} {'% CHANGE':>12}")
print("=" * 70)
improvements = {}
for metric in ['AUC-ROC', 'AUC-PR', 'Brier', 'Accuracy', 'F1', 'Precision', 'Recall']:
base_val = baseline_results['Ensemble'][metric]
enh_val = enhanced_results['Ensemble'][metric]
delta = enh_val - base_val
pct = delta / base_val * 100 if base_val != 0 else 0
# For Brier, lower is better
if metric == 'Brier':
direction = '✅' if delta < 0 else '❌'
else:
direction = '✅' if delta > 0 else '❌' if delta < 0 else '➖'
print(f"{metric:<20} {base_val:>12.4f} {enh_val:>12.4f} {delta:>+12.4f} {pct:>+11.2f}% {direction}")
improvements[metric] = {'baseline': base_val, 'enhanced': enh_val, 'delta': delta, 'pct_change': pct}
# Per-model breakdown
print(f"\n\nPer-model AUC-ROC comparison:")
print(f"{'Model':<12} {'Baseline':>12} {'Enhanced':>12} {'Delta':>12}")
print("-" * 50)
for model in ['XGBoost', 'LightGBM', 'CatBoost', 'Ensemble']:
base = baseline_results[model]['AUC-ROC']
enh = enhanced_results[model]['AUC-ROC']
delta = enh - base
direction = '✅' if delta > 0 else '❌'
print(f"{model:<12} {base:>12.4f} {enh:>12.4f} {delta:>+12.4f} {direction}")
# ============================================================
# 7. TRAIN FINAL ENHANCED MODELS ON FULL DATA
# ============================================================
print("\n" + "=" * 70)
print("Step 7: Training Final Enhanced Models on Full Data")
print("=" * 70)
X_full = df[enhanced_feature_cols].fillna(df[enhanced_feature_cols].median())
final_xgb_full = XGBClassifier(
n_estimators=2000, max_depth=7, learning_rate=0.03,
colsample_bytree=0.8, subsample=0.8, min_child_weight=3,
gamma=0.1, reg_alpha=0.1, reg_lambda=1.0,
scale_pos_weight=scale_pos_weight,
use_label_encoder=False, eval_metric='auc',
tree_method='hist', random_state=42, n_jobs=-1
)
final_xgb_full.fit(X_full, y)
final_lgb_full = LGBMClassifier(
n_estimators=2000, max_depth=7, learning_rate=0.03,
colsample_bytree=0.8, subsample=0.8, min_child_samples=10,
reg_alpha=0.1, reg_lambda=1.0,
scale_pos_weight=scale_pos_weight,
random_state=42, n_jobs=-1, verbose=-1
)
final_lgb_full.fit(X_full, y)
final_cat_full = CatBoostClassifier(
iterations=2000, depth=7, learning_rate=0.03,
l2_leaf_reg=3.0, auto_class_weights='Balanced',
random_seed=42, verbose=0
)
final_cat_full.fit(X_full, y)
# Save enhanced models
joblib.dump(final_xgb_full, f"{OUTPUT_DIR}/enhanced_xgb.joblib")
joblib.dump(final_lgb_full, f"{OUTPUT_DIR}/enhanced_lgb.joblib")
final_cat_full.save_model(f"{OUTPUT_DIR}/enhanced_cat.cbm")
joblib.dump(enhanced_feature_cols, f"{OUTPUT_DIR}/enhanced_feature_columns.joblib")
# ============================================================
# 8. SHAP ANALYSIS ON ENHANCED MODEL
# ============================================================
print("\n" + "=" * 70)
print("Step 8: SHAP Analysis on Enhanced Model")
print("=" * 70)
explainer = shap.TreeExplainer(final_xgb_full)
shap_values = explainer.shap_values(X_full)
mean_shap = np.abs(shap_values).mean(axis=0)
shap_df = pd.DataFrame({
'feature': enhanced_feature_cols,
'mean_abs_shap': mean_shap,
'source': ['original' if f not in gottman_proxy_features + survival_features
else 'gottman' if f in gottman_proxy_features
else 'survival' for f in enhanced_feature_cols]
}).sort_values('mean_abs_shap', ascending=False)
print("\nTop 30 Features in Enhanced Model:")
for i, row in shap_df.head(30).iterrows():
marker = {'original': ' ', 'gottman': '🔴', 'survival': '🔵'}[row['source']]
print(f" {marker} {row['feature']:50s} SHAP={row['mean_abs_shap']:.4f} [{row['source']}]")
# New features contribution
new_features_shap = shap_df[shap_df['source'] != 'original']
print(f"\nNew features in top 30: {len(shap_df.head(30)[shap_df.head(30)['source'] != 'original'])}")
print(f"Total SHAP from Gottman features: {shap_df[shap_df['source']=='gottman']['mean_abs_shap'].sum():.4f}")
print(f"Total SHAP from Survival features: {shap_df[shap_df['source']=='survival']['mean_abs_shap'].sum():.4f}")
print(f"Total SHAP from Original features: {shap_df[shap_df['source']=='original']['mean_abs_shap'].sum():.4f}")
shap_df.to_csv(f"{OUTPUT_DIR}/enhanced_shap_importance.csv", index=False)
# SHAP summary plot
fig, ax = plt.subplots(figsize=(12, 12))
shap.summary_plot(shap_values, X_full, feature_names=enhanced_feature_cols, max_display=30, show=False)
plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/figures/enhanced_shap_summary.png", dpi=150, bbox_inches='tight')
plt.close()
# ============================================================
# 9. COMPARISON VISUALIZATIONS
# ============================================================
print("\n" + "=" * 70)
print("Step 9: Comparison Visualizations")
print("=" * 70)
# ROC curves comparison
fig, ax = plt.subplots(figsize=(9, 8))
fpr_base, tpr_base, _ = roc_curve(y, baseline_preds)
fpr_enh, tpr_enh, _ = roc_curve(y, enhanced_preds)
ax.plot(fpr_base, tpr_base, label=f'Baseline Ensemble (AUC={baseline_results["Ensemble"]["AUC-ROC"]:.4f})',
linewidth=2, color='#95a5a6', linestyle='--')
ax.plot(fpr_enh, tpr_enh, label=f'Enhanced Ensemble (AUC={enhanced_results["Ensemble"]["AUC-ROC"]:.4f})',
linewidth=2.5, color='#e74c3c')
ax.plot([0, 1], [0, 1], 'k--', alpha=0.3)
ax.set_xlabel('False Positive Rate', fontsize=12)
ax.set_ylabel('True Positive Rate', fontsize=12)
ax.set_title('ROC Curves: Baseline vs Enhanced Model\n(+Gottman Behavioral + Survival Priors)', fontsize=14)
ax.legend(fontsize=11, loc='lower right')
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/figures/roc_comparison.png", dpi=150, bbox_inches='tight')
plt.close()
# Feature source contribution bar chart
fig, ax = plt.subplots(figsize=(8, 5))
source_shap = shap_df.groupby('source')['mean_abs_shap'].agg(['sum', 'count', 'mean'])
colors = {'original': '#3498db', 'gottman': '#e74c3c', 'survival': '#2ecc71'}
bars = ax.bar(source_shap.index, source_shap['sum'], color=[colors[s] for s in source_shap.index])
ax.set_ylabel('Total SHAP Importance', fontsize=12)
ax.set_title('Feature Source Contribution to Enhanced Model', fontsize=14)
for bar, (idx, row) in zip(bars, source_shap.iterrows()):
ax.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.01,
f'n={int(row["count"])}', ha='center', fontsize=10)
plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/figures/source_contribution.png", dpi=150, bbox_inches='tight')
plt.close()
# Improvement metrics bar chart
fig, ax = plt.subplots(figsize=(10, 6))
metrics = ['AUC-ROC', 'AUC-PR', 'Accuracy', 'F1', 'Precision', 'Recall']
baseline_vals = [baseline_results['Ensemble'][m] for m in metrics]
enhanced_vals = [enhanced_results['Ensemble'][m] for m in metrics]
x = np.arange(len(metrics))
width = 0.35
bars1 = ax.bar(x - width/2, baseline_vals, width, label='Baseline', color='#95a5a6', alpha=0.8)
bars2 = ax.bar(x + width/2, enhanced_vals, width, label='Enhanced', color='#e74c3c', alpha=0.8)
ax.set_ylabel('Score', fontsize=12)
ax.set_title('Baseline vs Enhanced Model Metrics', fontsize=14)
ax.set_xticks(x)
ax.set_xticklabels(metrics, fontsize=10)
ax.legend(fontsize=11)
ax.set_ylim(0.4, 1.0)
ax.grid(True, alpha=0.3, axis='y')
# Add delta annotations
for i, (b, e) in enumerate(zip(baseline_vals, enhanced_vals)):
delta = e - b
if delta > 0:
ax.annotate(f'+{delta:.3f}', xy=(x[i] + width/2, e),
xytext=(0, 5), textcoords='offset points',
ha='center', fontsize=8, color='green', fontweight='bold')
plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/figures/metrics_comparison.png", dpi=150, bbox_inches='tight')
plt.close()
# ============================================================
# 10. SAVE ENHANCED CONFIG
# ============================================================
best_threshold = enhanced_results['Ensemble']['Threshold']
enhanced_config = {
'model_version': 'v2.0-enhanced',
'weights': {'xgboost': 0.4, 'lightgbm': 0.35, 'catboost': 0.25},
'optimal_threshold': float(best_threshold),
'feature_columns': enhanced_feature_cols,
'feature_sources': {
'original': [f for f in enhanced_feature_cols if f not in gottman_proxy_features + survival_features],
'gottman_proxy': gottman_proxy_features,
'survival_prior': survival_features,
},
'metrics': {
'auc_roc': float(enhanced_results['Ensemble']['AUC-ROC']),
'auc_pr': float(enhanced_results['Ensemble']['AUC-PR']),
'f1': float(enhanced_results['Ensemble']['F1']),
'accuracy': float(enhanced_results['Ensemble']['Accuracy']),
'brier': float(enhanced_results['Ensemble']['Brier']),
},
'improvements_over_baseline': improvements,
'data_sources': {
'primary': 'mstz/speeddating (1048 encounters)',
'gottman_behavioral': 'andrewmvd/divorce-prediction (170 couples, Kaggle)',
'survival_longitudinal': 'vedastro-org/15000-Famous-People-Marriage-Divorce-Info (14688 marriages)',
}
}
with open(f"{OUTPUT_DIR}/enhanced_config.json", "w") as f:
json.dump(enhanced_config, f, indent=2)
# ============================================================
# FINAL SUMMARY
# ============================================================
print("\n" + "=" * 70)
print("PHASE 3 — INTEGRATION COMPLETE: IMPROVEMENT SUMMARY")
print("=" * 70)
print(f"""
Model Enhancement: v1.0 (baseline) → v2.0 (enhanced)
=====================================================
Data Sources Added:
Phase 1: Gottman Behavioral Model (54 Q divorce predictors → {len(gottman_proxy_features)} proxy features)
Phase 2: Marriage Duration Survival (14,688 marriages → {len(survival_features)} prior features)
Feature Count: {len(original_feature_cols)}{len(enhanced_feature_cols)} (+{len(enhanced_feature_cols) - len(original_feature_cols)} new features)
PERFORMANCE COMPARISON (5-Fold CV, Ensemble):
""")
print(f"{'Metric':<20} {'v1.0 Baseline':>14} {'v2.0 Enhanced':>14} {'Change':>14}")
print("-" * 65)
for metric in ['AUC-ROC', 'AUC-PR', 'Brier', 'Accuracy', 'F1', 'Precision', 'Recall']:
b = improvements[metric]['baseline']
e = improvements[metric]['enhanced']
d = improvements[metric]['delta']
print(f"{metric:<20} {b:>14.4f} {e:>14.4f} {d:>+14.4f}")
print(f"""
Files Saved:
{OUTPUT_DIR}/enhanced_xgb.joblib
{OUTPUT_DIR}/enhanced_lgb.joblib
{OUTPUT_DIR}/enhanced_cat.cbm
{OUTPUT_DIR}/enhanced_config.json
{OUTPUT_DIR}/enhanced_feature_columns.joblib
{OUTPUT_DIR}/enhanced_shap_importance.csv
{OUTPUT_DIR}/figures/*.png
DONE!
""")