mosaic / base.py
maxxxzdn's picture
Initial release: Mosaic weather model (era5 + hres variants)
5f226eb verified
import torch
from torch import nn
from dataset import WeatherMetadata
class WeatherModel(nn.Module):
"""Weather forecasting model wrapper."""
def __init__(self, model: nn.Module, weather_metadata: WeatherMetadata):
super().__init__()
self.model = model
self.model.initialize_static_vars(weather_metadata.static_data, weather_metadata.longitude, weather_metadata.latitude)
self.model.initialize_interpolation(weather_metadata.longitude, weather_metadata.latitude)
self.weather_metadata = weather_metadata
def forward(self, norm_state: torch.Tensor, day_year_time: torch.Tensor, num_noise_samples: int):
return self.model(norm_state, day_year_time, num_noise_samples)