Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """Train graph model placeholder.""" | |
| from __future__ import annotations | |
| import json | |
| from pathlib import Path | |
| from app.common.types import PatientProfile | |
| from app.models.graph.train import train_graph_model | |
| def _load_regimens(root: Path) -> list[list[str]]: | |
| regimens: list[list[str]] = [] | |
| for difficulty in ["easy", "medium", "hard"]: | |
| scenario_dir = root / "data" / "scenarios" / difficulty | |
| if not scenario_dir.exists(): | |
| continue | |
| for path in sorted(scenario_dir.glob("*.json"))[:60]: | |
| payload = json.loads(path.read_text(encoding="utf-8")) | |
| patient = PatientProfile.model_validate(payload) | |
| regimens.append([m.drug for m in patient.medications]) | |
| return regimens | |
| def main() -> None: | |
| root = Path(__file__).resolve().parents[1] | |
| regimens = _load_regimens(root) | |
| if not regimens: | |
| regimens = [["warfarin_like", "nsaid_like"], ["metformin_like", "statin_like"]] | |
| result = train_graph_model(regimens, model_path=root / "outputs" / "models" / "graph_model.pkl") | |
| out = root / "outputs" / "reports" | |
| out.mkdir(parents=True, exist_ok=True) | |
| (out / "graph_train.json").write_text(json.dumps(result, ensure_ascii=True, indent=2), encoding="utf-8") | |
| print("graph_model_trained") | |
| if __name__ == "__main__": | |
| main() | |