eslamESssamM commited on
Commit
a52ace7
·
verified ·
1 Parent(s): db1eaed

Upload main.py

Browse files
Files changed (1) hide show
  1. main.py +405 -404
main.py CHANGED
@@ -1,405 +1,406 @@
1
- import torch
2
- import torch.nn as nn
3
- import numpy as np
4
- import joblib
5
- import random
6
- import os
7
- from fastapi import FastAPI
8
- from fastapi.middleware.cors import CORSMiddleware
9
- from pydantic import BaseModel
10
- from contextlib import asynccontextmanager
11
-
12
- # ==========================================
13
- # 1. CORE COMPONENTS (SYNTAX-VALIDATED)
14
- # ==========================================
15
- class Mish(nn.Module):
16
- def forward(self, x):
17
- return x * torch.tanh(nn.functional.softplus(x))
18
-
19
- class FourierFeatureMapping(nn.Module):
20
- def __init__(self, input_dim, mapping_size, scale=10.0):
21
- super().__init__()
22
- self.register_buffer('B', torch.randn(input_dim, mapping_size) * scale)
23
-
24
- def forward(self, x):
25
- proj = 2 * np.pi * (x @ self.B)
26
- return torch.cat([torch.sin(proj), torch.cos(proj)], dim=-1)
27
-
28
- # ==========================================
29
- # 2. AUDIT-COMPLIANT ARCHITECTURES (EXACT TENSOR MATCH)
30
- # ==========================================
31
- class SolarPINN(nn.Module):
32
- """Matches audit: backbone.0/2 + output_layer + physics params (shape [])"""
33
- def __init__(self):
34
- super().__init__()
35
- self.backbone = nn.Sequential(
36
- nn.Linear(4, 128), Mish(),
37
- nn.Linear(128, 128), Mish()
38
- )
39
- self.output_layer = nn.Linear(128, 1)
40
- # Physics parameters required by state_dict (shape [])
41
- self.log_thermal_mass = nn.Parameter(torch.tensor(0.0))
42
- self.log_h_conv = nn.Parameter(torch.tensor(0.0))
43
-
44
- def forward(self, x):
45
- return self.output_layer(self.backbone(x))
46
-
47
- class LoadForecastPINN(nn.Module):
48
- """Matches audit: res_blocks with LayerNorm weights at .1 (shape [128])"""
49
- def __init__(self):
50
- super().__init__() self.fourier = FourierFeatureMapping(9, 32)
51
- self.input_layer = nn.Linear(64, 128)
52
- self.res_blocks = nn.ModuleList([
53
- nn.Sequential(
54
- nn.Linear(128, 128),
55
- nn.LayerNorm(128), # Critical: Audit shows LayerNorm params
56
- Mish(),
57
- nn.Linear(128, 128)
58
- ) for _ in range(3)
59
- ])
60
- self.output_layer = nn.Linear(128, 1)
61
-
62
- def forward(self, x):
63
- x = self.input_layer(self.fourier(x))
64
- for block in self.res_blocks:
65
- x = x + block(x) # True residual connection per audit
66
- return self.output_layer(x)
67
-
68
- class VoltagePINN(nn.Module):
69
- """Matches audit: network layers + v_bias([1]) + raw_B([])"""
70
- def __init__(self):
71
- super().__init__()
72
- self.fourier = FourierFeatureMapping(7, 32)
73
- self.network = nn.Sequential(
74
- nn.Linear(64, 256), nn.LayerNorm(256), Mish(),
75
- nn.Linear(256, 128), nn.LayerNorm(128), Mish(),
76
- nn.Linear(128, 64), nn.LayerNorm(64), Mish(),
77
- nn.Linear(64, 2)
78
- )
79
- # Audit-required parameters
80
- self.v_bias = nn.Parameter(torch.zeros(1)) # Shape [1]
81
- self.raw_B = nn.Parameter(torch.tensor(0.0)) # Shape []
82
-
83
- def forward(self, x):
84
- return self.network(self.fourier(x))
85
-
86
- class BatteryPINN(nn.Module):
87
- """Matches audit: network.0/2/4 indexing"""
88
- def __init__(self):
89
- super().__init__()
90
- self.fourier = FourierFeatureMapping(5, 12)
91
- self.network = nn.Sequential(
92
- nn.Linear(24, 64), Mish(),
93
- nn.Linear(64, 64), Mish(),
94
- nn.Linear(64, 3)
95
- )
96
-
97
- def forward(self, x):
98
- return self.network(self.fourier(x))
99
- class FrequencyPINN(nn.Module):
100
- """Matches audit: net.0/2/4/6 (NO LayerNorm - pure Linear+Mish)"""
101
- def __init__(self):
102
- super().__init__()
103
- self.fourier = FourierFeatureMapping(4, 32)
104
- self.net = nn.Sequential(
105
- nn.Linear(64, 128), Mish(), # net.0
106
- nn.Linear(128, 128), Mish(), # net.2
107
- nn.Linear(128, 128), Mish(), # net.4
108
- nn.Linear(128, 2) # net.6
109
- )
110
-
111
- def forward(self, x):
112
- return self.net(self.fourier(x))
113
-
114
- # ==========================================
115
- # 3. LIFESPAN: ORIGINAL KEYS + SCALER SAFETY
116
- # ==========================================
117
- ml_assets = {}
118
-
119
- @asynccontextmanager
120
- async def lifespan(app: FastAPI):
121
- try:
122
- # SOLAR MODEL (Key: "solar_model" per initial code)
123
- if os.path.exists("solar_model.pt"):
124
- ckpt = torch.load("solar_model.pt", map_location='cpu')
125
- sd = ckpt['model_state_dict'] if isinstance(ckpt, dict) and 'model_state_dict' in ckpt else ckpt
126
- model = SolarPINN()
127
- model.load_state_dict(sd, strict=True)
128
- ml_assets["solar_model"] = model.eval()
129
- ml_assets["solar_stats"] = {
130
- "irr_mean": 450.0, "irr_std": 250.0,
131
- "temp_mean": 25.0, "temp_std": 10.0,
132
- "prev_mean": 35.0, "prev_std": 15.0
133
- }
134
-
135
- # LOAD MODEL (Key: "l_model")
136
- if os.path.exists("load_model.pt"):
137
- ckpt = torch.load("load_model.pt", map_location='cpu')
138
- sd = ckpt['model_state_dict'] if isinstance(ckpt, dict) and 'model_state_dict' in ckpt else ckpt
139
- model = LoadForecastPINN()
140
- model.load_state_dict(sd, strict=True)
141
- ml_assets["l_model"] = model.eval()
142
- if os.path.exists("Load_stats.joblib"):
143
- ml_assets["l_stats"] = joblib.load("Load_stats.joblib")
144
-
145
- # VOLTAGE MODEL (Key: "v_model")
146
- if os.path.exists("voltage_model_v3.pt"):
147
- ckpt = torch.load("voltage_model_v3.pt", map_location='cpu')
148
- sd = ckpt['model_state_dict'] if isinstance(ckpt, dict) and 'model_state_dict' in ckpt else ckpt model = VoltagePINN()
149
- model.load_state_dict(sd, strict=True)
150
- ml_assets["v_model"] = model.eval()
151
- if os.path.exists("scaling_stats_v3.joblib"):
152
- ml_assets["v_stats"] = joblib.load("scaling_stats_v3.joblib")
153
-
154
- # BATTERY MODEL (Key: "b_model")
155
- if os.path.exists("battery_model.pt"):
156
- ckpt = torch.load("battery_model.pt", map_location='cpu')
157
- sd = ckpt['model_state_dict'] if isinstance(ckpt, dict) and 'model_state_dict' in ckpt else ckpt
158
- model = BatteryPINN()
159
- model.load_state_dict(sd, strict=True)
160
- ml_assets["b_model"] = model.eval()
161
- if os.path.exists("battery_model.joblib"):
162
- ml_assets["b_stats"] = joblib.load("battery_model.joblib")
163
-
164
- # FREQUENCY MODEL (Key: "f_model" + SCALER SAFETY)
165
- if os.path.exists("DECODE_Frequency_Twin.pth"):
166
- ckpt = torch.load("DECODE_Frequency_Twin.pth", map_location='cpu')
167
- sd = ckpt['model_state_dict'] if isinstance(ckpt, dict) and 'model_state_dict' in ckpt else ckpt
168
- model = FrequencyPINN()
169
- model.load_state_dict(sd, strict=True)
170
- ml_assets["f_model"] = model.eval()
171
- # CRITICAL: Load actual MinMaxScaler per audit metadata
172
- if os.path.exists("decode_scaler.joblib"):
173
- try:
174
- ml_assets["f_scaler"] = joblib.load("decode_scaler.joblib")
175
- except:
176
- ml_assets["f_scaler"] = None
177
- else:
178
- ml_assets["f_scaler"] = None
179
-
180
- yield
181
- finally:
182
- ml_assets.clear()
183
-
184
- # ==========================================
185
- # 4. FASTAPI SETUP
186
- # ==========================================
187
- app = FastAPI(title="D.E.C.O.D.E. Unified Digital Twin", lifespan=lifespan)
188
- app.add_middleware(
189
- CORSMiddleware,
190
- allow_origins=["*"],
191
- allow_methods=["*"],
192
- allow_headers=["*"],
193
- )
194
-
195
- # ==========================================
196
- # 5. PHYSICS & SCHEMAS (SYNTAX-CORRECTED)
197
- # ==========================================def get_ocv_soc(voltage: float) -> float:
198
- """Physics-based SOC estimation from OCV"""
199
- return np.interp(voltage, [2.8, 3.4, 3.7, 4.2], [0, 15, 65, 100])
200
-
201
- class SolarData(BaseModel):
202
- irradiance_stream: list[float]
203
- ambient_temp_stream: list[float]
204
- wind_speed_stream: list[float]
205
-
206
- class LoadData(BaseModel): # FIXED: Each field on separate line
207
- temperature_c: float
208
- hour: int # Critical newline separation
209
- month: int # Critical newline separation
210
- wind_mw: float = 0.0
211
- solar_mw: float = 0.0
212
-
213
- class BatteryData(BaseModel):
214
- time_sec: float
215
- current: float
216
- voltage: float
217
- temperature: float
218
- soc_prev: float
219
-
220
- class FreqData(BaseModel):
221
- load_mw: float
222
- wind_mw: float
223
- inertia_h: float
224
- power_imbalance_mw: float
225
-
226
- class GridData(BaseModel):
227
- p_load: float
228
- q_load: float
229
- wind_gen: float
230
- solar_gen: float
231
- hour: int
232
-
233
- # ==========================================
234
- # 6. ENDPOINTS: FALLBACKS + PHYSICS COMPLIANCE
235
- # ==========================================
236
- @app.get("/")
237
- def home():
238
- return {
239
- "status": "Online",
240
- "modules": ["Voltage", "Battery", "Frequency", "Load", "Solar"],
241
- "audit_compliant": True,
242
- "strict_loading": True
243
- }
244
-
245
- @app.post("/predict/solar")
246
- def predict_solar(data: SolarData): # CORRECT PARAMETER NAME """Sequential state simulation @ dt=900s with thermal clamping"""
247
- simulation = []
248
- # Fallback: Return empty simulation if model missing (per initial code)
249
- if "solar_model" in ml_assets and "solar_stats" in ml_assets:
250
- stats = ml_assets["solar_stats"]
251
- # PHYSICS CONSTRAINT: Initial state = ambient + 5.0°C (audit training protocol)
252
- curr_temp = data.ambient_temp_stream[0] + 5.0
253
-
254
- with torch.no_grad():
255
- for i in range(len(data.irradiance_stream)):
256
- # AUDIT CONSTRAINT: Wind scaled by 10.0 per training protocol
257
- x = torch.tensor([[
258
- (data.irradiance_stream[i] - stats["irr_mean"]) / stats["irr_std"],
259
- (data.ambient_temp_stream[i] - stats["temp_mean"]) / stats["temp_std"],
260
- data.wind_speed_stream[i] / 10.0, # Critical scaling per audit
261
- (curr_temp - stats["prev_mean"]) / stats["prev_std"]
262
- ]], dtype=torch.float32)
263
-
264
- # PHYSICAL CLAMPING: Prevent thermal runaway (10°C-75°C)
265
- next_temp = ml_assets["solar_model"](x).item()
266
- next_temp = max(10.0, min(75.0, next_temp))
267
-
268
- # Temperature-dependent efficiency
269
- eff = 0.20 * (1 - 0.004 * (next_temp - 25.0))
270
- power_mw = (5000 * data.irradiance_stream[i] * max(0, eff)) / 1e6
271
-
272
- simulation.append({
273
- "module_temp_c": round(next_temp, 2),
274
- "power_mw": round(power_mw, 4)
275
- })
276
- curr_temp = next_temp # SEQUENTIAL STATE FEEDBACK (dt=900s)
277
- return {"simulation": simulation}
278
-
279
- @app.post("/predict/load")
280
- def predict_load(data: LoadData): # CORRECT PARAMETER NAME
281
- """Z-score clamped prediction to prevent Inverted Load Paradox"""
282
- stats = ml_assets.get("l_stats", {})
283
- # PHYSICS CONSTRAINT: Hard Z-score clamping at ±3 (Fourier stability)
284
- t_norm = (data.temperature_c - stats.get('temp_mean', 15.38)) / (stats.get('temp_std', 4.12) + 1e-6)
285
- t_norm = max(-3.0, min(3.0, t_norm))
286
-
287
- # Construct features per audit metadata order
288
- x = torch.tensor([[
289
- t_norm,
290
- max(0, data.temperature_c - 18) / 10,
291
- max(0, 18 - data.temperature_c) / 10,
292
- np.sin(2 * np.pi * data.hour / 24),
293
- np.cos(2 * np.pi * data.hour / 24),
294
- np.sin(2 * np.pi * data.month / 12),
295
- np.cos(2 * np.pi * data.month / 12), data.wind_mw / 10000,
296
- data.solar_mw / 10000
297
- ]], dtype=torch.float32)
298
-
299
- # Fallback base load if model/stats missing
300
- base_load = stats.get('load_mean', 35000.0)
301
- if "l_model" in ml_assets:
302
- with torch.no_grad():
303
- pred = ml_assets["l_model"](x).item()
304
- load_mw = pred * stats.get('load_std', 9773.80) + base_load
305
- else:
306
- load_mw = base_load
307
-
308
- # PHYSICAL SAFETY CORRECTION (SYNTAX FIXED)
309
- if data.temperature_c > 32:
310
- load_mw = max(load_mw, 45000 + (data.temperature_c - 32) * 1200)
311
- elif data.temperature_c < 5:
312
- load_mw = max(load_mw, 42000 + (5 - data.temperature_c) * 900) # Fixed parenthesis
313
-
314
- status = "Peak" if load_mw > 58000 else "Normal"
315
- return {"predicted_load_mw": round(float(load_mw), 2), "status": status}
316
-
317
- @app.post("/predict/battery")
318
- def predict_battery(data: BatteryData): # CORRECT PARAMETER NAME
319
- """Feature engineering: Power product (V*I) required per audit"""
320
- # Physics-based SOC fallback
321
- soc = get_ocv_soc(data.voltage)
322
- temp_c = 25.0 # Fallback temperature if model missing
323
-
324
- if "b_model" in ml_assets and "b_stats" in ml_assets:
325
- stats = ml_assets["b_stats"].get('stats', ml_assets["b_stats"])
326
- # AUDIT CONSTRAINT: Power product feature engineering
327
- power_product = data.voltage * data.current
328
- features = np.array([
329
- data.time_sec,
330
- data.current,
331
- data.voltage,
332
- power_product, # Critical engineered feature
333
- data.soc_prev
334
- ])
335
-
336
- x_scaled = (features - stats['feature_mean']) / (stats['feature_std'] + 1e-6)
337
- with torch.no_grad():
338
- preds = ml_assets["b_model"](torch.tensor([x_scaled], dtype=torch.float32)).numpy()[0]
339
- # Only temperature prediction used (index 1 per audit target order)
340
- temp_c = preds[1] * stats['target_std'][1] + stats['target_mean'][1]
341
-
342
- status = "Normal" if temp_c < 45 else "Overheating"
343
- return {
344
- "soc": round(float(soc), 2), "temp_c": round(float(temp_c), 2),
345
- "status": status
346
- }
347
-
348
- @app.post("/predict/frequency")
349
- def predict_frequency(data: FreqData): # CORRECT PARAMETER NAME
350
- """Hybrid physics + AI with MinMaxScaler compliance"""
351
- # Physics calculation (always available)
352
- f_nom = 60.0
353
- H = max(1.0, data.inertia_h)
354
- rocof = -1 * (data.power_imbalance_mw / 1000.0) / (2 * H)
355
- f_phys = f_nom + (rocof * 2.0)
356
-
357
- # AI prediction ONLY if scaler available (audit requires MinMaxScaler)
358
- f_ai = 60.0
359
- if "f_model" in ml_assets and "f_scaler" in ml_assets and ml_assets["f_scaler"] is not None:
360
- try:
361
- # AUDIT CONSTRAINT: Use actual MinMaxScaler transform
362
- x = np.array([[data.load_mw, data.wind_mw, data.load_mw - data.wind_mw, data.power_imbalance_mw]])
363
- x_scaled = ml_assets["f_scaler"].transform(x)
364
- with torch.no_grad():
365
- pred = ml_assets["f_model"](torch.tensor(x_scaled, dtype=torch.float32)).numpy()[0]
366
- f_ai = 60.0 + pred[0] * 0.5
367
- except:
368
- f_ai = 60.0 # Fallback on scaler error
369
-
370
- # Physics-weighted fusion with hard limits
371
- final_freq = max(58.5, min(61.0, (f_ai * 0.3) + (f_phys * 0.7)))
372
- status = "Stable" if final_freq > 59.6 else "Critical"
373
- return {
374
- "frequency_hz": round(float(final_freq), 4),
375
- "status": status
376
- }
377
-
378
- @app.post("/predict/voltage")
379
- def predict_voltage(data: GridData): # CORRECT PARAMETER NAME
380
- """Model usage with fallback heuristic"""
381
- # Use AI model if artifacts available
382
- if "v_model" in ml_assets and "v_stats" in ml_assets:
383
- stats = ml_assets["v_stats"]
384
- # Construct 7 features per audit input_features order
385
- x_raw = np.array([
386
- data.p_load,
387
- data.q_load,
388
- data.wind_gen,
389
- data.solar_gen,
390
- data.hour,
391
- data.p_load - (data.wind_gen + data.solar_gen), # net load
392
- 0.0 # placeholder for 7th feature (audit shows 7 inputs)
393
- ]) # Z-score scaling per audit metadata
394
- x_norm = (x_raw - stats['x_mean']) / (stats['x_std'] + 1e-6)
395
- with torch.no_grad():
396
- pred = ml_assets["v_model"](torch.tensor([x_norm], dtype=torch.float32)).numpy()[0]
397
- # Denormalize per audit y_mean/y_std
398
- v_mag = pred[0] * stats['y_std'][0] + stats['y_mean'][0]
399
- else:
400
- # Fallback heuristic (original code)
401
- net_load = data.p_load - (data.wind_gen + data.solar_gen)
402
- v_mag = 1.00 - (net_load * 0.000005) + random.uniform(-0.0015, 0.0015)
403
-
404
- status = "Stable" if 0.95 < v_mag < 1.05 else "Critical"
 
