Spaces:
Running
Running
| """Tabular model training placeholder.""" | |
| from __future__ import annotations | |
| import pickle | |
| from pathlib import Path | |
| import numpy as np | |
| from sklearn.ensemble import RandomForestRegressor | |
| from sklearn.multioutput import MultiOutputRegressor | |
| from app.common.enums import Difficulty | |
| from app.models.tabular.features import build_tabular_features | |
| from app.models.tabular.risk_heads import predict_risk_heads | |
| from app.simulator.patient_generator import generate_patient_profile | |
| TARGET_KEYS = [ | |
| "ade_proxy", | |
| "hospitalization_proxy", | |
| "falls_proxy", | |
| "destabilization_proxy", | |
| "burden_proxy", | |
| ] | |
| def train_tabular_model(dataset_size: int) -> dict[str, float | str]: | |
| x_rows: list[list[float]] = [] | |
| y_rows: list[list[float]] = [] | |
| for i in range(dataset_size): | |
| if i < dataset_size // 3: | |
| difficulty = Difficulty.EASY | |
| elif i < (dataset_size * 2) // 3: | |
| difficulty = Difficulty.MEDIUM | |
| else: | |
| difficulty = Difficulty.HARD | |
| patient = generate_patient_profile(seed=3000 + i, difficulty=difficulty) | |
| features = build_tabular_features(patient) | |
| targets = predict_risk_heads(features) | |
| x_rows.append(list(features.values())) | |
| y_rows.append([targets[k] for k in TARGET_KEYS]) | |
| x = np.array(x_rows, dtype=float) | |
| y = np.array(y_rows, dtype=float) | |
| model = MultiOutputRegressor(RandomForestRegressor(n_estimators=80, random_state=42)) | |
| model.fit(x, y) | |
| predictions = model.predict(x) | |
| mae = float(np.mean(np.abs(predictions - y))) | |
| artifact = {"model": model, "feature_keys": list(build_tabular_features(generate_patient_profile(seed=1, difficulty=Difficulty.EASY)).keys()), "target_keys": TARGET_KEYS} | |
| path = Path("outputs/models/tabular_risk.pkl") | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| with path.open("wb") as f: | |
| pickle.dump(artifact, f) | |
| return { | |
| "dataset_size": float(dataset_size), | |
| "status": "trained", | |
| "train_mae": round(mae, 4), | |
| "model_path": str(path), | |
| } | |