Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """Build unified training corpus: HF + local + synthetic (+ optional web/DDI).""" | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| from pathlib import Path | |
| from typing import Any | |
| import sys | |
| ROOT = Path(__file__).resolve().parents[1] | |
| if str(ROOT) not in sys.path: | |
| sys.path.insert(0, str(ROOT)) | |
| from app.dataops.ddi_api import cache_ddi_records, fetch_ddi_api_records | |
| from app.dataops.web_fallback import scrape_with_fallback | |
| from app.env.env_core import PolyGuardEnv | |
| from app.knowledge.drug_catalog import DRUG_CLASSES | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="Build SFT/GRPO corpus from multiple data sources.") | |
| parser.add_argument("--profile", choices=["small", "massive"], default="small") | |
| parser.add_argument("--with-hf", action="store_true") | |
| parser.add_argument("--with-local", action="store_true") | |
| parser.add_argument("--with-synthetic", action="store_true") | |
| parser.add_argument("--enable-ddi-api", action="store_true") | |
| parser.add_argument("--enable-web-fallback", action="store_true") | |
| return parser.parse_args() | |
| def _load_local_sft(path: Path) -> list[dict[str, Any]]: | |
| if not path.exists(): | |
| return [] | |
| payload = json.loads(path.read_text(encoding="utf-8")) | |
| if isinstance(payload, list): | |
| return [item for item in payload if isinstance(item, dict)] | |
| return [] | |
| def _build_synthetic(count: int) -> list[dict[str, Any]]: | |
| env = PolyGuardEnv() | |
| rows: list[dict[str, Any]] = [] | |
| schedule = ["easy", "medium", "hard"] | |
| for i in range(count): | |
| env.reset(seed=8_000 + i, difficulty=schedule[i % len(schedule)]) | |
| obs = env._build_observation() # noqa: SLF001 - internal observation snapshot for synthetic corpus assembly. | |
| candidates = [item.model_dump(mode="json") for item in obs.candidate_action_set] | |
| target = candidates[0]["candidate_id"] if candidates else "cand_01" | |
| rows.append( | |
| { | |
| "source": "synthetic", | |
| "task": "planner_action_selection", | |
| "prompt": { | |
| "patient_summary": obs.patient_summary, | |
| "medications": obs.medication_table, | |
| "candidates": candidates, | |
| "uncertainty": obs.abstention_indicators.get("uncertainty", 0.5), | |
| "severe_pair_count": obs.graph_safety_summary.get("estimated_risk", 0.0), | |
| }, | |
| "target_candidate_id": target, | |
| } | |
| ) | |
| return rows | |
| def _load_hf(max_rows: int) -> list[dict[str, Any]]: | |
| try: | |
| from datasets import load_dataset | |
| except Exception: | |
| return [] | |
| records: list[dict[str, Any]] = [] | |
| try: | |
| ds = load_dataset("tatsu-lab/alpaca", split="train") | |
| for row in ds.select(range(min(max_rows, len(ds)))): | |
| instruction = str(row.get("instruction", "")) | |
| input_text = str(row.get("input", "")) | |
| output_text = str(row.get("output", "")) | |
| records.append( | |
| { | |
| "source": "hf_alpaca", | |
| "task": "instruction_following", | |
| "prompt": { | |
| "instruction": instruction, | |
| "input": input_text, | |
| "candidates": [ | |
| { | |
| "candidate_id": "cand_01", | |
| "mode": "REVIEW", | |
| "action_type": "REQUEST_SPECIALIST_REVIEW", | |
| "estimated_safety_delta": 0.0, | |
| "uncertainty_score": 0.5, | |
| "legality_precheck": True, | |
| } | |
| ], | |
| }, | |
| "target_candidate_id": "cand_01", | |
| "target_text": output_text, | |
| } | |
| ) | |
| except Exception: | |
| return [] | |
| return records | |
| def _write_jsonl(path: Path, rows: list[dict[str, Any]]) -> None: | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| with path.open("w", encoding="utf-8") as f: | |
| for row in rows: | |
| f.write(json.dumps(row, ensure_ascii=True) + "\n") | |
| def main() -> None: | |
| args = parse_args() | |
| root = Path(__file__).resolve().parents[1] | |
| processed = root / "data" / "processed" | |
| processed.mkdir(parents=True, exist_ok=True) | |
| target_size = 80 if args.profile == "small" else 2000 | |
| rows: list[dict[str, Any]] = [] | |
| if args.with_local: | |
| rows.extend(_load_local_sft(processed / "sft_examples.json")) | |
| if args.with_synthetic: | |
| synth_count = min(target_size, 60 if args.profile == "small" else 1200) | |
| rows.extend(_build_synthetic(synth_count)) | |
| if args.with_hf: | |
| hf_count = min(target_size, 40 if args.profile == "small" else 800) | |
| rows.extend(_load_hf(hf_count)) | |
| if args.enable_ddi_api: | |
| ddi_path = processed / "ddi_api_cache.json" | |
| top_drugs = list(sorted(DRUG_CLASSES.keys()))[:20] | |
| ddi_records = fetch_ddi_api_records(top_drugs) | |
| cache_ddi_records(ddi_path, ddi_records) | |
| if args.enable_web_fallback: | |
| allow_domains = ["who.int", "nih.gov", "fda.gov", "ema.europa.eu"] | |
| seeds = ["https://www.who.int", "https://www.nih.gov"] | |
| crawled = [scrape_with_fallback(url, allow_domains) for url in seeds] | |
| (processed / "web_fallback_records.json").write_text( | |
| json.dumps(crawled, ensure_ascii=True, indent=2), | |
| encoding="utf-8", | |
| ) | |
| if not rows: | |
| # last-resort generated seed rows | |
| rows.extend(_build_synthetic(24)) | |
| rows = rows[:target_size] if args.profile == "small" else rows | |
| (processed / "training_corpus_sft.json").write_text(json.dumps(rows, ensure_ascii=True, indent=2), encoding="utf-8") | |
| _write_jsonl(processed / "training_corpus_sft.jsonl", rows) | |
| grpo_prompts = [ | |
| { | |
| "prompt": row.get("prompt", {}), | |
| "task": row.get("task", "planner_action_selection"), | |
| } | |
| for row in rows | |
| ] | |
| _write_jsonl(processed / "training_corpus_grpo_prompts.jsonl", grpo_prompts) | |
| summary = { | |
| "status": "ok", | |
| "profile": args.profile, | |
| "rows": len(rows), | |
| "with_local": args.with_local, | |
| "with_hf": args.with_hf, | |
| "with_synthetic": args.with_synthetic, | |
| "ddi_api": args.enable_ddi_api, | |
| "web_fallback": args.enable_web_fallback, | |
| } | |
| (processed / "training_corpus_summary.json").write_text(json.dumps(summary, ensure_ascii=True, indent=2), encoding="utf-8") | |
| print("training_corpus_done") | |
| if __name__ == "__main__": | |
| main() | |