EnYa32 commited on
Commit
461d207
·
verified ·
1 Parent(s): 3cc52db

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +169 -36
src/streamlit_app.py CHANGED
@@ -1,40 +1,173 @@
1
- import altair as alt
 
2
  import numpy as np
3
  import pandas as pd
4
  import streamlit as st
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import joblib
3
  import numpy as np
4
  import pandas as pd
5
  import streamlit as st
6
+ from pathlib import Path
7
 
8
+ # -------------------------
9
+ # Page config
10
+ # -------------------------
11
+ st.set_page_config(
12
+ page_title='Sales Forecast (LightGBM)',
13
+ page_icon='📈',
14
+ layout='centered'
15
+ )
16
+
17
+ st.title('📈 Sales Forecast (LightGBM)')
18
+ st.write('Predict **num_sold** using a trained LightGBM model + saved encoders and preprocessing.')
19
+
20
+ BASE_DIR = Path(__file__).resolve().parent
21
+
22
+ MODEL_PATH = BASE_DIR / 'model_lgbm.pkl'
23
+ FEATURES_PATH = BASE_DIR / 'feature_names.pkl'
24
+ ENCODERS_PATH = BASE_DIR / 'encoders.pkl'
25
+ FILLMAP_PATH = BASE_DIR / 'fill_map.pkl'
26
+ META_PATH = BASE_DIR / 'meta.json'
27
+
28
+
29
+ @st.cache_resource
30
+ def load_assets():
31
+ if not MODEL_PATH.exists():
32
+ raise FileNotFoundError(f'Missing {MODEL_PATH.name} (put it next to app.py).')
33
+ if not FEATURES_PATH.exists():
34
+ raise FileNotFoundError(f'Missing {FEATURES_PATH.name} (put it next to app.py).')
35
+ if not ENCODERS_PATH.exists():
36
+ raise FileNotFoundError(f'Missing {ENCODERS_PATH.name} (put it next to app.py).')
37
+ if not FILLMAP_PATH.exists():
38
+ raise FileNotFoundError(f'Missing {FILLMAP_PATH.name} (put it next to app.py).')
39
+
40
+ model = joblib.load(MODEL_PATH)
41
+ features = joblib.load(FEATURES_PATH)
42
+ encoders = joblib.load(ENCODERS_PATH)
43
+ fill_map = joblib.load(FILLMAP_PATH)
44
+
45
+ meta = None
46
+ if META_PATH.exists():
47
+ with open(META_PATH, 'r') as f:
48
+ meta = json.load(f)
49
+
50
+ return model, features, encoders, fill_map, meta
51
+
52
+
53
+ model, FEATURES, encoders, fill_map, meta = load_assets()
54
+
55
+ with st.expander('ℹ️ Model info'):
56
+ if meta:
57
+ st.write(meta)
58
+ else:
59
+ st.write('No meta.json found.')
60
+
61
+
62
+ # -------------------------
63
+ # Helpers
64
+ # -------------------------
65
+ def make_date_features(date_value: pd.Timestamp) -> dict:
66
+ # date_value is a Timestamp
67
+ year = int(date_value.year)
68
+ month = int(date_value.month)
69
+ week = int(date_value.isocalendar().week)
70
+ dayofweek = int(date_value.dayofweek) # Monday=0
71
+ is_weekend = int(dayofweek >= 5)
72
+ dayofyear = int(date_value.dayofyear)
73
+
74
+ return {
75
+ 'year': year,
76
+ 'month': month,
77
+ 'week': week,
78
+ 'dayofweek': dayofweek,
79
+ 'is_weekend': is_weekend,
80
+ 'dayofyear': dayofyear
81
+ }
82
+
83
+
84
+ def safe_encode(col_name: str, value: str) -> int:
85
+ # If unseen label appears, fall back to the most frequent label (index 0) or safe default.
86
+ le = encoders.get(col_name)
87
+ if le is None:
88
+ return 0
89
+
90
+ classes = set(le.classes_.astype(str))
91
+ v = str(value)
92
+
93
+ if v in classes:
94
+ return int(le.transform([v])[0])
95
+
96
+ # fallback: use first known class
97
+ return int(le.transform([str(le.classes_[0])])[0])
98
+
99
+
100
+ # -------------------------
101
+ # UI Inputs
102
+ # -------------------------
103
+ st.subheader('🧾 Input')
104
+
105
+ date_in = st.date_input('Date', value=pd.to_datetime('2019-01-01'))
106
+ country_in = st.text_input('Country', value='Finland')
107
+ store_in = st.text_input('Store', value='KaggleMart')
108
+ product_in = st.text_input('Product', value='Kaggle Mug')
109
+
110
+ st.markdown('---')
111
+ st.subheader('⏳ Lag features')
112
+
113
+ use_manual_lags = st.checkbox('Enter lag values manually (recommended if you know them)', value=False)
114
+
115
+ default_lag_364 = float(fill_map.get('lag_364', 0.0))
116
+ default_lag_365 = float(fill_map.get('lag_365', 0.0))
117
+ default_lag_371 = float(fill_map.get('lag_371', 0.0))
118
+
119
+ if use_manual_lags:
120
+ lag_364 = st.number_input('lag_364', value=default_lag_364)
121
+ lag_365 = st.number_input('lag_365', value=default_lag_365)
122
+ lag_371 = st.number_input('lag_371', value=default_lag_371)
123
+ else:
124
+ st.write('Using default lag values (from training medians):')
125
+ st.write({
126
+ 'lag_364': default_lag_364,
127
+ 'lag_365': default_lag_365,
128
+ 'lag_371': default_lag_371
129
+ })
130
+ lag_364, lag_365, lag_371 = default_lag_364, default_lag_365, default_lag_371
131
+
132
+
133
+ # -------------------------
134
+ # Predict
135
+ # -------------------------
136
+ if st.button('Predict'):
137
+ date_ts = pd.to_datetime(date_in)
138
+
139
+ d_feats = make_date_features(date_ts)
140
+
141
+ row = {}
142
+ row.update(d_feats)
143
+
144
+ row['lag_364'] = float(lag_364)
145
+ row['lag_365'] = float(lag_365)
146
+ row['lag_371'] = float(lag_371)
147
+
148
+ row['country'] = safe_encode('country', country_in)
149
+ row['store'] = safe_encode('store', store_in)
150
+ row['product'] = safe_encode('product', product_in)
151
+
152
+ X = pd.DataFrame([row])
153
+
154
+ # Ensure all FEATURES exist and order is correct
155
+ for c in FEATURES:
156
+ if c not in X.columns:
157
+ # numeric fallback from fill_map
158
+ X[c] = fill_map.get(c, 0.0)
159
+
160
+ X = X[FEATURES].copy()
161
+
162
+ # Fill numeric NaNs just in case
163
+ for c in X.columns:
164
+ if X[c].isna().any():
165
+ X[c] = X[c].fillna(fill_map.get(c, 0.0))
166
+
167
+ pred = model.predict(X)[0]
168
+
169
+ st.success(f'✅ Predicted num_sold: **{pred:.2f}**')
170
+ st.caption('Note: Lag features heavily influence the forecast. If you can compute real lags, the prediction will be more accurate.')
171
+
172
+ with st.expander('Show model input vector'):
173
+ st.dataframe(X)