| """Tabular inference.""" |
|
|
| from __future__ import annotations |
|
|
| import pickle |
| from pathlib import Path |
|
|
| from app.common.types import PatientProfile |
| from app.models.tabular.features import build_tabular_features |
| from app.models.tabular.risk_heads import predict_risk_heads |
|
|
|
|
| def infer_tabular_risk(patient: PatientProfile) -> dict[str, float]: |
| features = build_tabular_features(patient) |
| model_path = Path("outputs/models/tabular_risk.pkl") |
| if not model_path.exists(): |
| return predict_risk_heads(features) |
| with model_path.open("rb") as f: |
| artifact = pickle.load(f) |
| model = artifact.get("model") |
| feature_keys = artifact.get("feature_keys", list(features.keys())) |
| target_keys = artifact.get("target_keys", []) |
| x = [[float(features.get(k, 0.0)) for k in feature_keys]] |
| preds = model.predict(x)[0] |
| return {str(k): float(v) for k, v in zip(target_keys, preds)} |
|
|