soilformer / modelling /utils.py
Kuangdai
Initial release of SoilFormer
6fb6c07
# utils.py
# -*- coding: utf-8 -*-
import json
from typing import Dict
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa
class GroupedMLP(nn.Module):
"""
Batched per-variable MLP for a fixed n_in bucket.
Input: X [B, V, n_in]
Output: Y [B, V, n_out]
Per-variable weights (NOT shared across V):
- 1-layer: W [V, n_out, n_in], b [V, n_out]
- 2-layer: W1 [V, mid, n_in], b1 [V, mid]
W2 [V, n_out, mid], b2 [V, n_out]
"""
def __init__(
self,
n_var: int,
n_in: int,
n_out: int,
middle_size: Optional[int] = None,
bias: bool = True,
):
super().__init__()
self.n_var = int(n_var)
self.n_in = int(n_in)
self.n_out = int(n_out)
self.middle_size = None if middle_size is None else int(middle_size)
self.bias = bias
if self.middle_size is None:
self.W = nn.Parameter(torch.empty(self.n_var, self.n_out, self.n_in))
if bias:
self.b = nn.Parameter(torch.empty(self.n_var, self.n_out))
else:
self.register_parameter("b", None)
self.W1 = self.b1 = self.W2 = self.b2 = None
else:
mid = self.middle_size
self.W1 = nn.Parameter(torch.empty(self.n_var, mid, self.n_in))
self.W2 = nn.Parameter(torch.empty(self.n_var, self.n_out, mid))
if bias:
self.b1 = nn.Parameter(torch.empty(self.n_var, mid))
self.b2 = nn.Parameter(torch.empty(self.n_var, self.n_out))
else:
self.register_parameter("b1", None)
self.register_parameter("b2", None)
self.W = self.b = None
def init_weights(self, std: float = 0.02) -> None:
"""
Initialize weights manually.
"""
if self.middle_size is None:
nn.init.normal_(self.W, std=std)
if self.bias:
nn.init.zeros_(self.b)
else:
nn.init.normal_(self.W1, std=std)
nn.init.normal_(self.W2, std=std)
if self.bias:
nn.init.zeros_(self.b1)
nn.init.zeros_(self.b2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.dim() != 3:
raise ValueError(f"Expected x [B,V,n_in], got {tuple(x.shape)}")
B, V, I = x.shape
if V != self.n_var or I != self.n_in:
raise ValueError(
f"Shape mismatch: expected V={self.n_var}, n_in={self.n_in}; got V={V}, n_in={I}"
)
if self.middle_size is None:
y = torch.einsum("bvi,voi->bvo", x, self.W)
if self.bias:
y = y + self.b.unsqueeze(0)
return y
h = torch.einsum("bvi,vmi->bvm", x, self.W1)
if self.bias:
h = h + self.b1.unsqueeze(0)
h = F.gelu(h)
y = torch.einsum("bvm,vom->bvo", h, self.W2)
if self.bias:
y = y + self.b2.unsqueeze(0)
return y
def get_dtype(dtype: Optional[str]) -> torch.dtype:
dtype_str = (dtype or "bfloat16").lower()
dtype_map = {
"bfloat16": torch.bfloat16,
"float16": torch.float16,
"float32": torch.float32,
}
if dtype_str not in dtype_map:
raise ValueError(f"Unsupported dtype={dtype}. Choose from {list(dtype_map.keys())}")
return dtype_map[dtype_str]
def load_json(path: str):
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
def save_json(obj: Dict, path: str) -> None:
with open(path, "w", encoding="utf-8") as f:
json.dump(obj, f, ensure_ascii=False, indent=2) # noqa