| |
| """Train graph model placeholder.""" |
|
|
| from __future__ import annotations |
|
|
| import json |
| from pathlib import Path |
|
|
| from app.common.types import PatientProfile |
| from app.models.graph.train import train_graph_model |
|
|
|
|
| def _load_regimens(root: Path) -> list[list[str]]: |
| regimens: list[list[str]] = [] |
| for difficulty in ["easy", "medium", "hard"]: |
| scenario_dir = root / "data" / "scenarios" / difficulty |
| if not scenario_dir.exists(): |
| continue |
| for path in sorted(scenario_dir.glob("*.json"))[:60]: |
| payload = json.loads(path.read_text(encoding="utf-8")) |
| patient = PatientProfile.model_validate(payload) |
| regimens.append([m.drug for m in patient.medications]) |
| return regimens |
|
|
|
|
| def main() -> None: |
| root = Path(__file__).resolve().parents[1] |
| regimens = _load_regimens(root) |
| if not regimens: |
| regimens = [["warfarin_like", "nsaid_like"], ["metformin_like", "statin_like"]] |
| result = train_graph_model(regimens, model_path=root / "outputs" / "models" / "graph_model.pkl") |
| out = root / "outputs" / "reports" |
| out.mkdir(parents=True, exist_ok=True) |
| (out / "graph_train.json").write_text(json.dumps(result, ensure_ascii=True, indent=2), encoding="utf-8") |
| print("graph_model_trained") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|