Chayanat commited on
Commit
bbce220
·
verified ·
1 Parent(s): 2e02ac0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -51
app.py CHANGED
@@ -1,62 +1,90 @@
1
- # เพิ่มในส่วนโหลดโมเดลและ hyperparameters
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  @st.cache_resource
3
  def load_model_resources():
4
  model_path = "model.pth"
5
  try:
6
  model, hyperparams = load_model(model_path)
7
-
8
- if model is None:
9
- st.error("ไม่สามารถโหลดโมเดลอัตโนมัติได้ จะใช้การกำหนดค่าด้วยตนเอง")
10
- # ถ้าโหลดไม่สำเร็จ ให้ใช้ค่าจากพารามิเตอร์เริ่มต้นที่คุณให้มา
11
- input_size = st.number_input("Input Size", min_value=1, max_value=100, value=10)
12
- hidden_size = st.number_input("Hidden Size", min_value=16, max_value=512, value=64)
13
- num_layers = st.number_input("Number of Layers", min_value=1, max_value=5, value=2)
14
- output_size = st.number_input("Output Size", min_value=1, max_value=100, value=1)
15
- dropout_rate = st.slider("Dropout Rate", min_value=0.0, max_value=0.2, value=0.1, step=0.01)
16
-
17
- hyperparams = {
18
- 'input_size': input_size,
19
- 'hidden_size': hidden_size,
20
- 'num_layers': num_layers,
21
- 'output_size': output_size,
22
- 'dropout_rate': dropout_rate,
23
- 'sequence_length': st.slider("Sequence Length", min_value=3, max_value=30, value=10)
24
- }
25
-
26
- # สร้างโมเดลใหม่ด้วยค่าที่กำหนดเอง
27
- model = GRUModel(
28
- input_size=hyperparams['input_size'],
29
- hidden_size=hyperparams['hidden_size'],
30
- num_layers=hyperparams['num_layers'],
31
- output_size=hyperparams['output_size'],
32
- dropout_rate=hyperparams['dropout_rate']
33
- )
34
-
35
- # โหลด scalers และ encoders
36
  numeric_scaler, label_encoders, y_scaler = load_scalers_and_encoders(model_path)
37
 
38
- # ตรวจสอว่ามีารโหลด scalers แ encoders สำเร็จหรไม่
39
- if numeric_scaler is None or label_encoders is None or y_scaler is None:
40
- st.warning("ไม่สามารถโหลด scalers และ encoders ได้ โปรดอัปโหลดไฟล์ใหม่หรือกำหนดค่าด้วยตนเอง")
41
-
42
- # ให้ผู้ใช้อัปโหลด scalers และ encoders ด้วยตนเอง (ตัวอย่าง)
43
- scaler_file = st.file_uploader("อัปโหลดไฟล์ Scaler (pickle)", type=["pkl"])
44
- if scaler_file is not None:
45
- import pickle
46
- numeric_scaler = pickle.load(scaler_file)
47
- else:
48
- from sklearn.preprocessing import MinMaxScaler
49
- numeric_scaler = MinMaxScaler()
50
- st.info("ใช้ MinMaxScaler เริ่มต้น")
51
-
52
- # สร้าง y_scaler ตัวใหม่
53
- y_scaler = MinMaxScaler()
54
-
55
- # สร้าง label_encoders ตัวใหม่
56
- from sklearn.preprocessing import LabelEncoder
57
- label_encoders = [LabelEncoder(), LabelEncoder()] # สมมติว่ามี categorical features 2 ตัว
58
 
59
  return model, hyperparams, numeric_scaler, label_encoders, y_scaler
60
  except Exception as e:
61
  st.error(f"เกิดข้อผิดพลาดในการโหลดโมเดล: {str(e)}")
