polyguard-openenv-workbench / polyguard-rl /scripts /build_training_corpus.py
TheJackBright's picture
Deploy GitHub root master to Space
c296d62
#!/usr/bin/env python3
"""Build unified training corpus: HF + local + synthetic (+ optional web/DDI)."""
from __future__ import annotations
import argparse
import json
import os
from pathlib import Path
from typing import Any
import sys
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from app.dataops.ddi_api import cache_ddi_records, fetch_ddi_api_records
from app.dataops.web_fallback import scrape_with_fallback
from app.env.env_core import PolyGuardEnv
from app.knowledge.drug_catalog import DRUG_CLASSES
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Build SFT/GRPO corpus from multiple data sources.")
parser.add_argument("--profile", choices=["small", "massive"], default="small")
parser.add_argument("--with-hf", action="store_true")
parser.add_argument("--with-local", action="store_true")
parser.add_argument("--with-synthetic", action="store_true")
parser.add_argument("--enable-ddi-api", action="store_true")
parser.add_argument("--enable-web-fallback", action="store_true")
return parser.parse_args()
def _load_local_sft(path: Path) -> list[dict[str, Any]]:
if not path.exists():
return []
payload = json.loads(path.read_text(encoding="utf-8"))
if isinstance(payload, list):
return [item for item in payload if isinstance(item, dict)]
return []
def _build_synthetic(count: int) -> list[dict[str, Any]]:
env = PolyGuardEnv()
rows: list[dict[str, Any]] = []
schedule = ["easy", "medium", "hard"]
for i in range(count):
env.reset(seed=8_000 + i, difficulty=schedule[i % len(schedule)])
obs = env._build_observation() # noqa: SLF001 - internal observation snapshot for synthetic corpus assembly.
candidates = [item.model_dump(mode="json") for item in obs.candidate_action_set]
target = candidates[0]["candidate_id"] if candidates else "cand_01"
rows.append(
{
"source": "synthetic",
"task": "planner_action_selection",
"prompt": {
"patient_summary": obs.patient_summary,
"medications": obs.medication_table,
"candidates": candidates,
"uncertainty": obs.abstention_indicators.get("uncertainty", 0.5),
"severe_pair_count": obs.graph_safety_summary.get("estimated_risk", 0.0),
},
"target_candidate_id": target,
}
)
return rows
def _load_hf(max_rows: int) -> list[dict[str, Any]]:
try:
from datasets import load_dataset
except Exception:
return []
records: list[dict[str, Any]] = []
try:
ds = load_dataset("tatsu-lab/alpaca", split="train")
for row in ds.select(range(min(max_rows, len(ds)))):
instruction = str(row.get("instruction", ""))
input_text = str(row.get("input", ""))
output_text = str(row.get("output", ""))
records.append(
{
"source": "hf_alpaca",
"task": "instruction_following",
"prompt": {
"instruction": instruction,
"input": input_text,
"candidates": [
{
"candidate_id": "cand_01",
"mode": "REVIEW",
"action_type": "REQUEST_SPECIALIST_REVIEW",
"estimated_safety_delta": 0.0,
"uncertainty_score": 0.5,
"legality_precheck": True,
}
],
},
"target_candidate_id": "cand_01",
"target_text": output_text,
}
)
except Exception:
return []
return records
def _write_jsonl(path: Path, rows: list[dict[str, Any]]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8") as f:
for row in rows:
f.write(json.dumps(row, ensure_ascii=True) + "\n")
def main() -> None:
args = parse_args()
root = Path(__file__).resolve().parents[1]
processed = root / "data" / "processed"
processed.mkdir(parents=True, exist_ok=True)
target_size = 80 if args.profile == "small" else 2000
rows: list[dict[str, Any]] = []
if args.with_local:
rows.extend(_load_local_sft(processed / "sft_examples.json"))
if args.with_synthetic:
synth_count = min(target_size, 60 if args.profile == "small" else 1200)
rows.extend(_build_synthetic(synth_count))
if args.with_hf:
hf_count = min(target_size, 40 if args.profile == "small" else 800)
rows.extend(_load_hf(hf_count))
if args.enable_ddi_api:
ddi_path = processed / "ddi_api_cache.json"
top_drugs = list(sorted(DRUG_CLASSES.keys()))[:20]
ddi_records = fetch_ddi_api_records(top_drugs)
cache_ddi_records(ddi_path, ddi_records)
if args.enable_web_fallback:
allow_domains = ["who.int", "nih.gov", "fda.gov", "ema.europa.eu"]
seeds = ["https://www.who.int", "https://www.nih.gov"]
crawled = [scrape_with_fallback(url, allow_domains) for url in seeds]
(processed / "web_fallback_records.json").write_text(
json.dumps(crawled, ensure_ascii=True, indent=2),
encoding="utf-8",
)
if not rows:
# last-resort generated seed rows
rows.extend(_build_synthetic(24))
rows = rows[:target_size] if args.profile == "small" else rows
(processed / "training_corpus_sft.json").write_text(json.dumps(rows, ensure_ascii=True, indent=2), encoding="utf-8")
_write_jsonl(processed / "training_corpus_sft.jsonl", rows)
grpo_prompts = [
{
"prompt": row.get("prompt", {}),
"task": row.get("task", "planner_action_selection"),
}
for row in rows
]
_write_jsonl(processed / "training_corpus_grpo_prompts.jsonl", grpo_prompts)
summary = {
"status": "ok",
"profile": args.profile,
"rows": len(rows),
"with_local": args.with_local,
"with_hf": args.with_hf,
"with_synthetic": args.with_synthetic,
"ddi_api": args.enable_ddi_api,
"web_fallback": args.enable_web_fallback,
}
(processed / "training_corpus_summary.json").write_text(json.dumps(summary, ensure_ascii=True, indent=2), encoding="utf-8")
print("training_corpus_done")
if __name__ == "__main__":
main()