Chayanat commited on
Commit
2e02ac0
·
verified ·
1 Parent(s): a52baac

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +293 -0
utils.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import torch
4
+ import json
5
+ import time
6
+ import firebase_admin
7
+ from firebase_admin import credentials, db
8
+ from sklearn.preprocessing import MinMaxScaler, LabelEncoder
9
+ import pickle
10
+ import io
11
+ import base64
12
+
13
+ def prepare_input_data(data, sequence_length, input_size):
14
+ """
15
+ เตรียมข้อมูลนำเข้าให้อยู่ในรูปแบบที่เหมาะสมสำหรับโมเดล GRU
16
+
17
+ Args:
18
+ data (numpy.ndarray): ข้อมูลนำเข้า
19
+ sequence_length (int): ความยาวของลำดับเวลา
20
+ input_size (int): จำนวนคุณลักษณะนำเข้า
21
+
22
+ Returns:
23
+ tensor: ข้อมูลในรูปแบบ [batch_size, sequence_length, input_size]
24
+ """
25
+ # ตรวจสอบรูปร่างของข้อมูล
26
+ if len(data.shape) == 1: # ถ้าเป็น 1D array
27
+ # สมมติว่ามี input_size features ในแต่ละ timestep
28
+ data = data.reshape(-1, input_size)
29
+
30
+ # ตรวจสอบว่ามีข้อมูลพอสำหรับ sequence_length หรือไม่
31
+ if data.shape[0] < sequence_length:
32
+ # ถ้าไม่พอ ให้เพิ่มข้อมูลโดยการทำซ้ำข้อมูลแรก
33
+ repeats_needed = sequence_length - data.shape[0]
34
+ first_row = np.tile(data[0:1], (repeats_needed, 1))
35
+ data = np.vstack([first_row, data])
36
+
37
+ # ถ้ามีข้อมูลมากกว่า sequence_length ให้ใช้แค่ sequence_length ล่าสุด
38
+ if data.shape[0] > sequence_length:
39
+ data = data[-sequence_length:]
40
+
41
+ # เพิ่มมิติ batch_size (=1)
42
+ data = data.reshape(1, sequence_length, -1)
43
+
44
+ return torch.FloatTensor(data)
45
+
46
+ def create_sequences(data, seq_length):
47
+ """
48
+ สร้างลำดับ (sequences) จากข้อมูล
49
+
50
+ Args:
51
+ data (numpy.ndarray): ข้อมูลต้นฉบับ
52
+ seq_length (int): ความยาวของลำดับเวลา
53
+
54
+ Returns:
55
+ numpy.ndarray: ข้อมูลในรูปแบบลำดับเวลา
56
+ """
57
+ xs = []
58
+ for i in range(len(data) - seq_length + 1):
59
+ x = data[i:(i + seq_length)]
60
+ xs.append(x)
61
+ return np.array(xs)
62
+
63
+ def init_firebase(credentials_json, database_url):
64
+ """
65
+ เริ่มต้นการเชื่อมต่อกับ Firebase
66
+
67
+ Args:
68
+ credentials_json (str): ข้อมูล JSON ของ Firebase credentials
69
+ database_url (str): URL ของ Firebase Realtime Database
70
+
71
+ Returns:
72
+ bool: True ถ้าเชื่อมต่อสำเร็จ, False ถ้าไม่สำเร็จ
73
+ """
74
+ if not firebase_admin._apps:
75
+ try:
76
+ # แปลง JSON string เป็น dictionary
77
+ cred_dict = json.loads(credentials_json)
78
+ cred = credentials.Certificate(cred_dict)
79
+ firebase_admin.initialize_app(cred, {
80
+ 'databaseURL': database_url
81
+ })
82
+ return True
83
+ except Exception as e:
84
+ print(f"เกิดข้อผิดพลาดในการเชื่อมต่อกับ Firebase: {str(e)}")
85
+ return False
86
+ return True
87
+
88
+ def get_data_from_firebase(ref_path='input_data'):
89
+ """
90
+ ดึงข้อมูลจาก Firebase Realtime Database
91
+
92
+ Args:
93
+ ref_path (str): พาธสำหรับดึงข้อมูลจาก Firebase
94
+
95
+ Returns:
96
+ dict/list: ข้อมูลที่ดึงมาจาก Firebase
97
+ """
98
+ try:
99
+ ref = db.reference(ref_path)
100
+ data = ref.get()
101
+ return data
102
+ except Exception as e:
103
+ print(f"เกิดข้อผิดพลาดในการดึงข้อมูลจาก Firebase: {str(e)}")
104
+ return None
105
+
106
+ def save_data_to_firebase(data, ref_path='prediction_results'):
107
+ """
108
+ บันทึกข้อมูลลงใน Firebase Realtime Database
109
+
110
+ Args:
111
+ data (dict/list): ข้อมูลที่ต้องการบันทึก
112
+ ref_path (str): พาธสำหรับบันทึกข้อมูลลงใน Firebase
113
+
114
+ Returns:
115
+ bool: True ถ้าบันทึกสำเร็จ, False ถ้าไม่สำเร็จ
116
+ """
117
+ try:
118
+ ref = db.reference(ref_path)
119
+ ref.set(data)
120
+ return True
121
+ except Exception as e:
122
+ print(f"เกิดข้อผิดพลาดในการบันทึกข้อมูลลงใน Firebase: {str(e)}")
123
+ return False
124
+
125
+ def load_scalers_and_encoders(model_path):
126
+ """
127
+ โหลด scalers และ encoders จากไฟล์โมเดล
128
+
129
+ Args:
130
+ model_path (str): พาธไปยังไฟล์โมเดล
131
+
132
+ Returns:
133
+ tuple: (numeric_scaler, label_encoders, y_scaler)
134
+ """
135
+ try:
136
+ checkpoint = torch.load(model_path, map_location='cpu')
137
+
138
+ # ตรวจสอบแต่ละกรณี
139
+ numeric_scaler = None
140
+ label_encoders = None
141
+ y_scaler = None
142
+
143
+ if isinstance(checkpoint, dict):
144
+ # กรณีที่มี key โดยตรง
145
+ numeric_scaler = checkpoint.get('numeric_scaler')
146
+ label_encoders = checkpoint.get('label_encoders')
147
+ y_scaler = checkpoint.get('y_scaler')
148
+
149
+ # กรณีที่เก็บไว้ใน key อื่น
150
+ if numeric_scaler is None and 'scalers' in checkpoint:
151
+ numeric_scaler = checkpoint['scalers'].get('numeric_scaler')
152
+
153
+ if y_scaler is None and 'scalers' in checkpoint:
154
+ y_scaler = checkpoint['scalers'].get('y_scaler')
155
+
156
+ if label_encoders is None and 'encoders' in checkpoint:
157
+ label_encoders = checkpoint['encoders'].get('label_encoders')
158
+
159
+ return numeric_scaler, label_encoders, y_scaler
160
+
161
+ except Exception as e:
162
+ print(f"เกิดข้อผิดพลาดในการโหลด scalers และ encoders: {str(e)}")
163
+ return None, None, None
164
+
165
+ def create_default_scaler():
166
+ """
167
+ สร้าง MinMaxScaler เริ่มต้น
168
+ """
169
+ scaler = MinMaxScaler(feature_range=(0, 1))
170
+ # กำหนดค่า min และ max เริ่มต้น
171
+ scaler.min_ = np.zeros(1)
172
+ scaler.scale_ = np.ones(1)
173
+ scaler.data_min_ = np.zeros(1)
174
+ scaler.data_max_ = np.ones(1)
175
+ scaler.data_range_ = np.ones(1)
176
+ scaler.n_samples_seen_ = 1
177
+ return scaler
178
+
179
+ def create_default_encoders(n_categories=2):
180
+ """
181
+ สร้าง LabelEncoder เริ่มต้น
182
+ """
183
+ encoders = []
184
+ for i in range(n_categories):
185
+ le = LabelEncoder()
186
+ # กำหนดค่าเริ่มต้น
187
+ le.classes_ = np.array(['class0', 'class1'])
188
+ encoders.append(le)
189
+ return encoders
190
+
191
+ def preprocess_data(data, numeric_features, categorical_features, numeric_scaler, label_encoders):
192
+ """
193
+ ประมวลผลข้อมูลก่อนการทำนาย
194
+
195
+ Args:
196
+ data (dict/list): ข้อมูลนำเข้า
197
+ numeric_features (list): รายชื่อคุณลักษณะตัวเลข
198
+ categorical_features (list): รายชื่อคุณลักษณะเชิงกลุ่ม
199
+ numeric_scaler (MinMaxScaler): scaler สำหรับข้อมูลตัวเลข
200
+ label_encoders (list): encoders สำหรับข้อมูลเชิงกลุ่ม
201
+
202
+ Returns:
203
+ numpy.ndarray: ข้อมูลที่ผ่านการประมวลผลแล้ว
204
+ """
205
+ try:
206
+ # ตรวจสอบรูปแบบข้อมูล
207
+ if isinstance(data, list) and all(isinstance(item, dict) for item in data):
208
+ # กรณีที่ข้อมูลเป็นลิสต์ของ dict (หลาย timestep)
209
+ X_numeric = np.array([[item[feature] for feature in numeric_features] for item in data])
210
+ X_categorical = np.array([[item[feature] for feature in categorical_features] for item in data])
211
+ elif isinstance(data, dict):
212
+ # กรณีที่ข้อมูลเป็น dict เดียว (single timestep)
213
+ X_numeric = np.array([[data[feature] for feature in numeric_features]])
214
+ X_categorical = np.array([[data[feature] for feature in categorical_features]])
215
+ else:
216
+ raise ValueError("รูปแบบข้อมูลไม่ถูกต้อง ต้องเป็น dict หรือ list ของ dict")
217
+
218
+ # ตรวจสอบ scaler และ encoders
219
+ if numeric_scaler is None:
220
+ print("Warning: ไม่พบ numeric_scaler จะสร้างใหม่")
221
+ numeric_scaler = create_default_scaler()
222
+
223
+ if label_encoders is None or len(label_encoders) != len(categorical_features):
224
+ print("Warning: label_encoders ไม่ถูกต้อง จะสร้างใหม่")
225
+ label_encoders = create_default_encoders(len(categorical_features))
226
+
227
+ # ปรับสเกลข้อมูลตัวเลข
228
+ X_numeric_scaled = numeric_scaler.transform(X_numeric)
229
+
230
+ # Encode ข้อมูลเชิงกลุ่ม
231
+ X_categorical_encoded = []
232
+ for i, encoder in enumerate(label_encoders):
233
+ try:
234
+ # พยายาม transform ข้อมูล
235
+ encoded_col = encoder.transform(X_categorical[:, i])
236
+ except (ValueError, IndexError) as e:
237
+ # ถ้าเกิดข้อผิดพลาด (เช่น พบค่าที่ไม่เคยเห็น)
238
+ print(f"Warning: เกิดข้อผิดพลาดในการ encode คุณลักษณะที่ {i}: {str(e)}")
239
+ print(f"จะใช้ค่า 0 แทน")
240
+ # ใช้ค่า 0 แทน
241
+ encoded_col = np.zeros(X_categorical.shape[0], dtype=np.int64)
242
+
243
+ X_categorical_encoded.append(encoded_col)
244
+
245
+ # รวมข้อมูล
246
+ X_categorical_encoded = np.column_stack(X_categorical_encoded) if X_categorical_encoded else np.array([])
247
+
248
+ if X_categorical_encoded.size > 0:
249
+ # ถ้ามีข้อมูลเชิงกลุ่ม ให้รวมกับข้อมูลตัวเลข
250
+ X_encoded = np.concatenate([X_numeric_scaled, X_categorical_encoded], axis=1)
251
+ else:
252
+ # ถ้าไม่มีข้อมูลเชิงกลุ่ม ใช้เฉพาะข้อมูลตัวเลข
253
+ X_encoded = X_numeric_scaled
254
+
255
+ return X_encoded
256
+
257
+ except Exception as e:
258
+ print(f"เกิดข้อผิดพลาดในการประมวลผลข้อมูล: {str(e)}")
259
+ raise e
260
+
261
+ def get_file_download_link(data, filename, text="Download File"):
262
+ """
263
+ สร้างลิงก์สำหรับดาวน์โหลดไฟล์
264
+
265
+ Args:
266
+ data: ข้อมูลที่ต้องการให้ดาวน์โหลด
267
+ filename (str): ชื่อไฟล์
268
+ text (str): ข้อความที่แสดงบนลิงก์
269
+
270
+ Returns:
271
+ str: HTML ลิงก์สำหรับดาวน์โหลด
272
+ """
273
+ b64 = base64.b64encode(data).decode()
274
+ href = f'<a href="data:application/octet-stream;base64,{b64}" download="{filename}">{text}</a>'
275
+ return href
276
+
277
+ def save_scaler_to_bytes(scaler):
278
+ """
279
+ แปลง scaler เป็น bytes สำหรับดาวน์โหลด
280
+ """
281
+ bytes_io = io.BytesIO()
282
+ pickle.dump(scaler, bytes_io)
283
+ bytes_io.seek(0)
284
+ return bytes_io.read()
285
+
286
+ def save_encoders_to_bytes(encoders):
287
+ """
288
+ แปลง encoders เป็น bytes สำหรับดาวน์โหลด
289
+ """
290
+ bytes_io = io.BytesIO()
291
+ pickle.dump(encoders, bytes_io)
292
+ bytes_io.seek(0)
293
+ return bytes_io.read()