File size: 730 Bytes
5f226eb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch
from torch import nn
from dataset import WeatherMetadata
class WeatherModel(nn.Module):
"""Weather forecasting model wrapper."""
def __init__(self, model: nn.Module, weather_metadata: WeatherMetadata):
super().__init__()
self.model = model
self.model.initialize_static_vars(weather_metadata.static_data, weather_metadata.longitude, weather_metadata.latitude)
self.model.initialize_interpolation(weather_metadata.longitude, weather_metadata.latitude)
self.weather_metadata = weather_metadata
def forward(self, norm_state: torch.Tensor, day_year_time: torch.Tensor, num_noise_samples: int):
return self.model(norm_state, day_year_time, num_noise_samples)
|