EnYa32 commited on
Commit
09d5f59
·
verified ·
1 Parent(s): e2eb1bd

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +185 -36
src/streamlit_app.py CHANGED
@@ -1,40 +1,189 @@
1
- import altair as alt
2
  import numpy as np
3
  import pandas as pd
4
  import streamlit as st
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
  import numpy as np
3
  import pandas as pd
4
  import streamlit as st
5
+ import joblib
6
+ from pathlib import Path
7
 
8
+ st.set_page_config(page_title="Employee Attrition Predictor (XGBoost)", page_icon="🏢", layout="centered")
9
+
10
+ BASE_DIR = Path(__file__).resolve().parent
11
+ MODEL_PATH = BASE_DIR / "xgb_model.pkl"
12
+ FEATURES_PATH = BASE_DIR / "feature_names.pkl"
13
+ THRESHOLD_PATH = BASE_DIR / "threshold.pkl"
14
+
15
+ # --------- Load artifacts ---------
16
+ @st.cache_resource
17
+ def load_artifacts():
18
+ missing = [p.name for p in [MODEL_PATH, FEATURES_PATH, THRESHOLD_PATH] if not p.exists()]
19
+ if missing:
20
+ raise FileNotFoundError(
21
+ f"Missing files: {missing}. Put them in the repo root (same folder as app.py)."
22
+ )
23
+
24
+ model = joblib.load(MODEL_PATH)
25
+ feature_names = joblib.load(FEATURES_PATH)
26
+ threshold = joblib.load(THRESHOLD_PATH)
27
+
28
+ # Safety
29
+ if not isinstance(feature_names, (list, tuple)) or len(feature_names) == 0:
30
+ raise ValueError("feature_names.pkl must be a non-empty list of column names.")
31
+ threshold = float(threshold)
32
+
33
+ return model, list(feature_names), threshold
34
+
35
+
36
+ model, feature_names, threshold = load_artifacts()
37
+
38
+ st.title("🏢 Employee Attrition Predictor (XGBoost)")
39
+ st.caption("Predicts the probability that an employee will leave (Attrition=1).")
40
+
41
+ with st.expander("Model info"):
42
+ st.write(f"**Model:** XGBoost (saved as `xgb_model.pkl`)")
43
+ st.write(f"**Number of features:** {len(feature_names)}")
44
+ st.write(f"**Decision threshold:** {threshold:.2f}")
45
+ st.write("Tip: Probability is used for Kaggle submissions (ROC-AUC metric).")
46
+
47
+
48
+ # --------- Helpers ---------
49
+ def build_input_from_form(form_values: dict) -> pd.DataFrame:
50
+ """
51
+ Create a single-row dataframe aligned to training feature order.
52
+ Any missing one-hot columns are filled with 0.
53
+ """
54
+ X = pd.DataFrame([form_values])
55
+ X = X.reindex(columns=feature_names, fill_value=0)
56
+ return X
57
+
58
+
59
+ def predict_single(X_one_row: pd.DataFrame):
60
+ proba = float(model.predict_proba(X_one_row)[:, 1][0])
61
+ pred = int(proba >= threshold)
62
+ return pred, proba
63
+
64
+
65
+ # --------- Input mode selection ---------
66
+ mode = st.radio("Choose input method", ["Single prediction (form)", "Batch prediction (CSV upload)"], horizontal=True)
67
+
68
+ # --------- Single prediction (manual form) ---------
69
+ if mode == "Single prediction (form)":
70
+ st.subheader("Single prediction")
71
+
72
+ # Minimal & robust form: user enters main numeric features + selects a few one-hot options.
73
+ # Because your training features are one-hot, we provide a simple way to set them.
74
+ # Any feature not set will default to 0.
75
+
76
+ # Detect numeric-ish columns (not perfect, but good for UI)
77
+ numeric_cols = [c for c in feature_names if not any(c.startswith(prefix) for prefix in [
78
+ "BusinessTravel_", "Department_", "EducationField_", "Gender_", "JobRole_", "MaritalStatus_"
79
+ ])]
80
+
81
+ # Split numeric cols into two columns for nicer UI
82
+ col_left, col_right = st.columns(2)
83
+ form_values = {}
84
+
85
+ with col_left:
86
+ st.markdown("**Numeric / ordinal inputs**")
87
+ # Provide a curated list of common HR numeric columns if present
88
+ preferred_numeric = [
89
+ "Age", "DistanceFromHome", "Education", "EnvironmentSatisfaction", "HourlyRate",
90
+ "JobInvolvement", "JobLevel", "JobSatisfaction", "MonthlyIncome", "MonthlyRate",
91
+ "NumCompaniesWorked", "PercentSalaryHike", "PerformanceRating", "RelationshipSatisfaction",
92
+ "StockOptionLevel", "TotalWorkingYears", "TrainingTimesLastYear", "WorkLifeBalance",
93
+ "YearsAtCompany", "YearsInCurrentRole", "YearsSinceLastPromotion", "YearsWithCurrManager",
94
+ "OverTime", # might be 0/1
95
+ # engineered features (if you used these names)
96
+ "tenure_ratio", "promotion_gap", "manager_stability", "income_per_level",
97
+ "time_experience", "no_promotion", "income_experience",
98
+ ]
99
+
100
+ # Use preferred list if exists, else fallback to numeric_cols
101
+ shown_numeric = [c for c in preferred_numeric if c in feature_names] or numeric_cols[:18]
102
+
103
+ for c in shown_numeric:
104
+ if c == "OverTime":
105
+ form_values[c] = st.selectbox("OverTime (0=No, 1=Yes)", [0, 1], index=0)
106
+ else:
107
+ # default 0; user can adjust
108
+ form_values[c] = st.number_input(c, value=0.0, step=1.0)
109
+
110
+ with col_right:
111
+ st.markdown("**Categorical one-hot selections**")
112
+ st.caption("Select one option per group. Unselected groups remain 0 (baseline category).")
113
+
114
+ # Map groups to their one-hot columns
115
+ groups = {
116
+ "BusinessTravel": [c for c in feature_names if c.startswith("BusinessTravel_")],
117
+ "Department": [c for c in feature_names if c.startswith("Department_")],
118
+ "EducationField": [c for c in feature_names if c.startswith("EducationField_")],
119
+ "Gender": [c for c in feature_names if c.startswith("Gender_")],
120
+ "JobRole": [c for c in feature_names if c.startswith("JobRole_")],
121
+ "MaritalStatus": [c for c in feature_names if c.startswith("MaritalStatus_")],
122
+ }
123
+
124
+ # Initialize all one-hot group columns to 0
125
+ for gcols in groups.values():
126
+ for c in gcols:
127
+ form_values[c] = 0
128
+
129
+ for gname, gcols in groups.items():
130
+ if not gcols:
131
+ continue
132
+ # Convert one-hot col name to label
133
+ labels = ["(baseline / dropped category)"] + [c.split(f"{gname}_", 1)[1] for c in gcols]
134
+ choice = st.selectbox(gname, labels, index=0)
135
+ if choice != "(baseline / dropped category)":
136
+ # Find matching one-hot column and set to 1
137
+ target_col = f"{gname}_{choice}"
138
+ if target_col in form_values:
139
+ form_values[target_col] = 1
140
+
141
+ # Ensure all missing features exist
142
+ for c in feature_names:
143
+ form_values.setdefault(c, 0)
144
+
145
+ X_one = build_input_from_form(form_values)
146
+
147
+ if st.button("Predict", type="primary"):
148
+ pred, proba = predict_single(X_one)
149
+
150
+ st.metric("Attrition probability (P=1)", f"{proba:.3f}")
151
+ if pred == 1:
152
+ st.error(f"Prediction: Attrition = 1 (Leave) | threshold={threshold:.2f}")
153
+ else:
154
+ st.success(f"Prediction: Attrition = 0 (Stay) | threshold={threshold:.2f}")
155
+
156
+ with st.expander("Show input vector (aligned features)"):
157
+ st.dataframe(X_one)
158
+
159
+ # --------- Batch prediction (CSV upload) ---------
160
+ else:
161
+ st.subheader("Batch prediction (CSV upload)")
162
+ st.write("Upload a CSV that already matches the training feature format (after preprocessing/one-hot).")
163
+ st.caption("If your CSV is raw, preprocess it the same way as in your notebook before uploading.")
164
+
165
+ uploaded = st.file_uploader("Upload CSV", type=["csv"])
166
+ if uploaded is not None:
167
+ df_in = pd.read_csv(uploaded)
168
+
169
+ # Drop target if user included it
170
+ df_in = df_in.drop(columns=["Attrition"], errors="ignore")
171
+
172
+ # Align columns to training feature set
173
+ Xb = df_in.reindex(columns=feature_names, fill_value=0)
174
+
175
+ probs = model.predict_proba(Xb)[:, 1]
176
+ preds = (probs >= threshold).astype(int)
177
+
178
+ out = df_in.copy()
179
+ out["Attrition_proba"] = probs
180
+ out["Attrition_pred"] = preds
181
+
182
+ st.success(f"Predicted {len(out)} rows.")
183
+ st.dataframe(out.head(20))
184
+
185
+ csv_bytes = out.to_csv(index=False).encode("utf-8")
186
+ st.download_button("Download predictions CSV", data=csv_bytes, file_name="predictions.csv", mime="text/csv")
187
+
188
+ st.divider()
189
+ st.caption("Built with Streamlit • Model: XGBoost • Metric focus: ROC-AUC")