File size: 1,734 Bytes
0ed3271
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from __future__ import annotations

from typing import Any, Dict, Optional

try:
    from transformers import PretrainedConfig
except Exception:  # pragma: no cover - lets the file import in minimal environments
    class PretrainedConfig:  # type: ignore
        model_type = "custom"

        def __init__(self, **kwargs):
            for k, v in kwargs.items():
                setattr(self, k, v)


class WeatherModelConfig(PretrainedConfig):

    model_type = "mwm"

    def __init__(
        self,
        input_dim: Optional[int] = None,
        seq_len: int = 72,
        num_predict: int = 12,
        hidden_dim: int = 128,
        num_layers: int = 3,
        dropout: float = 0.1,
        encoder_type: str = "lstm",
        num_locations: int = 82,
        location_emb_dim: int = 32,
        num_weather_classes: int = 7,
        rain_pos_weight: float = 1.0,
        weather_class_weights: Optional[list[float]] = None,
        target_norms: Optional[Dict[str, Dict[str, float]]] = None,
        distill_teacher_head_dim: int = 416,
        **kwargs: Any,
    ):
        super().__init__(**kwargs)
        self.input_dim = input_dim
        self.seq_len = seq_len
        self.num_predict = num_predict
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.dropout = dropout
        self.encoder_type = encoder_type
        self.num_locations = num_locations
        self.location_emb_dim = location_emb_dim
        self.num_weather_classes = num_weather_classes
        self.rain_pos_weight = rain_pos_weight
        self.weather_class_weights = weather_class_weights
        self.target_norms = target_norms or {}
        self.distill_teacher_head_dim = int(distill_teacher_head_dim)