"""Validate MolForge SFT JSONL before training.""" from __future__ import annotations import argparse import json import sys from pathlib import Path from typing import Any PROJECT_ROOT = Path(__file__).resolve().parents[1] if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) from models import MolForgeAction def main() -> None: parser = argparse.ArgumentParser(description="Validate MolForge SFT trace JSONL.") parser.add_argument("path", help="Path to JSONL generated by scripts/generate_sft_traces.py") parser.add_argument("--max-errors", type=int, default=20) args = parser.parse_args() path = Path(args.path) errors: list[str] = [] records = 0 action_types: dict[str, int] = {} scenario_ids: dict[str, int] = {} for line_number, line in enumerate(path.open(encoding="utf-8"), start=1): if not line.strip(): continue records += 1 try: record = json.loads(line) messages = record["messages"] assistant_content = messages[-1]["content"] action_dict = json.loads(assistant_content) action = MolForgeAction(**action_dict) validation_error = validate_action_contract(action) if validation_error: raise ValueError(validation_error) metadata = record.get("metadata", {}) scenario_id = metadata.get("scenario_id", "unknown") scenario_ids[scenario_id] = scenario_ids.get(scenario_id, 0) + 1 action_types[action.action_type] = action_types.get(action.action_type, 0) + 1 except Exception as exc: errors.append(f"line {line_number}: {exc}") if len(errors) >= args.max_errors: break summary: dict[str, Any] = { "path": str(path), "records_checked": records, "valid": not errors, "action_types": action_types, "scenario_ids": scenario_ids, "errors": errors, } print(json.dumps(summary, indent=2)) if errors: raise SystemExit(1) def validate_action_contract(action: MolForgeAction) -> str: if action.action_type == "run_assay" and action.acting_role != "assay_planner": return "run_assay must use acting_role=assay_planner" if action.action_type in {"edit", "submit", "restart", "defer"} and action.acting_role != "lead_chemist": return f"{action.action_type} must use acting_role=lead_chemist" if not action.rationale.strip(): return "missing rationale" if not action.evidence: return "missing evidence" if not action.expected_effects: return "missing expected_effects" allowed_message_types = { "lead_chemist": {"proposal", "revision_request", "submission_recommendation"}, "assay_planner": {"proposal", "approval", "rejection", "assay_request", "submission_recommendation"}, "toxicologist": {"approval", "objection", "risk_flag", "assay_request", "rejection"}, "process_chemist": {"approval", "objection", "risk_flag", "assay_request"}, } seen_senders = set() for message in action.messages: if message.sender in seen_senders: return f"duplicate message sender {message.sender}" seen_senders.add(message.sender) if message.message_type not in allowed_message_types.get(message.sender, set()): return f"{message.sender} cannot emit {message.message_type}" actor_message = next((message for message in action.messages if message.sender == action.acting_role), None) if action.action_type != "defer" and (actor_message is None or actor_message.message_type != "proposal"): return "acting_role must include a proposal message" return "" if __name__ == "__main__": main()