Harley-ml commited on
Commit
0ed3271
·
verified ·
1 Parent(s): 3c58f59

Upload 5 files

Browse files
Files changed (5) hide show
  1. __init__.py +4 -0
  2. config.json +82 -0
  3. configuration.py +52 -0
  4. model.safetensors +3 -0
  5. modeling.py +301 -0
__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .configuration import WeatherSequenceConfig
2
+ from .modeling import WeatherSequenceModel, WeatherSequenceOutput
3
+
4
+ __all__ = ["WeatherModelConfig", "WeatherForcastModel", "WeatherModelOutput"]
config.json ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "WeatherSequenceModel"
4
+ ],
5
+ "distill_teacher_head_dim": 416,
6
+ "dropout": 0.1,
7
+ "dtype": "float32",
8
+ "encoder_type": "lstm",
9
+ "hidden_dim": 128,
10
+ "input_dim": 22,
11
+ "location_emb_dim": 32,
12
+ "model_type": "weather_sequence",
13
+ "num_layers": 3,
14
+ "num_locations": 82,
15
+ "num_predict": 12,
16
+ "num_weather_classes": 7,
17
+ "rain_pos_weight": 6.547722074664306,
18
+ "seq_len": 72,
19
+ "target_norms": {
20
+ "apparent": {
21
+ "mean": 16.420160986060196,
22
+ "std": 12.332221212388726,
23
+ "transform": "raw"
24
+ },
25
+ "cloud_cover": {
26
+ "mean": 51.71581237675868,
27
+ "std": 42.028595137718646,
28
+ "transform": "raw"
29
+ },
30
+ "humidity": {
31
+ "mean": 69.18635597409919,
32
+ "std": 21.524024234467674,
33
+ "transform": "raw"
34
+ },
35
+ "precip": {
36
+ "mean": 0.05684371705656333,
37
+ "std": 0.22550783339649325,
38
+ "transform": "log1p"
39
+ },
40
+ "sea_level_pressure": {
41
+ "mean": 1014.5679196119568,
42
+ "std": 7.459071118489876,
43
+ "transform": "raw"
44
+ },
45
+ "surface_pressure": {
46
+ "mean": 963.4371186423618,
47
+ "std": 86.39541603431283,
48
+ "transform": "raw"
49
+ },
50
+ "temp": {
51
+ "mean": 16.863813962719767,
52
+ "std": 9.92933797011761,
53
+ "transform": "raw"
54
+ },
55
+ "wind": {
56
+ "mean": 9.952284635073887,
57
+ "std": 6.676419945197847,
58
+ "transform": "raw"
59
+ },
60
+ "wind_dir_cos": {
61
+ "mean": -0.03349104536155545,
62
+ "std": 0.70513783656826,
63
+ "transform": "raw"
64
+ },
65
+ "wind_dir_sin": {
66
+ "mean": -0.002100160488024665,
67
+ "std": 0.7082757736110976,
68
+ "transform": "raw"
69
+ }
70
+ },
71
+ "transformers_version": "5.5.0",
72
+ "use_cache": false,
73
+ "weather_class_weights": [
74
+ 0.23165243864059448,
75
+ 0.1921183317899704,
76
+ 1.7011765241622925,
77
+ 0.4377932548522949,
78
+ 0.8518651723861694,
79
+ 1.6312177181243896,
80
+ 1.9541765451431274
81
+ ]
82
+ }
configuration.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict, Optional
4
+
5
+ try:
6
+ from transformers import PretrainedConfig
7
+ except Exception: # pragma: no cover - lets the file import in minimal environments
8
+ class PretrainedConfig: # type: ignore
9
+ model_type = "custom"
10
+
11
+ def __init__(self, **kwargs):
12
+ for k, v in kwargs.items():
13
+ setattr(self, k, v)
14
+
15
+
16
+ class WeatherModelConfig(PretrainedConfig):
17
+
18
+ model_type = "mwm"
19
+
20
+ def __init__(
21
+ self,
22
+ input_dim: Optional[int] = None,
23
+ seq_len: int = 72,
24
+ num_predict: int = 12,
25
+ hidden_dim: int = 128,
26
+ num_layers: int = 3,
27
+ dropout: float = 0.1,
28
+ encoder_type: str = "lstm",
29
+ num_locations: int = 82,
30
+ location_emb_dim: int = 32,
31
+ num_weather_classes: int = 7,
32
+ rain_pos_weight: float = 1.0,
33
+ weather_class_weights: Optional[list[float]] = None,
34
+ target_norms: Optional[Dict[str, Dict[str, float]]] = None,
35
+ distill_teacher_head_dim: int = 416,
36
+ **kwargs: Any,
37
+ ):
38
+ super().__init__(**kwargs)
39
+ self.input_dim = input_dim
40
+ self.seq_len = seq_len
41
+ self.num_predict = num_predict
42
+ self.hidden_dim = hidden_dim
43
+ self.num_layers = num_layers
44
+ self.dropout = dropout
45
+ self.encoder_type = encoder_type
46
+ self.num_locations = num_locations
47
+ self.location_emb_dim = location_emb_dim
48
+ self.num_weather_classes = num_weather_classes
49
+ self.rain_pos_weight = rain_pos_weight
50
+ self.weather_class_weights = weather_class_weights
51
+ self.target_norms = target_norms or {}
52
+ self.distill_teacher_head_dim = int(distill_teacher_head_dim)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca32ca6acc93e3d503343366a6cc277b723d82ad1cce2815dc7fe761772a0748
3
+ size 1788776
modeling.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, Optional, Tuple
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ try:
11
+ from transformers import PreTrainedModel
12
+ from transformers.modeling_outputs import ModelOutput
13
+ except Exception:
14
+ class PreTrainedModel(nn.Module):
15
+ config_class = None
16
+ base_model_prefix = ""
17
+ main_input_name = "input_ids"
18
+
19
+ def __init__(self, config):
20
+ super().__init__()
21
+ self.config = config
22
+
23
+ class ModelOutput(dict): # type: ignore
24
+ pass
25
+
26
+ from .configuration import WeatherModelConfig
27
+
28
+ CONTINUOUS_TARGET_ORDER = [
29
+ "temp",
30
+ "humidity",
31
+ "apparent",
32
+ "precip",
33
+ "sea_level_pressure",
34
+ "surface_pressure",
35
+ "cloud_cover",
36
+ "wind",
37
+ "wind_dir_sin",
38
+ "wind_dir_cos",
39
+ ]
40
+
41
+ CONTINUOUS_TARGET_SPECS = {
42
+ "temp": {"loss_weight": 1.0, "transform": "raw"},
43
+ "humidity": {"loss_weight": 1.0, "transform": "raw"},
44
+ "apparent": {"loss_weight": 0.8, "transform": "raw"},
45
+ "precip": {"loss_weight": 0.9, "transform": "log1p"},
46
+ "sea_level_pressure": {"loss_weight": 0.6, "transform": "raw"},
47
+ "surface_pressure": {"loss_weight": 0.4, "transform": "raw"},
48
+ "cloud_cover": {"loss_weight": 0.4, "transform": "raw"},
49
+ "wind": {"loss_weight": 0.6, "transform": "raw"},
50
+ "wind_dir_sin": {"loss_weight": 0.55, "transform": "raw"},
51
+ "wind_dir_cos": {"loss_weight": 0.55, "transform": "raw"},
52
+ }
53
+
54
+
55
+ @dataclass
56
+ class WeatherModelOutput(ModelOutput):
57
+ loss: Optional[torch.Tensor] = None
58
+ logits: Optional[Tuple[torch.Tensor, ...]] = None
59
+ head_repr: Optional[torch.Tensor] = None
60
+ norm_preds: Optional[Dict[str, torch.Tensor]] = None
61
+ raw_preds: Optional[Dict[str, torch.Tensor]] = None
62
+ distill_head_repr: Optional[torch.Tensor] = None
63
+
64
+
65
+ class WeatherForcastModel(PreTrainedModel):
66
+
67
+ config_class = WeatherModelConfig
68
+ base_model_prefix = "weather_sequence"
69
+ main_input_name = "X"
70
+
71
+ # Newer Transformers versions may create auto_map entries from these registrations.
72
+ _tied_weights_keys: list[str] = []
73
+
74
+ def __init__(self, config: WeatherModelConfig):
75
+ super().__init__(config)
76
+
77
+ self.encoder_type = str(getattr(config, "encoder_type", "lstm")).lower()
78
+ self.hidden_dim = int(config.hidden_dim)
79
+ self.seq_len = int(config.seq_len)
80
+ self.num_predict = int(config.num_predict)
81
+ self.num_weather_classes = int(config.num_weather_classes)
82
+
83
+ if config.input_dim is None:
84
+ raise ValueError("WeatherModelConfig.input_dim must be set")
85
+
86
+ self.location_embedding = nn.Embedding(max(1, int(config.num_locations)), int(config.location_emb_dim))
87
+
88
+ if config.weather_class_weights is not None:
89
+ self.register_buffer(
90
+ "weather_class_weights",
91
+ torch.tensor(config.weather_class_weights, dtype=torch.float32),
92
+ persistent=False,
93
+ )
94
+ else:
95
+ self.weather_class_weights = None
96
+
97
+ self.register_buffer(
98
+ "rain_pos_weight",
99
+ torch.tensor(float(config.rain_pos_weight), dtype=torch.float32),
100
+ persistent=False,
101
+ )
102
+
103
+ self.target_norm_meta: Dict[str, Dict[str, Any]] = {}
104
+ for name in CONTINUOUS_TARGET_ORDER:
105
+ spec = dict(config.target_norms.get(name, {}))
106
+ mean = float(spec.get("mean", 0.0))
107
+ std = max(float(spec.get("std", 1.0)), 1e-6)
108
+ transform = str(spec.get("transform", CONTINUOUS_TARGET_SPECS[name]["transform"]))
109
+ self.register_buffer(f"{name}_mean", torch.tensor(mean, dtype=torch.float32), persistent=False)
110
+ self.register_buffer(f"{name}_std", torch.tensor(std, dtype=torch.float32), persistent=False)
111
+ self.target_norm_meta[name] = {"transform": transform}
112
+
113
+ if self.encoder_type == "lstm":
114
+ self.encoder = nn.LSTM(
115
+ input_size=int(config.input_dim),
116
+ hidden_size=self.hidden_dim,
117
+ num_layers=int(config.num_layers),
118
+ batch_first=True,
119
+ dropout=float(config.dropout) if int(config.num_layers) > 1 else 0.0,
120
+ bidirectional=False,
121
+ )
122
+ elif self.encoder_type == "transformer":
123
+ self.input_proj = nn.Linear(int(config.input_dim), self.hidden_dim)
124
+ self.pos_encoding = nn.Parameter(torch.randn(1, int(config.seq_len), self.hidden_dim) * 0.1)
125
+ encoder_layer = nn.TransformerEncoderLayer(
126
+ d_model=self.hidden_dim,
127
+ nhead=4,
128
+ dropout=float(config.dropout),
129
+ batch_first=True,
130
+ )
131
+ self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=int(config.num_layers))
132
+ else:
133
+ raise ValueError(f"Unknown encoder_type: {self.encoder_type}")
134
+
135
+ self.head_dim = self.hidden_dim + int(config.location_emb_dim)
136
+ self.head_norm = nn.LayerNorm(self.head_dim)
137
+ self.head_dropout = nn.Dropout(float(config.dropout))
138
+
139
+ self.reg_heads = nn.ModuleDict({name: nn.Linear(self.head_dim, self.num_predict) for name in CONTINUOUS_TARGET_ORDER})
140
+ self.fc_rain = nn.Linear(self.head_dim, self.num_predict)
141
+ self.fc_weather = nn.Linear(self.head_dim, self.num_predict * self.num_weather_classes)
142
+
143
+ teacher_head_dim = int(getattr(config, "distill_teacher_head_dim", 0))
144
+ if teacher_head_dim > 0 and teacher_head_dim != self.head_dim:
145
+ self.distill_proj = nn.Linear(self.head_dim, teacher_head_dim, bias=False)
146
+ else:
147
+ self.distill_proj = None
148
+
149
+ self.post_init()
150
+
151
+ @staticmethod
152
+ def _masked_mean(x: torch.Tensor) -> torch.Tensor:
153
+ mask = (x.abs().sum(dim=-1) > 0).float().unsqueeze(-1)
154
+ summed = (x * mask).sum(dim=1)
155
+ denom = mask.sum(dim=1).clamp(min=1.0)
156
+ return summed / denom
157
+
158
+ def _target_mean_std(self, name: str) -> Tuple[torch.Tensor, torch.Tensor]:
159
+ return getattr(self, f"{name}_mean"), getattr(self, f"{name}_std")
160
+
161
+ def _encode_target(self, name: str, target: torch.Tensor) -> torch.Tensor:
162
+ transform = self.target_norm_meta[name]["transform"]
163
+ target = target.to(dtype=torch.float32)
164
+ if transform == "log1p":
165
+ target = torch.log1p(torch.clamp(target, min=0.0))
166
+ mean, std = self._target_mean_std(name)
167
+ return (target - mean.to(target.device)) / std.to(target.device)
168
+
169
+ def _decode_prediction(self, name: str, pred_norm: torch.Tensor) -> torch.Tensor:
170
+ transform = self.target_norm_meta[name]["transform"]
171
+ mean, std = self._target_mean_std(name)
172
+ raw = pred_norm * std.to(pred_norm.device) + mean.to(pred_norm.device)
173
+ if transform == "log1p":
174
+ raw = torch.expm1(raw).clamp(min=0.0)
175
+ return raw
176
+
177
+ def forward(
178
+ self,
179
+ X: torch.Tensor,
180
+ location_id: Optional[torch.Tensor] = None,
181
+ temp_target: Optional[torch.Tensor] = None,
182
+ humidity_target: Optional[torch.Tensor] = None,
183
+ apparent_target: Optional[torch.Tensor] = None,
184
+ precip_target: Optional[torch.Tensor] = None,
185
+ sea_level_pressure_target: Optional[torch.Tensor] = None,
186
+ surface_pressure_target: Optional[torch.Tensor] = None,
187
+ cloud_cover_target: Optional[torch.Tensor] = None,
188
+ wind_target: Optional[torch.Tensor] = None,
189
+ wind_dir_sin_target: Optional[torch.Tensor] = None,
190
+ wind_dir_cos_target: Optional[torch.Tensor] = None,
191
+ rain_target: Optional[torch.Tensor] = None,
192
+ weather_target: Optional[torch.Tensor] = None,
193
+ return_repr: bool = False,
194
+ **kwargs: Any,
195
+ ) -> WeatherModelOutput:
196
+ if location_id is None:
197
+ location_id = torch.zeros(X.size(0), dtype=torch.long, device=X.device)
198
+
199
+ if self.encoder_type == "lstm":
200
+ _, (h, _) = self.encoder(X)
201
+ seq_repr = h[-1]
202
+ else:
203
+ z = self.input_proj(X) + self.pos_encoding[:, : X.size(1), :]
204
+ out = self.encoder(z)
205
+ seq_repr = self._masked_mean(out)
206
+
207
+ loc_emb = self.location_embedding(location_id)
208
+ head_repr = self.head_norm(torch.cat([seq_repr, loc_emb], dim=1))
209
+ h = self.head_dropout(head_repr)
210
+
211
+ raw_preds: Dict[str, torch.Tensor] = {}
212
+ norm_preds: Dict[str, torch.Tensor] = {}
213
+ for name in CONTINUOUS_TARGET_ORDER:
214
+ norm_pred = self.reg_heads[name](h)
215
+ norm_preds[name] = norm_pred
216
+ raw_preds[name] = self._decode_prediction(name, norm_pred)
217
+
218
+ rain_logit = self.fc_rain(h)
219
+ weather_logits = self.fc_weather(h).view(-1, self.num_predict, self.num_weather_classes)
220
+
221
+ loss = None
222
+ if temp_target is not None:
223
+ targets = {
224
+ "temp": temp_target,
225
+ "humidity": humidity_target,
226
+ "apparent": apparent_target,
227
+ "precip": precip_target,
228
+ "sea_level_pressure": sea_level_pressure_target,
229
+ "surface_pressure": surface_pressure_target,
230
+ "cloud_cover": cloud_cover_target,
231
+ "wind": wind_target,
232
+ "wind_dir_sin": wind_dir_sin_target,
233
+ "wind_dir_cos": wind_dir_cos_target,
234
+ }
235
+
236
+ loss_terms = []
237
+ for name, target in targets.items():
238
+ if target is None:
239
+ continue
240
+ target_norm = self._encode_target(name, target.to(h.device))
241
+ pred_norm = norm_preds[name].to(target_norm.dtype)
242
+ loss_terms.append(
243
+ F.smooth_l1_loss(pred_norm, target_norm) * float(CONTINUOUS_TARGET_SPECS[name]["loss_weight"])
244
+ )
245
+
246
+ if rain_target is not None:
247
+ rain_target = rain_target.to(rain_logit.dtype)
248
+ rain_loss = F.binary_cross_entropy_with_logits(
249
+ rain_logit,
250
+ rain_target,
251
+ pos_weight=self.rain_pos_weight.to(rain_logit.device),
252
+ )
253
+ loss_terms.append(0.7 * rain_loss)
254
+
255
+ if weather_target is not None:
256
+ weather_loss = F.cross_entropy(
257
+ weather_logits.reshape(-1, self.num_weather_classes),
258
+ weather_target.long().reshape(-1),
259
+ weight=self.weather_class_weights,
260
+ label_smoothing=0.0,
261
+ )
262
+ loss_terms.append(0.9 * weather_loss)
263
+
264
+ loss = sum(loss_terms) if loss_terms else None
265
+
266
+ logits = (
267
+ raw_preds["temp"],
268
+ raw_preds["humidity"],
269
+ raw_preds["apparent"],
270
+ raw_preds["precip"],
271
+ raw_preds["sea_level_pressure"],
272
+ raw_preds["surface_pressure"],
273
+ raw_preds["cloud_cover"],
274
+ raw_preds["wind"],
275
+ raw_preds["wind_dir_sin"],
276
+ raw_preds["wind_dir_cos"],
277
+ rain_logit,
278
+ weather_logits,
279
+ )
280
+
281
+ output = WeatherModelOutput(
282
+ loss=loss,
283
+ logits=logits,
284
+ head_repr=head_repr if return_repr else None,
285
+ norm_preds=norm_preds if return_repr else None,
286
+ raw_preds=raw_preds if return_repr else None,
287
+ distill_head_repr=(self.distill_proj(head_repr) if self.distill_proj is not None else head_repr) if return_repr else None,
288
+ )
289
+ return output
290
+
291
+
292
+ # Make the repo usable with AutoConfig/AutoModel when loaded from the Hub.
293
+ try: # pragma: no cover
294
+ WeatherModelConfig.register_for_auto_class()
295
+ except Exception:
296
+ pass
297
+
298
+ try: # pragma: no cover
299
+ WeatherForcastModel.register_for_auto_class("AutoModel")
300
+ except Exception:
301
+ pass