DistilHweh-446k / modeling.py
Harley-ml's picture
Upload 5 files
0ed3271 verified
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
from transformers import PreTrainedModel
from transformers.modeling_outputs import ModelOutput
except Exception:
class PreTrainedModel(nn.Module):
config_class = None
base_model_prefix = ""
main_input_name = "input_ids"
def __init__(self, config):
super().__init__()
self.config = config
class ModelOutput(dict): # type: ignore
pass
from .configuration import WeatherModelConfig
CONTINUOUS_TARGET_ORDER = [
"temp",
"humidity",
"apparent",
"precip",
"sea_level_pressure",
"surface_pressure",
"cloud_cover",
"wind",
"wind_dir_sin",
"wind_dir_cos",
]
CONTINUOUS_TARGET_SPECS = {
"temp": {"loss_weight": 1.0, "transform": "raw"},
"humidity": {"loss_weight": 1.0, "transform": "raw"},
"apparent": {"loss_weight": 0.8, "transform": "raw"},
"precip": {"loss_weight": 0.9, "transform": "log1p"},
"sea_level_pressure": {"loss_weight": 0.6, "transform": "raw"},
"surface_pressure": {"loss_weight": 0.4, "transform": "raw"},
"cloud_cover": {"loss_weight": 0.4, "transform": "raw"},
"wind": {"loss_weight": 0.6, "transform": "raw"},
"wind_dir_sin": {"loss_weight": 0.55, "transform": "raw"},
"wind_dir_cos": {"loss_weight": 0.55, "transform": "raw"},
}
@dataclass
class WeatherModelOutput(ModelOutput):
loss: Optional[torch.Tensor] = None
logits: Optional[Tuple[torch.Tensor, ...]] = None
head_repr: Optional[torch.Tensor] = None
norm_preds: Optional[Dict[str, torch.Tensor]] = None
raw_preds: Optional[Dict[str, torch.Tensor]] = None
distill_head_repr: Optional[torch.Tensor] = None
class WeatherForcastModel(PreTrainedModel):
config_class = WeatherModelConfig
base_model_prefix = "weather_sequence"
main_input_name = "X"
# Newer Transformers versions may create auto_map entries from these registrations.
_tied_weights_keys: list[str] = []
def __init__(self, config: WeatherModelConfig):
super().__init__(config)
self.encoder_type = str(getattr(config, "encoder_type", "lstm")).lower()
self.hidden_dim = int(config.hidden_dim)
self.seq_len = int(config.seq_len)
self.num_predict = int(config.num_predict)
self.num_weather_classes = int(config.num_weather_classes)
if config.input_dim is None:
raise ValueError("WeatherModelConfig.input_dim must be set")
self.location_embedding = nn.Embedding(max(1, int(config.num_locations)), int(config.location_emb_dim))
if config.weather_class_weights is not None:
self.register_buffer(
"weather_class_weights",
torch.tensor(config.weather_class_weights, dtype=torch.float32),
persistent=False,
)
else:
self.weather_class_weights = None
self.register_buffer(
"rain_pos_weight",
torch.tensor(float(config.rain_pos_weight), dtype=torch.float32),
persistent=False,
)
self.target_norm_meta: Dict[str, Dict[str, Any]] = {}
for name in CONTINUOUS_TARGET_ORDER:
spec = dict(config.target_norms.get(name, {}))
mean = float(spec.get("mean", 0.0))
std = max(float(spec.get("std", 1.0)), 1e-6)
transform = str(spec.get("transform", CONTINUOUS_TARGET_SPECS[name]["transform"]))
self.register_buffer(f"{name}_mean", torch.tensor(mean, dtype=torch.float32), persistent=False)
self.register_buffer(f"{name}_std", torch.tensor(std, dtype=torch.float32), persistent=False)
self.target_norm_meta[name] = {"transform": transform}
if self.encoder_type == "lstm":
self.encoder = nn.LSTM(
input_size=int(config.input_dim),
hidden_size=self.hidden_dim,
num_layers=int(config.num_layers),
batch_first=True,
dropout=float(config.dropout) if int(config.num_layers) > 1 else 0.0,
bidirectional=False,
)
elif self.encoder_type == "transformer":
self.input_proj = nn.Linear(int(config.input_dim), self.hidden_dim)
self.pos_encoding = nn.Parameter(torch.randn(1, int(config.seq_len), self.hidden_dim) * 0.1)
encoder_layer = nn.TransformerEncoderLayer(
d_model=self.hidden_dim,
nhead=4,
dropout=float(config.dropout),
batch_first=True,
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=int(config.num_layers))
else:
raise ValueError(f"Unknown encoder_type: {self.encoder_type}")
self.head_dim = self.hidden_dim + int(config.location_emb_dim)
self.head_norm = nn.LayerNorm(self.head_dim)
self.head_dropout = nn.Dropout(float(config.dropout))
self.reg_heads = nn.ModuleDict({name: nn.Linear(self.head_dim, self.num_predict) for name in CONTINUOUS_TARGET_ORDER})
self.fc_rain = nn.Linear(self.head_dim, self.num_predict)
self.fc_weather = nn.Linear(self.head_dim, self.num_predict * self.num_weather_classes)
teacher_head_dim = int(getattr(config, "distill_teacher_head_dim", 0))
if teacher_head_dim > 0 and teacher_head_dim != self.head_dim:
self.distill_proj = nn.Linear(self.head_dim, teacher_head_dim, bias=False)
else:
self.distill_proj = None
self.post_init()
@staticmethod
def _masked_mean(x: torch.Tensor) -> torch.Tensor:
mask = (x.abs().sum(dim=-1) > 0).float().unsqueeze(-1)
summed = (x * mask).sum(dim=1)
denom = mask.sum(dim=1).clamp(min=1.0)
return summed / denom
def _target_mean_std(self, name: str) -> Tuple[torch.Tensor, torch.Tensor]:
return getattr(self, f"{name}_mean"), getattr(self, f"{name}_std")
def _encode_target(self, name: str, target: torch.Tensor) -> torch.Tensor:
transform = self.target_norm_meta[name]["transform"]
target = target.to(dtype=torch.float32)
if transform == "log1p":
target = torch.log1p(torch.clamp(target, min=0.0))
mean, std = self._target_mean_std(name)
return (target - mean.to(target.device)) / std.to(target.device)
def _decode_prediction(self, name: str, pred_norm: torch.Tensor) -> torch.Tensor:
transform = self.target_norm_meta[name]["transform"]
mean, std = self._target_mean_std(name)
raw = pred_norm * std.to(pred_norm.device) + mean.to(pred_norm.device)
if transform == "log1p":
raw = torch.expm1(raw).clamp(min=0.0)
return raw
def forward(
self,
X: torch.Tensor,
location_id: Optional[torch.Tensor] = None,
temp_target: Optional[torch.Tensor] = None,
humidity_target: Optional[torch.Tensor] = None,
apparent_target: Optional[torch.Tensor] = None,
precip_target: Optional[torch.Tensor] = None,
sea_level_pressure_target: Optional[torch.Tensor] = None,
surface_pressure_target: Optional[torch.Tensor] = None,
cloud_cover_target: Optional[torch.Tensor] = None,
wind_target: Optional[torch.Tensor] = None,
wind_dir_sin_target: Optional[torch.Tensor] = None,
wind_dir_cos_target: Optional[torch.Tensor] = None,
rain_target: Optional[torch.Tensor] = None,
weather_target: Optional[torch.Tensor] = None,
return_repr: bool = False,
**kwargs: Any,
) -> WeatherModelOutput:
if location_id is None:
location_id = torch.zeros(X.size(0), dtype=torch.long, device=X.device)
if self.encoder_type == "lstm":
_, (h, _) = self.encoder(X)
seq_repr = h[-1]
else:
z = self.input_proj(X) + self.pos_encoding[:, : X.size(1), :]
out = self.encoder(z)
seq_repr = self._masked_mean(out)
loc_emb = self.location_embedding(location_id)
head_repr = self.head_norm(torch.cat([seq_repr, loc_emb], dim=1))
h = self.head_dropout(head_repr)
raw_preds: Dict[str, torch.Tensor] = {}
norm_preds: Dict[str, torch.Tensor] = {}
for name in CONTINUOUS_TARGET_ORDER:
norm_pred = self.reg_heads[name](h)
norm_preds[name] = norm_pred
raw_preds[name] = self._decode_prediction(name, norm_pred)
rain_logit = self.fc_rain(h)
weather_logits = self.fc_weather(h).view(-1, self.num_predict, self.num_weather_classes)
loss = None
if temp_target is not None:
targets = {
"temp": temp_target,
"humidity": humidity_target,
"apparent": apparent_target,
"precip": precip_target,
"sea_level_pressure": sea_level_pressure_target,
"surface_pressure": surface_pressure_target,
"cloud_cover": cloud_cover_target,
"wind": wind_target,
"wind_dir_sin": wind_dir_sin_target,
"wind_dir_cos": wind_dir_cos_target,
}
loss_terms = []
for name, target in targets.items():
if target is None:
continue
target_norm = self._encode_target(name, target.to(h.device))
pred_norm = norm_preds[name].to(target_norm.dtype)
loss_terms.append(
F.smooth_l1_loss(pred_norm, target_norm) * float(CONTINUOUS_TARGET_SPECS[name]["loss_weight"])
)
if rain_target is not None:
rain_target = rain_target.to(rain_logit.dtype)
rain_loss = F.binary_cross_entropy_with_logits(
rain_logit,
rain_target,
pos_weight=self.rain_pos_weight.to(rain_logit.device),
)
loss_terms.append(0.7 * rain_loss)
if weather_target is not None:
weather_loss = F.cross_entropy(
weather_logits.reshape(-1, self.num_weather_classes),
weather_target.long().reshape(-1),
weight=self.weather_class_weights,
label_smoothing=0.0,
)
loss_terms.append(0.9 * weather_loss)
loss = sum(loss_terms) if loss_terms else None
logits = (
raw_preds["temp"],
raw_preds["humidity"],
raw_preds["apparent"],
raw_preds["precip"],
raw_preds["sea_level_pressure"],
raw_preds["surface_pressure"],
raw_preds["cloud_cover"],
raw_preds["wind"],
raw_preds["wind_dir_sin"],
raw_preds["wind_dir_cos"],
rain_logit,
weather_logits,
)
output = WeatherModelOutput(
loss=loss,
logits=logits,
head_repr=head_repr if return_repr else None,
norm_preds=norm_preds if return_repr else None,
raw_preds=raw_preds if return_repr else None,
distill_head_repr=(self.distill_proj(head_repr) if self.distill_proj is not None else head_repr) if return_repr else None,
)
return output
# Make the repo usable with AutoConfig/AutoModel when loaded from the Hub.
try: # pragma: no cover
WeatherModelConfig.register_for_auto_class()
except Exception:
pass
try: # pragma: no cover
WeatherForcastModel.register_for_auto_class("AutoModel")
except Exception:
pass