File size: 730 Bytes
5f226eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch
from torch import nn
from dataset import WeatherMetadata


class WeatherModel(nn.Module):
    """Weather forecasting model wrapper."""

    def __init__(self, model: nn.Module, weather_metadata: WeatherMetadata):
        super().__init__()
        self.model = model
        self.model.initialize_static_vars(weather_metadata.static_data, weather_metadata.longitude, weather_metadata.latitude)
        self.model.initialize_interpolation(weather_metadata.longitude, weather_metadata.latitude)
        self.weather_metadata = weather_metadata

    def forward(self, norm_state: torch.Tensor, day_year_time: torch.Tensor, num_noise_samples: int):
        return self.model(norm_state, day_year_time, num_noise_samples)