Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import json | |
| class GRUModel(nn.Module): | |
| def __init__(self, input_size, hidden_size, num_layers, output_size, dropout_rate): | |
| super(GRUModel, self).__init__() | |
| self.hidden_size = hidden_size | |
| self.num_layers = num_layers | |
| # ชั้น GRU | |
| self.gru = nn.GRU( | |
| input_size=input_size, | |
| hidden_size=hidden_size, | |
| num_layers=num_layers, | |
| batch_first=True, | |
| dropout=dropout_rate if num_layers > 1 else 0 | |
| ) | |
| # ชั้น Dropout | |
| self.dropout = nn.Dropout(dropout_rate) | |
| # ชั้น Fully Connected สำหรับเอาต์พุตสุดท้าย | |
| self.fc = nn.Linear(hidden_size, output_size) | |
| def forward(self, x): | |
| # กำหนดค่าเริ่มต้นของ hidden state | |
| h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device) | |
| # ส่งข้อมูลผ่านชั้น GRU | |
| out, _ = self.gru(x, h0) | |
| # เลือกผลลัพธ์สุดท้ายจากลำดับ | |
| out = out[:, -1, :] | |
| # ส่งผ่านชั้น Dropout | |
| out = self.dropout(out) | |
| # ส่งผ่านชั้น Fully Connected | |
| out = self.fc(out) | |
| return out | |
| def extract_hyperparams_from_state_dict(state_dict): | |
| """ | |
| วิเคราะห์ค่า hyperparameters จาก state_dict ของโมเดล | |
| """ | |
| hyperparams = { | |
| 'input_size': None, | |
| 'hidden_size': None, | |
| 'num_layers': 1, # ค่าเริ่มต้น | |
| 'output_size': None, | |
| 'dropout_rate': 0.1 # ค่าเริ่มต้น | |
| } | |
| # ตรวจหาค่า hidden_size จาก weight ของ GRU | |
| if 'gru.weight_ih_l0' in state_dict: | |
| # รูปแบบของ weight_ih_l0 คือ [3*hidden_size, input_size] | |
| weight_shape = state_dict['gru.weight_ih_l0'].shape | |
| hyperparams['hidden_size'] = weight_shape[0] // 3 | |
| hyperparams['input_size'] = weight_shape[1] | |
| # ตรวจหาค่า output_size จาก weight ของ fully connected layer | |
| if 'fc.weight' in state_dict: | |
| # รูปแบบของ fc.weight คือ [output_size, hidden_size] | |
| fc_shape = state_dict['fc.weight'].shape | |
| hyperparams['output_size'] = fc_shape[0] | |
| # นับจำนวนชั้นของ GRU จากชื่อของ parameter | |
| layer_num = 0 | |
| while f'gru.weight_ih_l{layer_num}' in state_dict: | |
| layer_num += 1 | |
| hyperparams['num_layers'] = layer_num | |
| print(f"สกัดค่า hyperparameters จาก state_dict: {hyperparams}") | |
| return hyperparams | |
| def load_model(model_path, device='cpu'): | |
| """ | |
| โหลดโมเดล GRU จากไฟล์ .pth | |
| Args: | |
| model_path (str): พาธไปยังไฟล์โมเดล | |
| device (str): อุปกรณ์ที่ใช้ ('cpu' หรือ 'cuda') | |
| Returns: | |
| tuple: (model, hyperparams) - โมเดลและพารามิเตอร์ของโมเดล | |
| """ | |
| try: | |
| # โหลดไฟล์โมเดล | |
| checkpoint = torch.load(model_path, map_location=device) | |
| # ตรวจสอบโครงสร้างของไฟล์โมเดล | |
| model_state = None | |
| hyperparams = None | |
| # กรณีที่ 1: ไฟล์เป็น dictionary และมี model_state_dict | |
| if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: | |
| model_state = checkpoint['model_state_dict'] | |
| # ดึง hyperparameters ถ้ามี | |
| if 'hyperparameters' in checkpoint: | |
| hyperparams = checkpoint['hyperparameters'] | |
| # กรณีที่ 2: ไฟล์เป็น dictionary แต่ไม่มี model_state_dict | |
| elif isinstance(checkpoint, dict) and 'hyperparameters' in checkpoint: | |
| # สมมติว่า state_dict อยู่ในระดับบนสุด | |
| model_state = {k: v for k, v in checkpoint.items() if k != 'hyperparameters'} | |
| hyperparams = checkpoint['hyperparameters'] | |
| # กรณีที่ 3: ไฟล์เป็น state_dict โดยตรง | |
| else: | |
| model_state = checkpoint | |
| # ถ้าไม่มี hyperparams ให้สกัดจาก state_dict | |
| if hyperparams is None and model_state is not None: | |
| hyperparams = extract_hyperparams_from_state_dict(model_state) | |
| # ตรวจสอบว่ามี hyperparams ครบหรือไม่ | |
| required_params = ['input_size', 'hidden_size', 'output_size', 'num_layers', 'dropout_rate'] | |
| if not all(param in hyperparams for param in required_params): | |
| print(f"Warning: ไม่พบ hyperparameters บางตัว จะใช้ค่าเริ่มต้น") | |
| # กำหนดค่าเริ่มต้นสำหรับพารามิเตอร์ที่ขาดหายไป | |
| defaults = {'input_size': 10, 'hidden_size': 64, 'output_size': 1, 'num_layers': 2, 'dropout_rate': 0.1} | |
| for param in required_params: | |
| if param not in hyperparams: | |
| hyperparams[param] = defaults[param] | |
| # สร้างโมเดล | |
| model = GRUModel( | |
| input_size=hyperparams['input_size'], | |
| hidden_size=hyperparams['hidden_size'], | |
| num_layers=hyperparams['num_layers'], | |
| output_size=hyperparams['output_size'], | |
| dropout_rate=hyperparams['dropout_rate'] | |
| ) | |
| # โหลด state_dict เข้าไปในโมเดล | |
| model.load_state_dict(model_state) | |
| # ตั้งค่าโมเดลให้อยู่ในโหมดทำนาย | |
| model.eval() | |
| # แสดงข้อมูลโมเดล | |
| print(f"โหลดโมเดลสำเร็จ: input_size={hyperparams['input_size']}, hidden_size={hyperparams['hidden_size']}, " | |
| f"num_layers={hyperparams['num_layers']}, output_size={hyperparams['output_size']}") | |
| return model, hyperparams | |
| except Exception as e: | |
| print(f"เกิดข้อผิดพลาดในการโหลดโมเดล: {str(e)}") | |
| return None, None | |
| def save_model_info(model, hyperparams, file_path): | |
| """ | |
| บันทึกข้อมูลโมเดลเป็นไฟล์ JSON | |
| """ | |
| try: | |
| model_info = { | |
| "hyperparameters": hyperparams, | |
| "structure": { | |
| "type": "GRU", | |
| "layers": [] | |
| } | |
| } | |
| # เพิ่มข้อมูลเกี่ยวกับชั้นของโมเดล | |
| model_info["structure"]["layers"].append({ | |
| "name": "GRU", | |
| "input_size": hyperparams["input_size"], | |
| "hidden_size": hyperparams["hidden_size"], | |
| "num_layers": hyperparams["num_layers"], | |
| "dropout_rate": hyperparams["dropout_rate"] | |
| }) | |
| model_info["structure"]["layers"].append({ | |
| "name": "Dropout", | |
| "rate": hyperparams["dropout_rate"] | |
| }) | |
| model_info["structure"]["layers"].append({ | |
| "name": "Linear", | |
| "in_features": hyperparams["hidden_size"], | |
| "out_features": hyperparams["output_size"] | |
| }) | |
| # บันทึกเป็นไฟล์ JSON | |
| with open(file_path, 'w') as f: | |
| json.dump(model_info, f, indent=4) | |
| return True | |
| except Exception as e: | |
| print(f"เกิดข้อผิดพลาดในการบันทึกข้อมูลโมเดล: {str(e)}") | |
| return False |