import torch from torch import nn from dataset import WeatherMetadata class WeatherModel(nn.Module): """Weather forecasting model wrapper.""" def __init__(self, model: nn.Module, weather_metadata: WeatherMetadata): super().__init__() self.model = model self.model.initialize_static_vars(weather_metadata.static_data, weather_metadata.longitude, weather_metadata.latitude) self.model.initialize_interpolation(weather_metadata.longitude, weather_metadata.latitude) self.weather_metadata = weather_metadata def forward(self, norm_state: torch.Tensor, day_year_time: torch.Tensor, num_noise_samples: int): return self.model(norm_state, day_year_time, num_noise_samples)