Sairamr46's picture
Upload predict.py with huggingface_hub
f359029 verified
"""Inference script for price increase churn prediction."""
import pandas as pd
import numpy as np
import joblib
from features import SubscriptionFeatureEngineer
def predict_churn_risk(customer_data: pd.DataFrame,
model_path: str = 'price_increase_churn_model.pkl',
price_increase_pct: float = 0.15) -> pd.DataFrame:
"""
Predict cancellation risk for customers after a price increase.
Args:
customer_data: DataFrame with columns matching your subscription data
model_path: Path to saved model
price_increase_pct: Percentage price increase (e.g., 0.15 for 15%)
Returns:
DataFrame with churn probability and risk tier
"""
# Engineer features
engineer = SubscriptionFeatureEngineer(price_increase_pct=price_increase_pct)
df = engineer.transform(customer_data.copy())
# Load model
pipeline = joblib.load(model_path)
# Predict
churn_prob = pipeline.predict_proba(df)[:, 1]
# Risk tiers
def risk_tier(p):
if p >= 0.7:
return 'High Risk'
elif p >= 0.4:
return 'Medium Risk'
else:
return 'Low Risk'
results = pd.DataFrame({
'customer_id': customer_data.get('Customer ID', customer_data.index),
'churn_probability': churn_prob,
'risk_tier': [risk_tier(p) for p in churn_prob],
'predicted_cancel_90d': churn_prob >= 0.5
})
return results.sort_values('churn_probability', ascending=False)
def batch_predict_example():
"""Example usage with sample data."""
sample = pd.DataFrame({
'Customer ID': ['C001', 'C002', 'C003'],
'Tenure in Months': [3, 36, 72],
'Contract': ['Month-to-Month', 'One Year', 'Two Year'],
'Monthly Charge': [95.0, 65.0, 45.0],
'Total Charges': [285.0, 2340.0, 3240.0],
'Total Revenue': [300.0, 2500.0, 3500.0],
'CLTV': [2000, 4500, 6000],
'Age': [25, 45, 55],
'Satisfaction Score': [1, 3, 5],
'Churn Score': [85, 50, 20],
'Senior Citizen': [0, 0, 1],
'Dependents': [0, 1, 0],
'Married': [0, 1, 1],
'Partner': [0, 1, 1],
'Phone Service': [1, 1, 1],
'Internet Service': [1, 1, 0],
'Online Security': [0, 1, 1],
'Online Backup': [0, 1, 0],
'Device Protection Plan': [0, 1, 1],
'Premium Tech Support': [0, 0, 1],
'Streaming TV': [1, 0, 0],
'Streaming Movies': [1, 0, 0],
'Streaming Music': [0, 0, 0],
'Multiple Lines': [1, 0, 0],
'Unlimited Data': [1, 0, 0],
'Number of Dependents': [0, 2, 0],
'Number of Referrals': [0, 2, 5],
'Avg Monthly GB Download': [50, 20, 5],
'Avg Monthly Long Distance Charges': [30.0, 10.0, 5.0],
'Total Long Distance Charges': [90.0, 360.0, 360.0],
'Total Extra Data Charges': [10, 0, 0],
'Total Refunds': [0.0, 0.0, 0.0],
'Population': [50000, 30000, 10000],
'Quarter': ['Q3', 'Q3', 'Q3'],
'Offer': ['None', 'Offer A', 'None'],
'Internet Type': ['Fiber Optic', 'DSL', 'None'],
'Gender': ['Male', 'Female', 'Male'],
'Payment Method': ['Bank Withdrawal', 'Credit Card', 'Bank Withdrawal'],
'Paperless Billing': [1, 0, 1],
'Under 30': [1, 0, 0],
'Referred a Friend': [0, 1, 1],
'City': ['CityA', 'CityB', 'CityC'],
'State': ['CA', 'NY', 'TX'],
'Zip Code': ['90210', '10001', '77001'],
'Country': ['United States'] * 3,
'Lat Long': ['0,0', '0,0', '0,0'],
'Latitude': [0.0, 0.0, 0.0],
'Longitude': [0.0, 0.0, 0.0]
})
results = predict_churn_risk(sample,
model_path='/app/price_increase_model/price_increase_churn_model.pkl',
price_increase_pct=0.15)
print("\nSample Predictions:")
print(results.to_string(index=False))
return results
if __name__ == '__main__':
batch_predict_example()