soilformer / modelling /soilformer.py
Kuangdai
Initial release of SoilFormer
6fb6c07
# soilformer.py
# -*- coding: utf-8 -*-
import json
import os
from pathlib import Path
from typing import Dict, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa
from decode_categorical import CategoricalDecoder
from decode_numeric import NumericDecoder
from embed_categorical import (
CategoricalEmbedding,
build_cat_vocab_spec_from_meta,
get_categorical_feature_names_from_meta,
save_cat_vocab_json,
)
from embed_numeric import (
NumericEmbedding,
build_numeric_vocab_spec_from_meta,
)
from embed_vision_gemma3n import Gemma3nVisionFeatureExtractor
from layer import TabularImageGQALayer
from utils import load_json, save_json, get_dtype
# ============================================================
# SoilFormer
# ============================================================
class SoilFormer(nn.Module):
"""
Full model: embeddings -> TabularImageGQALayer stack -> decoders.
"""
def __init__(self, config: Dict, device: Optional[str] = None):
super().__init__()
self.config = dict(config)
dtype = get_dtype(self.config.get("dtype", "bfloat16"))
dev = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
# ---- Tabular dims
cat_hidden = int(self.config["cat_hidden_size"])
num_hidden = int(self.config["numeric_hidden_size"])
if cat_hidden != num_hidden:
raise ValueError("Expect cat_hidden_size == numeric_hidden_size for one tabular stream.")
self.tabular_dim = cat_hidden
# ---- Embeddings
self.embed_cat = CategoricalEmbedding(
hidden_size=cat_hidden,
cat_vocab_json=self.config["cat_vocab_json"],
)
self.embed_num = NumericEmbedding(
hidden_size=num_hidden,
numeric_vocab_json=self.config["numeric_vocab_json"],
middle_size=self.config.get("numeric_encode_middle_size", None),
)
# ---- Decoders
self.decode_cat = CategoricalDecoder(
hidden_size=cat_hidden,
cat_vocab_json=self.config["cat_vocab_json"],
middle_size=self.config.get("cat_decode_middle_size", None),
homoscedastic=self.config.get("cat_homoscedastic", True),
)
self.decode_num = NumericDecoder(
hidden_size=num_hidden,
numeric_vocab_json=self.config["numeric_vocab_json"],
middle_size=self.config.get("numeric_decode_middle_size", None),
homoscedastic=self.config.get("num_homoscedastic", True),
)
# ---- Vision
self.vision_extractor = Gemma3nVisionFeatureExtractor.from_pretrained_vision_only_dir(
model_dir=self.config["vision_model_dir"],
map_location="cpu",
num_output_tokens_reduced=self.config["vision_num_output_tokens_reduced"],
num_heads_for_token_reduction=self.config["vision_num_heads_for_token_reduction"],
reducer_bottleneck_dim=self.config["vision_reducer_bottleneck_dim"],
reducer_project_back=self.config["vision_reducer_project_back"],
)
# ---- Layers
L = int(self.config["layer_num_layers"])
self.layers = nn.ModuleList([
TabularImageGQALayer(
tabular_dim=self.tabular_dim,
vision_dim=self.vision_extractor.get_actual_hidden_dim(),
num_query_heads=int(self.config["layer_num_query_heads"]),
num_kv_heads=int(self.config["layer_num_kv_heads"]),
head_dim=int(self.config["layer_head_dim"]),
mlp_ratio=float(self.config["layer_mlp_ratio"]),
dropout=float(self.config["layer_dropout"]),
)
for _ in range(L)
])
# ---- Move
self.to(device=dev, dtype=dtype)
def init_weights(self, std: float = 0.02):
self.embed_cat.init_weights(std=std)
self.embed_num.init_weights(std=std)
self.decode_cat.init_weights(std=std)
self.decode_num.init_weights(std=std)
self.vision_extractor.init_weights(std=std)
for blk in self.layers:
blk.init_weights(std=std)
def forward(
self,
cat_local_ids: torch.LongTensor, # [B, M_cat]
numeric_values_by_nin: Dict[int, torch.Tensor], # {n_in: [B, V, n_in]}
cat_valid_positions: Optional[torch.Tensor] = None, # [B, M_cat] bool
numeric_valid_positions_by_nin: Optional[Dict[int, torch.Tensor]] = None, # {n_in: [B,V] bool}
pixel_values: Optional[torch.Tensor] = None, # [B, 3, H, W]
vision_valid_positions: Optional[torch.Tensor] = None, # [B] bool OR indices [K]
):
# ----------------------------
# Embeddings (tabular)
# ----------------------------
x_cat, cat_mask = self.embed_cat(
local_ids=cat_local_ids,
valid_positions=cat_valid_positions,
)
x_num, num_mask = self.embed_num(
values_by_nin=numeric_values_by_nin,
valid_positions_by_nin=numeric_valid_positions_by_nin,
)
x_tab = torch.cat([x_cat, x_num], dim=1) # [B, T_tab, H]
B, T_tab, _ = x_tab.shape
M_cat = x_cat.size(1)
T_num = x_num.size(1)
# ----------------------------
# Tabular attention mask
# ----------------------------
cat_mask = cat_mask.to(device=x_tab.device, dtype=torch.long)
num_mask = num_mask.to(device=x_tab.device, dtype=torch.long)
if self.config["disable_tabular_attention_mask"]:
attention_mask_tab = torch.ones(B, T_tab, device=x_tab.device, dtype=torch.long)
else:
attention_mask_tab = torch.cat([cat_mask, num_mask], dim=1)
if attention_mask_tab.shape != (B, T_tab):
raise RuntimeError("Internal attention_mask_tab shape mismatch")
# ----------------------------
# Vision features
# ----------------------------
if pixel_values is None:
vision_features = None
vision_mask = None
else:
vision_features, vision_mask = self.vision_extractor(
pixel_values=pixel_values,
valid_positions=vision_valid_positions,
)
if vision_features.shape[0] != B:
raise ValueError("vision_features batch mismatch with tabular batch")
if vision_mask.shape[0] != B or vision_mask.shape[1] != vision_features.shape[1]:
raise ValueError("vision_mask shape mismatch with vision_features")
vision_mask = vision_mask.to(
device=attention_mask_tab.device,
dtype=attention_mask_tab.dtype,
)
# ----------------------------
# Transformer blocks
# ----------------------------
for blk in self.layers: # type: TabularImageGQALayer
x_tab = blk(
x_tab=x_tab,
attention_mask=attention_mask_tab,
vision_features=vision_features,
vision_mask=vision_mask
)
# ----------------------------
# Slice outputs
# ----------------------------
x_cat_out = x_tab[:, :M_cat, :]
x_num_out = x_tab[:, M_cat:M_cat + T_num, :]
# ----------------------------
# Decode
# ----------------------------
cat_logits_padded, cat_s, valid_class_mask = self.decode_cat(
x_cat_out,
return_padded=True,
)
value_by_nin, s_by_nin = self.decode_num(
x_num_out
)
return cat_logits_padded, cat_s, valid_class_mask, value_by_nin, s_by_nin, x_tab
def _checkpoint_state_dict(self) -> Dict[str, torch.Tensor]:
"""
State dict used for save/load.
Excludes pretrained frozen vision weights:
- vision_extractor.vision_tower.*
- vision_extractor.embed_vision.*
Keeps reducer weights if reducer exists.
"""
full_sd = self.state_dict()
out = {}
for k, v in full_sd.items():
if k.startswith("vision_extractor.vision_tower."):
continue
if k.startswith("vision_extractor.embed_vision."):
continue
out[k] = v
return out
def save_weights(self, path: str):
"""
Save model weights needed for SoilFormer training/inference,
excluding pretrained frozen vision weights.
"""
payload = {
"model_state_dict": self._checkpoint_state_dict(),
"config": self.config,
}
torch.save(payload, path)
def load_weights(self, path: str, map_location: str = "cpu", strict: bool = True):
"""
Load weights saved by save_weights().
Only the checkpoint-managed subset is loaded:
- embeddings / decoders / layers
- vision_extractor.reducer.* (if present)
Pretrained frozen vision weights are ignored here and are expected
to come from vision_model_dir during model construction.
"""
ckpt = torch.load(path, map_location=map_location)
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
sd = ckpt["model_state_dict"]
elif isinstance(ckpt, dict):
sd = ckpt
else:
raise ValueError(f"Unsupported checkpoint format: {path}")
expected_sd = self._checkpoint_state_dict()
# Only keep keys that belong to the checkpoint-managed subset
loadable_sd = {k: v for k, v in sd.items() if k in expected_sd}
missing = sorted(set(expected_sd.keys()) - set(loadable_sd.keys()))
unexpected = sorted(set(sd.keys()) - set(expected_sd.keys()))
# Actually load
load_info = self.load_state_dict(loadable_sd, strict=False)
# PyTorch may still report missing keys from the full model state_dict;
# keep only checkpoint-managed ones.
missing_after_load = [
k for k in load_info.missing_keys
if k in expected_sd
]
unexpected_after_load = [
k for k in load_info.unexpected_keys
if k in expected_sd
]
# Merge both sources of mismatch info
missing_final = sorted(set(missing) | set(missing_after_load))
unexpected_final = sorted(set(unexpected) | set(unexpected_after_load))
if strict and (missing_final or unexpected_final):
raise RuntimeError(
"Checkpoint load mismatch.\n"
f"Missing keys: {missing_final}\n"
f"Unexpected keys: {unexpected_final}"
)
return {
"missing_keys": missing_final,
"unexpected_keys": unexpected_final,
}
def loss_function(
x_cat: torch.Tensor, # [B,M,Cmax] padded logits
s_cat: torch.Tensor, # [B,M] log-variance
y_cat: torch.Tensor, # [B,M] class index
loss_mask_cat: torch.Tensor, # [B,M] 0/1
valid_class_mask: torch.Tensor, # [M,Cmax] bool
x_num: Dict[int, torch.Tensor], # {n_in: [B,V,n_in]}
s_num: Dict[int, torch.Tensor], # {n_in: [B,V]}
y_num: Dict[int, torch.Tensor], # {n_in: [B,V,n_in]}
loss_mask_num: Dict[int, torch.Tensor], # {n_in: [B,V]} 0/1
cat_temperature: float = 1.0,
reduction: str = "mean", # "mean" or "sum"
eps: float = 1e-12,
cat_s_bound: Optional[float] = None,
num_s_bound: Optional[float] = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Strict loss for SoilFormer.
Categorical:
- Uses per-column CE over the valid class range only.
- Does NOT rely on padded logits values.
- s_cat[b,m] = log sigma^2 for categorical column m.
Numeric:
- Per-variable MSE averaged over n_in dimensions.
- s_num[n_in][b,v] = log sigma^2 for numeric variable v.
Optional soft bound:
If cat_s_bound or num_s_bound is not None, apply
s <- bound * tanh(s / bound)
before using s in heteroscedastic weighting.
Returns:
total_loss: scalar (float32)
stats: dict with cat_loss, num_loss, cat_base, num_base, counts...
"""
def _soft_bound_logvar(s_: torch.Tensor, bound: Optional[float]) -> torch.Tensor:
if bound is None:
return s_
b = float(bound)
if b <= 0:
# Turn off weighting by signalling a non-positive bound
return torch.zeros_like(s_)
return b * torch.tanh(s_ / b)
# ---------------------------------------------------
# 1) Categorical loss (strict per-column CE)
# ---------------------------------------------------
if x_cat.dim() != 3:
raise ValueError(f"x_cat must be [B,M,Cmax], got {tuple(x_cat.shape)}")
B, M, Cmax = x_cat.shape
if s_cat.shape != (B, M):
raise ValueError(f"s_cat must be [B,M]=({B},{M}), got {tuple(s_cat.shape)}")
if y_cat.shape != (B, M):
raise ValueError(f"y_cat must be [B,M]=({B},{M}), got {tuple(y_cat.shape)}")
if loss_mask_cat.shape != (B, M):
raise ValueError(f"loss_mask_cat must be [B,M]=({B},{M}), got {tuple(loss_mask_cat.shape)}")
if valid_class_mask.shape != (M, Cmax):
raise ValueError(
f"valid_class_mask must be [M,Cmax]=({M},{Cmax}), got {tuple(valid_class_mask.shape)}"
)
x_cat_f = x_cat.float()
s_cat_f = _soft_bound_logvar(s_cat.float(), cat_s_bound)
y_cat_l = y_cat.long()
mcat = loss_mask_cat.float()
valid_class_mask = valid_class_mask.to(device=x_cat.device, dtype=torch.bool)
if cat_temperature != 1.0:
x_cat_f = x_cat_f / float(cat_temperature)
cat_loss_acc = torch.zeros((), device=x_cat.device, dtype=torch.float32)
cat_base_acc = torch.zeros((), device=x_cat.device, dtype=torch.float32)
cat_correct_acc = torch.zeros((), device=x_cat.device, dtype=torch.float32)
# denominator = number of actively supervised categorical cells
cat_denom = mcat.sum().clamp_min(float(eps))
for m in range(M):
cm = int(valid_class_mask[m].sum().item()) # real class count for column m
if cm <= 0:
raise ValueError(f"Column {m} has no valid classes")
logits_m = x_cat_f[:, m, :cm] # [B, C_m]
target_m = y_cat_l[:, m] # [B]
s_m = s_cat_f[:, m] # [B]
mask_m = mcat[:, m] # [B]
active = mask_m > 0
if active.any():
tgt_active = target_m[active]
if (tgt_active < 0).any() or (tgt_active >= cm).any():
raise ValueError(f"y_cat contains invalid class id for categorical column {m}")
target_m_safe = target_m.clone()
target_m_safe[~active] = 0
ce_m = F.cross_entropy(
logits_m,
target_m_safe,
reduction="none",
) # [B], float32
# ---------------------------------------------------
# accuracy (only count active positions)
# ---------------------------------------------------
pred_m = logits_m.argmax(dim=-1) # [B]
correct_m = (pred_m == target_m_safe) & active # [B]
cat_correct_acc = cat_correct_acc + correct_m.float().sum()
# heteroscedastic weighting: exp(-s) * CE + s
L_m = torch.exp(-s_m) * ce_m + s_m # [B]
cat_loss_acc = cat_loss_acc + (L_m * mask_m).sum()
cat_base_acc = cat_base_acc + (ce_m * mask_m).sum()
if reduction == "mean":
cat_loss = cat_loss_acc / cat_denom
cat_base = cat_base_acc / cat_denom
elif reduction == "sum":
cat_loss = cat_loss_acc
cat_base = cat_base_acc
else:
raise ValueError(f"Unsupported reduction: {reduction}")
cat_acc = cat_correct_acc / cat_denom
# ---------------------------------------------------
# 2) Numeric loss (per-variable heteroscedastic MSE)
# ---------------------------------------------------
num_loss_acc = torch.zeros((), device=x_cat.device, dtype=torch.float32)
num_base_acc = torch.zeros((), device=x_cat.device, dtype=torch.float32)
num_denom_acc = torch.zeros((), device=x_cat.device, dtype=torch.float32)
for n_in, x in x_num.items():
if n_in not in y_num or n_in not in s_num or n_in not in loss_mask_num:
raise KeyError(f"Missing key n_in={n_in} in y_num/s_num/loss_mask_num")
y = y_num[n_in]
s = s_num[n_in]
m = loss_mask_num[n_in]
if x.shape != y.shape:
raise ValueError(
f"x_num[{n_in}] and y_num[{n_in}] shape mismatch: "
f"{tuple(x.shape)} vs {tuple(y.shape)}"
)
if x.dim() != 3:
raise ValueError(f"x_num[{n_in}] must be [B,V,n_in], got {tuple(x.shape)}")
Bb, V, Nin = x.shape
if Nin != n_in:
raise ValueError(f"x_num[{n_in}] last dim mismatch: got {Nin}, expected {n_in}")
if s.shape != (Bb, V):
raise ValueError(f"s_num[{n_in}] must be [B,V], got {tuple(s.shape)}")
if m.shape != (Bb, V):
raise ValueError(f"loss_mask_num[{n_in}] must be [B,V], got {tuple(m.shape)}")
x_f = x.float()
y_f = y.float()
s_f = _soft_bound_logvar(s.float(), num_s_bound)
m_f = m.float()
# base numeric loss per variable: mean over n_in dims
mse = (x_f - y_f).pow(2).mean(dim=-1) # [B,V]
# heteroscedastic weighting: exp(-s) * mse + s
L = torch.exp(-s_f) * mse + s_f # [B,V]
num_loss_acc = num_loss_acc + (L * m_f).sum()
num_base_acc = num_base_acc + (mse * m_f).sum()
num_denom_acc = num_denom_acc + m_f.sum()
num_denom = num_denom_acc.clamp_min(float(eps))
if reduction == "mean":
num_loss = num_loss_acc / num_denom
num_base = num_base_acc / num_denom
elif reduction == "sum":
num_loss = num_loss_acc
num_base = num_base_acc
else:
raise ValueError(f"Unsupported reduction: {reduction}")
# ---------------------------------------------------
# 3) Total
# ---------------------------------------------------
total = cat_loss + num_loss
stats = {
"total": total.detach(),
"cat_loss": cat_loss.detach(),
"num_loss": num_loss.detach(),
"cat_base": cat_base.detach(),
"num_base": num_base.detach(),
"cat_count": cat_denom.detach(),
"num_count": num_denom.detach(),
"cat_acc": cat_acc.detach(),
}
return total, stats
# ============================================================
# DEMO
# ============================================================
def _demo_main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--config_json", type=str, default="config/config_model.json")
parser.add_argument("--batch_size", type=int, default=2)
parser.add_argument("--with_vision", action="store_true")
args = parser.parse_args()
cfg = load_json(args.config_json)
print("===== Loaded config =====")
print(json.dumps(cfg, ensure_ascii=False, indent=2))
# --------------------------------------------------
# Ensure vocab files exist
# --------------------------------------------------
tabular_meta = load_json(cfg["tabular_meta"])
if not os.path.isfile(cfg["cat_vocab_json"]):
cat_names = get_categorical_feature_names_from_meta(tabular_meta)
vocab = build_cat_vocab_spec_from_meta(tabular_meta, cat_names)
Path(cfg["cat_vocab_json"]).parent.mkdir(parents=True, exist_ok=True)
save_cat_vocab_json(vocab, cfg["cat_vocab_json"])
print(f"[demo] Built cat_vocab_json at {cfg['cat_vocab_json']}")
if not os.path.isfile(cfg["numeric_vocab_json"]):
spec = build_numeric_vocab_spec_from_meta(tabular_meta)
Path(cfg["numeric_vocab_json"]).parent.mkdir(parents=True, exist_ok=True)
save_json(spec, cfg["numeric_vocab_json"])
print(f"[demo] Built numeric_vocab_json at {cfg['numeric_vocab_json']}")
# --------------------------------------------------
# Build model
# --------------------------------------------------
model = SoilFormer(cfg)
model.init_weights()
model.eval()
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
B = args.batch_size
# --------------------------------------------------
# Build dummy categorical inputs
# --------------------------------------------------
cat_spec = load_json(cfg["cat_vocab_json"])
cat_items = sorted(cat_spec.items(), key=lambda x: x[1]["col_id"])
M_cat = len(cat_items)
cat_local_ids = torch.zeros(B, M_cat, dtype=torch.long, device=device)
cat_valid_positions = torch.ones(B, M_cat, dtype=torch.bool, device=device)
# --------------------------------------------------
# Build dummy numeric inputs
# --------------------------------------------------
num_spec = load_json(cfg["numeric_vocab_json"])
numeric_values_by_nin: Dict[int, torch.Tensor] = {}
numeric_valid_positions_by_nin: Dict[int, torch.Tensor] = {}
for g in num_spec["groups"]:
n_in = int(g["n_in"])
V = len(g["feature_names"])
numeric_values_by_nin[n_in] = torch.randn(B, V, n_in, device=device, dtype=dtype)
numeric_valid_positions_by_nin[n_in] = torch.ones(B, V, dtype=torch.bool, device=device)
# --------------------------------------------------
# Build dummy vision inputs
# --------------------------------------------------
if args.with_vision:
pixel_values = torch.randn(B, 3, 224, 224, device=device, dtype=dtype)
vision_valid_positions = torch.ones(B, dtype=torch.bool, device=device)
else:
pixel_values = None
vision_valid_positions = None
# --------------------------------------------------
# Vision debug
# --------------------------------------------------
print("\n===== Vision debug =====")
if pixel_values is None:
print("pixel_values: None")
print("vision_features: None")
print("vision_mask: None")
else:
print("pixel_values:", tuple(pixel_values.shape), pixel_values.dtype, pixel_values.device)
with torch.no_grad():
vision_features, vision_mask = model.vision_extractor.forward(
pixel_values=pixel_values,
valid_positions=vision_valid_positions,
)
print("vision_features:", tuple(vision_features.shape), vision_features.dtype, vision_features.device)
print("vision_mask:", tuple(vision_mask.shape), vision_mask.dtype, vision_mask.device)
# --------------------------------------------------
# Forward
# --------------------------------------------------
with torch.no_grad():
cat_logits_padded, cat_s, valid_class_mask, value_by_nin, s_by_nin, x_tab = model.forward(
cat_local_ids=cat_local_ids, # noqa
numeric_values_by_nin=numeric_values_by_nin,
cat_valid_positions=cat_valid_positions,
numeric_valid_positions_by_nin=numeric_valid_positions_by_nin,
pixel_values=pixel_values,
vision_valid_positions=vision_valid_positions,
)
print("\n===== SoilFormer demo =====")
print("cat_local_ids:", tuple(cat_local_ids.shape))
print("cat_valid_positions:", tuple(cat_valid_positions.shape))
print("numeric_values_by_nin:", {k: tuple(v.shape) for k, v in numeric_values_by_nin.items()})
print("numeric_valid_positions_by_nin:", {k: tuple(v.shape) for k, v in numeric_valid_positions_by_nin.items()})
print("x_tab_final:", tuple(x_tab.shape), x_tab.dtype, x_tab.device)
print("Categorical outputs:")
print("cat_logits_padded:", tuple(cat_logits_padded.shape), cat_logits_padded.dtype, cat_logits_padded.device)
print("cat_s:", tuple(cat_s.shape), cat_s.dtype, cat_s.device)
print("Numeric decoded values:", {k: tuple(v.shape) for k, v in value_by_nin.items()})
print("Numeric decoded s:", {k: tuple(s.shape) for k, s in s_by_nin.items()})
# --------------------------------------------------
# Loss debug
# --------------------------------------------------
print("\n===== Loss debug =====")
if cat_logits_padded.dim() != 3:
raise RuntimeError(f"cat_logits_padded must be [B,M,Cmax], got {tuple(cat_logits_padded.shape)}")
B_logits, M_cat2, Cmax2 = cat_logits_padded.shape
if cat_s.shape != (B_logits, M_cat2):
raise RuntimeError(f"cat_s shape mismatch: got {tuple(cat_s.shape)} expected {(B_logits, M_cat2)}")
# Build dummy categorical targets within valid class ranges
num_classes = [int(s["num_classes"]) for _, s in cat_items]
if len(num_classes) != M_cat2:
raise RuntimeError("M_cat mismatch between vocab and model output")
y_cat = torch.zeros(B_logits, M_cat2, dtype=torch.long, device=device)
for m, cm in enumerate(num_classes):
y_cat[:, m] = torch.randint(low=0, high=cm, size=(B_logits,), device=device)
mask_cat = torch.ones(B_logits, M_cat2, dtype=torch.long, device=device)
# Build dummy numeric targets and masks
y_num = {
n_in: torch.randn_like(x_pred)
for n_in, x_pred in value_by_nin.items()
}
mask_num = {
n_in: torch.ones(x_pred.size(0), x_pred.size(1), dtype=torch.long, device=x_pred.device)
for n_in, x_pred in value_by_nin.items()
}
total_loss, stats = loss_function(
x_cat=cat_logits_padded,
s_cat=cat_s,
y_cat=y_cat,
loss_mask_cat=mask_cat,
x_num=value_by_nin,
s_num=s_by_nin,
y_num=y_num,
loss_mask_num=mask_num,
reduction="mean",
valid_class_mask=valid_class_mask
)
print("total_loss:", float(total_loss))
print("stats:", {k: float(v) for k, v in stats.items()})
if not torch.isfinite(total_loss):
raise RuntimeError("Loss is not finite!")
if __name__ == "__main__":
_demo_main()