# soilformer.py # -*- coding: utf-8 -*- import json import os from pathlib import Path from typing import Dict, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F # noqa from decode_categorical import CategoricalDecoder from decode_numeric import NumericDecoder from embed_categorical import ( CategoricalEmbedding, build_cat_vocab_spec_from_meta, get_categorical_feature_names_from_meta, save_cat_vocab_json, ) from embed_numeric import ( NumericEmbedding, build_numeric_vocab_spec_from_meta, ) from embed_vision_gemma3n import Gemma3nVisionFeatureExtractor from layer import TabularImageGQALayer from utils import load_json, save_json, get_dtype # ============================================================ # SoilFormer # ============================================================ class SoilFormer(nn.Module): """ Full model: embeddings -> TabularImageGQALayer stack -> decoders. """ def __init__(self, config: Dict, device: Optional[str] = None): super().__init__() self.config = dict(config) dtype = get_dtype(self.config.get("dtype", "bfloat16")) dev = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu")) # ---- Tabular dims cat_hidden = int(self.config["cat_hidden_size"]) num_hidden = int(self.config["numeric_hidden_size"]) if cat_hidden != num_hidden: raise ValueError("Expect cat_hidden_size == numeric_hidden_size for one tabular stream.") self.tabular_dim = cat_hidden # ---- Embeddings self.embed_cat = CategoricalEmbedding( hidden_size=cat_hidden, cat_vocab_json=self.config["cat_vocab_json"], ) self.embed_num = NumericEmbedding( hidden_size=num_hidden, numeric_vocab_json=self.config["numeric_vocab_json"], middle_size=self.config.get("numeric_encode_middle_size", None), ) # ---- Decoders self.decode_cat = CategoricalDecoder( hidden_size=cat_hidden, cat_vocab_json=self.config["cat_vocab_json"], middle_size=self.config.get("cat_decode_middle_size", None), homoscedastic=self.config.get("cat_homoscedastic", True), ) self.decode_num = NumericDecoder( hidden_size=num_hidden, numeric_vocab_json=self.config["numeric_vocab_json"], middle_size=self.config.get("numeric_decode_middle_size", None), homoscedastic=self.config.get("num_homoscedastic", True), ) # ---- Vision self.vision_extractor = Gemma3nVisionFeatureExtractor.from_pretrained_vision_only_dir( model_dir=self.config["vision_model_dir"], map_location="cpu", num_output_tokens_reduced=self.config["vision_num_output_tokens_reduced"], num_heads_for_token_reduction=self.config["vision_num_heads_for_token_reduction"], reducer_bottleneck_dim=self.config["vision_reducer_bottleneck_dim"], reducer_project_back=self.config["vision_reducer_project_back"], ) # ---- Layers L = int(self.config["layer_num_layers"]) self.layers = nn.ModuleList([ TabularImageGQALayer( tabular_dim=self.tabular_dim, vision_dim=self.vision_extractor.get_actual_hidden_dim(), num_query_heads=int(self.config["layer_num_query_heads"]), num_kv_heads=int(self.config["layer_num_kv_heads"]), head_dim=int(self.config["layer_head_dim"]), mlp_ratio=float(self.config["layer_mlp_ratio"]), dropout=float(self.config["layer_dropout"]), ) for _ in range(L) ]) # ---- Move self.to(device=dev, dtype=dtype) def init_weights(self, std: float = 0.02): self.embed_cat.init_weights(std=std) self.embed_num.init_weights(std=std) self.decode_cat.init_weights(std=std) self.decode_num.init_weights(std=std) self.vision_extractor.init_weights(std=std) for blk in self.layers: blk.init_weights(std=std) def forward( self, cat_local_ids: torch.LongTensor, # [B, M_cat] numeric_values_by_nin: Dict[int, torch.Tensor], # {n_in: [B, V, n_in]} cat_valid_positions: Optional[torch.Tensor] = None, # [B, M_cat] bool numeric_valid_positions_by_nin: Optional[Dict[int, torch.Tensor]] = None, # {n_in: [B,V] bool} pixel_values: Optional[torch.Tensor] = None, # [B, 3, H, W] vision_valid_positions: Optional[torch.Tensor] = None, # [B] bool OR indices [K] ): # ---------------------------- # Embeddings (tabular) # ---------------------------- x_cat, cat_mask = self.embed_cat( local_ids=cat_local_ids, valid_positions=cat_valid_positions, ) x_num, num_mask = self.embed_num( values_by_nin=numeric_values_by_nin, valid_positions_by_nin=numeric_valid_positions_by_nin, ) x_tab = torch.cat([x_cat, x_num], dim=1) # [B, T_tab, H] B, T_tab, _ = x_tab.shape M_cat = x_cat.size(1) T_num = x_num.size(1) # ---------------------------- # Tabular attention mask # ---------------------------- cat_mask = cat_mask.to(device=x_tab.device, dtype=torch.long) num_mask = num_mask.to(device=x_tab.device, dtype=torch.long) if self.config["disable_tabular_attention_mask"]: attention_mask_tab = torch.ones(B, T_tab, device=x_tab.device, dtype=torch.long) else: attention_mask_tab = torch.cat([cat_mask, num_mask], dim=1) if attention_mask_tab.shape != (B, T_tab): raise RuntimeError("Internal attention_mask_tab shape mismatch") # ---------------------------- # Vision features # ---------------------------- if pixel_values is None: vision_features = None vision_mask = None else: vision_features, vision_mask = self.vision_extractor( pixel_values=pixel_values, valid_positions=vision_valid_positions, ) if vision_features.shape[0] != B: raise ValueError("vision_features batch mismatch with tabular batch") if vision_mask.shape[0] != B or vision_mask.shape[1] != vision_features.shape[1]: raise ValueError("vision_mask shape mismatch with vision_features") vision_mask = vision_mask.to( device=attention_mask_tab.device, dtype=attention_mask_tab.dtype, ) # ---------------------------- # Transformer blocks # ---------------------------- for blk in self.layers: # type: TabularImageGQALayer x_tab = blk( x_tab=x_tab, attention_mask=attention_mask_tab, vision_features=vision_features, vision_mask=vision_mask ) # ---------------------------- # Slice outputs # ---------------------------- x_cat_out = x_tab[:, :M_cat, :] x_num_out = x_tab[:, M_cat:M_cat + T_num, :] # ---------------------------- # Decode # ---------------------------- cat_logits_padded, cat_s, valid_class_mask = self.decode_cat( x_cat_out, return_padded=True, ) value_by_nin, s_by_nin = self.decode_num( x_num_out ) return cat_logits_padded, cat_s, valid_class_mask, value_by_nin, s_by_nin, x_tab def _checkpoint_state_dict(self) -> Dict[str, torch.Tensor]: """ State dict used for save/load. Excludes pretrained frozen vision weights: - vision_extractor.vision_tower.* - vision_extractor.embed_vision.* Keeps reducer weights if reducer exists. """ full_sd = self.state_dict() out = {} for k, v in full_sd.items(): if k.startswith("vision_extractor.vision_tower."): continue if k.startswith("vision_extractor.embed_vision."): continue out[k] = v return out def save_weights(self, path: str): """ Save model weights needed for SoilFormer training/inference, excluding pretrained frozen vision weights. """ payload = { "model_state_dict": self._checkpoint_state_dict(), "config": self.config, } torch.save(payload, path) def load_weights(self, path: str, map_location: str = "cpu", strict: bool = True): """ Load weights saved by save_weights(). Only the checkpoint-managed subset is loaded: - embeddings / decoders / layers - vision_extractor.reducer.* (if present) Pretrained frozen vision weights are ignored here and are expected to come from vision_model_dir during model construction. """ ckpt = torch.load(path, map_location=map_location) if isinstance(ckpt, dict) and "model_state_dict" in ckpt: sd = ckpt["model_state_dict"] elif isinstance(ckpt, dict): sd = ckpt else: raise ValueError(f"Unsupported checkpoint format: {path}") expected_sd = self._checkpoint_state_dict() # Only keep keys that belong to the checkpoint-managed subset loadable_sd = {k: v for k, v in sd.items() if k in expected_sd} missing = sorted(set(expected_sd.keys()) - set(loadable_sd.keys())) unexpected = sorted(set(sd.keys()) - set(expected_sd.keys())) # Actually load load_info = self.load_state_dict(loadable_sd, strict=False) # PyTorch may still report missing keys from the full model state_dict; # keep only checkpoint-managed ones. missing_after_load = [ k for k in load_info.missing_keys if k in expected_sd ] unexpected_after_load = [ k for k in load_info.unexpected_keys if k in expected_sd ] # Merge both sources of mismatch info missing_final = sorted(set(missing) | set(missing_after_load)) unexpected_final = sorted(set(unexpected) | set(unexpected_after_load)) if strict and (missing_final or unexpected_final): raise RuntimeError( "Checkpoint load mismatch.\n" f"Missing keys: {missing_final}\n" f"Unexpected keys: {unexpected_final}" ) return { "missing_keys": missing_final, "unexpected_keys": unexpected_final, } def loss_function( x_cat: torch.Tensor, # [B,M,Cmax] padded logits s_cat: torch.Tensor, # [B,M] log-variance y_cat: torch.Tensor, # [B,M] class index loss_mask_cat: torch.Tensor, # [B,M] 0/1 valid_class_mask: torch.Tensor, # [M,Cmax] bool x_num: Dict[int, torch.Tensor], # {n_in: [B,V,n_in]} s_num: Dict[int, torch.Tensor], # {n_in: [B,V]} y_num: Dict[int, torch.Tensor], # {n_in: [B,V,n_in]} loss_mask_num: Dict[int, torch.Tensor], # {n_in: [B,V]} 0/1 cat_temperature: float = 1.0, reduction: str = "mean", # "mean" or "sum" eps: float = 1e-12, cat_s_bound: Optional[float] = None, num_s_bound: Optional[float] = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Strict loss for SoilFormer. Categorical: - Uses per-column CE over the valid class range only. - Does NOT rely on padded logits values. - s_cat[b,m] = log sigma^2 for categorical column m. Numeric: - Per-variable MSE averaged over n_in dimensions. - s_num[n_in][b,v] = log sigma^2 for numeric variable v. Optional soft bound: If cat_s_bound or num_s_bound is not None, apply s <- bound * tanh(s / bound) before using s in heteroscedastic weighting. Returns: total_loss: scalar (float32) stats: dict with cat_loss, num_loss, cat_base, num_base, counts... """ def _soft_bound_logvar(s_: torch.Tensor, bound: Optional[float]) -> torch.Tensor: if bound is None: return s_ b = float(bound) if b <= 0: # Turn off weighting by signalling a non-positive bound return torch.zeros_like(s_) return b * torch.tanh(s_ / b) # --------------------------------------------------- # 1) Categorical loss (strict per-column CE) # --------------------------------------------------- if x_cat.dim() != 3: raise ValueError(f"x_cat must be [B,M,Cmax], got {tuple(x_cat.shape)}") B, M, Cmax = x_cat.shape if s_cat.shape != (B, M): raise ValueError(f"s_cat must be [B,M]=({B},{M}), got {tuple(s_cat.shape)}") if y_cat.shape != (B, M): raise ValueError(f"y_cat must be [B,M]=({B},{M}), got {tuple(y_cat.shape)}") if loss_mask_cat.shape != (B, M): raise ValueError(f"loss_mask_cat must be [B,M]=({B},{M}), got {tuple(loss_mask_cat.shape)}") if valid_class_mask.shape != (M, Cmax): raise ValueError( f"valid_class_mask must be [M,Cmax]=({M},{Cmax}), got {tuple(valid_class_mask.shape)}" ) x_cat_f = x_cat.float() s_cat_f = _soft_bound_logvar(s_cat.float(), cat_s_bound) y_cat_l = y_cat.long() mcat = loss_mask_cat.float() valid_class_mask = valid_class_mask.to(device=x_cat.device, dtype=torch.bool) if cat_temperature != 1.0: x_cat_f = x_cat_f / float(cat_temperature) cat_loss_acc = torch.zeros((), device=x_cat.device, dtype=torch.float32) cat_base_acc = torch.zeros((), device=x_cat.device, dtype=torch.float32) cat_correct_acc = torch.zeros((), device=x_cat.device, dtype=torch.float32) # denominator = number of actively supervised categorical cells cat_denom = mcat.sum().clamp_min(float(eps)) for m in range(M): cm = int(valid_class_mask[m].sum().item()) # real class count for column m if cm <= 0: raise ValueError(f"Column {m} has no valid classes") logits_m = x_cat_f[:, m, :cm] # [B, C_m] target_m = y_cat_l[:, m] # [B] s_m = s_cat_f[:, m] # [B] mask_m = mcat[:, m] # [B] active = mask_m > 0 if active.any(): tgt_active = target_m[active] if (tgt_active < 0).any() or (tgt_active >= cm).any(): raise ValueError(f"y_cat contains invalid class id for categorical column {m}") target_m_safe = target_m.clone() target_m_safe[~active] = 0 ce_m = F.cross_entropy( logits_m, target_m_safe, reduction="none", ) # [B], float32 # --------------------------------------------------- # accuracy (only count active positions) # --------------------------------------------------- pred_m = logits_m.argmax(dim=-1) # [B] correct_m = (pred_m == target_m_safe) & active # [B] cat_correct_acc = cat_correct_acc + correct_m.float().sum() # heteroscedastic weighting: exp(-s) * CE + s L_m = torch.exp(-s_m) * ce_m + s_m # [B] cat_loss_acc = cat_loss_acc + (L_m * mask_m).sum() cat_base_acc = cat_base_acc + (ce_m * mask_m).sum() if reduction == "mean": cat_loss = cat_loss_acc / cat_denom cat_base = cat_base_acc / cat_denom elif reduction == "sum": cat_loss = cat_loss_acc cat_base = cat_base_acc else: raise ValueError(f"Unsupported reduction: {reduction}") cat_acc = cat_correct_acc / cat_denom # --------------------------------------------------- # 2) Numeric loss (per-variable heteroscedastic MSE) # --------------------------------------------------- num_loss_acc = torch.zeros((), device=x_cat.device, dtype=torch.float32) num_base_acc = torch.zeros((), device=x_cat.device, dtype=torch.float32) num_denom_acc = torch.zeros((), device=x_cat.device, dtype=torch.float32) for n_in, x in x_num.items(): if n_in not in y_num or n_in not in s_num or n_in not in loss_mask_num: raise KeyError(f"Missing key n_in={n_in} in y_num/s_num/loss_mask_num") y = y_num[n_in] s = s_num[n_in] m = loss_mask_num[n_in] if x.shape != y.shape: raise ValueError( f"x_num[{n_in}] and y_num[{n_in}] shape mismatch: " f"{tuple(x.shape)} vs {tuple(y.shape)}" ) if x.dim() != 3: raise ValueError(f"x_num[{n_in}] must be [B,V,n_in], got {tuple(x.shape)}") Bb, V, Nin = x.shape if Nin != n_in: raise ValueError(f"x_num[{n_in}] last dim mismatch: got {Nin}, expected {n_in}") if s.shape != (Bb, V): raise ValueError(f"s_num[{n_in}] must be [B,V], got {tuple(s.shape)}") if m.shape != (Bb, V): raise ValueError(f"loss_mask_num[{n_in}] must be [B,V], got {tuple(m.shape)}") x_f = x.float() y_f = y.float() s_f = _soft_bound_logvar(s.float(), num_s_bound) m_f = m.float() # base numeric loss per variable: mean over n_in dims mse = (x_f - y_f).pow(2).mean(dim=-1) # [B,V] # heteroscedastic weighting: exp(-s) * mse + s L = torch.exp(-s_f) * mse + s_f # [B,V] num_loss_acc = num_loss_acc + (L * m_f).sum() num_base_acc = num_base_acc + (mse * m_f).sum() num_denom_acc = num_denom_acc + m_f.sum() num_denom = num_denom_acc.clamp_min(float(eps)) if reduction == "mean": num_loss = num_loss_acc / num_denom num_base = num_base_acc / num_denom elif reduction == "sum": num_loss = num_loss_acc num_base = num_base_acc else: raise ValueError(f"Unsupported reduction: {reduction}") # --------------------------------------------------- # 3) Total # --------------------------------------------------- total = cat_loss + num_loss stats = { "total": total.detach(), "cat_loss": cat_loss.detach(), "num_loss": num_loss.detach(), "cat_base": cat_base.detach(), "num_base": num_base.detach(), "cat_count": cat_denom.detach(), "num_count": num_denom.detach(), "cat_acc": cat_acc.detach(), } return total, stats # ============================================================ # DEMO # ============================================================ def _demo_main(): import argparse parser = argparse.ArgumentParser() parser.add_argument("--config_json", type=str, default="config/config_model.json") parser.add_argument("--batch_size", type=int, default=2) parser.add_argument("--with_vision", action="store_true") args = parser.parse_args() cfg = load_json(args.config_json) print("===== Loaded config =====") print(json.dumps(cfg, ensure_ascii=False, indent=2)) # -------------------------------------------------- # Ensure vocab files exist # -------------------------------------------------- tabular_meta = load_json(cfg["tabular_meta"]) if not os.path.isfile(cfg["cat_vocab_json"]): cat_names = get_categorical_feature_names_from_meta(tabular_meta) vocab = build_cat_vocab_spec_from_meta(tabular_meta, cat_names) Path(cfg["cat_vocab_json"]).parent.mkdir(parents=True, exist_ok=True) save_cat_vocab_json(vocab, cfg["cat_vocab_json"]) print(f"[demo] Built cat_vocab_json at {cfg['cat_vocab_json']}") if not os.path.isfile(cfg["numeric_vocab_json"]): spec = build_numeric_vocab_spec_from_meta(tabular_meta) Path(cfg["numeric_vocab_json"]).parent.mkdir(parents=True, exist_ok=True) save_json(spec, cfg["numeric_vocab_json"]) print(f"[demo] Built numeric_vocab_json at {cfg['numeric_vocab_json']}") # -------------------------------------------------- # Build model # -------------------------------------------------- model = SoilFormer(cfg) model.init_weights() model.eval() device = next(model.parameters()).device dtype = next(model.parameters()).dtype B = args.batch_size # -------------------------------------------------- # Build dummy categorical inputs # -------------------------------------------------- cat_spec = load_json(cfg["cat_vocab_json"]) cat_items = sorted(cat_spec.items(), key=lambda x: x[1]["col_id"]) M_cat = len(cat_items) cat_local_ids = torch.zeros(B, M_cat, dtype=torch.long, device=device) cat_valid_positions = torch.ones(B, M_cat, dtype=torch.bool, device=device) # -------------------------------------------------- # Build dummy numeric inputs # -------------------------------------------------- num_spec = load_json(cfg["numeric_vocab_json"]) numeric_values_by_nin: Dict[int, torch.Tensor] = {} numeric_valid_positions_by_nin: Dict[int, torch.Tensor] = {} for g in num_spec["groups"]: n_in = int(g["n_in"]) V = len(g["feature_names"]) numeric_values_by_nin[n_in] = torch.randn(B, V, n_in, device=device, dtype=dtype) numeric_valid_positions_by_nin[n_in] = torch.ones(B, V, dtype=torch.bool, device=device) # -------------------------------------------------- # Build dummy vision inputs # -------------------------------------------------- if args.with_vision: pixel_values = torch.randn(B, 3, 224, 224, device=device, dtype=dtype) vision_valid_positions = torch.ones(B, dtype=torch.bool, device=device) else: pixel_values = None vision_valid_positions = None # -------------------------------------------------- # Vision debug # -------------------------------------------------- print("\n===== Vision debug =====") if pixel_values is None: print("pixel_values: None") print("vision_features: None") print("vision_mask: None") else: print("pixel_values:", tuple(pixel_values.shape), pixel_values.dtype, pixel_values.device) with torch.no_grad(): vision_features, vision_mask = model.vision_extractor.forward( pixel_values=pixel_values, valid_positions=vision_valid_positions, ) print("vision_features:", tuple(vision_features.shape), vision_features.dtype, vision_features.device) print("vision_mask:", tuple(vision_mask.shape), vision_mask.dtype, vision_mask.device) # -------------------------------------------------- # Forward # -------------------------------------------------- with torch.no_grad(): cat_logits_padded, cat_s, valid_class_mask, value_by_nin, s_by_nin, x_tab = model.forward( cat_local_ids=cat_local_ids, # noqa numeric_values_by_nin=numeric_values_by_nin, cat_valid_positions=cat_valid_positions, numeric_valid_positions_by_nin=numeric_valid_positions_by_nin, pixel_values=pixel_values, vision_valid_positions=vision_valid_positions, ) print("\n===== SoilFormer demo =====") print("cat_local_ids:", tuple(cat_local_ids.shape)) print("cat_valid_positions:", tuple(cat_valid_positions.shape)) print("numeric_values_by_nin:", {k: tuple(v.shape) for k, v in numeric_values_by_nin.items()}) print("numeric_valid_positions_by_nin:", {k: tuple(v.shape) for k, v in numeric_valid_positions_by_nin.items()}) print("x_tab_final:", tuple(x_tab.shape), x_tab.dtype, x_tab.device) print("Categorical outputs:") print("cat_logits_padded:", tuple(cat_logits_padded.shape), cat_logits_padded.dtype, cat_logits_padded.device) print("cat_s:", tuple(cat_s.shape), cat_s.dtype, cat_s.device) print("Numeric decoded values:", {k: tuple(v.shape) for k, v in value_by_nin.items()}) print("Numeric decoded s:", {k: tuple(s.shape) for k, s in s_by_nin.items()}) # -------------------------------------------------- # Loss debug # -------------------------------------------------- print("\n===== Loss debug =====") if cat_logits_padded.dim() != 3: raise RuntimeError(f"cat_logits_padded must be [B,M,Cmax], got {tuple(cat_logits_padded.shape)}") B_logits, M_cat2, Cmax2 = cat_logits_padded.shape if cat_s.shape != (B_logits, M_cat2): raise RuntimeError(f"cat_s shape mismatch: got {tuple(cat_s.shape)} expected {(B_logits, M_cat2)}") # Build dummy categorical targets within valid class ranges num_classes = [int(s["num_classes"]) for _, s in cat_items] if len(num_classes) != M_cat2: raise RuntimeError("M_cat mismatch between vocab and model output") y_cat = torch.zeros(B_logits, M_cat2, dtype=torch.long, device=device) for m, cm in enumerate(num_classes): y_cat[:, m] = torch.randint(low=0, high=cm, size=(B_logits,), device=device) mask_cat = torch.ones(B_logits, M_cat2, dtype=torch.long, device=device) # Build dummy numeric targets and masks y_num = { n_in: torch.randn_like(x_pred) for n_in, x_pred in value_by_nin.items() } mask_num = { n_in: torch.ones(x_pred.size(0), x_pred.size(1), dtype=torch.long, device=x_pred.device) for n_in, x_pred in value_by_nin.items() } total_loss, stats = loss_function( x_cat=cat_logits_padded, s_cat=cat_s, y_cat=y_cat, loss_mask_cat=mask_cat, x_num=value_by_nin, s_num=s_by_nin, y_num=y_num, loss_mask_num=mask_num, reduction="mean", valid_class_mask=valid_class_mask ) print("total_loss:", float(total_loss)) print("stats:", {k: float(v) for k, v in stats.items()}) if not torch.isfinite(total_loss): raise RuntimeError("Loss is not finite!") if __name__ == "__main__": _demo_main()