# utils.py # -*- coding: utf-8 -*- import json from typing import Dict from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F # noqa class GroupedMLP(nn.Module): """ Batched per-variable MLP for a fixed n_in bucket. Input: X [B, V, n_in] Output: Y [B, V, n_out] Per-variable weights (NOT shared across V): - 1-layer: W [V, n_out, n_in], b [V, n_out] - 2-layer: W1 [V, mid, n_in], b1 [V, mid] W2 [V, n_out, mid], b2 [V, n_out] """ def __init__( self, n_var: int, n_in: int, n_out: int, middle_size: Optional[int] = None, bias: bool = True, ): super().__init__() self.n_var = int(n_var) self.n_in = int(n_in) self.n_out = int(n_out) self.middle_size = None if middle_size is None else int(middle_size) self.bias = bias if self.middle_size is None: self.W = nn.Parameter(torch.empty(self.n_var, self.n_out, self.n_in)) if bias: self.b = nn.Parameter(torch.empty(self.n_var, self.n_out)) else: self.register_parameter("b", None) self.W1 = self.b1 = self.W2 = self.b2 = None else: mid = self.middle_size self.W1 = nn.Parameter(torch.empty(self.n_var, mid, self.n_in)) self.W2 = nn.Parameter(torch.empty(self.n_var, self.n_out, mid)) if bias: self.b1 = nn.Parameter(torch.empty(self.n_var, mid)) self.b2 = nn.Parameter(torch.empty(self.n_var, self.n_out)) else: self.register_parameter("b1", None) self.register_parameter("b2", None) self.W = self.b = None def init_weights(self, std: float = 0.02) -> None: """ Initialize weights manually. """ if self.middle_size is None: nn.init.normal_(self.W, std=std) if self.bias: nn.init.zeros_(self.b) else: nn.init.normal_(self.W1, std=std) nn.init.normal_(self.W2, std=std) if self.bias: nn.init.zeros_(self.b1) nn.init.zeros_(self.b2) def forward(self, x: torch.Tensor) -> torch.Tensor: if x.dim() != 3: raise ValueError(f"Expected x [B,V,n_in], got {tuple(x.shape)}") B, V, I = x.shape if V != self.n_var or I != self.n_in: raise ValueError( f"Shape mismatch: expected V={self.n_var}, n_in={self.n_in}; got V={V}, n_in={I}" ) if self.middle_size is None: y = torch.einsum("bvi,voi->bvo", x, self.W) if self.bias: y = y + self.b.unsqueeze(0) return y h = torch.einsum("bvi,vmi->bvm", x, self.W1) if self.bias: h = h + self.b1.unsqueeze(0) h = F.gelu(h) y = torch.einsum("bvm,vom->bvo", h, self.W2) if self.bias: y = y + self.b2.unsqueeze(0) return y def get_dtype(dtype: Optional[str]) -> torch.dtype: dtype_str = (dtype or "bfloat16").lower() dtype_map = { "bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32, } if dtype_str not in dtype_map: raise ValueError(f"Unsupported dtype={dtype}. Choose from {list(dtype_map.keys())}") return dtype_map[dtype_str] def load_json(path: str): with open(path, "r", encoding="utf-8") as f: return json.load(f) def save_json(obj: Dict, path: str) -> None: with open(path, "w", encoding="utf-8") as f: json.dump(obj, f, ensure_ascii=False, indent=2) # noqa