| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| from typing import Any, Dict, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| try: |
| from transformers import PreTrainedModel |
| from transformers.modeling_outputs import ModelOutput |
| except Exception: |
| class PreTrainedModel(nn.Module): |
| config_class = None |
| base_model_prefix = "" |
| main_input_name = "input_ids" |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
|
|
| class ModelOutput(dict): |
| pass |
|
|
| from .configuration import WeatherModelConfig |
|
|
| CONTINUOUS_TARGET_ORDER = [ |
| "temp", |
| "humidity", |
| "apparent", |
| "precip", |
| "sea_level_pressure", |
| "surface_pressure", |
| "cloud_cover", |
| "wind", |
| "wind_dir_sin", |
| "wind_dir_cos", |
| ] |
|
|
| CONTINUOUS_TARGET_SPECS = { |
| "temp": {"loss_weight": 1.0, "transform": "raw"}, |
| "humidity": {"loss_weight": 1.0, "transform": "raw"}, |
| "apparent": {"loss_weight": 0.8, "transform": "raw"}, |
| "precip": {"loss_weight": 0.9, "transform": "log1p"}, |
| "sea_level_pressure": {"loss_weight": 0.6, "transform": "raw"}, |
| "surface_pressure": {"loss_weight": 0.4, "transform": "raw"}, |
| "cloud_cover": {"loss_weight": 0.4, "transform": "raw"}, |
| "wind": {"loss_weight": 0.6, "transform": "raw"}, |
| "wind_dir_sin": {"loss_weight": 0.55, "transform": "raw"}, |
| "wind_dir_cos": {"loss_weight": 0.55, "transform": "raw"}, |
| } |
|
|
|
|
| @dataclass |
| class WeatherModelOutput(ModelOutput): |
| loss: Optional[torch.Tensor] = None |
| logits: Optional[Tuple[torch.Tensor, ...]] = None |
| head_repr: Optional[torch.Tensor] = None |
| norm_preds: Optional[Dict[str, torch.Tensor]] = None |
| raw_preds: Optional[Dict[str, torch.Tensor]] = None |
| distill_head_repr: Optional[torch.Tensor] = None |
|
|
|
|
| class WeatherForcastModel(PreTrainedModel): |
|
|
| config_class = WeatherModelConfig |
| base_model_prefix = "weather_sequence" |
| main_input_name = "X" |
|
|
| |
| _tied_weights_keys: list[str] = [] |
|
|
| def __init__(self, config: WeatherModelConfig): |
| super().__init__(config) |
|
|
| self.encoder_type = str(getattr(config, "encoder_type", "lstm")).lower() |
| self.hidden_dim = int(config.hidden_dim) |
| self.seq_len = int(config.seq_len) |
| self.num_predict = int(config.num_predict) |
| self.num_weather_classes = int(config.num_weather_classes) |
|
|
| if config.input_dim is None: |
| raise ValueError("WeatherModelConfig.input_dim must be set") |
|
|
| self.location_embedding = nn.Embedding(max(1, int(config.num_locations)), int(config.location_emb_dim)) |
|
|
| if config.weather_class_weights is not None: |
| self.register_buffer( |
| "weather_class_weights", |
| torch.tensor(config.weather_class_weights, dtype=torch.float32), |
| persistent=False, |
| ) |
| else: |
| self.weather_class_weights = None |
|
|
| self.register_buffer( |
| "rain_pos_weight", |
| torch.tensor(float(config.rain_pos_weight), dtype=torch.float32), |
| persistent=False, |
| ) |
|
|
| self.target_norm_meta: Dict[str, Dict[str, Any]] = {} |
| for name in CONTINUOUS_TARGET_ORDER: |
| spec = dict(config.target_norms.get(name, {})) |
| mean = float(spec.get("mean", 0.0)) |
| std = max(float(spec.get("std", 1.0)), 1e-6) |
| transform = str(spec.get("transform", CONTINUOUS_TARGET_SPECS[name]["transform"])) |
| self.register_buffer(f"{name}_mean", torch.tensor(mean, dtype=torch.float32), persistent=False) |
| self.register_buffer(f"{name}_std", torch.tensor(std, dtype=torch.float32), persistent=False) |
| self.target_norm_meta[name] = {"transform": transform} |
|
|
| if self.encoder_type == "lstm": |
| self.encoder = nn.LSTM( |
| input_size=int(config.input_dim), |
| hidden_size=self.hidden_dim, |
| num_layers=int(config.num_layers), |
| batch_first=True, |
| dropout=float(config.dropout) if int(config.num_layers) > 1 else 0.0, |
| bidirectional=False, |
| ) |
| elif self.encoder_type == "transformer": |
| self.input_proj = nn.Linear(int(config.input_dim), self.hidden_dim) |
| self.pos_encoding = nn.Parameter(torch.randn(1, int(config.seq_len), self.hidden_dim) * 0.1) |
| encoder_layer = nn.TransformerEncoderLayer( |
| d_model=self.hidden_dim, |
| nhead=4, |
| dropout=float(config.dropout), |
| batch_first=True, |
| ) |
| self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=int(config.num_layers)) |
| else: |
| raise ValueError(f"Unknown encoder_type: {self.encoder_type}") |
|
|
| self.head_dim = self.hidden_dim + int(config.location_emb_dim) |
| self.head_norm = nn.LayerNorm(self.head_dim) |
| self.head_dropout = nn.Dropout(float(config.dropout)) |
|
|
| self.reg_heads = nn.ModuleDict({name: nn.Linear(self.head_dim, self.num_predict) for name in CONTINUOUS_TARGET_ORDER}) |
| self.fc_rain = nn.Linear(self.head_dim, self.num_predict) |
| self.fc_weather = nn.Linear(self.head_dim, self.num_predict * self.num_weather_classes) |
|
|
| teacher_head_dim = int(getattr(config, "distill_teacher_head_dim", 0)) |
| if teacher_head_dim > 0 and teacher_head_dim != self.head_dim: |
| self.distill_proj = nn.Linear(self.head_dim, teacher_head_dim, bias=False) |
| else: |
| self.distill_proj = None |
|
|
| self.post_init() |
|
|
| @staticmethod |
| def _masked_mean(x: torch.Tensor) -> torch.Tensor: |
| mask = (x.abs().sum(dim=-1) > 0).float().unsqueeze(-1) |
| summed = (x * mask).sum(dim=1) |
| denom = mask.sum(dim=1).clamp(min=1.0) |
| return summed / denom |
|
|
| def _target_mean_std(self, name: str) -> Tuple[torch.Tensor, torch.Tensor]: |
| return getattr(self, f"{name}_mean"), getattr(self, f"{name}_std") |
|
|
| def _encode_target(self, name: str, target: torch.Tensor) -> torch.Tensor: |
| transform = self.target_norm_meta[name]["transform"] |
| target = target.to(dtype=torch.float32) |
| if transform == "log1p": |
| target = torch.log1p(torch.clamp(target, min=0.0)) |
| mean, std = self._target_mean_std(name) |
| return (target - mean.to(target.device)) / std.to(target.device) |
|
|
| def _decode_prediction(self, name: str, pred_norm: torch.Tensor) -> torch.Tensor: |
| transform = self.target_norm_meta[name]["transform"] |
| mean, std = self._target_mean_std(name) |
| raw = pred_norm * std.to(pred_norm.device) + mean.to(pred_norm.device) |
| if transform == "log1p": |
| raw = torch.expm1(raw).clamp(min=0.0) |
| return raw |
|
|
| def forward( |
| self, |
| X: torch.Tensor, |
| location_id: Optional[torch.Tensor] = None, |
| temp_target: Optional[torch.Tensor] = None, |
| humidity_target: Optional[torch.Tensor] = None, |
| apparent_target: Optional[torch.Tensor] = None, |
| precip_target: Optional[torch.Tensor] = None, |
| sea_level_pressure_target: Optional[torch.Tensor] = None, |
| surface_pressure_target: Optional[torch.Tensor] = None, |
| cloud_cover_target: Optional[torch.Tensor] = None, |
| wind_target: Optional[torch.Tensor] = None, |
| wind_dir_sin_target: Optional[torch.Tensor] = None, |
| wind_dir_cos_target: Optional[torch.Tensor] = None, |
| rain_target: Optional[torch.Tensor] = None, |
| weather_target: Optional[torch.Tensor] = None, |
| return_repr: bool = False, |
| **kwargs: Any, |
| ) -> WeatherModelOutput: |
| if location_id is None: |
| location_id = torch.zeros(X.size(0), dtype=torch.long, device=X.device) |
|
|
| if self.encoder_type == "lstm": |
| _, (h, _) = self.encoder(X) |
| seq_repr = h[-1] |
| else: |
| z = self.input_proj(X) + self.pos_encoding[:, : X.size(1), :] |
| out = self.encoder(z) |
| seq_repr = self._masked_mean(out) |
|
|
| loc_emb = self.location_embedding(location_id) |
| head_repr = self.head_norm(torch.cat([seq_repr, loc_emb], dim=1)) |
| h = self.head_dropout(head_repr) |
|
|
| raw_preds: Dict[str, torch.Tensor] = {} |
| norm_preds: Dict[str, torch.Tensor] = {} |
| for name in CONTINUOUS_TARGET_ORDER: |
| norm_pred = self.reg_heads[name](h) |
| norm_preds[name] = norm_pred |
| raw_preds[name] = self._decode_prediction(name, norm_pred) |
|
|
| rain_logit = self.fc_rain(h) |
| weather_logits = self.fc_weather(h).view(-1, self.num_predict, self.num_weather_classes) |
|
|
| loss = None |
| if temp_target is not None: |
| targets = { |
| "temp": temp_target, |
| "humidity": humidity_target, |
| "apparent": apparent_target, |
| "precip": precip_target, |
| "sea_level_pressure": sea_level_pressure_target, |
| "surface_pressure": surface_pressure_target, |
| "cloud_cover": cloud_cover_target, |
| "wind": wind_target, |
| "wind_dir_sin": wind_dir_sin_target, |
| "wind_dir_cos": wind_dir_cos_target, |
| } |
|
|
| loss_terms = [] |
| for name, target in targets.items(): |
| if target is None: |
| continue |
| target_norm = self._encode_target(name, target.to(h.device)) |
| pred_norm = norm_preds[name].to(target_norm.dtype) |
| loss_terms.append( |
| F.smooth_l1_loss(pred_norm, target_norm) * float(CONTINUOUS_TARGET_SPECS[name]["loss_weight"]) |
| ) |
|
|
| if rain_target is not None: |
| rain_target = rain_target.to(rain_logit.dtype) |
| rain_loss = F.binary_cross_entropy_with_logits( |
| rain_logit, |
| rain_target, |
| pos_weight=self.rain_pos_weight.to(rain_logit.device), |
| ) |
| loss_terms.append(0.7 * rain_loss) |
|
|
| if weather_target is not None: |
| weather_loss = F.cross_entropy( |
| weather_logits.reshape(-1, self.num_weather_classes), |
| weather_target.long().reshape(-1), |
| weight=self.weather_class_weights, |
| label_smoothing=0.0, |
| ) |
| loss_terms.append(0.9 * weather_loss) |
|
|
| loss = sum(loss_terms) if loss_terms else None |
|
|
| logits = ( |
| raw_preds["temp"], |
| raw_preds["humidity"], |
| raw_preds["apparent"], |
| raw_preds["precip"], |
| raw_preds["sea_level_pressure"], |
| raw_preds["surface_pressure"], |
| raw_preds["cloud_cover"], |
| raw_preds["wind"], |
| raw_preds["wind_dir_sin"], |
| raw_preds["wind_dir_cos"], |
| rain_logit, |
| weather_logits, |
| ) |
|
|
| output = WeatherModelOutput( |
| loss=loss, |
| logits=logits, |
| head_repr=head_repr if return_repr else None, |
| norm_preds=norm_preds if return_repr else None, |
| raw_preds=raw_preds if return_repr else None, |
| distill_head_repr=(self.distill_proj(head_repr) if self.distill_proj is not None else head_repr) if return_repr else None, |
| ) |
| return output |
|
|
|
|
| |
| try: |
| WeatherModelConfig.register_for_auto_class() |
| except Exception: |
| pass |
|
|
| try: |
| WeatherForcastModel.register_for_auto_class("AutoModel") |
| except Exception: |
| pass |
|
|