| """ |
| 🌪️ STORM ORACLE — Tornado Super-Predictor (training-ready, no placeholders) |
| |
| - RadarPatternExtractor: multi-scale CNN + spatial attention pooling |
| - AtmosphericConditionEncoder: per-variable MLPs -> tokens -> attention -> fused vector |
| - Heads: probability (sigmoid), EF (logits), location (reg), timing (reg), uncertainty (sigmoid) |
| - Calibration: single temperature parameter (learnable/fittable after training) |
| - ContinuousLearner: online fine-tuning with replay buffer and EMA weights |
| """ |
|
|
| from dataclasses import dataclass |
| from typing import Dict, List, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| |
|
|
| @dataclass |
| class TornadoPredictionBatch: |
| """All outputs are BATCH TENSORS (no Python scalars).""" |
| tornado_probability: torch.Tensor |
| ef_scale_probs: torch.Tensor |
| most_likely_ef_scale: torch.Tensor |
| location_offset: torch.Tensor |
| timing_predictions: torch.Tensor |
| uncertainty_scores: torch.Tensor |
| radar_signatures: torch.Tensor |
| atmospheric_indicators: torch.Tensor |
| logits: Optional[torch.Tensor] = None |
|
|
|
|
| |
|
|
| class SpatialAttentionPool(nn.Module): |
| """ |
| Turns a 2D feature map (B,C,H,W) into (B,C) using a learned query and MHA over H*W tokens. |
| """ |
| def __init__(self, channels: int, num_heads: int = 8): |
| super().__init__() |
| self.channels = channels |
| self.pos_embed = nn.Parameter(torch.randn(1, channels, 1)) |
| self.query = nn.Parameter(torch.randn(1, 1, channels)) |
| self.attn = nn.MultiheadAttention(embed_dim=channels, num_heads=num_heads, batch_first=True) |
| self.ln = nn.LayerNorm(channels) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| B, C, H, W = x.shape |
| tokens = x.view(B, C, H * W).transpose(1, 2) |
| tokens = self.ln(tokens + self.pos_embed.expand(B, C, 1).transpose(1, 2)) |
| q = self.query.expand(B, -1, -1) |
| pooled, _ = self.attn(q, tokens, tokens) |
| return pooled.squeeze(1) |
|
|
|
|
| class RadarPatternExtractor(nn.Module): |
| """ |
| Advanced radar pattern extraction with spatial attention pooling. |
| Accepts variable input_channels (e.g., 3×T for T time steps). |
| """ |
| def __init__(self, input_channels: int = 3): |
| super().__init__() |
| self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, padding=3) |
| self.conv2 = nn.Conv2d(64, 128, kernel_size=5, padding=2) |
| self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1) |
| self.conv4 = nn.Conv2d(256, 512, kernel_size=3, padding=1) |
|
|
| self.bn4 = nn.BatchNorm2d(512) |
|
|
| |
| self.hook_echo_detector = nn.Conv2d(512, 64, kernel_size=3, padding=1) |
| self.mesocyclone_detector = nn.Conv2d(512, 64, kernel_size=5, padding=2) |
| self.velocity_couplet_detector = nn.Conv2d(512, 64, kernel_size=3, padding=1) |
|
|
| |
| self.pool = SpatialAttentionPool(512, num_heads=8) |
|
|
| |
| self.proj = nn.Sequential( |
| nn.Linear(512 + 64 * 3, 1024), |
| nn.ReLU(), |
| nn.Dropout(0.5), |
| ) |
|
|
| def forward(self, radar_data: torch.Tensor) -> Dict[str, torch.Tensor]: |
| |
| x = F.relu(self.conv1(radar_data)); x = F.max_pool2d(x, 2) |
| x = F.relu(self.conv2(x)); x = F.max_pool2d(x, 2) |
| x = F.relu(self.conv3(x)); x = F.max_pool2d(x, 2) |
| x = F.relu(self.conv4(x)); x = self.bn4(x) |
|
|
| hook = F.relu(self.hook_echo_detector(x)) |
| meso = F.relu(self.mesocyclone_detector(x)) |
| vel = F.relu(self.velocity_couplet_detector(x)) |
|
|
| base_vec = self.pool(x) |
| hook_vec = hook.mean(dim=(2, 3)) |
| meso_vec = meso.mean(dim=(2, 3)) |
| vel_vec = vel.mean(dim=(2, 3)) |
|
|
| fused = torch.cat([base_vec, hook_vec, meso_vec, vel_vec], dim=1) |
| combined = self.proj(fused) |
|
|
| strengths = torch.stack([ |
| hook_vec.mean(dim=1), |
| meso_vec.mean(dim=1), |
| vel_vec.mean(dim=1), |
| ], dim=1) |
|
|
| return { |
| "combined_features": combined, |
| "signature_strengths": strengths, |
| } |
|
|
|
|
| class AtmosphericConditionEncoder(nn.Module): |
| """ |
| Encode environmental parameters using per-variable MLPs, then treat them as tokens and apply MHA. |
| """ |
| def __init__(self): |
| super().__init__() |
| self.enc_cape = nn.Linear(1, 32) |
| self.enc_shear = nn.Linear(4, 64) |
| self.enc_helicity = nn.Linear(2, 32) |
| self.enc_temp = nn.Linear(3, 32) |
| self.enc_dewpoint = nn.Linear(2, 32) |
| self.enc_pressure = nn.Linear(1, 16) |
|
|
| |
| self.to_64 = nn.ModuleDict({ |
| "cape": nn.Linear(32, 64), |
| "shear": nn.Identity(), |
| "helicity": nn.Linear(32, 64), |
| "temp": nn.Linear(32, 64), |
| "dewpoint": nn.Linear(32, 64), |
| "pressure": nn.Linear(16, 64), |
| }) |
| self.ln = nn.LayerNorm(64) |
| self.attn = nn.MultiheadAttention(embed_dim=64, num_heads=4, batch_first=True) |
|
|
| self.fuse = nn.Sequential( |
| nn.Linear(64 * 6, 256), |
| nn.ReLU(), |
| nn.Dropout(0.3), |
| ) |
|
|
| def forward(self, atmo: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
| def ensure_2d(t: torch.Tensor, d: int) -> torch.Tensor: |
| |
| t = t if t.ndim == 2 else t.view(-1, d) |
| return t |
|
|
| cape = ensure_2d(atmo.get("cape", torch.zeros(1, 1, device=next(self.parameters()).device)), 1) |
| shear= ensure_2d(atmo.get("wind_shear", torch.zeros(1, 4, device=next(self.parameters()).device)), 4) |
| hel = ensure_2d(atmo.get("helicity", torch.zeros(1, 2, device=next(self.parameters()).device)), 2) |
| temp = ensure_2d(atmo.get("temperature", torch.zeros(1, 3, device=next(self.parameters()).device)), 3) |
| dew = ensure_2d(atmo.get("dewpoint", torch.zeros(1, 2, device=next(self.parameters()).device)), 2) |
| pres = ensure_2d(atmo.get("pressure", torch.zeros(1, 1, device=next(self.parameters()).device)), 1) |
|
|
| cape_e = F.relu(self.enc_cape(cape)) |
| shear_e= F.relu(self.enc_shear(shear)) |
| hel_e = F.relu(self.enc_helicity(hel)) |
| temp_e = F.relu(self.enc_temp(temp)) |
| dew_e = F.relu(self.enc_dewpoint(dew)) |
| pres_e = F.relu(self.enc_pressure(pres)) |
|
|
| tokens = torch.stack([ |
| self.ln(self.to_64["cape"](cape_e)), |
| self.ln(self.to_64["shear"](shear_e)), |
| self.ln(self.to_64["helicity"](hel_e)), |
| self.ln(self.to_64["temp"](temp_e)), |
| self.ln(self.to_64["dewpoint"](dew_e)), |
| self.ln(self.to_64["pressure"](pres_e)), |
| ], dim=1) |
|
|
| attn_out, _ = self.attn(tokens, tokens, tokens) |
| fused = self.fuse(attn_out.reshape(attn_out.size(0), -1)) |
|
|
| |
| shear_mag = torch.linalg.vector_norm(shear, dim=-1) |
| instab = cape.squeeze(-1) * shear_mag |
|
|
| return { |
| "atmospheric_features": fused, |
| "cape_score": cape.squeeze(-1), |
| "shear_magnitude": shear_mag, |
| "instability_index": instab, |
| } |
|
|
|
|
| |
|
|
| class TornadoSuperPredictor(nn.Module): |
| def __init__(self, in_channels: int = 3): |
| super().__init__() |
| self.radar_extractor = RadarPatternExtractor(input_channels=in_channels) |
| self.atmo_encoder = AtmosphericConditionEncoder() |
|
|
| fused_dim = 1024 + 256 |
|
|
| self.prob_head = nn.Sequential( |
| nn.Linear(fused_dim, 512), nn.ReLU(), nn.Dropout(0.4), |
| nn.Linear(512, 256), nn.ReLU(), |
| nn.Linear(256, 1) |
| ) |
| self.ef_head = nn.Sequential( |
| nn.Linear(fused_dim, 512), nn.ReLU(), nn.Dropout(0.4), |
| nn.Linear(512, 6) |
| ) |
| self.loc_head = nn.Sequential( |
| nn.Linear(fused_dim, 512), nn.ReLU(), nn.Dropout(0.4), |
| nn.Linear(512, 2) |
| ) |
| self.time_head = nn.Sequential( |
| nn.Linear(fused_dim, 512), nn.ReLU(), nn.Dropout(0.4), |
| nn.Linear(512, 3) |
| ) |
| self.unc_head = nn.Sequential( |
| nn.Linear(fused_dim, 256), nn.ReLU(), |
| nn.Linear(256, 4) |
| ) |
|
|
| |
| self.register_parameter("log_temperature", nn.Parameter(torch.zeros(()))) |
|
|
| self._init_weights() |
|
|
| def _init_weights(self): |
| for m in self.modules(): |
| if isinstance(m, (nn.Linear, nn.Conv2d)): |
| if isinstance(m, nn.Linear): |
| nn.init.xavier_uniform_(m.weight) |
| else: |
| nn.init.kaiming_uniform_(m.weight, mode="fan_out", nonlinearity="relu") |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
|
|
| @property |
| def temperature(self) -> torch.Tensor: |
| return torch.exp(self.log_temperature) |
|
|
| def forward(self, radar_x: torch.Tensor, atmo: Dict[str, torch.Tensor]) -> TornadoPredictionBatch: |
| |
| r = self.radar_extractor(radar_x) |
| a = self.atmo_encoder(atmo) |
|
|
| fused = torch.cat([r["combined_features"], a["atmospheric_features"]], dim=1) |
|
|
| logits = self.prob_head(fused).squeeze(-1) |
| logits = logits / self.temperature.clamp_min(1e-6) |
| probs = torch.sigmoid(logits) |
|
|
| ef_logits = self.ef_head(fused) |
| ef_probs = F.softmax(ef_logits, dim=-1) |
| ef_idx = ef_probs.argmax(dim=-1) |
|
|
| loc = self.loc_head(fused) |
| tim = self.time_head(fused) |
| unc = torch.sigmoid(self.unc_head(fused)) |
|
|
| return TornadoPredictionBatch( |
| tornado_probability=probs, |
| ef_scale_probs=ef_probs, |
| most_likely_ef_scale=ef_idx, |
| location_offset=loc, |
| timing_predictions=tim, |
| uncertainty_scores=unc, |
| radar_signatures=r["signature_strengths"], |
| atmospheric_indicators=torch.stack([ |
| a["cape_score"], a["shear_magnitude"], a["instability_index"] |
| ], dim=1), |
| logits=logits, |
| ) |
|
|
|
|
| |
|
|
| class ContinuousLearner(nn.Module): |
| """ |
| Light wrapper that adds: |
| - optimizer + (optional) pos_weight or focal loss |
| - EMA weights for stable inference during online updates |
| - small replay buffer to avoid catastrophic forgetting |
| """ |
| def __init__( |
| self, |
| model: TornadoSuperPredictor, |
| lr: float = 1e-4, |
| wd: float = 1e-4, |
| use_focal: bool = False, |
| pos_weight: Optional[float] = None, |
| ema_decay: float = 0.999, |
| replay_capacity: int = 2048, |
| device: Optional[torch.device] = None, |
| ): |
| super().__init__() |
| self.model = model |
| self.device = device or next(model.parameters()).device |
| self.opt = torch.optim.AdamW(self.model.parameters(), lr=lr, weight_decay=wd) |
| self.use_focal = use_focal |
| self.pos_weight = None if pos_weight is None else torch.tensor(pos_weight, device=self.device) |
| self.ema_decay = ema_decay |
|
|
| |
| self.shadow = {k: v.detach().clone() for k, v in self.model.state_dict().items()} |
| self.replay_capacity = replay_capacity |
| self._replay = [] |
|
|
| def _bce_loss(self, logits: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
| if self.pos_weight is not None: |
| return F.binary_cross_entropy_with_logits(logits, y.float(), pos_weight=self.pos_weight) |
| return F.binary_cross_entropy_with_logits(logits, y.float()) |
|
|
| def _focal_loss(self, logits: torch.Tensor, y: torch.Tensor, gamma: float = 2.0, alpha: float = 0.5) -> torch.Tensor: |
| p = torch.sigmoid(logits) |
| pt = p * y + (1 - p) * (1 - y) |
| w = (1 - pt).pow(gamma) |
| at = alpha * y + (1 - alpha) * (1 - y) |
| loss = -(y * torch.log(p.clamp_min(1e-9)) + (1 - y) * torch.log((1 - p).clamp_min(1e-9))) * w * at |
| return loss.mean() |
|
|
| @torch.no_grad() |
| def _update_ema(self): |
| for k, v in self.model.state_dict().items(): |
| self.shadow[k].mul_(self.ema_decay).add_(v, alpha=(1.0 - self.ema_decay)) |
|
|
| def train_step(self, radar_x: torch.Tensor, atmo: Dict[str, torch.Tensor], y: torch.Tensor) -> Dict[str, float]: |
| self.model.train() |
| out = self.model(radar_x, atmo) |
|
|
| if self.use_focal: |
| loss = self._focal_loss(out.logits, y) |
| else: |
| loss = self._bce_loss(out.logits, y) |
|
|
| self.opt.zero_grad(set_to_none=True) |
| loss.backward() |
| nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) |
| self.opt.step() |
| self._update_ema() |
|
|
| |
| if self.replay_capacity > 0: |
| with torch.no_grad(): |
| if len(self._replay) >= self.replay_capacity: |
| self._replay.pop(0) |
| |
| self._replay.append(( |
| radar_x.detach().cpu(), |
| {k: v.detach().cpu() for k, v in atmo.items()}, |
| y.detach().cpu() |
| )) |
|
|
| with torch.no_grad(): |
| prob = out.tornado_probability.mean().item() |
| return {"loss": float(loss.item()), "avg_prob": prob} |
|
|
| @torch.no_grad() |
| def ema_state_dict(self) -> Dict[str, torch.Tensor]: |
| return {k: v.clone() for k, v in self.shadow.items()} |
|
|
| @torch.no_grad() |
| def load_ema_weights(self): |
| self.model.load_state_dict(self.ema_state_dict()) |
|
|
| def replay_step(self, batch_size: int = 16) -> Optional[Dict[str, float]]: |
| if not self._replay: |
| return None |
| import random |
| idxs = random.sample(range(len(self._replay)), k=min(batch_size, len(self._replay))) |
| xs = torch.cat([self._replay[i][0] for i in idxs], dim=0).to(self.device) |
| ys = torch.cat([self._replay[i][2] for i in idxs], dim=0).to(self.device) |
| atmo = {} |
| |
| keys = list(self._replay[idxs[0]][1].keys()) |
| for k in keys: |
| atmo[k] = torch.cat([self._replay[i][1][k] for i in idxs], dim=0).to(self.device) |
| return self.train_step(xs, atmo, ys) |
|
|