| |
| """Build unified training corpus: HF + local + synthetic (+ optional web/DDI).""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| from pathlib import Path |
| from typing import Any |
|
|
| import sys |
|
|
| ROOT = Path(__file__).resolve().parents[1] |
| if str(ROOT) not in sys.path: |
| sys.path.insert(0, str(ROOT)) |
|
|
| from app.dataops.ddi_api import cache_ddi_records, fetch_ddi_api_records |
| from app.dataops.web_fallback import scrape_with_fallback |
| from app.env.env_core import PolyGuardEnv |
| from app.knowledge.drug_catalog import DRUG_CLASSES |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Build SFT/GRPO corpus from multiple data sources.") |
| parser.add_argument("--profile", choices=["small", "massive"], default="small") |
| parser.add_argument("--with-hf", action="store_true") |
| parser.add_argument("--with-local", action="store_true") |
| parser.add_argument("--with-synthetic", action="store_true") |
| parser.add_argument("--enable-ddi-api", action="store_true") |
| parser.add_argument("--enable-web-fallback", action="store_true") |
| return parser.parse_args() |
|
|
|
|
| def _load_local_sft(path: Path) -> list[dict[str, Any]]: |
| if not path.exists(): |
| return [] |
| payload = json.loads(path.read_text(encoding="utf-8")) |
| if isinstance(payload, list): |
| return [item for item in payload if isinstance(item, dict)] |
| return [] |
|
|
|
|
| def _build_synthetic(count: int) -> list[dict[str, Any]]: |
| env = PolyGuardEnv() |
| rows: list[dict[str, Any]] = [] |
| schedule = ["easy", "medium", "hard"] |
| for i in range(count): |
| env.reset(seed=8_000 + i, difficulty=schedule[i % len(schedule)]) |
| obs = env._build_observation() |
| candidates = [item.model_dump(mode="json") for item in obs.candidate_action_set] |
| target = candidates[0]["candidate_id"] if candidates else "cand_01" |
| rows.append( |
| { |
| "source": "synthetic", |
| "task": "planner_action_selection", |
| "prompt": { |
| "patient_summary": obs.patient_summary, |
| "medications": obs.medication_table, |
| "candidates": candidates, |
| "uncertainty": obs.abstention_indicators.get("uncertainty", 0.5), |
| "severe_pair_count": obs.graph_safety_summary.get("estimated_risk", 0.0), |
| }, |
| "target_candidate_id": target, |
| } |
| ) |
| return rows |
|
|
|
|
| def _load_hf(max_rows: int) -> list[dict[str, Any]]: |
| try: |
| from datasets import load_dataset |
| except Exception: |
| return [] |
|
|
| records: list[dict[str, Any]] = [] |
| try: |
| ds = load_dataset("tatsu-lab/alpaca", split="train") |
| for row in ds.select(range(min(max_rows, len(ds)))): |
| instruction = str(row.get("instruction", "")) |
| input_text = str(row.get("input", "")) |
| output_text = str(row.get("output", "")) |
| records.append( |
| { |
| "source": "hf_alpaca", |
| "task": "instruction_following", |
| "prompt": { |
| "instruction": instruction, |
| "input": input_text, |
| "candidates": [ |
| { |
| "candidate_id": "cand_01", |
| "mode": "REVIEW", |
| "action_type": "REQUEST_SPECIALIST_REVIEW", |
| "estimated_safety_delta": 0.0, |
| "uncertainty_score": 0.5, |
| "legality_precheck": True, |
| } |
| ], |
| }, |
| "target_candidate_id": "cand_01", |
| "target_text": output_text, |
| } |
| ) |
| except Exception: |
| return [] |
| return records |
|
|
|
|
| def _write_jsonl(path: Path, rows: list[dict[str, Any]]) -> None: |
| path.parent.mkdir(parents=True, exist_ok=True) |
| with path.open("w", encoding="utf-8") as f: |
| for row in rows: |
| f.write(json.dumps(row, ensure_ascii=True) + "\n") |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| root = Path(__file__).resolve().parents[1] |
| processed = root / "data" / "processed" |
| processed.mkdir(parents=True, exist_ok=True) |
|
|
| target_size = 80 if args.profile == "small" else 2000 |
| rows: list[dict[str, Any]] = [] |
|
|
| if args.with_local: |
| rows.extend(_load_local_sft(processed / "sft_examples.json")) |
|
|
| if args.with_synthetic: |
| synth_count = min(target_size, 60 if args.profile == "small" else 1200) |
| rows.extend(_build_synthetic(synth_count)) |
|
|
| if args.with_hf: |
| hf_count = min(target_size, 40 if args.profile == "small" else 800) |
| rows.extend(_load_hf(hf_count)) |
|
|
| if args.enable_ddi_api: |
| ddi_path = processed / "ddi_api_cache.json" |
| top_drugs = list(sorted(DRUG_CLASSES.keys()))[:20] |
| ddi_records = fetch_ddi_api_records(top_drugs) |
| cache_ddi_records(ddi_path, ddi_records) |
|
|
| if args.enable_web_fallback: |
| allow_domains = ["who.int", "nih.gov", "fda.gov", "ema.europa.eu"] |
| seeds = ["https://www.who.int", "https://www.nih.gov"] |
| crawled = [scrape_with_fallback(url, allow_domains) for url in seeds] |
| (processed / "web_fallback_records.json").write_text( |
| json.dumps(crawled, ensure_ascii=True, indent=2), |
| encoding="utf-8", |
| ) |
|
|
| if not rows: |
| |
| rows.extend(_build_synthetic(24)) |
|
|
| rows = rows[:target_size] if args.profile == "small" else rows |
|
|
| (processed / "training_corpus_sft.json").write_text(json.dumps(rows, ensure_ascii=True, indent=2), encoding="utf-8") |
| _write_jsonl(processed / "training_corpus_sft.jsonl", rows) |
|
|
| grpo_prompts = [ |
| { |
| "prompt": row.get("prompt", {}), |
| "task": row.get("task", "planner_action_selection"), |
| } |
| for row in rows |
| ] |
| _write_jsonl(processed / "training_corpus_grpo_prompts.jsonl", grpo_prompts) |
|
|
| summary = { |
| "status": "ok", |
| "profile": args.profile, |
| "rows": len(rows), |
| "with_local": args.with_local, |
| "with_hf": args.with_hf, |
| "with_synthetic": args.with_synthetic, |
| "ddi_api": args.enable_ddi_api, |
| "web_fallback": args.enable_web_fallback, |
| } |
| (processed / "training_corpus_summary.json").write_text(json.dumps(summary, ensure_ascii=True, indent=2), encoding="utf-8") |
| print("training_corpus_done") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|