| """Metadata dataclasses for weather forecasting inference.""" | |
| import torch | |
| from dataclasses import dataclass | |
| class NormalizationStats: | |
| """Normalization statistics for state variables.""" | |
| state_mean: torch.Tensor | |
| state_std: torch.Tensor | |
| residual_mean: torch.Tensor | |
| residual_std: torch.Tensor | |
| def to(self, device) -> 'NormalizationStats': | |
| return NormalizationStats( | |
| state_mean=self.state_mean.to(device), | |
| state_std=self.state_std.to(device), | |
| residual_mean=self.residual_mean.to(device), | |
| residual_std=self.residual_std.to(device), | |
| ) | |
| class WeatherMetadata: | |
| """Metadata for the weather dataset.""" | |
| variables: list[str] | |
| static_variables: list[str] | |
| longitude: torch.Tensor | |
| latitude: torch.Tensor | |
| static_data: torch.Tensor | |
| day_year_delta: torch.Tensor | |
| norm_stats: NormalizationStats | |