Chayanat commited on
Commit
a52baac
·
verified ·
1 Parent(s): d017c89

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +117 -54
model.py CHANGED
@@ -1,5 +1,7 @@
1
  import torch
2
  import torch.nn as nn
 
 
3
 
4
  class GRUModel(nn.Module):
5
  def __init__(self, input_size, hidden_size, num_layers, output_size, dropout_rate):
@@ -40,39 +42,90 @@ class GRUModel(nn.Module):
40
 
41
  return out
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  def load_model(model_path, device='cpu'):
44
  """
45
- โหลดโมเดล GRU จากไฟล์ .pth ทั้งแบบที่มีแค่ state_dict และแบบที่มีข้อมูล hyperparameters
 
 
 
 
 
 
 
46
  """
47
  try:
48
  # โหลดไฟล์โมเดล
49
  checkpoint = torch.load(model_path, map_location=device)
50
 
51
  # ตรวจสอบโครงสร้างของไฟล์โมเดล
52
- if isinstance(checkpoint, dict):
53
- # ตรวจสอบว่ามี model_state_dict หรือไม่
54
- if 'model_state_dict' in checkpoint:
55
- model_state = checkpoint['model_state_dict']
56
- # ตรวจสอบว่ามี hyperparameters หรือไม่
57
- if 'hyperparameters' in checkpoint:
58
- hyperparams = checkpoint['hyperparameters']
59
- else:
60
- # กรณีไม่มี hyperparameters แต่มี model_state_dict
61
- print("ไม่พบ hyperparameters ในไฟล์โมเดล จะใช้ค่าที่ดึงจาก state_dict แทน")
62
- hyperparams = extract_hyperparams_from_state_dict(model_state)
63
- elif 'hyperparameters' in checkpoint:
64
- # กรณีมี hyperparameters แต่ไม่มี model_state_dict
65
- model_state = checkpoint
66
  hyperparams = checkpoint['hyperparameters']
67
- else:
68
- # กรณีเป็น state_dict เปลาๆ
69
- model_state = checkpoint
70
- hyperparams = extract_hyperparams_from_state_dict(model_state)
 
 
 
 
71
  else:
72
- # กรณีไฟล์ไม่ได้เป็น dict
73
  model_state = checkpoint
 
 
 
74
  hyperparams = extract_hyperparams_from_state_dict(model_state)
75
-
 
 
 
 
 
 
 
 
 
 
76
  # สร้างโมเดล
