File size: 3,785 Bytes
6fb6c07 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | # utils.py
# -*- coding: utf-8 -*-
import json
from typing import Dict
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa
class GroupedMLP(nn.Module):
"""
Batched per-variable MLP for a fixed n_in bucket.
Input: X [B, V, n_in]
Output: Y [B, V, n_out]
Per-variable weights (NOT shared across V):
- 1-layer: W [V, n_out, n_in], b [V, n_out]
- 2-layer: W1 [V, mid, n_in], b1 [V, mid]
W2 [V, n_out, mid], b2 [V, n_out]
"""
def __init__(
self,
n_var: int,
n_in: int,
n_out: int,
middle_size: Optional[int] = None,
bias: bool = True,
):
super().__init__()
self.n_var = int(n_var)
self.n_in = int(n_in)
self.n_out = int(n_out)
self.middle_size = None if middle_size is None else int(middle_size)
self.bias = bias
if self.middle_size is None:
self.W = nn.Parameter(torch.empty(self.n_var, self.n_out, self.n_in))
if bias:
self.b = nn.Parameter(torch.empty(self.n_var, self.n_out))
else:
self.register_parameter("b", None)
self.W1 = self.b1 = self.W2 = self.b2 = None
else:
mid = self.middle_size
self.W1 = nn.Parameter(torch.empty(self.n_var, mid, self.n_in))
self.W2 = nn.Parameter(torch.empty(self.n_var, self.n_out, mid))
if bias:
self.b1 = nn.Parameter(torch.empty(self.n_var, mid))
self.b2 = nn.Parameter(torch.empty(self.n_var, self.n_out))
else:
self.register_parameter("b1", None)
self.register_parameter("b2", None)
self.W = self.b = None
def init_weights(self, std: float = 0.02) -> None:
"""
Initialize weights manually.
"""
if self.middle_size is None:
nn.init.normal_(self.W, std=std)
if self.bias:
nn.init.zeros_(self.b)
else:
nn.init.normal_(self.W1, std=std)
nn.init.normal_(self.W2, std=std)
if self.bias:
nn.init.zeros_(self.b1)
nn.init.zeros_(self.b2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.dim() != 3:
raise ValueError(f"Expected x [B,V,n_in], got {tuple(x.shape)}")
B, V, I = x.shape
if V != self.n_var or I != self.n_in:
raise ValueError(
f"Shape mismatch: expected V={self.n_var}, n_in={self.n_in}; got V={V}, n_in={I}"
)
if self.middle_size is None:
y = torch.einsum("bvi,voi->bvo", x, self.W)
if self.bias:
y = y + self.b.unsqueeze(0)
return y
h = torch.einsum("bvi,vmi->bvm", x, self.W1)
if self.bias:
h = h + self.b1.unsqueeze(0)
h = F.gelu(h)
y = torch.einsum("bvm,vom->bvo", h, self.W2)
if self.bias:
y = y + self.b2.unsqueeze(0)
return y
def get_dtype(dtype: Optional[str]) -> torch.dtype:
dtype_str = (dtype or "bfloat16").lower()
dtype_map = {
"bfloat16": torch.bfloat16,
"float16": torch.float16,
"float32": torch.float32,
}
if dtype_str not in dtype_map:
raise ValueError(f"Unsupported dtype={dtype}. Choose from {list(dtype_map.keys())}")
return dtype_map[dtype_str]
def load_json(path: str):
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
def save_json(obj: Dict, path: str) -> None:
with open(path, "w", encoding="utf-8") as f:
json.dump(obj, f, ensure_ascii=False, indent=2) # noqa
|