soilformer / inference_predict_output_card.py
Kuangdai
Initial release of SoilFormer
6fb6c07
import argparse
import ast
import json
import sys
from io import BytesIO
from pathlib import Path
from typing import Any, Dict, Optional, Tuple
from urllib.parse import urljoin
import numpy as np
import pandas as pd
import requests
import torch
from PIL import Image
from torchvision import transforms
# The script is intended to live one level above ./modelling.
# modelling/ modules still contain some legacy absolute imports, so expose the
# modelling directory on sys.path as well.
PROJECT_ROOT = Path(__file__).resolve().parent
MODELLING_DIR = PROJECT_ROOT / "modelling"
if str(MODELLING_DIR) not in sys.path:
sys.path.insert(0, str(MODELLING_DIR))
from modelling.soilformer import SoilFormer # noqa: E402
from modelling.utils import get_dtype, load_json # noqa: E402
# -----------------------------------------------------------------------------
# JSON helpers
# -----------------------------------------------------------------------------
def load_card(path: str) -> Dict[str, Any]:
with open(path, "r", encoding="utf-8") as f:
obj = json.load(f)
if not isinstance(obj, dict):
raise ValueError(f"Card must be a JSON object: {path}")
return obj
def save_json_pretty(obj: Dict[str, Any], path: Path) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8") as f:
json.dump(obj, f, ensure_ascii=False, indent=2)
f.write("\n")
def to_jsonable(x: Any) -> Any:
if isinstance(x, np.generic):
return x.item()
if isinstance(x, np.ndarray):
return x.tolist()
if isinstance(x, torch.Tensor):
x = x.detach().cpu()
if x.ndim == 0:
return x.item()
return x.tolist()
if isinstance(x, dict):
return {str(k): to_jsonable(v) for k, v in x.items()}
if isinstance(x, (list, tuple)):
return [to_jsonable(v) for v in x]
return x
# -----------------------------------------------------------------------------
# Runtime / model loading
# -----------------------------------------------------------------------------
def resolve_device(device_str: str) -> torch.device:
device_str = str(device_str).lower()
if device_str == "auto":
if torch.cuda.is_available():
return torch.device("cuda")
if torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
if device_str == "cuda":
if not torch.cuda.is_available():
raise RuntimeError("--device cuda requested, but CUDA is not available")
return torch.device("cuda")
if device_str == "mps":
if not torch.backends.mps.is_available():
raise RuntimeError("--device mps requested, but MPS is not available")
return torch.device("mps")
if device_str == "cpu":
return torch.device("cpu")
raise ValueError(f"Unsupported device: {device_str}")
def load_model(args: argparse.Namespace, config_model: Dict[str, Any], device: torch.device, dtype: torch.dtype) -> SoilFormer:
print("[INFO] Initializing model...")
model = SoilFormer(config=config_model, device=str(device))
print("[INFO] Loading checkpoint...")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
missing, unexpected = model.load_state_dict(
checkpoint["model_state_dict"], strict=False
)
non_vision_missing = [k for k in missing if not k.startswith("vision_extractor.")]
if len(non_vision_missing) > 0:
raise RuntimeError(
f"[ERROR] Missing non-vision keys detected: {non_vision_missing[:10]}"
)
print(f"[INFO] Missing keys (vision only): {len(missing)}")
print(f"[INFO] Unexpected keys: {len(unexpected)}")
model.to(device=device, dtype=dtype)
model.eval()
return model
# -----------------------------------------------------------------------------
# Metadata loading
# -----------------------------------------------------------------------------
def load_metadata(config_data: Dict[str, Any]) -> Dict[str, Any]:
cat_vocab = load_json(config_data["cat_vocab_path"])
numeric_vocab = load_json(config_data["numeric_vocab_path"])
stats_df = pd.read_csv(config_data["numeric_stats_path"])
numeric_stats = {}
for _, row in stats_df.iterrows():
col = row["column"]
mean = float(row["mean"])
std = float(row["std"])
if std == 0.0:
std = 1.0
numeric_stats[str(col)] = (mean, std)
cat_columns = list(cat_vocab.keys())
cat_mask_local_ids = [int(cat_vocab[col]["mask_local_id"]) for col in cat_columns]
id_to_label_by_col = {}
for col in cat_columns:
label2id = cat_vocab[col]["label2id"]
id_to_label_by_col[col] = {int(v): str(k) for k, v in label2id.items()}
return {
"cat_vocab": cat_vocab,
"numeric_vocab": numeric_vocab,
"numeric_stats": numeric_stats,
"cat_columns": cat_columns,
"cat_mask_local_ids": cat_mask_local_ids,
"id_to_label_by_col": id_to_label_by_col,
}
# -----------------------------------------------------------------------------
# Image handling, matching loader.py behavior
# -----------------------------------------------------------------------------
class CenterSquareCrop:
def __call__(self, img: Image.Image) -> Image.Image:
w, h = img.size
if w == h:
return img
if w > h:
left = (w - h) // 2
return img.crop((left, 0, left + h, h))
top = (h - w) // 2
return img.crop((0, top, w, top + w))
def build_image_transform(image_size: int):
return transforms.Compose([
CenterSquareCrop(),
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
])
def join_photo_root(photo_root: str, relative_path: str) -> str:
if photo_root.startswith("http://") or photo_root.startswith("https://"):
return urljoin(photo_root.rstrip("/") + "/", relative_path)
return photo_root.rstrip("/") + "/" + relative_path.lstrip("/")
def load_image_tensor(image_path: str, image_size: int) -> torch.Tensor:
if image_path.startswith("http://") or image_path.startswith("https://"):
resp = requests.get(image_path, timeout=(3, 10))
resp.raise_for_status()
img = Image.open(BytesIO(resp.content)).convert("RGB")
else:
img = Image.open(image_path).convert("RGB")
return build_image_transform(image_size)(img)
# -----------------------------------------------------------------------------
# Tensorization from readable input card
# -----------------------------------------------------------------------------
def is_masked_or_missing(value: Any) -> bool:
return value is None or value == ""
def parse_numeric_card_value(value: Any, n_in: int) -> Tuple[list[float], bool]:
if value is None or value == "":
return [0.0] * n_in, False
if n_in == 1:
if isinstance(value, list):
if len(value) != 1:
raise ValueError(f"Expected scalar or length-1 list for n_in=1, got {value!r}")
return [float(value[0])], True
return [float(value)], True
if isinstance(value, str):
parsed = ast.literal_eval(value)
else:
parsed = value
if not isinstance(parsed, (list, tuple)):
raise ValueError(f"Expected list-like numeric vector for n_in={n_in}, got {value!r}")
if len(parsed) != n_in:
raise ValueError(f"Numeric vector length mismatch: expected {n_in}, got {len(parsed)}")
return [float(v) for v in parsed], True
def tensorize_card(
input_card: Dict[str, Any],
config_data: Dict[str, Any],
meta: Dict[str, Any],
) -> Dict[str, Any]:
categorical = input_card.get("categorical", {})
numeric = input_card.get("numeric", {})
vision = input_card.get("vision", {})
if not isinstance(categorical, dict):
raise ValueError("input_card['categorical'] must be an object")
if not isinstance(numeric, dict):
raise ValueError("input_card['numeric'] must be an object")
if not isinstance(vision, dict):
vision = {}
# Categorical: raw label -> local id, null/"" -> mask id and invalid.
cat_ids = []
cat_valids = []
for col, mask_id in zip(meta["cat_columns"], meta["cat_mask_local_ids"]):
value = categorical.get(col, "")
if is_masked_or_missing(value):
cat_ids.append(mask_id)
cat_valids.append(False)
else:
label2id = meta["cat_vocab"][col]["label2id"]
if value not in label2id:
raise KeyError(f"Unknown categorical value: column={col}, value={value!r}")
cat_ids.append(int(label2id[value]))
cat_valids.append(True)
cat_local_ids = torch.tensor([cat_ids], dtype=torch.long)
cat_valid_positions = torch.tensor([cat_valids], dtype=torch.bool)
# Numeric: raw actual units -> z-score grouped tensors.
numeric_values_by_nin = {}
numeric_valid_positions_by_nin = {}
for group in meta["numeric_vocab"]["groups"]:
n_in = int(group["n_in"])
values = []
valids = []
for feat in group["feature_names"]:
feat = str(feat)
raw_value = numeric.get(feat, "")
parsed, is_valid = parse_numeric_card_value(raw_value, n_in)
if is_valid:
mean, std = meta["numeric_stats"][feat]
parsed = [(v - mean) / std for v in parsed]
values.append(parsed)
valids.append(is_valid)
numeric_values_by_nin[n_in] = torch.tensor([values], dtype=torch.float32)
numeric_valid_positions_by_nin[n_in] = torch.tensor([valids], dtype=torch.bool)
# Vision: readable card stores suffix only. Load/transform here.
image_size = int(config_data["image_size"])
image_path_suffix = vision.get("image_path_suffix", "")
if image_path_suffix is None or image_path_suffix == "":
pixel_values = torch.zeros(1, 3, image_size, image_size, dtype=torch.float32)
vision_valid_positions = torch.tensor([False], dtype=torch.bool)
else:
image_path = join_photo_root(str(config_data["photo_root"]), str(image_path_suffix))
try:
image = load_image_tensor(image_path, image_size=image_size)
pixel_values = image.unsqueeze(0)
vision_valid_positions = torch.tensor([True], dtype=torch.bool)
except Exception as exc:
print(f"[WARN] Could not load image; using zero vision input: {exc}")
pixel_values = torch.zeros(1, 3, image_size, image_size, dtype=torch.float32)
vision_valid_positions = torch.tensor([False], dtype=torch.bool)
return {
"cat_local_ids": cat_local_ids,
"cat_valid_positions": cat_valid_positions,
"numeric_values_by_nin": numeric_values_by_nin,
"numeric_valid_positions_by_nin": numeric_valid_positions_by_nin,
"pixel_values": pixel_values,
"vision_valid_positions": vision_valid_positions,
}
def move_batch_to_device(batch: Dict[str, Any], device: torch.device, dtype: torch.dtype) -> Dict[str, Any]:
out = {}
for key, value in batch.items():
if isinstance(value, torch.Tensor):
if value.dtype.is_floating_point:
out[key] = value.to(device=device, dtype=dtype)
else:
out[key] = value.to(device=device)
elif isinstance(value, dict):
sub = {}
for k, v in value.items():
if isinstance(v, torch.Tensor):
if v.dtype.is_floating_point:
sub[k] = v.to(device=device, dtype=dtype)
else:
sub[k] = v.to(device=device)
else:
sub[k] = v
out[key] = sub
else:
out[key] = value
return out
# -----------------------------------------------------------------------------
# Decoding model outputs to readable card
# -----------------------------------------------------------------------------
def denormalize_numeric(values_z: list[float], mean: float, std: float) -> list[float]:
return [float(v) * float(std) + float(mean) for v in values_z]
def decode_outputs(
cat_logits_padded: torch.Tensor,
valid_class_mask: torch.Tensor,
value_by_nin: Dict[int, torch.Tensor],
meta: Dict[str, Any],
) -> Dict[str, Any]:
cat_logits = cat_logits_padded.detach().float().cpu()
valid_class_mask = valid_class_mask.detach().cpu().bool()
categorical_out = {}
for m, col in enumerate(meta["cat_columns"]):
cm = int(valid_class_mask[m].sum().item())
logits = cat_logits[0, m, :cm]
probs = torch.softmax(logits, dim=-1)
pred_id = int(torch.argmax(probs).item())
pred_label = meta["id_to_label_by_col"][col].get(pred_id, str(pred_id))
categorical_out[col] = pred_label
numeric_out = {}
for group in meta["numeric_vocab"]["groups"]:
n_in = int(group["n_in"])
preds_z = value_by_nin[n_in].detach().float().cpu()[0] # [V, n_in]
for v_idx, feat in enumerate(group["feature_names"]):
feat = str(feat)
mean, std = meta["numeric_stats"][feat]
raw_pred_values = denormalize_numeric(preds_z[v_idx].tolist(), mean, std)
if n_in == 1:
numeric_out[feat] = raw_pred_values[0]
else:
numeric_out[feat] = raw_pred_values
return {
"categorical": categorical_out,
"numeric": numeric_out,
}
# -----------------------------------------------------------------------------
# Accuracy / MAE analysis
# -----------------------------------------------------------------------------
def masked_feature_names(input_card: Dict[str, Any], section: str) -> list[str]:
values = input_card.get(section, {})
if not isinstance(values, dict):
return []
return [k for k, v in values.items() if v is None]
def numeric_abs_errors(pred_value: Any, answer_value: Any) -> list[float]:
if answer_value is None or answer_value == "":
return []
if pred_value is None or pred_value == "":
return []
if isinstance(answer_value, str):
s = answer_value.strip()
if s == "":
return []
if s.startswith("[") and s.endswith("]"):
answer_value = [float(x) for x in ast.literal_eval(s)]
else:
answer_value = float(s)
if isinstance(pred_value, str):
s = pred_value.strip()
if s.startswith("[") and s.endswith("]"):
pred_value = [float(x) for x in ast.literal_eval(s)]
else:
pred_value = float(s)
if isinstance(answer_value, (list, tuple)):
if not isinstance(pred_value, (list, tuple)):
return []
if len(pred_value) != len(answer_value):
return []
return [abs(float(p) - float(a)) for p, a in zip(pred_value, answer_value)]
return [abs(float(pred_value) - float(answer_value))]
def evaluate_against_answer(
input_card: Dict[str, Any],
output_card: Dict[str, Any],
answer_card: Dict[str, Any],
) -> Dict[str, Any]:
cat_masked = masked_feature_names(input_card, "categorical")
num_masked = masked_feature_names(input_card, "numeric")
cat_details = {}
correct = 0
total = 0
for feat in cat_masked:
answer = answer_card.get("categorical", {}).get(feat)
pred = output_card.get("categorical", {}).get(feat)
if answer is None or answer == "":
continue
is_correct = pred == answer
cat_details[feat] = {
"predicted": pred,
"answer": answer,
"correct": bool(is_correct),
}
correct += int(is_correct)
total += 1
num_details = {}
abs_errors_all = []
for feat in num_masked:
answer = answer_card.get("numeric", {}).get(feat)
pred = output_card.get("numeric", {}).get(feat)
errors = numeric_abs_errors(pred, answer)
if not errors:
continue
mae = sum(errors) / len(errors)
num_details[feat] = {
"predicted": pred,
"answer": answer,
"absolute_error": errors[0] if len(errors) == 1 else errors,
"mae": mae,
}
abs_errors_all.extend(errors)
return {
"categorical": {
"accuracy": None if total == 0 else correct / total,
"correct": correct,
"total": total,
"details": cat_details,
},
"numeric": {
"mae": None if len(abs_errors_all) == 0 else sum(abs_errors_all) / len(abs_errors_all),
"count": len(abs_errors_all),
"details": num_details,
},
"note": "Metrics are computed only on fields that are null in input_card. Natural missing values \"\" are ignored.",
}
def acc_path_from_output(output: str) -> Path:
path = Path(output)
if path.suffix == ".json":
base = path.with_suffix("")
else:
base = path
return base.with_name(base.name + "__acc.json")
# -----------------------------------------------------------------------------
# CLI
# -----------------------------------------------------------------------------
def main() -> None:
parser = argparse.ArgumentParser(description="Run SoilFormer inference from a readable input card.")
parser.add_argument("--input_card", type=str, required=True)
parser.add_argument("--output", type=str, required=True)
parser.add_argument("--answer_card", type=str, default=None)
parser.add_argument("--checkpoint", type=str, required=True)
parser.add_argument("--config_data", type=str, default="config/config_data.json")
parser.add_argument("--config_model", type=str, default="config/config_model.json")
parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "mps", "cpu"])
args = parser.parse_args()
config_data = load_json(args.config_data)
config_model = load_json(args.config_model)
dtype = get_dtype(config_model.get("dtype", "bfloat16"))
device = resolve_device(args.device)
meta = load_metadata(config_data)
input_card = load_card(args.input_card)
batch = tensorize_card(input_card=input_card, config_data=config_data, meta=meta)
batch = move_batch_to_device(batch, device=device, dtype=dtype)
model = load_model(args=args, config_model=config_model, device=device, dtype=dtype)
with torch.no_grad():
cat_logits_padded, cat_s, valid_class_mask, value_by_nin, s_by_nin, _ = model(
cat_local_ids=batch["cat_local_ids"],
numeric_values_by_nin=batch["numeric_values_by_nin"],
cat_valid_positions=batch["cat_valid_positions"],
numeric_valid_positions_by_nin=batch["numeric_valid_positions_by_nin"],
pixel_values=batch["pixel_values"],
vision_valid_positions=batch["vision_valid_positions"],
)
output_card = decode_outputs(
cat_logits_padded=cat_logits_padded,
valid_class_mask=valid_class_mask,
value_by_nin=value_by_nin,
meta=meta,
)
save_json_pretty(to_jsonable(output_card), Path(args.output))
result = {"status": "ok", "output": args.output}
if args.answer_card:
answer_card = load_card(args.answer_card)
acc_card = evaluate_against_answer(
input_card=input_card,
output_card=output_card,
answer_card=answer_card,
)
acc_path = acc_path_from_output(args.output)
save_json_pretty(to_jsonable(acc_card), acc_path)
result["acc_output"] = str(acc_path)
print(json.dumps(result, ensure_ascii=False))
if __name__ == "__main__":
main()