77
  model = GRUModel(
78
  input_size=hyperparams['input_size'],
@@ -88,44 +141,54 @@ def load_model(model_path, device='cpu'):
88
  # ตั้งค่าโมเดลให้อยู่ในโหมดทำนาย
89
  model.eval()
90
 
 
 
 
 
91
  return model, hyperparams
92
 
93
  except Exception as e:
94
  print(f"เกิดข้อผิดพลาดในการโหลดโมเดล: {str(e)}")
95
- # ในกรณีที่มีปัญหา ให้ส่งค่า None กลับไป
96
  return None, None
97
 
98
- def extract_hyperparams_from_state_dict(state_dict):
99
  """
100
- วิเคราะห์ค่า hyperparameters จา state_dict ขอโมเดล
101
  """
102
- hyperparams = {
103
- 'input_size': None,
104
- 'hidden_size': None,
105
- 'num_layers': 1, # ค่าเริ่มต้น
106
- 'output_size': None,
107
- 'dropout_rate': 0.0 # ค่าเริ่มต้น
108
- }
109
-
110
- # ตรวจหาค่า hidden_size จาก weight ของ GRU
111
- if 'gru.weight_ih_l0' in state_dict:
112
- # รูปแบบของ weight_ih_l0 คือ [3*hidden_size, input_size]
113
- weight_shape = state_dict['gru.weight_ih_l0'].shape
114
- hyperparams['hidden_size'] = weight_shape[0] // 3
115
- hyperparams['input_size'] = weight_shape[1]
116
-
117
- # ตรวจหาค่า output_size จาก weight ของ fully connected layer
118
- if 'fc.weight' in state_dict:
119
- # รูปแบบของ fc.weight คือ [output_size, hidden_size]
120
- fc_shape = state_dict['fc.weight'].shape
121
- hyperparams['output_size'] = fc_shape[0]
122
-
123
- # นับจำนวนชั้นของ GRU จากชื่อของ parameter
124
- layer_num = 0
125
- while f'gru.weight_ih_l{layer_num}' in state_dict:
126
- layer_num += 1
127
- hyperparams['num_layers'] = layer_num
128
-
129
- print(f"สกัดค่า hyperparameters จาก state_dict: {hyperparams}")
130
- return hyperparams
131
-
 
 
 
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
+ import numpy as np
4
+ import json
5
 
6
  class GRUModel(nn.Module):
7
  def __init__(self, input_size, hidden_size, num_layers, output_size, dropout_rate):
 
42
 
43
  return out
44
 
45
+ def extract_hyperparams_from_state_dict(state_dict):
46
+ """
47
+ วิเคราะห์ค่า hyperparameters จาก state_dict ของโมเดล
48
+ """
49
+ hyperparams = {
50
+ 'input_size': None,
51
+ 'hidden_size': None,
52
+ 'num_layers': 1, # ค่าเริ่มต้น
53
+ 'output_size': None,
54
+ 'dropout_rate': 0.1 # ค่าเริ่มต้น
55
+ }
56
+
57
+ # ตรวจหาค่า hidden_size จาก weight ของ GRU
58
+ if 'gru.weight_ih_l0' in state_dict:
59
+ # รูปแบบของ weight_ih_l0 คือ [3*hidden_size, input_size]
60
+ weight_shape = state_dict['gru.weight_ih_l0'].shape
61
+ hyperparams['hidden_size'] = weight_shape[0] // 3
62
+ hyperparams['input_size'] = weight_shape[1]
63
+
64
+ # ตรวจหาค่า output_size จาก weight ของ fully connected layer
65
+ if 'fc.weight' in state_dict:
66
+ # รูปแบบของ fc.weight คือ [output_size, hidden_size]
67
+ fc_shape = state_dict['fc.weight'].shape
68
+ hyperparams['output_size'] = fc_shape[0]
69
+
70
+ # นับจำนวนชั้นของ GRU จากชื่อของ parameter
71
+ layer_num = 0
72
+ while f'gru.weight_ih_l{layer_num}' in state_dict:
73
+ layer_num += 1
74
+ hyperparams['num_layers'] = layer_num
75
+
76
+ print(f"สกัดค่า hyperparameters จาก state_dict: {hyperparams}")
77
+ return hyperparams
78
+
79
  def load_model(model_path, device='cpu'):
80
  """
81
+ โหลดโมเดล GRU จากไฟล์ .pth
82
+
83
+ Args:
84
+ model_path (str): พาธไปยังไฟล์โมเดล
85
+ device (str): อุปกรณ์ที่ใช้ ('cpu' หรือ 'cuda')
86
+
87
+ Returns:
88
+ tuple: (model, hyperparams) - โมเดลและพารามิเตอร์ของโมเดล
89
  """
90
  try:
91
  # โหลดไฟล์โมเดล
92
  checkpoint = torch.load(model_path, map_location=device)
93
 
94
  # ตรวจสอบโครงสร้างของไฟล์โมเดล
95
+ model_state = None
96
+ hyperparams = None
97
+
98
+ # กรณีที่ 1: ไฟล์เป็น dictionary และมี model_state_dict
99
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
100
+ model_state = checkpoint['model_state_dict']
101
+ # ดึง hyperparameters ถ้ามี
102
+ if 'hyperparameters' in checkpoint:
 
 
 
 
 
 
103
  hyperparams = checkpoint['hyperparameters']
104
+
105
+ # กรณีที่ 2: ไฟล์เป็น dictionary แตไม่มี model_state_dict
106
+ elif isinstance(checkpoint, dict) and 'hyperparameters' in checkpoint:
107
+ # สมมติว่า state_dict อยู่ในระดับบนสุด
108
+ model_state = {k: v for k, v in checkpoint.items() if k != 'hyperparameters'}
109
+ hyperparams = checkpoint['hyperparameters']
110
+
111
+ # กรณีที่ 3: ไฟล์เป็น state_dict โดยตรง
112
  else:
 
113
  model_state = checkpoint
114
+
115
+ # ถ้าไม่มี hyperparams ใ��้สกัดจาก state_dict
116
+ if hyperparams is None and model_state is not None:
117
  hyperparams = extract_hyperparams_from_state_dict(model_state)
118
+
119
+ # ตรวจสอบว่ามี hyperparams ครบหรือไม่
120
+ required_params = ['input_size', 'hidden_size', 'output_size', 'num_layers', 'dropout_rate']
121
+ if not all(param in hyperparams for param in required_params):
122
+ print(f"Warning: ไม่พบ hyperparameters บางตัว จะใช้ค่าเริ่มต้น")
123
+ # กำหนดค่าเริ่มต้นสำหรับพารามิเตอร์ที่ขาดหายไป
124
+ defaults = {'input_size': 10, 'hidden_size': 64, 'output_size': 1, 'num_layers': 2, 'dropout_rate': 0.1}
125
+ for param in required_params:
126
+ if param not in hyperparams:
127
+ hyperparams[param] = defaults[param]
128
+
129
  # สร้างโมเดล
130
  model = GRUModel(
131
  input_size=hyperparams['input_size'],
 
141
  # ตั้งค่าโมเดลให้อยู่ในโหมดทำนาย
142
  model.eval()
143
 
144
+ # แสดงข้อมูลโมเดล
145
+ print(f"โหลดโมเดลสำเร็จ: input_size={hyperparams['input_size']}, hidden_size={hyperparams['hidden_size']}, "
146
+ f"num_layers={hyperparams['num_layers']}, output_size={hyperparams['output_size']}")
147
+
148
  return model, hyperparams
149
 
150
  except Exception as e:
151
  print(f"เกิดข้อผิดพลาดในการโหลดโมเดล: {str(e)}")
 
152
  return None, None
153
 
154
+ def save_model_info(model, hyperparams, file_path):
155
  """
156
+ บันทึกขมูลโมเดลเป็นไฟล์ JSON
157
  """
158
+ try:
159
+ model_info = {
160
+ "hyperparameters": hyperparams,
161
+ "structure": {
162
+ "type": "GRU",
163
+ "layers": []
164
+ }
165
+ }
166
+
167
+ # เพิ่มข้อมูลเกี่ยวกับชั้นของโมเดล
168
+ model_info["structure"]["layers"].append({
169
+ "name": "GRU",
170
+ "input_size": hyperparams["input_size"],
171
+ "hidden_size": hyperparams["hidden_size"],
172
+ "num_layers": hyperparams["num_layers"],
173
+ "dropout_rate": hyperparams["dropout_rate"]
174
+ })
175
+
176
+ model_info["structure"]["layers"].append({
177
+ "name": "Dropout",
178
+ "rate": hyperparams["dropout_rate"]
179
+ })
180
+
181
+ model_info["structure"]["layers"].append({
182
+ "name": "Linear",
183
+ "in_features": hyperparams["hidden_size"],
184
+ "out_features": hyperparams["output_size"]
185
+ })
186
+
187
+ # บันทึกเป็นไฟล์ JSON
188
+ with open(file_path, 'w') as f:
189
+ json.dump(model_info, f, indent=4)
190
+
191
+ return True
192
+ except Exception as e:
193
+ print(f"เกิดข้อผิดพลาดในการบันทึกข้อมูลโมเดล: {str(e)}")
194
+ return False