Spaces:
Running
Running
File size: 3,827 Bytes
bf9e424 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 | """Validate MolForge SFT JSONL before training."""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
from typing import Any
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from models import MolForgeAction
def main() -> None:
parser = argparse.ArgumentParser(description="Validate MolForge SFT trace JSONL.")
parser.add_argument("path", help="Path to JSONL generated by scripts/generate_sft_traces.py")
parser.add_argument("--max-errors", type=int, default=20)
args = parser.parse_args()
path = Path(args.path)
errors: list[str] = []
records = 0
action_types: dict[str, int] = {}
scenario_ids: dict[str, int] = {}
for line_number, line in enumerate(path.open(encoding="utf-8"), start=1):
if not line.strip():
continue
records += 1
try:
record = json.loads(line)
messages = record["messages"]
assistant_content = messages[-1]["content"]
action_dict = json.loads(assistant_content)
action = MolForgeAction(**action_dict)
validation_error = validate_action_contract(action)
if validation_error:
raise ValueError(validation_error)
metadata = record.get("metadata", {})
scenario_id = metadata.get("scenario_id", "unknown")
scenario_ids[scenario_id] = scenario_ids.get(scenario_id, 0) + 1
action_types[action.action_type] = action_types.get(action.action_type, 0) + 1
except Exception as exc:
errors.append(f"line {line_number}: {exc}")
if len(errors) >= args.max_errors:
break
summary: dict[str, Any] = {
"path": str(path),
"records_checked": records,
"valid": not errors,
"action_types": action_types,
"scenario_ids": scenario_ids,
"errors": errors,
}
print(json.dumps(summary, indent=2))
if errors:
raise SystemExit(1)
def validate_action_contract(action: MolForgeAction) -> str:
if action.action_type == "run_assay" and action.acting_role != "assay_planner":
return "run_assay must use acting_role=assay_planner"
if action.action_type in {"edit", "submit", "restart", "defer"} and action.acting_role != "lead_chemist":
return f"{action.action_type} must use acting_role=lead_chemist"
if not action.rationale.strip():
return "missing rationale"
if not action.evidence:
return "missing evidence"
if not action.expected_effects:
return "missing expected_effects"
allowed_message_types = {
"lead_chemist": {"proposal", "revision_request", "submission_recommendation"},
"assay_planner": {"proposal", "approval", "rejection", "assay_request", "submission_recommendation"},
"toxicologist": {"approval", "objection", "risk_flag", "assay_request", "rejection"},
"process_chemist": {"approval", "objection", "risk_flag", "assay_request"},
}
seen_senders = set()
for message in action.messages:
if message.sender in seen_senders:
return f"duplicate message sender {message.sender}"
seen_senders.add(message.sender)
if message.message_type not in allowed_message_types.get(message.sender, set()):
return f"{message.sender} cannot emit {message.message_type}"
actor_message = next((message for message in action.messages if message.sender == action.acting_role), None)
if action.action_type != "defer" and (actor_message is None or actor_message.message_type != "proposal"):
return "acting_role must include a proposal message"
return ""
if __name__ == "__main__":
main()
|