| |
| import numpy as np |
| import pandas as pd |
| import streamlit as st |
| import joblib |
| from pathlib import Path |
|
|
| st.set_page_config(page_title="Employee Attrition Predictor (XGBoost)", page_icon="🏢", layout="centered") |
|
|
| BASE_DIR = Path(__file__).resolve().parent |
| MODEL_PATH = BASE_DIR / "xgb_model.pkl" |
| FEATURES_PATH = BASE_DIR / "feature_names.pkl" |
| THRESHOLD_PATH = BASE_DIR / "threshold.pkl" |
|
|
| |
| @st.cache_resource |
| def load_artifacts(): |
| missing = [p.name for p in [MODEL_PATH, FEATURES_PATH, THRESHOLD_PATH] if not p.exists()] |
| if missing: |
| raise FileNotFoundError( |
| f"Missing files: {missing}. Put them in the repo root (same folder as app.py)." |
| ) |
|
|
| model = joblib.load(MODEL_PATH) |
| feature_names = joblib.load(FEATURES_PATH) |
| threshold = joblib.load(THRESHOLD_PATH) |
|
|
| |
| if not isinstance(feature_names, (list, tuple)) or len(feature_names) == 0: |
| raise ValueError("feature_names.pkl must be a non-empty list of column names.") |
| threshold = float(threshold) |
|
|
| return model, list(feature_names), threshold |
|
|
|
|
| model, feature_names, threshold = load_artifacts() |
|
|
| st.title("🏢 Employee Attrition Predictor (XGBoost)") |
| st.caption("Predicts the probability that an employee will leave (Attrition=1).") |
|
|
| with st.expander("Model info"): |
| st.write(f"**Model:** XGBoost (saved as `xgb_model.pkl`)") |
| st.write(f"**Number of features:** {len(feature_names)}") |
| st.write(f"**Decision threshold:** {threshold:.2f}") |
| st.write("Tip: Probability is used for Kaggle submissions (ROC-AUC metric).") |
|
|
|
|
| |
| def build_input_from_form(form_values: dict) -> pd.DataFrame: |
| """ |
| Create a single-row dataframe aligned to training feature order. |
| Any missing one-hot columns are filled with 0. |
| """ |
| X = pd.DataFrame([form_values]) |
| X = X.reindex(columns=feature_names, fill_value=0) |
| return X |
|
|
|
|
| def predict_single(X_one_row: pd.DataFrame): |
| proba = float(model.predict_proba(X_one_row)[:, 1][0]) |
| pred = int(proba >= threshold) |
| return pred, proba |
|
|
|
|
| |
| mode = st.radio("Choose input method", ["Single prediction (form)", "Batch prediction (CSV upload)"], horizontal=True) |
|
|
| |
| if mode == "Single prediction (form)": |
| st.subheader("Single prediction") |
|
|
| |
| |
| |
|
|
| |
| numeric_cols = [c for c in feature_names if not any(c.startswith(prefix) for prefix in [ |
| "BusinessTravel_", "Department_", "EducationField_", "Gender_", "JobRole_", "MaritalStatus_" |
| ])] |
|
|
| |
| col_left, col_right = st.columns(2) |
| form_values = {} |
|
|
| with col_left: |
| st.markdown("**Numeric / ordinal inputs**") |
| |
| preferred_numeric = [ |
| "Age", "DistanceFromHome", "Education", "EnvironmentSatisfaction", "HourlyRate", |
| "JobInvolvement", "JobLevel", "JobSatisfaction", "MonthlyIncome", "MonthlyRate", |
| "NumCompaniesWorked", "PercentSalaryHike", "PerformanceRating", "RelationshipSatisfaction", |
| "StockOptionLevel", "TotalWorkingYears", "TrainingTimesLastYear", "WorkLifeBalance", |
| "YearsAtCompany", "YearsInCurrentRole", "YearsSinceLastPromotion", "YearsWithCurrManager", |
| "OverTime", |
| |
| "tenure_ratio", "promotion_gap", "manager_stability", "income_per_level", |
| "time_experience", "no_promotion", "income_experience", |
| ] |
|
|
| |
| shown_numeric = [c for c in preferred_numeric if c in feature_names] or numeric_cols[:18] |
|
|
| for c in shown_numeric: |
| if c == "OverTime": |
| form_values[c] = st.selectbox("OverTime (0=No, 1=Yes)", [0, 1], index=0) |
| else: |
| |
| form_values[c] = st.number_input(c, value=0.0, step=1.0) |
|
|
| with col_right: |
| st.markdown("**Categorical one-hot selections**") |
| st.caption("Select one option per group. Unselected groups remain 0 (baseline category).") |
|
|
| |
| groups = { |
| "BusinessTravel": [c for c in feature_names if c.startswith("BusinessTravel_")], |
| "Department": [c for c in feature_names if c.startswith("Department_")], |
| "EducationField": [c for c in feature_names if c.startswith("EducationField_")], |
| "Gender": [c for c in feature_names if c.startswith("Gender_")], |
| "JobRole": [c for c in feature_names if c.startswith("JobRole_")], |
| "MaritalStatus": [c for c in feature_names if c.startswith("MaritalStatus_")], |
| } |
|
|
| |
| for gcols in groups.values(): |
| for c in gcols: |
| form_values[c] = 0 |
|
|
| for gname, gcols in groups.items(): |
| if not gcols: |
| continue |
| |
| labels = ["(baseline / dropped category)"] + [c.split(f"{gname}_", 1)[1] for c in gcols] |
| choice = st.selectbox(gname, labels, index=0) |
| if choice != "(baseline / dropped category)": |
| |
| target_col = f"{gname}_{choice}" |
| if target_col in form_values: |
| form_values[target_col] = 1 |
|
|
| |
| for c in feature_names: |
| form_values.setdefault(c, 0) |
|
|
| X_one = build_input_from_form(form_values) |
|
|
| if st.button("Predict", type="primary"): |
| pred, proba = predict_single(X_one) |
|
|
| st.metric("Attrition probability (P=1)", f"{proba:.3f}") |
| if pred == 1: |
| st.error(f"Prediction: Attrition = 1 (Leave) | threshold={threshold:.2f}") |
| else: |
| st.success(f"Prediction: Attrition = 0 (Stay) | threshold={threshold:.2f}") |
|
|
| with st.expander("Show input vector (aligned features)"): |
| st.dataframe(X_one) |
|
|
| |
| else: |
| st.subheader("Batch prediction (CSV upload)") |
| st.write("Upload a CSV that already matches the training feature format (after preprocessing/one-hot).") |
| st.caption("If your CSV is raw, preprocess it the same way as in your notebook before uploading.") |
|
|
| uploaded = st.file_uploader("Upload CSV", type=["csv"]) |
| if uploaded is not None: |
| df_in = pd.read_csv(uploaded) |
|
|
| |
| df_in = df_in.drop(columns=["Attrition"], errors="ignore") |
|
|
| |
| Xb = df_in.reindex(columns=feature_names, fill_value=0) |
|
|
| probs = model.predict_proba(Xb)[:, 1] |
| preds = (probs >= threshold).astype(int) |
|
|
| out = df_in.copy() |
| out["Attrition_proba"] = probs |
| out["Attrition_pred"] = preds |
|
|
| st.success(f"Predicted {len(out)} rows.") |
| st.dataframe(out.head(20)) |
|
|
| csv_bytes = out.to_csv(index=False).encode("utf-8") |
| st.download_button("Download predictions CSV", data=csv_bytes, file_name="predictions.csv", mime="text/csv") |
|
|
| st.divider() |
| st.caption("Built with Streamlit • Model: XGBoost • Metric focus: ROC-AUC") |
|
|