File size: 1,523 Bytes
21c7db9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 | """Graph model training entry."""
from __future__ import annotations
import pickle
from pathlib import Path
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.multiclass import OneVsRestClassifier
from sklearn.preprocessing import MultiLabelBinarizer
from app.models.graph.dataset import build_graph_samples
from app.models.graph.hetero_encoder import encode_regimen
def train_graph_model(regimens: list[list[str]], model_path: Path | None = None) -> dict:
samples = build_graph_samples(regimens)
if not samples:
return {"num_samples": 0, "status": "no_data"}
x = np.stack([encode_regimen(s.drugs) for s in samples], axis=0)
y_severe = np.array([s.severe_alert for s in samples], dtype=int)
y_tags = [s.side_effects for s in samples]
severe_model = LogisticRegression(max_iter=500, class_weight="balanced")
severe_model.fit(x, y_severe)
mlb = MultiLabelBinarizer()
y_tag_matrix = mlb.fit_transform(y_tags)
side_model = OneVsRestClassifier(LogisticRegression(max_iter=500))
side_model.fit(x, y_tag_matrix)
artifact = {
"severe_model": severe_model,
"side_model": side_model,
"mlb": mlb,
"feature_dim": x.shape[1],
}
target = model_path or Path("outputs/models/graph_model.pkl")
target.parent.mkdir(parents=True, exist_ok=True)
with target.open("wb") as f:
pickle.dump(artifact, f)
return {"num_samples": len(samples), "status": "trained", "model_path": str(target)}
|