Spaces:
Running
Running
| """Graph model training entry.""" | |
| from __future__ import annotations | |
| import pickle | |
| from pathlib import Path | |
| import numpy as np | |
| from sklearn.linear_model import LogisticRegression | |
| from sklearn.multiclass import OneVsRestClassifier | |
| from sklearn.preprocessing import MultiLabelBinarizer | |
| from app.models.graph.dataset import build_graph_samples | |
| from app.models.graph.hetero_encoder import encode_regimen | |
| def train_graph_model(regimens: list[list[str]], model_path: Path | None = None) -> dict: | |
| samples = build_graph_samples(regimens) | |
| if not samples: | |
| return {"num_samples": 0, "status": "no_data"} | |
| x = np.stack([encode_regimen(s.drugs) for s in samples], axis=0) | |
| y_severe = np.array([s.severe_alert for s in samples], dtype=int) | |
| y_tags = [s.side_effects for s in samples] | |
| severe_model = LogisticRegression(max_iter=500, class_weight="balanced") | |
| severe_model.fit(x, y_severe) | |
| mlb = MultiLabelBinarizer() | |
| y_tag_matrix = mlb.fit_transform(y_tags) | |
| side_model = OneVsRestClassifier(LogisticRegression(max_iter=500)) | |
| side_model.fit(x, y_tag_matrix) | |
| artifact = { | |
| "severe_model": severe_model, | |
| "side_model": side_model, | |
| "mlb": mlb, | |
| "feature_dim": x.shape[1], | |
| } | |
| target = model_path or Path("outputs/models/graph_model.pkl") | |
| target.parent.mkdir(parents=True, exist_ok=True) | |
| with target.open("wb") as f: | |
| pickle.dump(artifact, f) | |
| return {"num_samples": len(samples), "status": "trained", "model_path": str(target)} | |