File size: 907 Bytes
877add7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
"""Tabular inference."""

from __future__ import annotations

import pickle
from pathlib import Path

from app.common.types import PatientProfile
from app.models.tabular.features import build_tabular_features
from app.models.tabular.risk_heads import predict_risk_heads


def infer_tabular_risk(patient: PatientProfile) -> dict[str, float]:
    features = build_tabular_features(patient)
    model_path = Path("outputs/models/tabular_risk.pkl")
    if not model_path.exists():
        return predict_risk_heads(features)
    with model_path.open("rb") as f:
        artifact = pickle.load(f)
    model = artifact.get("model")
    feature_keys = artifact.get("feature_keys", list(features.keys()))
    target_keys = artifact.get("target_keys", [])
    x = [[float(features.get(k, 0.0)) for k in feature_keys]]
    preds = model.predict(x)[0]
    return {str(k): float(v) for k, v in zip(target_keys, preds)}