| import argparse |
| import ast |
| import json |
| import sys |
| from io import BytesIO |
| from pathlib import Path |
| from typing import Any, Dict, Optional, Tuple |
| from urllib.parse import urljoin |
|
|
| import numpy as np |
| import pandas as pd |
| import requests |
| import torch |
| from PIL import Image |
| from torchvision import transforms |
|
|
| |
| |
| |
| PROJECT_ROOT = Path(__file__).resolve().parent |
| MODELLING_DIR = PROJECT_ROOT / "modelling" |
| if str(MODELLING_DIR) not in sys.path: |
| sys.path.insert(0, str(MODELLING_DIR)) |
|
|
| from modelling.soilformer import SoilFormer |
| from modelling.utils import get_dtype, load_json |
|
|
|
|
| |
| |
| |
|
|
| def load_card(path: str) -> Dict[str, Any]: |
| with open(path, "r", encoding="utf-8") as f: |
| obj = json.load(f) |
| if not isinstance(obj, dict): |
| raise ValueError(f"Card must be a JSON object: {path}") |
| return obj |
|
|
|
|
| def save_json_pretty(obj: Dict[str, Any], path: Path) -> None: |
| path.parent.mkdir(parents=True, exist_ok=True) |
| with path.open("w", encoding="utf-8") as f: |
| json.dump(obj, f, ensure_ascii=False, indent=2) |
| f.write("\n") |
|
|
|
|
| def to_jsonable(x: Any) -> Any: |
| if isinstance(x, np.generic): |
| return x.item() |
| if isinstance(x, np.ndarray): |
| return x.tolist() |
| if isinstance(x, torch.Tensor): |
| x = x.detach().cpu() |
| if x.ndim == 0: |
| return x.item() |
| return x.tolist() |
| if isinstance(x, dict): |
| return {str(k): to_jsonable(v) for k, v in x.items()} |
| if isinstance(x, (list, tuple)): |
| return [to_jsonable(v) for v in x] |
| return x |
|
|
|
|
| |
| |
| |
|
|
| def resolve_device(device_str: str) -> torch.device: |
| device_str = str(device_str).lower() |
| if device_str == "auto": |
| if torch.cuda.is_available(): |
| return torch.device("cuda") |
| if torch.backends.mps.is_available(): |
| return torch.device("mps") |
| return torch.device("cpu") |
| if device_str == "cuda": |
| if not torch.cuda.is_available(): |
| raise RuntimeError("--device cuda requested, but CUDA is not available") |
| return torch.device("cuda") |
| if device_str == "mps": |
| if not torch.backends.mps.is_available(): |
| raise RuntimeError("--device mps requested, but MPS is not available") |
| return torch.device("mps") |
| if device_str == "cpu": |
| return torch.device("cpu") |
| raise ValueError(f"Unsupported device: {device_str}") |
|
|
|
|
| def load_model(args: argparse.Namespace, config_model: Dict[str, Any], device: torch.device, dtype: torch.dtype) -> SoilFormer: |
| print("[INFO] Initializing model...") |
| model = SoilFormer(config=config_model, device=str(device)) |
|
|
| print("[INFO] Loading checkpoint...") |
| checkpoint = torch.load(args.checkpoint, map_location="cpu") |
| missing, unexpected = model.load_state_dict( |
| checkpoint["model_state_dict"], strict=False |
| ) |
|
|
| non_vision_missing = [k for k in missing if not k.startswith("vision_extractor.")] |
| if len(non_vision_missing) > 0: |
| raise RuntimeError( |
| f"[ERROR] Missing non-vision keys detected: {non_vision_missing[:10]}" |
| ) |
|
|
| print(f"[INFO] Missing keys (vision only): {len(missing)}") |
| print(f"[INFO] Unexpected keys: {len(unexpected)}") |
|
|
| model.to(device=device, dtype=dtype) |
| model.eval() |
| return model |
|
|
|
|
| |
| |
| |
|
|
| def load_metadata(config_data: Dict[str, Any]) -> Dict[str, Any]: |
| cat_vocab = load_json(config_data["cat_vocab_path"]) |
| numeric_vocab = load_json(config_data["numeric_vocab_path"]) |
|
|
| stats_df = pd.read_csv(config_data["numeric_stats_path"]) |
| numeric_stats = {} |
| for _, row in stats_df.iterrows(): |
| col = row["column"] |
| mean = float(row["mean"]) |
| std = float(row["std"]) |
| if std == 0.0: |
| std = 1.0 |
| numeric_stats[str(col)] = (mean, std) |
|
|
| cat_columns = list(cat_vocab.keys()) |
| cat_mask_local_ids = [int(cat_vocab[col]["mask_local_id"]) for col in cat_columns] |
|
|
| id_to_label_by_col = {} |
| for col in cat_columns: |
| label2id = cat_vocab[col]["label2id"] |
| id_to_label_by_col[col] = {int(v): str(k) for k, v in label2id.items()} |
|
|
| return { |
| "cat_vocab": cat_vocab, |
| "numeric_vocab": numeric_vocab, |
| "numeric_stats": numeric_stats, |
| "cat_columns": cat_columns, |
| "cat_mask_local_ids": cat_mask_local_ids, |
| "id_to_label_by_col": id_to_label_by_col, |
| } |
|
|
|
|
| |
| |
| |
|
|
| class CenterSquareCrop: |
| def __call__(self, img: Image.Image) -> Image.Image: |
| w, h = img.size |
| if w == h: |
| return img |
| if w > h: |
| left = (w - h) // 2 |
| return img.crop((left, 0, left + h, h)) |
| top = (h - w) // 2 |
| return img.crop((0, top, w, top + w)) |
|
|
|
|
| def build_image_transform(image_size: int): |
| return transforms.Compose([ |
| CenterSquareCrop(), |
| transforms.Resize((image_size, image_size)), |
| transforms.ToTensor(), |
| ]) |
|
|
|
|
| def join_photo_root(photo_root: str, relative_path: str) -> str: |
| if photo_root.startswith("http://") or photo_root.startswith("https://"): |
| return urljoin(photo_root.rstrip("/") + "/", relative_path) |
| return photo_root.rstrip("/") + "/" + relative_path.lstrip("/") |
|
|
|
|
| def load_image_tensor(image_path: str, image_size: int) -> torch.Tensor: |
| if image_path.startswith("http://") or image_path.startswith("https://"): |
| resp = requests.get(image_path, timeout=(3, 10)) |
| resp.raise_for_status() |
| img = Image.open(BytesIO(resp.content)).convert("RGB") |
| else: |
| img = Image.open(image_path).convert("RGB") |
| return build_image_transform(image_size)(img) |
|
|
|
|
| |
| |
| |
|
|
| def is_masked_or_missing(value: Any) -> bool: |
| return value is None or value == "" |
|
|
|
|
| def parse_numeric_card_value(value: Any, n_in: int) -> Tuple[list[float], bool]: |
| if value is None or value == "": |
| return [0.0] * n_in, False |
|
|
| if n_in == 1: |
| if isinstance(value, list): |
| if len(value) != 1: |
| raise ValueError(f"Expected scalar or length-1 list for n_in=1, got {value!r}") |
| return [float(value[0])], True |
| return [float(value)], True |
|
|
| if isinstance(value, str): |
| parsed = ast.literal_eval(value) |
| else: |
| parsed = value |
|
|
| if not isinstance(parsed, (list, tuple)): |
| raise ValueError(f"Expected list-like numeric vector for n_in={n_in}, got {value!r}") |
| if len(parsed) != n_in: |
| raise ValueError(f"Numeric vector length mismatch: expected {n_in}, got {len(parsed)}") |
| return [float(v) for v in parsed], True |
|
|
|
|
| def tensorize_card( |
| input_card: Dict[str, Any], |
| config_data: Dict[str, Any], |
| meta: Dict[str, Any], |
| ) -> Dict[str, Any]: |
| categorical = input_card.get("categorical", {}) |
| numeric = input_card.get("numeric", {}) |
| vision = input_card.get("vision", {}) |
|
|
| if not isinstance(categorical, dict): |
| raise ValueError("input_card['categorical'] must be an object") |
| if not isinstance(numeric, dict): |
| raise ValueError("input_card['numeric'] must be an object") |
| if not isinstance(vision, dict): |
| vision = {} |
|
|
| |
| cat_ids = [] |
| cat_valids = [] |
| for col, mask_id in zip(meta["cat_columns"], meta["cat_mask_local_ids"]): |
| value = categorical.get(col, "") |
| if is_masked_or_missing(value): |
| cat_ids.append(mask_id) |
| cat_valids.append(False) |
| else: |
| label2id = meta["cat_vocab"][col]["label2id"] |
| if value not in label2id: |
| raise KeyError(f"Unknown categorical value: column={col}, value={value!r}") |
| cat_ids.append(int(label2id[value])) |
| cat_valids.append(True) |
|
|
| cat_local_ids = torch.tensor([cat_ids], dtype=torch.long) |
| cat_valid_positions = torch.tensor([cat_valids], dtype=torch.bool) |
|
|
| |
| numeric_values_by_nin = {} |
| numeric_valid_positions_by_nin = {} |
|
|
| for group in meta["numeric_vocab"]["groups"]: |
| n_in = int(group["n_in"]) |
| values = [] |
| valids = [] |
| for feat in group["feature_names"]: |
| feat = str(feat) |
| raw_value = numeric.get(feat, "") |
| parsed, is_valid = parse_numeric_card_value(raw_value, n_in) |
| if is_valid: |
| mean, std = meta["numeric_stats"][feat] |
| parsed = [(v - mean) / std for v in parsed] |
| values.append(parsed) |
| valids.append(is_valid) |
|
|
| numeric_values_by_nin[n_in] = torch.tensor([values], dtype=torch.float32) |
| numeric_valid_positions_by_nin[n_in] = torch.tensor([valids], dtype=torch.bool) |
|
|
| |
| image_size = int(config_data["image_size"]) |
| image_path_suffix = vision.get("image_path_suffix", "") |
| if image_path_suffix is None or image_path_suffix == "": |
| pixel_values = torch.zeros(1, 3, image_size, image_size, dtype=torch.float32) |
| vision_valid_positions = torch.tensor([False], dtype=torch.bool) |
| else: |
| image_path = join_photo_root(str(config_data["photo_root"]), str(image_path_suffix)) |
| try: |
| image = load_image_tensor(image_path, image_size=image_size) |
| pixel_values = image.unsqueeze(0) |
| vision_valid_positions = torch.tensor([True], dtype=torch.bool) |
| except Exception as exc: |
| print(f"[WARN] Could not load image; using zero vision input: {exc}") |
| pixel_values = torch.zeros(1, 3, image_size, image_size, dtype=torch.float32) |
| vision_valid_positions = torch.tensor([False], dtype=torch.bool) |
|
|
| return { |
| "cat_local_ids": cat_local_ids, |
| "cat_valid_positions": cat_valid_positions, |
| "numeric_values_by_nin": numeric_values_by_nin, |
| "numeric_valid_positions_by_nin": numeric_valid_positions_by_nin, |
| "pixel_values": pixel_values, |
| "vision_valid_positions": vision_valid_positions, |
| } |
|
|
|
|
| def move_batch_to_device(batch: Dict[str, Any], device: torch.device, dtype: torch.dtype) -> Dict[str, Any]: |
| out = {} |
| for key, value in batch.items(): |
| if isinstance(value, torch.Tensor): |
| if value.dtype.is_floating_point: |
| out[key] = value.to(device=device, dtype=dtype) |
| else: |
| out[key] = value.to(device=device) |
| elif isinstance(value, dict): |
| sub = {} |
| for k, v in value.items(): |
| if isinstance(v, torch.Tensor): |
| if v.dtype.is_floating_point: |
| sub[k] = v.to(device=device, dtype=dtype) |
| else: |
| sub[k] = v.to(device=device) |
| else: |
| sub[k] = v |
| out[key] = sub |
| else: |
| out[key] = value |
| return out |
|
|
|
|
| |
| |
| |
|
|
| def denormalize_numeric(values_z: list[float], mean: float, std: float) -> list[float]: |
| return [float(v) * float(std) + float(mean) for v in values_z] |
|
|
|
|
| def decode_outputs( |
| cat_logits_padded: torch.Tensor, |
| valid_class_mask: torch.Tensor, |
| value_by_nin: Dict[int, torch.Tensor], |
| meta: Dict[str, Any], |
| ) -> Dict[str, Any]: |
| cat_logits = cat_logits_padded.detach().float().cpu() |
| valid_class_mask = valid_class_mask.detach().cpu().bool() |
|
|
| categorical_out = {} |
| for m, col in enumerate(meta["cat_columns"]): |
| cm = int(valid_class_mask[m].sum().item()) |
| logits = cat_logits[0, m, :cm] |
| probs = torch.softmax(logits, dim=-1) |
| pred_id = int(torch.argmax(probs).item()) |
| pred_label = meta["id_to_label_by_col"][col].get(pred_id, str(pred_id)) |
| categorical_out[col] = pred_label |
|
|
| numeric_out = {} |
| for group in meta["numeric_vocab"]["groups"]: |
| n_in = int(group["n_in"]) |
| preds_z = value_by_nin[n_in].detach().float().cpu()[0] |
| for v_idx, feat in enumerate(group["feature_names"]): |
| feat = str(feat) |
| mean, std = meta["numeric_stats"][feat] |
| raw_pred_values = denormalize_numeric(preds_z[v_idx].tolist(), mean, std) |
| if n_in == 1: |
| numeric_out[feat] = raw_pred_values[0] |
| else: |
| numeric_out[feat] = raw_pred_values |
|
|
| return { |
| "categorical": categorical_out, |
| "numeric": numeric_out, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def masked_feature_names(input_card: Dict[str, Any], section: str) -> list[str]: |
| values = input_card.get(section, {}) |
| if not isinstance(values, dict): |
| return [] |
| return [k for k, v in values.items() if v is None] |
|
|
|
|
| def numeric_abs_errors(pred_value: Any, answer_value: Any) -> list[float]: |
| if answer_value is None or answer_value == "": |
| return [] |
| if pred_value is None or pred_value == "": |
| return [] |
|
|
| if isinstance(answer_value, str): |
| s = answer_value.strip() |
| if s == "": |
| return [] |
| if s.startswith("[") and s.endswith("]"): |
| answer_value = [float(x) for x in ast.literal_eval(s)] |
| else: |
| answer_value = float(s) |
|
|
| if isinstance(pred_value, str): |
| s = pred_value.strip() |
| if s.startswith("[") and s.endswith("]"): |
| pred_value = [float(x) for x in ast.literal_eval(s)] |
| else: |
| pred_value = float(s) |
|
|
| if isinstance(answer_value, (list, tuple)): |
| if not isinstance(pred_value, (list, tuple)): |
| return [] |
| if len(pred_value) != len(answer_value): |
| return [] |
| return [abs(float(p) - float(a)) for p, a in zip(pred_value, answer_value)] |
|
|
| return [abs(float(pred_value) - float(answer_value))] |
|
|
|
|
| def evaluate_against_answer( |
| input_card: Dict[str, Any], |
| output_card: Dict[str, Any], |
| answer_card: Dict[str, Any], |
| ) -> Dict[str, Any]: |
| cat_masked = masked_feature_names(input_card, "categorical") |
| num_masked = masked_feature_names(input_card, "numeric") |
|
|
| cat_details = {} |
| correct = 0 |
| total = 0 |
| for feat in cat_masked: |
| answer = answer_card.get("categorical", {}).get(feat) |
| pred = output_card.get("categorical", {}).get(feat) |
| if answer is None or answer == "": |
| continue |
| is_correct = pred == answer |
| cat_details[feat] = { |
| "predicted": pred, |
| "answer": answer, |
| "correct": bool(is_correct), |
| } |
| correct += int(is_correct) |
| total += 1 |
|
|
| num_details = {} |
| abs_errors_all = [] |
| for feat in num_masked: |
| answer = answer_card.get("numeric", {}).get(feat) |
| pred = output_card.get("numeric", {}).get(feat) |
| errors = numeric_abs_errors(pred, answer) |
| if not errors: |
| continue |
| mae = sum(errors) / len(errors) |
| num_details[feat] = { |
| "predicted": pred, |
| "answer": answer, |
| "absolute_error": errors[0] if len(errors) == 1 else errors, |
| "mae": mae, |
| } |
| abs_errors_all.extend(errors) |
|
|
| return { |
| "categorical": { |
| "accuracy": None if total == 0 else correct / total, |
| "correct": correct, |
| "total": total, |
| "details": cat_details, |
| }, |
| "numeric": { |
| "mae": None if len(abs_errors_all) == 0 else sum(abs_errors_all) / len(abs_errors_all), |
| "count": len(abs_errors_all), |
| "details": num_details, |
| }, |
| "note": "Metrics are computed only on fields that are null in input_card. Natural missing values \"\" are ignored.", |
| } |
|
|
|
|
| def acc_path_from_output(output: str) -> Path: |
| path = Path(output) |
| if path.suffix == ".json": |
| base = path.with_suffix("") |
| else: |
| base = path |
| return base.with_name(base.name + "__acc.json") |
|
|
|
|
| |
| |
| |
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser(description="Run SoilFormer inference from a readable input card.") |
| parser.add_argument("--input_card", type=str, required=True) |
| parser.add_argument("--output", type=str, required=True) |
| parser.add_argument("--answer_card", type=str, default=None) |
| parser.add_argument("--checkpoint", type=str, required=True) |
| parser.add_argument("--config_data", type=str, default="config/config_data.json") |
| parser.add_argument("--config_model", type=str, default="config/config_model.json") |
| parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "mps", "cpu"]) |
| args = parser.parse_args() |
|
|
| config_data = load_json(args.config_data) |
| config_model = load_json(args.config_model) |
| dtype = get_dtype(config_model.get("dtype", "bfloat16")) |
| device = resolve_device(args.device) |
|
|
| meta = load_metadata(config_data) |
| input_card = load_card(args.input_card) |
| batch = tensorize_card(input_card=input_card, config_data=config_data, meta=meta) |
| batch = move_batch_to_device(batch, device=device, dtype=dtype) |
|
|
| model = load_model(args=args, config_model=config_model, device=device, dtype=dtype) |
|
|
| with torch.no_grad(): |
| cat_logits_padded, cat_s, valid_class_mask, value_by_nin, s_by_nin, _ = model( |
| cat_local_ids=batch["cat_local_ids"], |
| numeric_values_by_nin=batch["numeric_values_by_nin"], |
| cat_valid_positions=batch["cat_valid_positions"], |
| numeric_valid_positions_by_nin=batch["numeric_valid_positions_by_nin"], |
| pixel_values=batch["pixel_values"], |
| vision_valid_positions=batch["vision_valid_positions"], |
| ) |
|
|
| output_card = decode_outputs( |
| cat_logits_padded=cat_logits_padded, |
| valid_class_mask=valid_class_mask, |
| value_by_nin=value_by_nin, |
| meta=meta, |
| ) |
|
|
| save_json_pretty(to_jsonable(output_card), Path(args.output)) |
|
|
| result = {"status": "ok", "output": args.output} |
|
|
| if args.answer_card: |
| answer_card = load_card(args.answer_card) |
| acc_card = evaluate_against_answer( |
| input_card=input_card, |
| output_card=output_card, |
| answer_card=answer_card, |
| ) |
| acc_path = acc_path_from_output(args.output) |
| save_json_pretty(to_jsonable(acc_card), acc_path) |
| result["acc_output"] = str(acc_path) |
|
|
| print(json.dumps(result, ensure_ascii=False)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|