File size: 1,523 Bytes
21c7db9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
"""Graph model training entry."""

from __future__ import annotations

import pickle
from pathlib import Path

import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.multiclass import OneVsRestClassifier
from sklearn.preprocessing import MultiLabelBinarizer

from app.models.graph.dataset import build_graph_samples
from app.models.graph.hetero_encoder import encode_regimen


def train_graph_model(regimens: list[list[str]], model_path: Path | None = None) -> dict:
    samples = build_graph_samples(regimens)
    if not samples:
        return {"num_samples": 0, "status": "no_data"}
    x = np.stack([encode_regimen(s.drugs) for s in samples], axis=0)
    y_severe = np.array([s.severe_alert for s in samples], dtype=int)
    y_tags = [s.side_effects for s in samples]

    severe_model = LogisticRegression(max_iter=500, class_weight="balanced")
    severe_model.fit(x, y_severe)

    mlb = MultiLabelBinarizer()
    y_tag_matrix = mlb.fit_transform(y_tags)
    side_model = OneVsRestClassifier(LogisticRegression(max_iter=500))
    side_model.fit(x, y_tag_matrix)

    artifact = {
        "severe_model": severe_model,
        "side_model": side_model,
        "mlb": mlb,
        "feature_dim": x.shape[1],
    }
    target = model_path or Path("outputs/models/graph_model.pkl")
    target.parent.mkdir(parents=True, exist_ok=True)
    with target.open("wb") as f:
        pickle.dump(artifact, f)
    return {"num_samples": len(samples), "status": "trained", "model_path": str(target)}