62
- return None, None, None, None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import numpy as np
4
+ import pandas as pd
5
+ import time
6
+ import json
7
+ import io
8
+ import base64
9
+ import pickle
10
+ import matplotlib.pyplot as plt
11
+ from model import GRUModel, load_model, save_model_info
12
+ from utils import (init_firebase, get_data_from_firebase, save_data_to_firebase,
13
+ preprocess_data, create_sequences, load_scalers_and_encoders,
14
+ prepare_input_data, get_file_download_link,
15
+ save_scaler_to_bytes, save_encoders_to_bytes,
16
+ create_default_scaler, create_default_encoders)
17
+
18
+ # ตั้งค่าหน้าเว็บ
19
+ st.set_page_config(page_title="GRU Model for PM0.1 Prediction", layout="wide")
20
+ st.title("GRU Model for PM0.1 Prediction")
21
+
22
+ # สร้าง session state สำหรับเก็บข้อมูลระหว่าง rerun
23
+ if 'prediction_history' not in st.session_state:
24
+ st.session_state.prediction_history = []
25
+ st.session_state.timestamp_history = []
26
+ st.session_state.initialized = False
27
+ st.session_state.model_loaded = False
28
+ st.session_state.firebase_connected = False
29
+
30
+ # โหลดโมเดลและ hyperparameters
31
  @st.cache_resource
32
  def load_model_resources():
33
  model_path = "model.pth"
34
  try:
35
  model, hyperparams = load_model(model_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  numeric_scaler, label_encoders, y_scaler = load_scalers_and_encoders(model_path)
37
 
38
+ # บันทึข้อมูโมเดลเป็น JSON สำหรับการตรวจส
39
+ if model is not None and hyperparams is not None:
40
+ save_model_info(model, hyperparams, "model_info.json")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  return model, hyperparams, numeric_scaler, label_encoders, y_scaler
43
  except Exception as e:
44
  st.error(f"เกิดข้อผิดพลาดในการโหลดโมเดล: {str(e)}")
45
+ return None, None, None, None, None
46
+
47
+ # ส่วนของ sidebar สำหรับการตั้งค่า
48
+ with st.sidebar:
49
+ st.header("การตั้งค่า")
50
+
51
+ # การตั้งค่า Firebase
52
+ st.subheader("Firebase Configuration")
53
+
54
+ # ใช้ secrets หรือป้อนข้อมูลโดยตรง
55
+ use_secrets = st.checkbox("ใช้ Secrets", value=True,
56
+ help="เลือกว่าจะใช้ค่า Secrets หรือป้อนข้อมูลโดยตรง")
57
+
58
+ if use_secrets:
59
+ firebase_credentials = st.secrets.get("firebase_credentials", "{}")
60
+ firebase_url = st.secrets.get("firebase_url", "https://your-project-id.firebaseio.com")
61
+ else:
62
+ firebase_credentials = st.text_area("Firebase Service Account JSON",
63
+ value="", height=100,
64
+ help="ใส่ข้อมูล JSON ของ Service Account สำหรับเชื่อมต่อกับ Firebase")
65
+
66
+ firebase_url = st.text_input("Firebase Database URL",
67
+ value="https://your-project-id.firebaseio.com",
68
+ help="URL ของ Firebase Realtime Database")
69
+
70
+ input_path = st.text_input("Firebase Input Path",
71
+ value="input_data",
72
+ help="พาธสำหรับดึงข้อมูลจาก Firebase")
73
+
74
+ output_path = st.text_input("Firebase Output Path",
75
+ value="prediction_results",
76
+ help="พาธสำหรับบันทึกผลลัพธ์ลงใน Firebase")
77
+
78
+ # การตั้งค่าการทำนาย
79
+ st.subheader("Prediction Configuration")
80
+
81
+ auto_predict = st.checkbox("Auto-predict", value=False,
82
+ help="เปิดใช้การทำนายอัตโนมัติตามระยะเวลาที่กำหนด")
83
+
84
+ if auto_predict:
85
+ predict_interval = st.number_input("Prediction Interval (seconds)",
86
+ min_value=10, max_value=3600, value=60,
87
+ help="ความถี่ในการทำนายอัตโนมัติ (วินาที)")
88
+
89
+ # โหลดโมเดลและ hyperparameters
90
+ model, hyperparams, numeric_scaler, label_