| |
| |
|
|
| import json |
| import os |
| from pathlib import Path |
| from typing import Dict, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from decode_categorical import CategoricalDecoder |
| from decode_numeric import NumericDecoder |
| from embed_categorical import ( |
| CategoricalEmbedding, |
| build_cat_vocab_spec_from_meta, |
| get_categorical_feature_names_from_meta, |
| save_cat_vocab_json, |
| ) |
| from embed_numeric import ( |
| NumericEmbedding, |
| build_numeric_vocab_spec_from_meta, |
| ) |
| from embed_vision_gemma3n import Gemma3nVisionFeatureExtractor |
| from layer import TabularImageGQALayer |
| from utils import load_json, save_json, get_dtype |
|
|
|
|
| |
| |
| |
|
|
| class SoilFormer(nn.Module): |
| """ |
| Full model: embeddings -> TabularImageGQALayer stack -> decoders. |
| """ |
|
|
| def __init__(self, config: Dict, device: Optional[str] = None): |
| super().__init__() |
| self.config = dict(config) |
|
|
| dtype = get_dtype(self.config.get("dtype", "bfloat16")) |
| dev = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu")) |
|
|
| |
| cat_hidden = int(self.config["cat_hidden_size"]) |
| num_hidden = int(self.config["numeric_hidden_size"]) |
| if cat_hidden != num_hidden: |
| raise ValueError("Expect cat_hidden_size == numeric_hidden_size for one tabular stream.") |
| self.tabular_dim = cat_hidden |
|
|
| |
| self.embed_cat = CategoricalEmbedding( |
| hidden_size=cat_hidden, |
| cat_vocab_json=self.config["cat_vocab_json"], |
| ) |
| self.embed_num = NumericEmbedding( |
| hidden_size=num_hidden, |
| numeric_vocab_json=self.config["numeric_vocab_json"], |
| middle_size=self.config.get("numeric_encode_middle_size", None), |
| ) |
|
|
| |
| self.decode_cat = CategoricalDecoder( |
| hidden_size=cat_hidden, |
| cat_vocab_json=self.config["cat_vocab_json"], |
| middle_size=self.config.get("cat_decode_middle_size", None), |
| homoscedastic=self.config.get("cat_homoscedastic", True), |
| ) |
| self.decode_num = NumericDecoder( |
| hidden_size=num_hidden, |
| numeric_vocab_json=self.config["numeric_vocab_json"], |
| middle_size=self.config.get("numeric_decode_middle_size", None), |
| homoscedastic=self.config.get("num_homoscedastic", True), |
| ) |
|
|
| |
| self.vision_extractor = Gemma3nVisionFeatureExtractor.from_pretrained_vision_only_dir( |
| model_dir=self.config["vision_model_dir"], |
| map_location="cpu", |
| num_output_tokens_reduced=self.config["vision_num_output_tokens_reduced"], |
| num_heads_for_token_reduction=self.config["vision_num_heads_for_token_reduction"], |
| reducer_bottleneck_dim=self.config["vision_reducer_bottleneck_dim"], |
| reducer_project_back=self.config["vision_reducer_project_back"], |
| ) |
|
|
| |
| L = int(self.config["layer_num_layers"]) |
| self.layers = nn.ModuleList([ |
| TabularImageGQALayer( |
| tabular_dim=self.tabular_dim, |
| vision_dim=self.vision_extractor.get_actual_hidden_dim(), |
| num_query_heads=int(self.config["layer_num_query_heads"]), |
| num_kv_heads=int(self.config["layer_num_kv_heads"]), |
| head_dim=int(self.config["layer_head_dim"]), |
| mlp_ratio=float(self.config["layer_mlp_ratio"]), |
| dropout=float(self.config["layer_dropout"]), |
| ) |
| for _ in range(L) |
| ]) |
|
|
| |
| self.to(device=dev, dtype=dtype) |
|
|
| def init_weights(self, std: float = 0.02): |
| self.embed_cat.init_weights(std=std) |
| self.embed_num.init_weights(std=std) |
|
|
| self.decode_cat.init_weights(std=std) |
| self.decode_num.init_weights(std=std) |
|
|
| self.vision_extractor.init_weights(std=std) |
|
|
| for blk in self.layers: |
| blk.init_weights(std=std) |
|
|
| def forward( |
| self, |
| cat_local_ids: torch.LongTensor, |
| numeric_values_by_nin: Dict[int, torch.Tensor], |
| cat_valid_positions: Optional[torch.Tensor] = None, |
| numeric_valid_positions_by_nin: Optional[Dict[int, torch.Tensor]] = None, |
| pixel_values: Optional[torch.Tensor] = None, |
| vision_valid_positions: Optional[torch.Tensor] = None, |
| ): |
| |
| |
| |
| x_cat, cat_mask = self.embed_cat( |
| local_ids=cat_local_ids, |
| valid_positions=cat_valid_positions, |
| ) |
|
|
| x_num, num_mask = self.embed_num( |
| values_by_nin=numeric_values_by_nin, |
| valid_positions_by_nin=numeric_valid_positions_by_nin, |
| ) |
|
|
| x_tab = torch.cat([x_cat, x_num], dim=1) |
|
|
| B, T_tab, _ = x_tab.shape |
| M_cat = x_cat.size(1) |
| T_num = x_num.size(1) |
|
|
| |
| |
| |
| cat_mask = cat_mask.to(device=x_tab.device, dtype=torch.long) |
| num_mask = num_mask.to(device=x_tab.device, dtype=torch.long) |
|
|
| if self.config["disable_tabular_attention_mask"]: |
| attention_mask_tab = torch.ones(B, T_tab, device=x_tab.device, dtype=torch.long) |
| else: |
| attention_mask_tab = torch.cat([cat_mask, num_mask], dim=1) |
| if attention_mask_tab.shape != (B, T_tab): |
| raise RuntimeError("Internal attention_mask_tab shape mismatch") |
|
|
| |
| |
| |
| if pixel_values is None: |
|
|
| vision_features = None |
| vision_mask = None |
|
|
| else: |
|
|
| vision_features, vision_mask = self.vision_extractor( |
| pixel_values=pixel_values, |
| valid_positions=vision_valid_positions, |
| ) |
|
|
| if vision_features.shape[0] != B: |
| raise ValueError("vision_features batch mismatch with tabular batch") |
|
|
| if vision_mask.shape[0] != B or vision_mask.shape[1] != vision_features.shape[1]: |
| raise ValueError("vision_mask shape mismatch with vision_features") |
|
|
| vision_mask = vision_mask.to( |
| device=attention_mask_tab.device, |
| dtype=attention_mask_tab.dtype, |
| ) |
|
|
| |
| |
| |
| for blk in self.layers: |
| x_tab = blk( |
| x_tab=x_tab, |
| attention_mask=attention_mask_tab, |
| vision_features=vision_features, |
| vision_mask=vision_mask |
| ) |
|
|
| |
| |
| |
| x_cat_out = x_tab[:, :M_cat, :] |
| x_num_out = x_tab[:, M_cat:M_cat + T_num, :] |
|
|
| |
| |
| |
| cat_logits_padded, cat_s, valid_class_mask = self.decode_cat( |
| x_cat_out, |
| return_padded=True, |
| ) |
|
|
| value_by_nin, s_by_nin = self.decode_num( |
| x_num_out |
| ) |
|
|
| return cat_logits_padded, cat_s, valid_class_mask, value_by_nin, s_by_nin, x_tab |
|
|
| def _checkpoint_state_dict(self) -> Dict[str, torch.Tensor]: |
| """ |
| State dict used for save/load. |
| |
| Excludes pretrained frozen vision weights: |
| - vision_extractor.vision_tower.* |
| - vision_extractor.embed_vision.* |
| |
| Keeps reducer weights if reducer exists. |
| """ |
| full_sd = self.state_dict() |
| out = {} |
|
|
| for k, v in full_sd.items(): |
| if k.startswith("vision_extractor.vision_tower."): |
| continue |
| if k.startswith("vision_extractor.embed_vision."): |
| continue |
| out[k] = v |
|
|
| return out |
|
|
| def save_weights(self, path: str): |
| """ |
| Save model weights needed for SoilFormer training/inference, |
| excluding pretrained frozen vision weights. |
| """ |
| payload = { |
| "model_state_dict": self._checkpoint_state_dict(), |
| "config": self.config, |
| } |
| torch.save(payload, path) |
|
|
| def load_weights(self, path: str, map_location: str = "cpu", strict: bool = True): |
| """ |
| Load weights saved by save_weights(). |
| |
| Only the checkpoint-managed subset is loaded: |
| - embeddings / decoders / layers |
| - vision_extractor.reducer.* (if present) |
| |
| Pretrained frozen vision weights are ignored here and are expected |
| to come from vision_model_dir during model construction. |
| """ |
| ckpt = torch.load(path, map_location=map_location) |
|
|
| if isinstance(ckpt, dict) and "model_state_dict" in ckpt: |
| sd = ckpt["model_state_dict"] |
| elif isinstance(ckpt, dict): |
| sd = ckpt |
| else: |
| raise ValueError(f"Unsupported checkpoint format: {path}") |
|
|
| expected_sd = self._checkpoint_state_dict() |
|
|
| |
| loadable_sd = {k: v for k, v in sd.items() if k in expected_sd} |
|
|
| missing = sorted(set(expected_sd.keys()) - set(loadable_sd.keys())) |
| unexpected = sorted(set(sd.keys()) - set(expected_sd.keys())) |
|
|
| |
| load_info = self.load_state_dict(loadable_sd, strict=False) |
|
|
| |
| |
| missing_after_load = [ |
| k for k in load_info.missing_keys |
| if k in expected_sd |
| ] |
| unexpected_after_load = [ |
| k for k in load_info.unexpected_keys |
| if k in expected_sd |
| ] |
|
|
| |
| missing_final = sorted(set(missing) | set(missing_after_load)) |
| unexpected_final = sorted(set(unexpected) | set(unexpected_after_load)) |
|
|
| if strict and (missing_final or unexpected_final): |
| raise RuntimeError( |
| "Checkpoint load mismatch.\n" |
| f"Missing keys: {missing_final}\n" |
| f"Unexpected keys: {unexpected_final}" |
| ) |
|
|
| return { |
| "missing_keys": missing_final, |
| "unexpected_keys": unexpected_final, |
| } |
|
|
|
|
| def loss_function( |
| x_cat: torch.Tensor, |
| s_cat: torch.Tensor, |
| y_cat: torch.Tensor, |
| loss_mask_cat: torch.Tensor, |
| valid_class_mask: torch.Tensor, |
| x_num: Dict[int, torch.Tensor], |
| s_num: Dict[int, torch.Tensor], |
| y_num: Dict[int, torch.Tensor], |
| loss_mask_num: Dict[int, torch.Tensor], |
| cat_temperature: float = 1.0, |
| reduction: str = "mean", |
| eps: float = 1e-12, |
| cat_s_bound: Optional[float] = None, |
| num_s_bound: Optional[float] = None, |
| ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: |
| """ |
| Strict loss for SoilFormer. |
| |
| Categorical: |
| - Uses per-column CE over the valid class range only. |
| - Does NOT rely on padded logits values. |
| - s_cat[b,m] = log sigma^2 for categorical column m. |
| |
| Numeric: |
| - Per-variable MSE averaged over n_in dimensions. |
| - s_num[n_in][b,v] = log sigma^2 for numeric variable v. |
| |
| Optional soft bound: |
| If cat_s_bound or num_s_bound is not None, apply |
| s <- bound * tanh(s / bound) |
| before using s in heteroscedastic weighting. |
| |
| Returns: |
| total_loss: scalar (float32) |
| stats: dict with cat_loss, num_loss, cat_base, num_base, counts... |
| """ |
|
|
| def _soft_bound_logvar(s_: torch.Tensor, bound: Optional[float]) -> torch.Tensor: |
| if bound is None: |
| return s_ |
| b = float(bound) |
| if b <= 0: |
| |
| return torch.zeros_like(s_) |
| return b * torch.tanh(s_ / b) |
|
|
| |
| |
| |
| if x_cat.dim() != 3: |
| raise ValueError(f"x_cat must be [B,M,Cmax], got {tuple(x_cat.shape)}") |
|
|
| B, M, Cmax = x_cat.shape |
|
|
| if s_cat.shape != (B, M): |
| raise ValueError(f"s_cat must be [B,M]=({B},{M}), got {tuple(s_cat.shape)}") |
| if y_cat.shape != (B, M): |
| raise ValueError(f"y_cat must be [B,M]=({B},{M}), got {tuple(y_cat.shape)}") |
| if loss_mask_cat.shape != (B, M): |
| raise ValueError(f"loss_mask_cat must be [B,M]=({B},{M}), got {tuple(loss_mask_cat.shape)}") |
| if valid_class_mask.shape != (M, Cmax): |
| raise ValueError( |
| f"valid_class_mask must be [M,Cmax]=({M},{Cmax}), got {tuple(valid_class_mask.shape)}" |
| ) |
|
|
| x_cat_f = x_cat.float() |
| s_cat_f = _soft_bound_logvar(s_cat.float(), cat_s_bound) |
| y_cat_l = y_cat.long() |
| mcat = loss_mask_cat.float() |
| valid_class_mask = valid_class_mask.to(device=x_cat.device, dtype=torch.bool) |
|
|
| if cat_temperature != 1.0: |
| x_cat_f = x_cat_f / float(cat_temperature) |
|
|
| cat_loss_acc = torch.zeros((), device=x_cat.device, dtype=torch.float32) |
| cat_base_acc = torch.zeros((), device=x_cat.device, dtype=torch.float32) |
| cat_correct_acc = torch.zeros((), device=x_cat.device, dtype=torch.float32) |
|
|
| |
| cat_denom = mcat.sum().clamp_min(float(eps)) |
|
|
| for m in range(M): |
| cm = int(valid_class_mask[m].sum().item()) |
| if cm <= 0: |
| raise ValueError(f"Column {m} has no valid classes") |
|
|
| logits_m = x_cat_f[:, m, :cm] |
| target_m = y_cat_l[:, m] |
| s_m = s_cat_f[:, m] |
| mask_m = mcat[:, m] |
|
|
| active = mask_m > 0 |
| if active.any(): |
| tgt_active = target_m[active] |
| if (tgt_active < 0).any() or (tgt_active >= cm).any(): |
| raise ValueError(f"y_cat contains invalid class id for categorical column {m}") |
|
|
| target_m_safe = target_m.clone() |
| target_m_safe[~active] = 0 |
|
|
| ce_m = F.cross_entropy( |
| logits_m, |
| target_m_safe, |
| reduction="none", |
| ) |
|
|
| |
| |
| |
| pred_m = logits_m.argmax(dim=-1) |
| correct_m = (pred_m == target_m_safe) & active |
| cat_correct_acc = cat_correct_acc + correct_m.float().sum() |
|
|
| |
| L_m = torch.exp(-s_m) * ce_m + s_m |
|
|
| cat_loss_acc = cat_loss_acc + (L_m * mask_m).sum() |
| cat_base_acc = cat_base_acc + (ce_m * mask_m).sum() |
|
|
| if reduction == "mean": |
| cat_loss = cat_loss_acc / cat_denom |
| cat_base = cat_base_acc / cat_denom |
| elif reduction == "sum": |
| cat_loss = cat_loss_acc |
| cat_base = cat_base_acc |
| else: |
| raise ValueError(f"Unsupported reduction: {reduction}") |
| cat_acc = cat_correct_acc / cat_denom |
|
|
| |
| |
| |
| num_loss_acc = torch.zeros((), device=x_cat.device, dtype=torch.float32) |
| num_base_acc = torch.zeros((), device=x_cat.device, dtype=torch.float32) |
| num_denom_acc = torch.zeros((), device=x_cat.device, dtype=torch.float32) |
|
|
| for n_in, x in x_num.items(): |
| if n_in not in y_num or n_in not in s_num or n_in not in loss_mask_num: |
| raise KeyError(f"Missing key n_in={n_in} in y_num/s_num/loss_mask_num") |
|
|
| y = y_num[n_in] |
| s = s_num[n_in] |
| m = loss_mask_num[n_in] |
|
|
| if x.shape != y.shape: |
| raise ValueError( |
| f"x_num[{n_in}] and y_num[{n_in}] shape mismatch: " |
| f"{tuple(x.shape)} vs {tuple(y.shape)}" |
| ) |
| if x.dim() != 3: |
| raise ValueError(f"x_num[{n_in}] must be [B,V,n_in], got {tuple(x.shape)}") |
|
|
| Bb, V, Nin = x.shape |
| if Nin != n_in: |
| raise ValueError(f"x_num[{n_in}] last dim mismatch: got {Nin}, expected {n_in}") |
| if s.shape != (Bb, V): |
| raise ValueError(f"s_num[{n_in}] must be [B,V], got {tuple(s.shape)}") |
| if m.shape != (Bb, V): |
| raise ValueError(f"loss_mask_num[{n_in}] must be [B,V], got {tuple(m.shape)}") |
|
|
| x_f = x.float() |
| y_f = y.float() |
| s_f = _soft_bound_logvar(s.float(), num_s_bound) |
| m_f = m.float() |
|
|
| |
| mse = (x_f - y_f).pow(2).mean(dim=-1) |
|
|
| |
| L = torch.exp(-s_f) * mse + s_f |
|
|
| num_loss_acc = num_loss_acc + (L * m_f).sum() |
| num_base_acc = num_base_acc + (mse * m_f).sum() |
| num_denom_acc = num_denom_acc + m_f.sum() |
|
|
| num_denom = num_denom_acc.clamp_min(float(eps)) |
|
|
| if reduction == "mean": |
| num_loss = num_loss_acc / num_denom |
| num_base = num_base_acc / num_denom |
| elif reduction == "sum": |
| num_loss = num_loss_acc |
| num_base = num_base_acc |
| else: |
| raise ValueError(f"Unsupported reduction: {reduction}") |
|
|
| |
| |
| |
| total = cat_loss + num_loss |
|
|
| stats = { |
| "total": total.detach(), |
| "cat_loss": cat_loss.detach(), |
| "num_loss": num_loss.detach(), |
| "cat_base": cat_base.detach(), |
| "num_base": num_base.detach(), |
| "cat_count": cat_denom.detach(), |
| "num_count": num_denom.detach(), |
| "cat_acc": cat_acc.detach(), |
| } |
| return total, stats |
|
|
|
|
| |
| |
| |
|
|
| def _demo_main(): |
| import argparse |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument("--config_json", type=str, default="config/config_model.json") |
| parser.add_argument("--batch_size", type=int, default=2) |
| parser.add_argument("--with_vision", action="store_true") |
| args = parser.parse_args() |
|
|
| cfg = load_json(args.config_json) |
|
|
| print("===== Loaded config =====") |
| print(json.dumps(cfg, ensure_ascii=False, indent=2)) |
|
|
| |
| |
| |
| tabular_meta = load_json(cfg["tabular_meta"]) |
|
|
| if not os.path.isfile(cfg["cat_vocab_json"]): |
| cat_names = get_categorical_feature_names_from_meta(tabular_meta) |
| vocab = build_cat_vocab_spec_from_meta(tabular_meta, cat_names) |
| Path(cfg["cat_vocab_json"]).parent.mkdir(parents=True, exist_ok=True) |
| save_cat_vocab_json(vocab, cfg["cat_vocab_json"]) |
| print(f"[demo] Built cat_vocab_json at {cfg['cat_vocab_json']}") |
|
|
| if not os.path.isfile(cfg["numeric_vocab_json"]): |
| spec = build_numeric_vocab_spec_from_meta(tabular_meta) |
| Path(cfg["numeric_vocab_json"]).parent.mkdir(parents=True, exist_ok=True) |
| save_json(spec, cfg["numeric_vocab_json"]) |
| print(f"[demo] Built numeric_vocab_json at {cfg['numeric_vocab_json']}") |
|
|
| |
| |
| |
| model = SoilFormer(cfg) |
| model.init_weights() |
| model.eval() |
|
|
| device = next(model.parameters()).device |
| dtype = next(model.parameters()).dtype |
|
|
| B = args.batch_size |
|
|
| |
| |
| |
| cat_spec = load_json(cfg["cat_vocab_json"]) |
| cat_items = sorted(cat_spec.items(), key=lambda x: x[1]["col_id"]) |
| M_cat = len(cat_items) |
|
|
| cat_local_ids = torch.zeros(B, M_cat, dtype=torch.long, device=device) |
| cat_valid_positions = torch.ones(B, M_cat, dtype=torch.bool, device=device) |
|
|
| |
| |
| |
| num_spec = load_json(cfg["numeric_vocab_json"]) |
|
|
| numeric_values_by_nin: Dict[int, torch.Tensor] = {} |
| numeric_valid_positions_by_nin: Dict[int, torch.Tensor] = {} |
|
|
| for g in num_spec["groups"]: |
| n_in = int(g["n_in"]) |
| V = len(g["feature_names"]) |
|
|
| numeric_values_by_nin[n_in] = torch.randn(B, V, n_in, device=device, dtype=dtype) |
| numeric_valid_positions_by_nin[n_in] = torch.ones(B, V, dtype=torch.bool, device=device) |
|
|
| |
| |
| |
| if args.with_vision: |
| pixel_values = torch.randn(B, 3, 224, 224, device=device, dtype=dtype) |
| vision_valid_positions = torch.ones(B, dtype=torch.bool, device=device) |
| else: |
| pixel_values = None |
| vision_valid_positions = None |
|
|
| |
| |
| |
| print("\n===== Vision debug =====") |
| if pixel_values is None: |
| print("pixel_values: None") |
| print("vision_features: None") |
| print("vision_mask: None") |
| else: |
| print("pixel_values:", tuple(pixel_values.shape), pixel_values.dtype, pixel_values.device) |
| with torch.no_grad(): |
| vision_features, vision_mask = model.vision_extractor.forward( |
| pixel_values=pixel_values, |
| valid_positions=vision_valid_positions, |
| ) |
| print("vision_features:", tuple(vision_features.shape), vision_features.dtype, vision_features.device) |
| print("vision_mask:", tuple(vision_mask.shape), vision_mask.dtype, vision_mask.device) |
|
|
| |
| |
| |
| with torch.no_grad(): |
| cat_logits_padded, cat_s, valid_class_mask, value_by_nin, s_by_nin, x_tab = model.forward( |
| cat_local_ids=cat_local_ids, |
| numeric_values_by_nin=numeric_values_by_nin, |
| cat_valid_positions=cat_valid_positions, |
| numeric_valid_positions_by_nin=numeric_valid_positions_by_nin, |
| pixel_values=pixel_values, |
| vision_valid_positions=vision_valid_positions, |
| ) |
|
|
| print("\n===== SoilFormer demo =====") |
| print("cat_local_ids:", tuple(cat_local_ids.shape)) |
| print("cat_valid_positions:", tuple(cat_valid_positions.shape)) |
| print("numeric_values_by_nin:", {k: tuple(v.shape) for k, v in numeric_values_by_nin.items()}) |
| print("numeric_valid_positions_by_nin:", {k: tuple(v.shape) for k, v in numeric_valid_positions_by_nin.items()}) |
| print("x_tab_final:", tuple(x_tab.shape), x_tab.dtype, x_tab.device) |
|
|
| print("Categorical outputs:") |
| print("cat_logits_padded:", tuple(cat_logits_padded.shape), cat_logits_padded.dtype, cat_logits_padded.device) |
| print("cat_s:", tuple(cat_s.shape), cat_s.dtype, cat_s.device) |
|
|
| print("Numeric decoded values:", {k: tuple(v.shape) for k, v in value_by_nin.items()}) |
| print("Numeric decoded s:", {k: tuple(s.shape) for k, s in s_by_nin.items()}) |
|
|
| |
| |
| |
| print("\n===== Loss debug =====") |
|
|
| if cat_logits_padded.dim() != 3: |
| raise RuntimeError(f"cat_logits_padded must be [B,M,Cmax], got {tuple(cat_logits_padded.shape)}") |
|
|
| B_logits, M_cat2, Cmax2 = cat_logits_padded.shape |
| if cat_s.shape != (B_logits, M_cat2): |
| raise RuntimeError(f"cat_s shape mismatch: got {tuple(cat_s.shape)} expected {(B_logits, M_cat2)}") |
|
|
| |
| num_classes = [int(s["num_classes"]) for _, s in cat_items] |
| if len(num_classes) != M_cat2: |
| raise RuntimeError("M_cat mismatch between vocab and model output") |
|
|
| y_cat = torch.zeros(B_logits, M_cat2, dtype=torch.long, device=device) |
| for m, cm in enumerate(num_classes): |
| y_cat[:, m] = torch.randint(low=0, high=cm, size=(B_logits,), device=device) |
|
|
| mask_cat = torch.ones(B_logits, M_cat2, dtype=torch.long, device=device) |
|
|
| |
| y_num = { |
| n_in: torch.randn_like(x_pred) |
| for n_in, x_pred in value_by_nin.items() |
| } |
|
|
| mask_num = { |
| n_in: torch.ones(x_pred.size(0), x_pred.size(1), dtype=torch.long, device=x_pred.device) |
| for n_in, x_pred in value_by_nin.items() |
| } |
|
|
| total_loss, stats = loss_function( |
| x_cat=cat_logits_padded, |
| s_cat=cat_s, |
| y_cat=y_cat, |
| loss_mask_cat=mask_cat, |
| x_num=value_by_nin, |
| s_num=s_by_nin, |
| y_num=y_num, |
| loss_mask_num=mask_num, |
| reduction="mean", |
| valid_class_mask=valid_class_mask |
| ) |
|
|
| print("total_loss:", float(total_loss)) |
| print("stats:", {k: float(v) for k, v in stats.items()}) |
|
|
| if not torch.isfinite(total_loss): |
| raise RuntimeError("Loss is not finite!") |
|
|
|
|
| if __name__ == "__main__": |
| _demo_main() |
|
|