File size: 949 Bytes
5f226eb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 | """Metadata dataclasses for weather forecasting inference."""
import torch
from dataclasses import dataclass
@dataclass(frozen=True)
class NormalizationStats:
"""Normalization statistics for state variables."""
state_mean: torch.Tensor
state_std: torch.Tensor
residual_mean: torch.Tensor
residual_std: torch.Tensor
def to(self, device) -> 'NormalizationStats':
return NormalizationStats(
state_mean=self.state_mean.to(device),
state_std=self.state_std.to(device),
residual_mean=self.residual_mean.to(device),
residual_std=self.residual_std.to(device),
)
@dataclass(frozen=True)
class WeatherMetadata:
"""Metadata for the weather dataset."""
variables: list[str]
static_variables: list[str]
longitude: torch.Tensor
latitude: torch.Tensor
static_data: torch.Tensor
day_year_delta: torch.Tensor
norm_stats: NormalizationStats
|