| """Tabular model training placeholder.""" |
|
|
| from __future__ import annotations |
|
|
| import pickle |
| from pathlib import Path |
|
|
| import numpy as np |
| from sklearn.ensemble import RandomForestRegressor |
| from sklearn.multioutput import MultiOutputRegressor |
|
|
| from app.common.enums import Difficulty |
| from app.models.tabular.features import build_tabular_features |
| from app.models.tabular.risk_heads import predict_risk_heads |
| from app.simulator.patient_generator import generate_patient_profile |
|
|
|
|
| TARGET_KEYS = [ |
| "ade_proxy", |
| "hospitalization_proxy", |
| "falls_proxy", |
| "destabilization_proxy", |
| "burden_proxy", |
| ] |
|
|
|
|
| def train_tabular_model(dataset_size: int) -> dict[str, float | str]: |
| x_rows: list[list[float]] = [] |
| y_rows: list[list[float]] = [] |
| for i in range(dataset_size): |
| if i < dataset_size // 3: |
| difficulty = Difficulty.EASY |
| elif i < (dataset_size * 2) // 3: |
| difficulty = Difficulty.MEDIUM |
| else: |
| difficulty = Difficulty.HARD |
| patient = generate_patient_profile(seed=3000 + i, difficulty=difficulty) |
| features = build_tabular_features(patient) |
| targets = predict_risk_heads(features) |
| x_rows.append(list(features.values())) |
| y_rows.append([targets[k] for k in TARGET_KEYS]) |
|
|
| x = np.array(x_rows, dtype=float) |
| y = np.array(y_rows, dtype=float) |
| model = MultiOutputRegressor(RandomForestRegressor(n_estimators=80, random_state=42)) |
| model.fit(x, y) |
| predictions = model.predict(x) |
| mae = float(np.mean(np.abs(predictions - y))) |
|
|
| artifact = {"model": model, "feature_keys": list(build_tabular_features(generate_patient_profile(seed=1, difficulty=Difficulty.EASY)).keys()), "target_keys": TARGET_KEYS} |
| path = Path("outputs/models/tabular_risk.pkl") |
| path.parent.mkdir(parents=True, exist_ok=True) |
| with path.open("wb") as f: |
| pickle.dump(artifact, f) |
| return { |
| "dataset_size": float(dataset_size), |
| "status": "trained", |
| "train_mae": round(mae, 4), |
| "model_path": str(path), |
| } |
|
|