File size: 3,827 Bytes
bf9e424
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
"""Validate MolForge SFT JSONL before training."""

from __future__ import annotations

import argparse
import json
import sys
from pathlib import Path
from typing import Any

PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from models import MolForgeAction


def main() -> None:
    parser = argparse.ArgumentParser(description="Validate MolForge SFT trace JSONL.")
    parser.add_argument("path", help="Path to JSONL generated by scripts/generate_sft_traces.py")
    parser.add_argument("--max-errors", type=int, default=20)
    args = parser.parse_args()

    path = Path(args.path)
    errors: list[str] = []
    records = 0
    action_types: dict[str, int] = {}
    scenario_ids: dict[str, int] = {}

    for line_number, line in enumerate(path.open(encoding="utf-8"), start=1):
        if not line.strip():
            continue
        records += 1
        try:
            record = json.loads(line)
            messages = record["messages"]
            assistant_content = messages[-1]["content"]
            action_dict = json.loads(assistant_content)
            action = MolForgeAction(**action_dict)
            validation_error = validate_action_contract(action)
            if validation_error:
                raise ValueError(validation_error)
            metadata = record.get("metadata", {})
            scenario_id = metadata.get("scenario_id", "unknown")
            scenario_ids[scenario_id] = scenario_ids.get(scenario_id, 0) + 1
            action_types[action.action_type] = action_types.get(action.action_type, 0) + 1
        except Exception as exc:
            errors.append(f"line {line_number}: {exc}")
            if len(errors) >= args.max_errors:
                break

    summary: dict[str, Any] = {
        "path": str(path),
        "records_checked": records,
        "valid": not errors,
        "action_types": action_types,
        "scenario_ids": scenario_ids,
        "errors": errors,
    }
    print(json.dumps(summary, indent=2))
    if errors:
        raise SystemExit(1)


def validate_action_contract(action: MolForgeAction) -> str:
    if action.action_type == "run_assay" and action.acting_role != "assay_planner":
        return "run_assay must use acting_role=assay_planner"
    if action.action_type in {"edit", "submit", "restart", "defer"} and action.acting_role != "lead_chemist":
        return f"{action.action_type} must use acting_role=lead_chemist"
    if not action.rationale.strip():
        return "missing rationale"
    if not action.evidence:
        return "missing evidence"
    if not action.expected_effects:
        return "missing expected_effects"

    allowed_message_types = {
        "lead_chemist": {"proposal", "revision_request", "submission_recommendation"},
        "assay_planner": {"proposal", "approval", "rejection", "assay_request", "submission_recommendation"},
        "toxicologist": {"approval", "objection", "risk_flag", "assay_request", "rejection"},
        "process_chemist": {"approval", "objection", "risk_flag", "assay_request"},
    }
    seen_senders = set()
    for message in action.messages:
        if message.sender in seen_senders:
            return f"duplicate message sender {message.sender}"
        seen_senders.add(message.sender)
        if message.message_type not in allowed_message_types.get(message.sender, set()):
            return f"{message.sender} cannot emit {message.message_type}"
    actor_message = next((message for message in action.messages if message.sender == action.acting_role), None)
    if action.action_type != "defer" and (actor_message is None or actor_message.message_type != "proposal"):
        return "acting_role must include a proposal message"
    return ""


if __name__ == "__main__":
    main()