| |
| |
|
|
| import json |
| from typing import Dict |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class GroupedMLP(nn.Module): |
| """ |
| Batched per-variable MLP for a fixed n_in bucket. |
| |
| Input: X [B, V, n_in] |
| Output: Y [B, V, n_out] |
| |
| Per-variable weights (NOT shared across V): |
| - 1-layer: W [V, n_out, n_in], b [V, n_out] |
| - 2-layer: W1 [V, mid, n_in], b1 [V, mid] |
| W2 [V, n_out, mid], b2 [V, n_out] |
| """ |
|
|
| def __init__( |
| self, |
| n_var: int, |
| n_in: int, |
| n_out: int, |
| middle_size: Optional[int] = None, |
| bias: bool = True, |
| ): |
| super().__init__() |
|
|
| self.n_var = int(n_var) |
| self.n_in = int(n_in) |
| self.n_out = int(n_out) |
| self.middle_size = None if middle_size is None else int(middle_size) |
| self.bias = bias |
|
|
| if self.middle_size is None: |
| self.W = nn.Parameter(torch.empty(self.n_var, self.n_out, self.n_in)) |
|
|
| if bias: |
| self.b = nn.Parameter(torch.empty(self.n_var, self.n_out)) |
| else: |
| self.register_parameter("b", None) |
|
|
| self.W1 = self.b1 = self.W2 = self.b2 = None |
|
|
| else: |
| mid = self.middle_size |
|
|
| self.W1 = nn.Parameter(torch.empty(self.n_var, mid, self.n_in)) |
| self.W2 = nn.Parameter(torch.empty(self.n_var, self.n_out, mid)) |
|
|
| if bias: |
| self.b1 = nn.Parameter(torch.empty(self.n_var, mid)) |
| self.b2 = nn.Parameter(torch.empty(self.n_var, self.n_out)) |
| else: |
| self.register_parameter("b1", None) |
| self.register_parameter("b2", None) |
|
|
| self.W = self.b = None |
|
|
| def init_weights(self, std: float = 0.02) -> None: |
| """ |
| Initialize weights manually. |
| """ |
| if self.middle_size is None: |
| nn.init.normal_(self.W, std=std) |
| if self.bias: |
| nn.init.zeros_(self.b) |
| else: |
| nn.init.normal_(self.W1, std=std) |
| nn.init.normal_(self.W2, std=std) |
|
|
| if self.bias: |
| nn.init.zeros_(self.b1) |
| nn.init.zeros_(self.b2) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| if x.dim() != 3: |
| raise ValueError(f"Expected x [B,V,n_in], got {tuple(x.shape)}") |
|
|
| B, V, I = x.shape |
|
|
| if V != self.n_var or I != self.n_in: |
| raise ValueError( |
| f"Shape mismatch: expected V={self.n_var}, n_in={self.n_in}; got V={V}, n_in={I}" |
| ) |
|
|
| if self.middle_size is None: |
| y = torch.einsum("bvi,voi->bvo", x, self.W) |
| if self.bias: |
| y = y + self.b.unsqueeze(0) |
| return y |
|
|
| h = torch.einsum("bvi,vmi->bvm", x, self.W1) |
| if self.bias: |
| h = h + self.b1.unsqueeze(0) |
|
|
| h = F.gelu(h) |
|
|
| y = torch.einsum("bvm,vom->bvo", h, self.W2) |
| if self.bias: |
| y = y + self.b2.unsqueeze(0) |
|
|
| return y |
|
|
|
|
| def get_dtype(dtype: Optional[str]) -> torch.dtype: |
| dtype_str = (dtype or "bfloat16").lower() |
| dtype_map = { |
| "bfloat16": torch.bfloat16, |
| "float16": torch.float16, |
| "float32": torch.float32, |
| } |
| if dtype_str not in dtype_map: |
| raise ValueError(f"Unsupported dtype={dtype}. Choose from {list(dtype_map.keys())}") |
| return dtype_map[dtype_str] |
|
|
|
|
| def load_json(path: str): |
| with open(path, "r", encoding="utf-8") as f: |
| return json.load(f) |
|
|
|
|
| def save_json(obj: Dict, path: str) -> None: |
| with open(path, "w", encoding="utf-8") as f: |
| json.dump(obj, f, ensure_ascii=False, indent=2) |
|
|