#!/usr/bin/env python3 """Train graph model placeholder.""" from __future__ import annotations import json from pathlib import Path from app.common.types import PatientProfile from app.models.graph.train import train_graph_model def _load_regimens(root: Path) -> list[list[str]]: regimens: list[list[str]] = [] for difficulty in ["easy", "medium", "hard"]: scenario_dir = root / "data" / "scenarios" / difficulty if not scenario_dir.exists(): continue for path in sorted(scenario_dir.glob("*.json"))[:60]: payload = json.loads(path.read_text(encoding="utf-8")) patient = PatientProfile.model_validate(payload) regimens.append([m.drug for m in patient.medications]) return regimens def main() -> None: root = Path(__file__).resolve().parents[1] regimens = _load_regimens(root) if not regimens: regimens = [["warfarin_like", "nsaid_like"], ["metformin_like", "statin_like"]] result = train_graph_model(regimens, model_path=root / "outputs" / "models" / "graph_model.pkl") out = root / "outputs" / "reports" out.mkdir(parents=True, exist_ok=True) (out / "graph_train.json").write_text(json.dumps(result, ensure_ascii=True, indent=2), encoding="utf-8") print("graph_model_trained") if __name__ == "__main__": main()