EnYa32 commited on
Commit
e78a19b
·
verified ·
1 Parent(s): 9869fe1

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +155 -37
src/streamlit_app.py CHANGED
@@ -1,40 +1,158 @@
1
- import altair as alt
2
- import numpy as np
3
  import pandas as pd
4
  import streamlit as st
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import pandas as pd
2
  import streamlit as st
3
+ import joblib
4
+ from pathlib import Path
5
 
6
+ st.set_page_config(page_title='Star System Classification (LightGBM)', page_icon='🪐', layout='centered')
7
+
8
+ BASE_DIR = Path(__file__).resolve().parent
9
+
10
+ MODEL_PATH = BASE_DIR / 'lightgbm_model.pkl'
11
+ FEATURES_PATH = BASE_DIR / 'featurer.pkl' # you saved it with this name
12
+ PLANET_ENCODER_PATH = BASE_DIR / 'planet_encoder.pkl'
13
+ STAR_ENCODER_PATH = BASE_DIR / 'star_encoder.pkl'
14
+
15
+ # --- Fixed mapping you used in training ---
16
+ ACTIVITY_MAP = {'Low': 0, 'Medium': 1, 'High': 2}
17
+
18
+ # Optional: label names (edit if your competition uses different names)
19
+ LABEL_NAMES = {
20
+ 0: 'Habitable',
21
+ 1: 'Young',
22
+ 2: 'Old',
23
+ 3: 'Exotic'
24
+ }
25
+
26
+ @st.cache_resource
27
+ def load_artifacts():
28
+ missing = [p.name for p in [MODEL_PATH, FEATURES_PATH, PLANET_ENCODER_PATH, STAR_ENCODER_PATH] if not p.exists()]
29
+ if missing:
30
+ raise FileNotFoundError(
31
+ 'Missing files in repo root: ' + ', '.join(missing) +
32
+ '\n\nMake sure these files are in the same folder as app.py:\n'
33
+ '- lightgbm_model.pkl\n- featurer.pkl\n- planet_encoder.pkl\n- star_encoder.pkl'
34
+ )
35
+
36
+ model = joblib.load(MODEL_PATH)
37
+ features = joblib.load(FEATURES_PATH)
38
+ le_planet = joblib.load(PLANET_ENCODER_PATH)
39
+ le_star = joblib.load(STAR_ENCODER_PATH)
40
+ return model, features, le_planet, le_star
41
+
42
+ def safe_transform(le, value: str, col_name: str) -> int:
43
+ """Transform a single category value with a saved LabelEncoder.
44
+ If unseen value appears, show a helpful error."""
45
+ try:
46
+ return int(le.transform([value])[0])
47
+ except Exception:
48
+ known = list(getattr(le, 'classes_', []))
49
+ st.error(f'Unknown category for {col_name}: {value}. Known values: {known}')
50
+ st.stop()
51
+
52
+ model, FEATURES, le_planet, le_star = load_artifacts()
53
+
54
+ st.title('🪐 Star System Classification (LightGBM)')
55
+ st.write('Predict the star system type using 10 astrophysical measurements (multiclass).')
56
+
57
+ with st.expander('ℹ️ Required files in this folder', expanded=False):
58
+ st.code(
59
+ 'app.py\n'
60
+ 'lightgbm_model.pkl\n'
61
+ 'featurer.pkl\n'
62
+ 'planet_encoder.pkl\n'
63
+ 'star_encoder.pkl\n'
64
+ 'requirements.txt'
65
+ )
66
+
67
+ st.subheader('Enter feature values')
68
+
69
+ # --- Inputs ---
70
+ # Numeric
71
+ star_size = st.number_input('star_size', min_value=0.0, value=1.0, step=0.01)
72
+ star_brightness = st.number_input('star_brightness', min_value=0.0, value=1.2, step=0.01)
73
+ distance_from_earth = st.number_input('distance_from_earth', min_value=0.0, value=90.0, step=1.0)
74
+ star_mass = st.number_input('star_mass', min_value=0.0, value=1.3, step=0.01)
75
+ metallicity = st.number_input('metallicity', value=0.02, step=0.001, format='%.4f')
76
+
77
+ # Discrete numeric / encoded-like
78
+ galaxy_region = st.selectbox('galaxy_region', options=[0, 1, 2], index=1)
79
+ galaxy_type = st.selectbox('galaxy_type', options=[0, 1, 2], index=0)
80
+
81
+ # Categorical (original strings)
82
+ star_spectral_class = st.selectbox(
83
+ 'star_spectral_class',
84
+ options=list(le_star.classes_),
85
+ index=0
86
+ )
87
+
88
+ planet_configuration = st.selectbox(
89
+ 'planet_configuration',
90
+ options=list(le_planet.classes_),
91
+ index=0
92
+ )
93
+
94
+ stellar_activity_class = st.selectbox(
95
+ 'stellar_activity_class',
96
+ options=['Low', 'Medium', 'High'],
97
+ index=0
98
+ )
99
+
100
+ # --- Build row in the ORIGINAL feature space ---
101
+ row = {
102
+ 'star_size': float(star_size),
103
+ 'star_brightness': float(star_brightness),
104
+ 'galaxy_region': int(galaxy_region),
105
+ 'distance_from_earth': float(distance_from_earth),
106
+ 'galaxy_type': int(galaxy_type),
107
+ 'star_spectral_class': star_spectral_class,
108
+ 'planet_configuration': planet_configuration,
109
+ 'stellar_activity_class': stellar_activity_class,
110
+ 'star_mass': float(star_mass),
111
+ 'metallicity': float(metallicity),
112
+ }
113
+
114
+ # --- Apply same preprocessing as training ---
115
+ # Mapping for activity (ordinal)
116
+ row['stellar_activity_class'] = ACTIVITY_MAP[row['stellar_activity_class']]
117
+
118
+ # LabelEncoders for the other two categorical columns
119
+ row['planet_configuration'] = safe_transform(le_planet, planet_configuration, 'planet_configuration')
120
+ row['star_spectral_class'] = safe_transform(le_star, star_spectral_class, 'star_spectral_class')
121
+
122
+ # Make DataFrame and enforce correct column order
123
+ X_input = pd.DataFrame([row])
124
+
125
+ # Ensure all expected feature columns exist
126
+ missing_cols = [c for c in FEATURES if c not in X_input.columns]
127
+ extra_cols = [c for c in X_input.columns if c not in FEATURES]
128
+ if missing_cols:
129
+ st.error(f'Missing columns for model: {missing_cols}')
130
+ st.stop()
131
+ if extra_cols:
132
+ # Not an error, but we will drop extras to be safe
133
+ X_input = X_input.drop(columns=extra_cols)
134
+
135
+ X_input = X_input[FEATURES]
136
+
137
+ st.divider()
138
+
139
+ col1, col2 = st.columns(2)
140
+
141
+ with col1:
142
+ if st.button('🔮 Predict', use_container_width=True):
143
+ pred = model.predict(X_input)[0]
144
+ pred_int = int(pred)
145
+ label = LABEL_NAMES.get(pred_int, str(pred_int))
146
+ st.success(f'Prediction: **{label}** (class {pred_int})')
147
+
148
+ with col2:
149
+ if st.button('📊 Predict probabilities', use_container_width=True):
150
+ if hasattr(model, 'predict_proba'):
151
+ proba = model.predict_proba(X_input)[0]
152
+ proba_df = pd.DataFrame({'class': list(range(len(proba))), 'probability': proba}).sort_values('probability', ascending=False)
153
+ proba_df['label'] = proba_df['class'].map(LABEL_NAMES).fillna(proba_df['class'].astype(str))
154
+ st.dataframe(proba_df[['label', 'class', 'probability']], use_container_width=True)
155
+ else:
156
+ st.warning('This model does not support predict_proba().')
157
+
158
+ st.caption('Tip: If predictions look wrong, ensure the same encoders and feature order are used as during training.')