| """Graph model training entry.""" |
|
|
| from __future__ import annotations |
|
|
| import pickle |
| from pathlib import Path |
|
|
| import numpy as np |
| from sklearn.linear_model import LogisticRegression |
| from sklearn.multiclass import OneVsRestClassifier |
| from sklearn.preprocessing import MultiLabelBinarizer |
|
|
| from app.models.graph.dataset import build_graph_samples |
| from app.models.graph.hetero_encoder import encode_regimen |
|
|
|
|
| def train_graph_model(regimens: list[list[str]], model_path: Path | None = None) -> dict: |
| samples = build_graph_samples(regimens) |
| if not samples: |
| return {"num_samples": 0, "status": "no_data"} |
| x = np.stack([encode_regimen(s.drugs) for s in samples], axis=0) |
| y_severe = np.array([s.severe_alert for s in samples], dtype=int) |
| y_tags = [s.side_effects for s in samples] |
|
|
| severe_model = LogisticRegression(max_iter=500, class_weight="balanced") |
| severe_model.fit(x, y_severe) |
|
|
| mlb = MultiLabelBinarizer() |
| y_tag_matrix = mlb.fit_transform(y_tags) |
| side_model = OneVsRestClassifier(LogisticRegression(max_iter=500)) |
| side_model.fit(x, y_tag_matrix) |
|
|
| artifact = { |
| "severe_model": severe_model, |
| "side_model": side_model, |
| "mlb": mlb, |
| "feature_dim": x.shape[1], |
| } |
| target = model_path or Path("outputs/models/graph_model.pkl") |
| target.parent.mkdir(parents=True, exist_ok=True) |
| with target.open("wb") as f: |
| pickle.dump(artifact, f) |
| return {"num_samples": len(samples), "status": "trained", "model_path": str(target)} |
|
|