"""Graph model training entry.""" from __future__ import annotations import pickle from pathlib import Path import numpy as np from sklearn.linear_model import LogisticRegression from sklearn.multiclass import OneVsRestClassifier from sklearn.preprocessing import MultiLabelBinarizer from app.models.graph.dataset import build_graph_samples from app.models.graph.hetero_encoder import encode_regimen def train_graph_model(regimens: list[list[str]], model_path: Path | None = None) -> dict: samples = build_graph_samples(regimens) if not samples: return {"num_samples": 0, "status": "no_data"} x = np.stack([encode_regimen(s.drugs) for s in samples], axis=0) y_severe = np.array([s.severe_alert for s in samples], dtype=int) y_tags = [s.side_effects for s in samples] severe_model = LogisticRegression(max_iter=500, class_weight="balanced") severe_model.fit(x, y_severe) mlb = MultiLabelBinarizer() y_tag_matrix = mlb.fit_transform(y_tags) side_model = OneVsRestClassifier(LogisticRegression(max_iter=500)) side_model.fit(x, y_tag_matrix) artifact = { "severe_model": severe_model, "side_model": side_model, "mlb": mlb, "feature_dim": x.shape[1], } target = model_path or Path("outputs/models/graph_model.pkl") target.parent.mkdir(parents=True, exist_ok=True) with target.open("wb") as f: pickle.dump(artifact, f) return {"num_samples": len(samples), "status": "trained", "model_path": str(target)}