Spaces:
Running
Running
| """Validate MolForge SFT JSONL before training.""" | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import sys | |
| from pathlib import Path | |
| from typing import Any | |
| PROJECT_ROOT = Path(__file__).resolve().parents[1] | |
| if str(PROJECT_ROOT) not in sys.path: | |
| sys.path.insert(0, str(PROJECT_ROOT)) | |
| from models import MolForgeAction | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Validate MolForge SFT trace JSONL.") | |
| parser.add_argument("path", help="Path to JSONL generated by scripts/generate_sft_traces.py") | |
| parser.add_argument("--max-errors", type=int, default=20) | |
| args = parser.parse_args() | |
| path = Path(args.path) | |
| errors: list[str] = [] | |
| records = 0 | |
| action_types: dict[str, int] = {} | |
| scenario_ids: dict[str, int] = {} | |
| for line_number, line in enumerate(path.open(encoding="utf-8"), start=1): | |
| if not line.strip(): | |
| continue | |
| records += 1 | |
| try: | |
| record = json.loads(line) | |
| messages = record["messages"] | |
| assistant_content = messages[-1]["content"] | |
| action_dict = json.loads(assistant_content) | |
| action = MolForgeAction(**action_dict) | |
| validation_error = validate_action_contract(action) | |
| if validation_error: | |
| raise ValueError(validation_error) | |
| metadata = record.get("metadata", {}) | |
| scenario_id = metadata.get("scenario_id", "unknown") | |
| scenario_ids[scenario_id] = scenario_ids.get(scenario_id, 0) + 1 | |
| action_types[action.action_type] = action_types.get(action.action_type, 0) + 1 | |
| except Exception as exc: | |
| errors.append(f"line {line_number}: {exc}") | |
| if len(errors) >= args.max_errors: | |
| break | |
| summary: dict[str, Any] = { | |
| "path": str(path), | |
| "records_checked": records, | |
| "valid": not errors, | |
| "action_types": action_types, | |
| "scenario_ids": scenario_ids, | |
| "errors": errors, | |
| } | |
| print(json.dumps(summary, indent=2)) | |
| if errors: | |
| raise SystemExit(1) | |
| def validate_action_contract(action: MolForgeAction) -> str: | |
| if action.action_type == "run_assay" and action.acting_role != "assay_planner": | |
| return "run_assay must use acting_role=assay_planner" | |
| if action.action_type in {"edit", "submit", "restart", "defer"} and action.acting_role != "lead_chemist": | |
| return f"{action.action_type} must use acting_role=lead_chemist" | |
| if not action.rationale.strip(): | |
| return "missing rationale" | |
| if not action.evidence: | |
| return "missing evidence" | |
| if not action.expected_effects: | |
| return "missing expected_effects" | |
| allowed_message_types = { | |
| "lead_chemist": {"proposal", "revision_request", "submission_recommendation"}, | |
| "assay_planner": {"proposal", "approval", "rejection", "assay_request", "submission_recommendation"}, | |
| "toxicologist": {"approval", "objection", "risk_flag", "assay_request", "rejection"}, | |
| "process_chemist": {"approval", "objection", "risk_flag", "assay_request"}, | |
| } | |
| seen_senders = set() | |
| for message in action.messages: | |
| if message.sender in seen_senders: | |
| return f"duplicate message sender {message.sender}" | |
| seen_senders.add(message.sender) | |
| if message.message_type not in allowed_message_types.get(message.sender, set()): | |
| return f"{message.sender} cannot emit {message.message_type}" | |
| actor_message = next((message for message in action.messages if message.sender == action.acting_role), None) | |
| if action.action_type != "defer" and (actor_message is None or actor_message.message_type != "proposal"): | |
| return "acting_role must include a proposal message" | |
| return "" | |
| if __name__ == "__main__": | |
| main() | |