#!/usr/bin/env python3 """Build unified training corpus: HF + local + synthetic (+ optional web/DDI).""" from __future__ import annotations import argparse import json import os from pathlib import Path from typing import Any import sys ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) from app.dataops.ddi_api import cache_ddi_records, fetch_ddi_api_records from app.dataops.web_fallback import scrape_with_fallback from app.env.env_core import PolyGuardEnv from app.knowledge.drug_catalog import DRUG_CLASSES def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Build SFT/GRPO corpus from multiple data sources.") parser.add_argument("--profile", choices=["small", "massive"], default="small") parser.add_argument("--with-hf", action="store_true") parser.add_argument("--with-local", action="store_true") parser.add_argument("--with-synthetic", action="store_true") parser.add_argument("--enable-ddi-api", action="store_true") parser.add_argument("--enable-web-fallback", action="store_true") return parser.parse_args() def _load_local_sft(path: Path) -> list[dict[str, Any]]: if not path.exists(): return [] payload = json.loads(path.read_text(encoding="utf-8")) if isinstance(payload, list): return [item for item in payload if isinstance(item, dict)] return [] def _build_synthetic(count: int) -> list[dict[str, Any]]: env = PolyGuardEnv() rows: list[dict[str, Any]] = [] schedule = ["easy", "medium", "hard"] for i in range(count): env.reset(seed=8_000 + i, difficulty=schedule[i % len(schedule)]) obs = env._build_observation() # noqa: SLF001 - internal observation snapshot for synthetic corpus assembly. candidates = [item.model_dump(mode="json") for item in obs.candidate_action_set] target = candidates[0]["candidate_id"] if candidates else "cand_01" rows.append( { "source": "synthetic", "task": "planner_action_selection", "prompt": { "patient_summary": obs.patient_summary, "medications": obs.medication_table, "candidates": candidates, "uncertainty": obs.abstention_indicators.get("uncertainty", 0.5), "severe_pair_count": obs.graph_safety_summary.get("estimated_risk", 0.0), }, "target_candidate_id": target, } ) return rows def _load_hf(max_rows: int) -> list[dict[str, Any]]: try: from datasets import load_dataset except Exception: return [] records: list[dict[str, Any]] = [] try: ds = load_dataset("tatsu-lab/alpaca", split="train") for row in ds.select(range(min(max_rows, len(ds)))): instruction = str(row.get("instruction", "")) input_text = str(row.get("input", "")) output_text = str(row.get("output", "")) records.append( { "source": "hf_alpaca", "task": "instruction_following", "prompt": { "instruction": instruction, "input": input_text, "candidates": [ { "candidate_id": "cand_01", "mode": "REVIEW", "action_type": "REQUEST_SPECIALIST_REVIEW", "estimated_safety_delta": 0.0, "uncertainty_score": 0.5, "legality_precheck": True, } ], }, "target_candidate_id": "cand_01", "target_text": output_text, } ) except Exception: return [] return records def _write_jsonl(path: Path, rows: list[dict[str, Any]]) -> None: path.parent.mkdir(parents=True, exist_ok=True) with path.open("w", encoding="utf-8") as f: for row in rows: f.write(json.dumps(row, ensure_ascii=True) + "\n") def main() -> None: args = parse_args() root = Path(__file__).resolve().parents[1] processed = root / "data" / "processed" processed.mkdir(parents=True, exist_ok=True) target_size = 80 if args.profile == "small" else 2000 rows: list[dict[str, Any]] = [] if args.with_local: rows.extend(_load_local_sft(processed / "sft_examples.json")) if args.with_synthetic: synth_count = min(target_size, 60 if args.profile == "small" else 1200) rows.extend(_build_synthetic(synth_count)) if args.with_hf: hf_count = min(target_size, 40 if args.profile == "small" else 800) rows.extend(_load_hf(hf_count)) if args.enable_ddi_api: ddi_path = processed / "ddi_api_cache.json" top_drugs = list(sorted(DRUG_CLASSES.keys()))[:20] ddi_records = fetch_ddi_api_records(top_drugs) cache_ddi_records(ddi_path, ddi_records) if args.enable_web_fallback: allow_domains = ["who.int", "nih.gov", "fda.gov", "ema.europa.eu"] seeds = ["https://www.who.int", "https://www.nih.gov"] crawled = [scrape_with_fallback(url, allow_domains) for url in seeds] (processed / "web_fallback_records.json").write_text( json.dumps(crawled, ensure_ascii=True, indent=2), encoding="utf-8", ) if not rows: # last-resort generated seed rows rows.extend(_build_synthetic(24)) rows = rows[:target_size] if args.profile == "small" else rows (processed / "training_corpus_sft.json").write_text(json.dumps(rows, ensure_ascii=True, indent=2), encoding="utf-8") _write_jsonl(processed / "training_corpus_sft.jsonl", rows) grpo_prompts = [ { "prompt": row.get("prompt", {}), "task": row.get("task", "planner_action_selection"), } for row in rows ] _write_jsonl(processed / "training_corpus_grpo_prompts.jsonl", grpo_prompts) summary = { "status": "ok", "profile": args.profile, "rows": len(rows), "with_local": args.with_local, "with_hf": args.with_hf, "with_synthetic": args.with_synthetic, "ddi_api": args.enable_ddi_api, "web_fallback": args.enable_web_fallback, } (processed / "training_corpus_summary.json").write_text(json.dumps(summary, ensure_ascii=True, indent=2), encoding="utf-8") print("training_corpus_done") if __name__ == "__main__": main()