File size: 1,349 Bytes
877add7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#!/usr/bin/env python3
"""Train graph model placeholder."""

from __future__ import annotations

import json
from pathlib import Path

from app.common.types import PatientProfile
from app.models.graph.train import train_graph_model


def _load_regimens(root: Path) -> list[list[str]]:
    regimens: list[list[str]] = []
    for difficulty in ["easy", "medium", "hard"]:
        scenario_dir = root / "data" / "scenarios" / difficulty
        if not scenario_dir.exists():
            continue
        for path in sorted(scenario_dir.glob("*.json"))[:60]:
            payload = json.loads(path.read_text(encoding="utf-8"))
            patient = PatientProfile.model_validate(payload)
            regimens.append([m.drug for m in patient.medications])
    return regimens


def main() -> None:
    root = Path(__file__).resolve().parents[1]
    regimens = _load_regimens(root)
    if not regimens:
        regimens = [["warfarin_like", "nsaid_like"], ["metformin_like", "statin_like"]]
    result = train_graph_model(regimens, model_path=root / "outputs" / "models" / "graph_model.pkl")
    out = root / "outputs" / "reports"
    out.mkdir(parents=True, exist_ok=True)
    (out / "graph_train.json").write_text(json.dumps(result, ensure_ascii=True, indent=2), encoding="utf-8")
    print("graph_model_trained")


if __name__ == "__main__":
    main()