import argparse import ast import json import sys from io import BytesIO from pathlib import Path from typing import Any, Dict, Optional, Tuple from urllib.parse import urljoin import numpy as np import pandas as pd import requests import torch from PIL import Image from torchvision import transforms # The script is intended to live one level above ./modelling. # modelling/ modules still contain some legacy absolute imports, so expose the # modelling directory on sys.path as well. PROJECT_ROOT = Path(__file__).resolve().parent MODELLING_DIR = PROJECT_ROOT / "modelling" if str(MODELLING_DIR) not in sys.path: sys.path.insert(0, str(MODELLING_DIR)) from modelling.soilformer import SoilFormer # noqa: E402 from modelling.utils import get_dtype, load_json # noqa: E402 # ----------------------------------------------------------------------------- # JSON helpers # ----------------------------------------------------------------------------- def load_card(path: str) -> Dict[str, Any]: with open(path, "r", encoding="utf-8") as f: obj = json.load(f) if not isinstance(obj, dict): raise ValueError(f"Card must be a JSON object: {path}") return obj def save_json_pretty(obj: Dict[str, Any], path: Path) -> None: path.parent.mkdir(parents=True, exist_ok=True) with path.open("w", encoding="utf-8") as f: json.dump(obj, f, ensure_ascii=False, indent=2) f.write("\n") def to_jsonable(x: Any) -> Any: if isinstance(x, np.generic): return x.item() if isinstance(x, np.ndarray): return x.tolist() if isinstance(x, torch.Tensor): x = x.detach().cpu() if x.ndim == 0: return x.item() return x.tolist() if isinstance(x, dict): return {str(k): to_jsonable(v) for k, v in x.items()} if isinstance(x, (list, tuple)): return [to_jsonable(v) for v in x] return x # ----------------------------------------------------------------------------- # Runtime / model loading # ----------------------------------------------------------------------------- def resolve_device(device_str: str) -> torch.device: device_str = str(device_str).lower() if device_str == "auto": if torch.cuda.is_available(): return torch.device("cuda") if torch.backends.mps.is_available(): return torch.device("mps") return torch.device("cpu") if device_str == "cuda": if not torch.cuda.is_available(): raise RuntimeError("--device cuda requested, but CUDA is not available") return torch.device("cuda") if device_str == "mps": if not torch.backends.mps.is_available(): raise RuntimeError("--device mps requested, but MPS is not available") return torch.device("mps") if device_str == "cpu": return torch.device("cpu") raise ValueError(f"Unsupported device: {device_str}") def load_model(args: argparse.Namespace, config_model: Dict[str, Any], device: torch.device, dtype: torch.dtype) -> SoilFormer: print("[INFO] Initializing model...") model = SoilFormer(config=config_model, device=str(device)) print("[INFO] Loading checkpoint...") checkpoint = torch.load(args.checkpoint, map_location="cpu") missing, unexpected = model.load_state_dict( checkpoint["model_state_dict"], strict=False ) non_vision_missing = [k for k in missing if not k.startswith("vision_extractor.")] if len(non_vision_missing) > 0: raise RuntimeError( f"[ERROR] Missing non-vision keys detected: {non_vision_missing[:10]}" ) print(f"[INFO] Missing keys (vision only): {len(missing)}") print(f"[INFO] Unexpected keys: {len(unexpected)}") model.to(device=device, dtype=dtype) model.eval() return model # ----------------------------------------------------------------------------- # Metadata loading # ----------------------------------------------------------------------------- def load_metadata(config_data: Dict[str, Any]) -> Dict[str, Any]: cat_vocab = load_json(config_data["cat_vocab_path"]) numeric_vocab = load_json(config_data["numeric_vocab_path"]) stats_df = pd.read_csv(config_data["numeric_stats_path"]) numeric_stats = {} for _, row in stats_df.iterrows(): col = row["column"] mean = float(row["mean"]) std = float(row["std"]) if std == 0.0: std = 1.0 numeric_stats[str(col)] = (mean, std) cat_columns = list(cat_vocab.keys()) cat_mask_local_ids = [int(cat_vocab[col]["mask_local_id"]) for col in cat_columns] id_to_label_by_col = {} for col in cat_columns: label2id = cat_vocab[col]["label2id"] id_to_label_by_col[col] = {int(v): str(k) for k, v in label2id.items()} return { "cat_vocab": cat_vocab, "numeric_vocab": numeric_vocab, "numeric_stats": numeric_stats, "cat_columns": cat_columns, "cat_mask_local_ids": cat_mask_local_ids, "id_to_label_by_col": id_to_label_by_col, } # ----------------------------------------------------------------------------- # Image handling, matching loader.py behavior # ----------------------------------------------------------------------------- class CenterSquareCrop: def __call__(self, img: Image.Image) -> Image.Image: w, h = img.size if w == h: return img if w > h: left = (w - h) // 2 return img.crop((left, 0, left + h, h)) top = (h - w) // 2 return img.crop((0, top, w, top + w)) def build_image_transform(image_size: int): return transforms.Compose([ CenterSquareCrop(), transforms.Resize((image_size, image_size)), transforms.ToTensor(), ]) def join_photo_root(photo_root: str, relative_path: str) -> str: if photo_root.startswith("http://") or photo_root.startswith("https://"): return urljoin(photo_root.rstrip("/") + "/", relative_path) return photo_root.rstrip("/") + "/" + relative_path.lstrip("/") def load_image_tensor(image_path: str, image_size: int) -> torch.Tensor: if image_path.startswith("http://") or image_path.startswith("https://"): resp = requests.get(image_path, timeout=(3, 10)) resp.raise_for_status() img = Image.open(BytesIO(resp.content)).convert("RGB") else: img = Image.open(image_path).convert("RGB") return build_image_transform(image_size)(img) # ----------------------------------------------------------------------------- # Tensorization from readable input card # ----------------------------------------------------------------------------- def is_masked_or_missing(value: Any) -> bool: return value is None or value == "" def parse_numeric_card_value(value: Any, n_in: int) -> Tuple[list[float], bool]: if value is None or value == "": return [0.0] * n_in, False if n_in == 1: if isinstance(value, list): if len(value) != 1: raise ValueError(f"Expected scalar or length-1 list for n_in=1, got {value!r}") return [float(value[0])], True return [float(value)], True if isinstance(value, str): parsed = ast.literal_eval(value) else: parsed = value if not isinstance(parsed, (list, tuple)): raise ValueError(f"Expected list-like numeric vector for n_in={n_in}, got {value!r}") if len(parsed) != n_in: raise ValueError(f"Numeric vector length mismatch: expected {n_in}, got {len(parsed)}") return [float(v) for v in parsed], True def tensorize_card( input_card: Dict[str, Any], config_data: Dict[str, Any], meta: Dict[str, Any], ) -> Dict[str, Any]: categorical = input_card.get("categorical", {}) numeric = input_card.get("numeric", {}) vision = input_card.get("vision", {}) if not isinstance(categorical, dict): raise ValueError("input_card['categorical'] must be an object") if not isinstance(numeric, dict): raise ValueError("input_card['numeric'] must be an object") if not isinstance(vision, dict): vision = {} # Categorical: raw label -> local id, null/"" -> mask id and invalid. cat_ids = [] cat_valids = [] for col, mask_id in zip(meta["cat_columns"], meta["cat_mask_local_ids"]): value = categorical.get(col, "") if is_masked_or_missing(value): cat_ids.append(mask_id) cat_valids.append(False) else: label2id = meta["cat_vocab"][col]["label2id"] if value not in label2id: raise KeyError(f"Unknown categorical value: column={col}, value={value!r}") cat_ids.append(int(label2id[value])) cat_valids.append(True) cat_local_ids = torch.tensor([cat_ids], dtype=torch.long) cat_valid_positions = torch.tensor([cat_valids], dtype=torch.bool) # Numeric: raw actual units -> z-score grouped tensors. numeric_values_by_nin = {} numeric_valid_positions_by_nin = {} for group in meta["numeric_vocab"]["groups"]: n_in = int(group["n_in"]) values = [] valids = [] for feat in group["feature_names"]: feat = str(feat) raw_value = numeric.get(feat, "") parsed, is_valid = parse_numeric_card_value(raw_value, n_in) if is_valid: mean, std = meta["numeric_stats"][feat] parsed = [(v - mean) / std for v in parsed] values.append(parsed) valids.append(is_valid) numeric_values_by_nin[n_in] = torch.tensor([values], dtype=torch.float32) numeric_valid_positions_by_nin[n_in] = torch.tensor([valids], dtype=torch.bool) # Vision: readable card stores suffix only. Load/transform here. image_size = int(config_data["image_size"]) image_path_suffix = vision.get("image_path_suffix", "") if image_path_suffix is None or image_path_suffix == "": pixel_values = torch.zeros(1, 3, image_size, image_size, dtype=torch.float32) vision_valid_positions = torch.tensor([False], dtype=torch.bool) else: image_path = join_photo_root(str(config_data["photo_root"]), str(image_path_suffix)) try: image = load_image_tensor(image_path, image_size=image_size) pixel_values = image.unsqueeze(0) vision_valid_positions = torch.tensor([True], dtype=torch.bool) except Exception as exc: print(f"[WARN] Could not load image; using zero vision input: {exc}") pixel_values = torch.zeros(1, 3, image_size, image_size, dtype=torch.float32) vision_valid_positions = torch.tensor([False], dtype=torch.bool) return { "cat_local_ids": cat_local_ids, "cat_valid_positions": cat_valid_positions, "numeric_values_by_nin": numeric_values_by_nin, "numeric_valid_positions_by_nin": numeric_valid_positions_by_nin, "pixel_values": pixel_values, "vision_valid_positions": vision_valid_positions, } def move_batch_to_device(batch: Dict[str, Any], device: torch.device, dtype: torch.dtype) -> Dict[str, Any]: out = {} for key, value in batch.items(): if isinstance(value, torch.Tensor): if value.dtype.is_floating_point: out[key] = value.to(device=device, dtype=dtype) else: out[key] = value.to(device=device) elif isinstance(value, dict): sub = {} for k, v in value.items(): if isinstance(v, torch.Tensor): if v.dtype.is_floating_point: sub[k] = v.to(device=device, dtype=dtype) else: sub[k] = v.to(device=device) else: sub[k] = v out[key] = sub else: out[key] = value return out # ----------------------------------------------------------------------------- # Decoding model outputs to readable card # ----------------------------------------------------------------------------- def denormalize_numeric(values_z: list[float], mean: float, std: float) -> list[float]: return [float(v) * float(std) + float(mean) for v in values_z] def decode_outputs( cat_logits_padded: torch.Tensor, valid_class_mask: torch.Tensor, value_by_nin: Dict[int, torch.Tensor], meta: Dict[str, Any], ) -> Dict[str, Any]: cat_logits = cat_logits_padded.detach().float().cpu() valid_class_mask = valid_class_mask.detach().cpu().bool() categorical_out = {} for m, col in enumerate(meta["cat_columns"]): cm = int(valid_class_mask[m].sum().item()) logits = cat_logits[0, m, :cm] probs = torch.softmax(logits, dim=-1) pred_id = int(torch.argmax(probs).item()) pred_label = meta["id_to_label_by_col"][col].get(pred_id, str(pred_id)) categorical_out[col] = pred_label numeric_out = {} for group in meta["numeric_vocab"]["groups"]: n_in = int(group["n_in"]) preds_z = value_by_nin[n_in].detach().float().cpu()[0] # [V, n_in] for v_idx, feat in enumerate(group["feature_names"]): feat = str(feat) mean, std = meta["numeric_stats"][feat] raw_pred_values = denormalize_numeric(preds_z[v_idx].tolist(), mean, std) if n_in == 1: numeric_out[feat] = raw_pred_values[0] else: numeric_out[feat] = raw_pred_values return { "categorical": categorical_out, "numeric": numeric_out, } # ----------------------------------------------------------------------------- # Accuracy / MAE analysis # ----------------------------------------------------------------------------- def masked_feature_names(input_card: Dict[str, Any], section: str) -> list[str]: values = input_card.get(section, {}) if not isinstance(values, dict): return [] return [k for k, v in values.items() if v is None] def numeric_abs_errors(pred_value: Any, answer_value: Any) -> list[float]: if answer_value is None or answer_value == "": return [] if pred_value is None or pred_value == "": return [] if isinstance(answer_value, str): s = answer_value.strip() if s == "": return [] if s.startswith("[") and s.endswith("]"): answer_value = [float(x) for x in ast.literal_eval(s)] else: answer_value = float(s) if isinstance(pred_value, str): s = pred_value.strip() if s.startswith("[") and s.endswith("]"): pred_value = [float(x) for x in ast.literal_eval(s)] else: pred_value = float(s) if isinstance(answer_value, (list, tuple)): if not isinstance(pred_value, (list, tuple)): return [] if len(pred_value) != len(answer_value): return [] return [abs(float(p) - float(a)) for p, a in zip(pred_value, answer_value)] return [abs(float(pred_value) - float(answer_value))] def evaluate_against_answer( input_card: Dict[str, Any], output_card: Dict[str, Any], answer_card: Dict[str, Any], ) -> Dict[str, Any]: cat_masked = masked_feature_names(input_card, "categorical") num_masked = masked_feature_names(input_card, "numeric") cat_details = {} correct = 0 total = 0 for feat in cat_masked: answer = answer_card.get("categorical", {}).get(feat) pred = output_card.get("categorical", {}).get(feat) if answer is None or answer == "": continue is_correct = pred == answer cat_details[feat] = { "predicted": pred, "answer": answer, "correct": bool(is_correct), } correct += int(is_correct) total += 1 num_details = {} abs_errors_all = [] for feat in num_masked: answer = answer_card.get("numeric", {}).get(feat) pred = output_card.get("numeric", {}).get(feat) errors = numeric_abs_errors(pred, answer) if not errors: continue mae = sum(errors) / len(errors) num_details[feat] = { "predicted": pred, "answer": answer, "absolute_error": errors[0] if len(errors) == 1 else errors, "mae": mae, } abs_errors_all.extend(errors) return { "categorical": { "accuracy": None if total == 0 else correct / total, "correct": correct, "total": total, "details": cat_details, }, "numeric": { "mae": None if len(abs_errors_all) == 0 else sum(abs_errors_all) / len(abs_errors_all), "count": len(abs_errors_all), "details": num_details, }, "note": "Metrics are computed only on fields that are null in input_card. Natural missing values \"\" are ignored.", } def acc_path_from_output(output: str) -> Path: path = Path(output) if path.suffix == ".json": base = path.with_suffix("") else: base = path return base.with_name(base.name + "__acc.json") # ----------------------------------------------------------------------------- # CLI # ----------------------------------------------------------------------------- def main() -> None: parser = argparse.ArgumentParser(description="Run SoilFormer inference from a readable input card.") parser.add_argument("--input_card", type=str, required=True) parser.add_argument("--output", type=str, required=True) parser.add_argument("--answer_card", type=str, default=None) parser.add_argument("--checkpoint", type=str, required=True) parser.add_argument("--config_data", type=str, default="config/config_data.json") parser.add_argument("--config_model", type=str, default="config/config_model.json") parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "mps", "cpu"]) args = parser.parse_args() config_data = load_json(args.config_data) config_model = load_json(args.config_model) dtype = get_dtype(config_model.get("dtype", "bfloat16")) device = resolve_device(args.device) meta = load_metadata(config_data) input_card = load_card(args.input_card) batch = tensorize_card(input_card=input_card, config_data=config_data, meta=meta) batch = move_batch_to_device(batch, device=device, dtype=dtype) model = load_model(args=args, config_model=config_model, device=device, dtype=dtype) with torch.no_grad(): cat_logits_padded, cat_s, valid_class_mask, value_by_nin, s_by_nin, _ = model( cat_local_ids=batch["cat_local_ids"], numeric_values_by_nin=batch["numeric_values_by_nin"], cat_valid_positions=batch["cat_valid_positions"], numeric_valid_positions_by_nin=batch["numeric_valid_positions_by_nin"], pixel_values=batch["pixel_values"], vision_valid_positions=batch["vision_valid_positions"], ) output_card = decode_outputs( cat_logits_padded=cat_logits_padded, valid_class_mask=valid_class_mask, value_by_nin=value_by_nin, meta=meta, ) save_json_pretty(to_jsonable(output_card), Path(args.output)) result = {"status": "ok", "output": args.output} if args.answer_card: answer_card = load_card(args.answer_card) acc_card = evaluate_against_answer( input_card=input_card, output_card=output_card, answer_card=answer_card, ) acc_path = acc_path_from_output(args.output) save_json_pretty(to_jsonable(acc_card), acc_path) result["acc_output"] = str(acc_path) print(json.dumps(result, ensure_ascii=False)) if __name__ == "__main__": main()