adithya9903's picture
Deploy PolyGuard HF training Space
fd0c71a verified
#!/usr/bin/env python3
"""Build required processed data artifacts for POLYGUARD-OPENENV."""
from __future__ import annotations
import json
from datetime import datetime, timezone
from pathlib import Path
import sys
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from typing import Any
import pandas as pd
import yaml
from app.knowledge.ddi_knowledge import is_contraindicated_pair
from app.knowledge.drug_catalog import DRUG_CLASSES
from app.knowledge.substitution_rules import SUBSTITUTIONS
from app.knowledge.taper_rules import requires_taper
def _safe_write_json(path: Path, payload: Any) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(payload, ensure_ascii=True, indent=2), encoding="utf-8")
def _write_jsonl(path: Path, rows: list[dict[str, Any]]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8") as f:
for row in rows:
f.write(json.dumps(row, ensure_ascii=True) + "\n")
def _load_scenario_rows(scenario_dir: Path) -> list[dict[str, Any]]:
rows: list[dict[str, Any]] = []
if not scenario_dir.exists():
return rows
for path in sorted(scenario_dir.glob("*.json")):
rows.append(json.loads(path.read_text(encoding="utf-8")))
return rows
def main() -> None:
root = Path(__file__).resolve().parents[1]
processed_dir = root / "data" / "processed"
processed_dir.mkdir(parents=True, exist_ok=True)
artifacts_dir = root / "data" / "artifacts"
artifacts_dir.mkdir(parents=True, exist_ok=True)
drug_rows: list[dict[str, Any]] = []
class_rows: list[dict[str, Any]] = []
for idx, (drug, class_name) in enumerate(sorted(DRUG_CLASSES.items()), start=1):
canonical_id = f"drug_{idx:04d}"
aliases = [drug.replace("_", " "), drug.upper()]
drug_rows.append(
{
"canonical_id": canonical_id,
"canonical_name": drug,
"aliases": aliases,
"class_name": class_name,
"source": "local_drug_catalog",
}
)
class_rows.append(
{
"canonical_id": canonical_id,
"class_name": class_name,
"subclass": f"{class_name}_core",
"source": "local_drug_catalog",
}
)
interactions: list[dict[str, Any]] = []
drugs = sorted(DRUG_CLASSES)
for i, drug_a in enumerate(drugs):
for drug_b in drugs[i + 1 :]:
if is_contraindicated_pair(drug_a, drug_b):
interactions.append(
{
"drug_a": drug_a,
"drug_b": drug_b,
"severity": "high",
"interaction_type": "contraindicated",
"source": "ddi_rules",
}
)
burden_rules = {
"version": "1.0",
"formula": "burden = med_count/12 + high_risk_count*0.04",
"high_risk_classes": ["sedative", "anticoagulant", "analgesic"],
}
taper_rows = [
{"drug": drug, "requires_taper": requires_taper(drug), "default_taper_days": 14 if requires_taper(drug) else 0}
for drug in drugs
]
taper_rules = {"rules": taper_rows, "source": "taper_rules"}
substitution_rules = {"rules": SUBSTITUTIONS, "source": "substitution_rules"}
retrieval_index_file = root / "data" / "retrieval_index" / "index.json"
retrieval_rows: list[dict[str, Any]] = []
if retrieval_index_file.exists():
retrieval_rows = json.loads(retrieval_index_file.read_text(encoding="utf-8"))
retrieval_corpus = [
{
"doc_id": row.get("id"),
"path": row.get("path"),
"text": row.get("text"),
"source": "retrieval_index",
}
for row in retrieval_rows
]
graph_edges: list[dict[str, Any]] = []
for drug, class_name in sorted(DRUG_CLASSES.items()):
graph_edges.append({"src": drug, "dst": class_name, "edge_type": "in_class", "weight": 1.0})
for row in interactions:
graph_edges.append({"src": row["drug_a"], "dst": row["drug_b"], "edge_type": "contraindicated_with", "weight": 1.0})
graph_edges.append({"src": row["drug_b"], "dst": row["drug_a"], "edge_type": "contraindicated_with", "weight": 1.0})
for src, replacements in SUBSTITUTIONS.items():
for dst in replacements:
graph_edges.append({"src": src, "dst": dst, "edge_type": "substitute_for", "weight": 0.8})
synthetic_file = root / "data" / "synthetic" / "synthetic_patients.json"
synthetic_rows: list[dict[str, Any]] = []
if synthetic_file.exists():
synthetic_rows = json.loads(synthetic_file.read_text(encoding="utf-8"))
easy_rows = _load_scenario_rows(root / "data" / "scenarios" / "easy")
medium_rows = _load_scenario_rows(root / "data" / "scenarios" / "medium")
hard_rows = _load_scenario_rows(root / "data" / "scenarios" / "hard")
pd.DataFrame(drug_rows).to_parquet(processed_dir / "normalized_drugs.parquet", index=False)
pd.DataFrame(class_rows).to_parquet(processed_dir / "drug_classes.parquet", index=False)
pd.DataFrame(interactions).to_parquet(processed_dir / "interactions.parquet", index=False)
pd.DataFrame(graph_edges).to_parquet(processed_dir / "graph_edges.parquet", index=False)
pd.DataFrame(synthetic_rows).to_parquet(processed_dir / "patients_synthetic.parquet", index=False)
(processed_dir / "burden_rules.yaml").write_text(yaml.safe_dump(burden_rules, sort_keys=False), encoding="utf-8")
(processed_dir / "taper_rules.yaml").write_text(yaml.safe_dump(taper_rules, sort_keys=False), encoding="utf-8")
(processed_dir / "substitution_rules.yaml").write_text(yaml.safe_dump(substitution_rules, sort_keys=False), encoding="utf-8")
_write_jsonl(processed_dir / "retrieval_corpus.jsonl", retrieval_corpus)
_write_jsonl(root / "data" / "scenarios" / "scenarios_easy.jsonl", easy_rows)
_write_jsonl(root / "data" / "scenarios" / "scenarios_medium.jsonl", medium_rows)
_write_jsonl(root / "data" / "scenarios" / "scenarios_hard.jsonl", hard_rows)
feature_dictionary = {
"normalized_drugs": ["canonical_id", "canonical_name", "aliases", "class_name", "source"],
"drug_classes": ["canonical_id", "class_name", "subclass", "source"],
"interactions": ["drug_a", "drug_b", "severity", "interaction_type", "source"],
"graph_edges": ["src", "dst", "edge_type", "weight"],
"patients_synthetic": [
"patient_id",
"age",
"sex",
"comorbidities",
"medications",
"labs",
"vitals",
"specialist_conflicts",
"prior_ade_history",
"frailty_score",
"adherence_estimate",
],
}
_safe_write_json(processed_dir / "feature_dictionary.json", feature_dictionary)
provenance_manifest = {
"generated_at": datetime.now(timezone.utc).isoformat(),
"policy": {
"core_sources_live_required": ["canonical_vocab", "interactions"],
"secondary_sources_fallback": True,
"weak_signal_labels_marked": True,
},
"inputs": {
"drug_catalog": "app/knowledge/drug_catalog.py",
"ddi_rules": "app/knowledge/ddi_knowledge.py",
"substitutions": "app/knowledge/substitution_rules.py",
"taper_rules": "app/knowledge/taper_rules.py",
"retrieval_index": str(retrieval_index_file),
},
"counts": {
"normalized_drugs": len(drug_rows),
"interactions": len(interactions),
"retrieval_docs": len(retrieval_corpus),
"scenario_easy": len(easy_rows),
"scenario_medium": len(medium_rows),
"scenario_hard": len(hard_rows),
"patients_synthetic": len(synthetic_rows),
},
}
_safe_write_json(processed_dir / "provenance_manifest.json", provenance_manifest)
dataset_report = f"""# Dataset Report
## Summary
- Normalized drugs: {len(drug_rows)}
- Drug classes: {len(class_rows)}
- Interactions: {len(interactions)}
- Graph edges: {len(graph_edges)}
- Synthetic patients: {len(synthetic_rows)}
- Scenarios (easy/medium/hard): {len(easy_rows)}/{len(medium_rows)}/{len(hard_rows)}
- Retrieval corpus documents: {len(retrieval_corpus)}
## Source Policy
- Core vocabulary/interactions are treated as core sources.
- Secondary sources are allowed fallback with explicit provenance.
- Weak/noisy safety signals are labeled as such in provenance metadata.
## Artifacts
Artifacts are stored under `data/processed`, `data/scenarios`, and `data/artifacts`.
"""
(root / "docs" / "dataset_report.md").write_text(dataset_report, encoding="utf-8")
summary = {
"status": "ok",
"processed_dir": str(processed_dir),
"docs_report": str(root / "docs" / "dataset_report.md"),
}
_safe_write_json(artifacts_dir / "bootstrap_data_summary.json", summary)
print("bootstrap_data_done")
if __name__ == "__main__":
main()