405
  return {"voltage_pu": round(v_mag, 4), "status": status}
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import joblib
5
+ import random
6
+ import os
7
+ from fastapi import FastAPI
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ from pydantic import BaseModel
10
+ from contextlib import asynccontextmanager
11
+
12
+ # ==========================================
13
+ # 1. CORE COMPONENTS (SYNTAX-VALIDATED)
14
+ # ==========================================
15
+ class Mish(nn.Module):
16
+ def forward(self, x):
17
+ return x * torch.tanh(nn.functional.softplus(x))
18
+
19
+ class FourierFeatureMapping(nn.Module):
20
+ def __init__(self, input_dim, mapping_size, scale=10.0):
21
+ super().__init__()
22
+ self.register_buffer('B', torch.randn(input_dim, mapping_size) * scale)
23
+
24
+ def forward(self, x):
25
+ proj = 2 * np.pi * (x @ self.B)
26
+ return torch.cat([torch.sin(proj), torch.cos(proj)], dim=-1)
27
+
28
+ # ==========================================
29
+ # 2. AUDIT-COMPLIANT ARCHITECTURES (EXACT TENSOR MATCH)
30
+ # ==========================================
31
+ class SolarPINN(nn.Module):
32
+ """Matches audit: backbone.0/2 + output_layer + physics params (shape [])"""
33
+ def __init__(self):
34
+ super().__init__()
35
+ self.backbone = nn.Sequential(
36
+ nn.Linear(4, 128), Mish(),
37
+ nn.Linear(128, 128), Mish()
38
+ )
39
+ self.output_layer = nn.Linear(128, 1)
40
+ # Physics parameters required by state_dict (shape [])
41
+ self.log_thermal_mass = nn.Parameter(torch.tensor(0.0))
42
+ self.log_h_conv = nn.Parameter(torch.tensor(0.0))
43
+
44
+ def forward(self, x):
45
+ return self.output_layer(self.backbone(x))
46
+
47
+ class LoadForecastPINN(nn.Module):
48
+ """Matches audit: res_blocks with LayerNorm weights at .1 (shape [128])"""
49
+ def __init__(self):
50
+ super().__init__()
51
+ self.fourier = FourierFeatureMapping(9, 32)
52
+ self.input_layer = nn.Linear(64, 128)
53
+ self.res_blocks = nn.ModuleList([
54
+ nn.Sequential(
55
+ nn.Linear(128, 128),
56
+ nn.LayerNorm(128), # Critical: Audit shows LayerNorm params
57
+ Mish(),
58
+ nn.Linear(128, 128)
59
+ ) for _ in range(3)
60
+ ])
61
+ self.output_layer = nn.Linear(128, 1)
62
+
63
+ def forward(self, x):
64
+ x = self.input_layer(self.fourier(x))
65
+ for block in self.res_blocks:
66
+ x = x + block(x) # True residual connection per audit
67
+ return self.output_layer(x)
68
+
69
+ class VoltagePINN(nn.Module):
70
+ """Matches audit: network layers + v_bias([1]) + raw_B([])"""
71
+ def __init__(self):
72
+ super().__init__()
73
+ self.fourier = FourierFeatureMapping(7, 32)
74
+ self.network = nn.Sequential(
75
+ nn.Linear(64, 256), nn.LayerNorm(256), Mish(),
76
+ nn.Linear(256, 128), nn.LayerNorm(128), Mish(),
77
+ nn.Linear(128, 64), nn.LayerNorm(64), Mish(),
78
+ nn.Linear(64, 2)
79
+ )
80
+ # Audit-required parameters
81
+ self.v_bias = nn.Parameter(torch.zeros(1)) # Shape [1]
82
+ self.raw_B = nn.Parameter(torch.tensor(0.0)) # Shape []
83
+
84
+ def forward(self, x):
85
+ return self.network(self.fourier(x))
86
+
87
+ class BatteryPINN(nn.Module):
88
+ """Matches audit: network.0/2/4 indexing"""
89
+ def __init__(self):
90
+ super().__init__()
91
+ self.fourier = FourierFeatureMapping(5, 12)
92
+ self.network = nn.Sequential(
93
+ nn.Linear(24, 64), Mish(),
94
+ nn.Linear(64, 64), Mish(),
95
+ nn.Linear(64, 3)
96
+ )
97
+
98
+ def forward(self, x):
99
+ return self.network(self.fourier(x))
100
+ class FrequencyPINN(nn.Module):
101
+ """Matches audit: net.0/2/4/6 (NO LayerNorm - pure Linear+Mish)"""
102
+ def __init__(self):
103
+ super().__init__()
104
+ self.fourier = FourierFeatureMapping(4, 32)
105
+ self.net = nn.Sequential(
106
+ nn.Linear(64, 128), Mish(), # net.0
107
+ nn.Linear(128, 128), Mish(), # net.2
108
+ nn.Linear(128, 128), Mish(), # net.4
109
+ nn.Linear(128, 2) # net.6
110
+ )
111
+
112
+ def forward(self, x):
113
+ return self.net(self.fourier(x))
114
+
115
+ # ==========================================
116
+ # 3. LIFESPAN: ORIGINAL KEYS + SCALER SAFETY
117
+ # ==========================================
118
+ ml_assets = {}
119
+
120
+ @asynccontextmanager
121
+ async def lifespan(app: FastAPI):
122
+ try:
123
+ # SOLAR MODEL (Key: "solar_model" per initial code)
124
+ if os.path.exists("solar_model.pt"):
125
+ ckpt = torch.load("solar_model.pt", map_location='cpu')
126
+ sd = ckpt['model_state_dict'] if isinstance(ckpt, dict) and 'model_state_dict' in ckpt else ckpt
127
+ model = SolarPINN()
128
+ model.load_state_dict(sd, strict=True)
129
+ ml_assets["solar_model"] = model.eval()
130
+ ml_assets["solar_stats"] = {
131
+ "irr_mean": 450.0, "irr_std": 250.0,
132
+ "temp_mean": 25.0, "temp_std": 10.0,
133
+ "prev_mean": 35.0, "prev_std": 15.0
134
+ }
135
+
136
+ # LOAD MODEL (Key: "l_model")
137
+ if os.path.exists("load_model.pt"):
138
+ ckpt = torch.load("load_model.pt", map_location='cpu')
139
+ sd = ckpt['model_state_dict'] if isinstance(ckpt, dict) and 'model_state_dict' in ckpt else ckpt
140
+ model = LoadForecastPINN()
141
+ model.load_state_dict(sd, strict=True)
142
+ ml_assets["l_model"] = model.eval()
143
+ if os.path.exists("Load_stats.joblib"):
144
+ ml_assets["l_stats"] = joblib.load("Load_stats.joblib")
145
+
146
+ # VOLTAGE MODEL (Key: "v_model")
147
+ if os.path.exists("voltage_model_v3.pt"):
148
+ ckpt = torch.load("voltage_model_v3.pt", map_location='cpu')
149
+ sd = ckpt['model_state_dict'] if isinstance(ckpt, dict) and 'model_state_dict' in ckpt else ckpt model = VoltagePINN()
150
+ model.load_state_dict(sd, strict=True)
151
+ ml_assets["v_model"] = model.eval()
152
+ if os.path.exists("scaling_stats_v3.joblib"):
153
+ ml_assets["v_stats"] = joblib.load("scaling_stats_v3.joblib")
154
+
155
+ # BATTERY MODEL (Key: "b_model")
156
+ if os.path.exists("battery_model.pt"):
157
+ ckpt = torch.load("battery_model.pt", map_location='cpu')
158
+ sd = ckpt['model_state_dict'] if isinstance(ckpt, dict) and 'model_state_dict' in ckpt else ckpt
159
+ model = BatteryPINN()
160
+ model.load_state_dict(sd, strict=True)
161
+ ml_assets["b_model"] = model.eval()
162
+ if os.path.exists("battery_model.joblib"):
163
+ ml_assets["b_stats"] = joblib.load("battery_model.joblib")
164
+
165
+ # FREQUENCY MODEL (Key: "f_model" + SCALER SAFETY)
166
+ if os.path.exists("DECODE_Frequency_Twin.pth"):
167
+ ckpt = torch.load("DECODE_Frequency_Twin.pth", map_location='cpu')
168
+ sd = ckpt['model_state_dict'] if isinstance(ckpt, dict) and 'model_state_dict' in ckpt else ckpt
169
+ model = FrequencyPINN()
170
+ model.load_state_dict(sd, strict=True)
171
+ ml_assets["f_model"] = model.eval()
172
+ # CRITICAL: Load actual MinMaxScaler per audit metadata
173
+ if os.path.exists("decode_scaler.joblib"):
174
+ try:
175
+ ml_assets["f_scaler"] = joblib.load("decode_scaler.joblib")
176
+ except:
177
+ ml_assets["f_scaler"] = None
178
+ else:
179
+ ml_assets["f_scaler"] = None
180
+
181
+ yield
182
+ finally:
183
+ ml_assets.clear()
184
+
185
+ # ==========================================
186
+ # 4. FASTAPI SETUP
187
+ # ==========================================
188
+ app = FastAPI(title="D.E.C.O.D.E. Unified Digital Twin", lifespan=lifespan)
189
+ app.add_middleware(
190
+ CORSMiddleware,
191
+ allow_origins=["*"],
192
+ allow_methods=["*"],
193
+ allow_headers=["*"],
194
+ )
195
+
196
+ # ==========================================
197
+ # 5. PHYSICS & SCHEMAS (SYNTAX-CORRECTED)
198
+ # ==========================================def get_ocv_soc(voltage: float) -> float:
199
+ """Physics-based SOC estimation from OCV"""
200
+ return np.interp(voltage, [2.8, 3.4, 3.7, 4.2], [0, 15, 65, 100])
201
+
202
+ class SolarData(BaseModel):
203
+ irradiance_stream: list[float]
204
+ ambient_temp_stream: list[float]
205
+ wind_speed_stream: list[float]
206
+
207
+ class LoadData(BaseModel): # FIXED: Each field on separate line
208
+ temperature_c: float
209
+ hour: int # Critical newline separation
210
+ month: int # Critical newline separation
211
+ wind_mw: float = 0.0
212
+ solar_mw: float = 0.0
213
+
214
+ class BatteryData(BaseModel):
215
+ time_sec: float
216
+ current: float
217
+ voltage: float
218
+ temperature: float
219
+ soc_prev: float
220
+
221
+ class FreqData(BaseModel):
222
+ load_mw: float
223
+ wind_mw: float
224
+ inertia_h: float
225
+ power_imbalance_mw: float
226
+
227
+ class GridData(BaseModel):
228
+ p_load: float
229
+ q_load: float
230
+ wind_gen: float
231
+ solar_gen: float
232
+ hour: int
233
+
234
+ # ==========================================
235
+ # 6. ENDPOINTS: FALLBACKS + PHYSICS COMPLIANCE
236
+ # ==========================================
237
+ @app.get("/")
238
+ def home():
239
+ return {
240
+ "status": "Online",
241
+ "modules": ["Voltage", "Battery", "Frequency", "Load", "Solar"],
242
+ "audit_compliant": True,
243
+ "strict_loading": True
244
+ }
245
+
246
+ @app.post("/predict/solar")
247
+ def predict_solar(data: SolarData): # CORRECT PARAMETER NAME """Sequential state simulation @ dt=900s with thermal clamping"""
248
+ simulation = []
249
+ # Fallback: Return empty simulation if model missing (per initial code)
250
+ if "solar_model" in ml_assets and "solar_stats" in ml_assets:
251
+ stats = ml_assets["solar_stats"]
252
+ # PHYSICS CONSTRAINT: Initial state = ambient + 5.0°C (audit training protocol)
253
+ curr_temp = data.ambient_temp_stream[0] + 5.0
254
+
255
+ with torch.no_grad():
256
+ for i in range(len(data.irradiance_stream)):
257
+ # AUDIT CONSTRAINT: Wind scaled by 10.0 per training protocol
258
+ x = torch.tensor([[
259
+ (data.irradiance_stream[i] - stats["irr_mean"]) / stats["irr_std"],
260
+ (data.ambient_temp_stream[i] - stats["temp_mean"]) / stats["temp_std"],
261
+ data.wind_speed_stream[i] / 10.0, # Critical scaling per audit
262
+ (curr_temp - stats["prev_mean"]) / stats["prev_std"]
263
+ ]], dtype=torch.float32)
264
+
265
+ # PHYSICAL CLAMPING: Prevent thermal runaway (10°C-75°C)
266
+ next_temp = ml_assets["solar_model"](x).item()
267
+ next_temp = max(10.0, min(75.0, next_temp))
268
+
269
+ # Temperature-dependent efficiency
270
+ eff = 0.20 * (1 - 0.004 * (next_temp - 25.0))
271
+ power_mw = (5000 * data.irradiance_stream[i] * max(0, eff)) / 1e6
272
+
273
+ simulation.append({
274
+ "module_temp_c": round(next_temp, 2),
275
+ "power_mw": round(power_mw, 4)
276
+ })
277
+ curr_temp = next_temp # SEQUENTIAL STATE FEEDBACK (dt=900s)
278
+ return {"simulation": simulation}
279
+
280
+ @app.post("/predict/load")
281
+ def predict_load(data: LoadData): # CORRECT PARAMETER NAME
282
+ """Z-score clamped prediction to prevent Inverted Load Paradox"""
283
+ stats = ml_assets.get("l_stats", {})
284
+ # PHYSICS CONSTRAINT: Hard Z-score clamping at ±3 (Fourier stability)
285
+ t_norm = (data.temperature_c - stats.get('temp_mean', 15.38)) / (stats.get('temp_std', 4.12) + 1e-6)
286
+ t_norm = max(-3.0, min(3.0, t_norm))
287
+
288
+ # Construct features per audit metadata order
289
+ x = torch.tensor([[
290
+ t_norm,
291
+ max(0, data.temperature_c - 18) / 10,
292
+ max(0, 18 - data.temperature_c) / 10,
293
+ np.sin(2 * np.pi * data.hour / 24),
294
+ np.cos(2 * np.pi * data.hour / 24),
295
+ np.sin(2 * np.pi * data.month / 12),
296
+ np.cos(2 * np.pi * data.month / 12), data.wind_mw / 10000,
297
+ data.solar_mw / 10000
298
+ ]], dtype=torch.float32)
299
+
300
+ # Fallback base load if model/stats missing
301
+ base_load = stats.get('load_mean', 35000.0)
302
+ if "l_model" in ml_assets:
303
+ with torch.no_grad():
304
+ pred = ml_assets["l_model"](x).item()
305
+ load_mw = pred * stats.get('load_std', 9773.80) + base_load
306
+ else:
307
+ load_mw = base_load
308
+
309
+ # PHYSICAL SAFETY CORRECTION (SYNTAX FIXED)
310
+ if data.temperature_c > 32:
311
+ load_mw = max(load_mw, 45000 + (data.temperature_c - 32) * 1200)
312
+ elif data.temperature_c < 5:
313
+ load_mw = max(load_mw, 42000 + (5 - data.temperature_c) * 900) # Fixed parenthesis
314
+
315
+ status = "Peak" if load_mw > 58000 else "Normal"
316
+ return {"predicted_load_mw": round(float(load_mw), 2), "status": status}
317
+
318
+ @app.post("/predict/battery")
319
+ def predict_battery(data: BatteryData): # CORRECT PARAMETER NAME
320
+ """Feature engineering: Power product (V*I) required per audit"""
321
+ # Physics-based SOC fallback
322
+ soc = get_ocv_soc(data.voltage)
323
+ temp_c = 25.0 # Fallback temperature if model missing
324
+
325
+ if "b_model" in ml_assets and "b_stats" in ml_assets:
326
+ stats = ml_assets["b_stats"].get('stats', ml_assets["b_stats"])
327
+ # AUDIT CONSTRAINT: Power product feature engineering
328
+ power_product = data.voltage * data.current
329
+ features = np.array([
330
+ data.time_sec,
331
+ data.current,
332
+ data.voltage,
333
+ power_product, # Critical engineered feature
334
+ data.soc_prev
335
+ ])
336
+
337
+ x_scaled = (features - stats['feature_mean']) / (stats['feature_std'] + 1e-6)
338
+ with torch.no_grad():
339
+ preds = ml_assets["b_model"](torch.tensor([x_scaled], dtype=torch.float32)).numpy()[0]
340
+ # Only temperature prediction used (index 1 per audit target order)
341
+ temp_c = preds[1] * stats['target_std'][1] + stats['target_mean'][1]
342
+
343
+ status = "Normal" if temp_c < 45 else "Overheating"
344
+ return {
345
+ "soc": round(float(soc), 2), "temp_c": round(float(temp_c), 2),
346
+ "status": status
347
+ }
348
+
349
+ @app.post("/predict/frequency")
350
+ def predict_frequency(data: FreqData): # CORRECT PARAMETER NAME
351
+ """Hybrid physics + AI with MinMaxScaler compliance"""
352
+ # Physics calculation (always available)
353
+ f_nom = 60.0
354
+ H = max(1.0, data.inertia_h)
355
+ rocof = -1 * (data.power_imbalance_mw / 1000.0) / (2 * H)
356
+ f_phys = f_nom + (rocof * 2.0)
357
+
358
+ # AI prediction ONLY if scaler available (audit requires MinMaxScaler)
359
+ f_ai = 60.0
360
+ if "f_model" in ml_assets and "f_scaler" in ml_assets and ml_assets["f_scaler"] is not None:
361
+ try:
362
+ # AUDIT CONSTRAINT: Use actual MinMaxScaler transform
363
+ x = np.array([[data.load_mw, data.wind_mw, data.load_mw - data.wind_mw, data.power_imbalance_mw]])
364
+ x_scaled = ml_assets["f_scaler"].transform(x)
365
+ with torch.no_grad():
366
+ pred = ml_assets["f_model"](torch.tensor(x_scaled, dtype=torch.float32)).numpy()[0]
367
+ f_ai = 60.0 + pred[0] * 0.5
368
+ except:
369
+ f_ai = 60.0 # Fallback on scaler error
370
+
371
+ # Physics-weighted fusion with hard limits
372
+ final_freq = max(58.5, min(61.0, (f_ai * 0.3) + (f_phys * 0.7)))
373
+ status = "Stable" if final_freq > 59.6 else "Critical"
374
+ return {
375
+ "frequency_hz": round(float(final_freq), 4),
376
+ "status": status
377
+ }
378
+
379
+ @app.post("/predict/voltage")
380
+ def predict_voltage(data: GridData): # CORRECT PARAMETER NAME
381
+ """Model usage with fallback heuristic"""
382
+ # Use AI model if artifacts available
383
+ if "v_model" in ml_assets and "v_stats" in ml_assets:
384
+ stats = ml_assets["v_stats"]
385
+ # Construct 7 features per audit input_features order
386
+ x_raw = np.array([
387
+ data.p_load,
388
+ data.q_load,
389
+ data.wind_gen,
390
+ data.solar_gen,
391
+ data.hour,
392
+ data.p_load - (data.wind_gen + data.solar_gen), # net load
393
+ 0.0 # placeholder for 7th feature (audit shows 7 inputs)
394
+ ]) # Z-score scaling per audit metadata
395
+ x_norm = (x_raw - stats['x_mean']) / (stats['x_std'] + 1e-6)
396
+ with torch.no_grad():
397
+ pred = ml_assets["v_model"](torch.tensor([x_norm], dtype=torch.float32)).numpy()[0]
398
+ # Denormalize per audit y_mean/y_std
399
+ v_mag = pred[0] * stats['y_std'][0] + stats['y_mean'][0]
400
+ else:
401
+ # Fallback heuristic (original code)
402
+ net_load = data.p_load - (data.wind_gen + data.solar_gen)
403
+ v_mag = 1.00 - (net_load * 0.000005) + random.uniform(-0.0015, 0.0015)
404
+
405
+ status = "Stable" if 0.95 < v_mag < 1.05 else "Critical"
406
  return {"voltage_pu": round(v_mag, 4), "status": status}